「torch.cosine_smilarity() = 0」引发的关于cpu与gpu精度问题的探讨

2023-11-22 05:12

本文主要是介绍「torch.cosine_smilarity() = 0」引发的关于cpu与gpu精度问题的探讨,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言:2023年11月21日下午16:00 许,本篇博客记录由「torch.cosine_smilarity()计算余弦相似度计算结果为0」现象引发的关于 CPU 与 GPU 计算精度的探索。

事情的起因是,本人在使用 torch.cosine_smilarity() 函数计算GPU上两个特征的余弦相似度时,发现得出的结果为 0,百思不得其解。首先排出特征维度的问题,然后尝试5种不同的相似度计算方法:

  • scipy.spatial.distance.cosine
  • torch.cosine_similarity
  • F.cosine_similarity
  • torch.nn.CosineSimilarity
  • 基于余弦相似度公式的torch代码

整体代码如下:

import torch
torch.set_printoptions(profile="full")
import torch.nn.functional as F
from scipy.spatial.distance import cosine
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"import clip
from PIL import Image
clip_model, processor = clip.load("ViT-L/14", device=device)
srcpath = '/newdata/SD/DEFAKE/data_test/9709_glide.png'
despath = '/newdata/SD/outputs/0_9_sd1.5.png'
src_feature = clip_model.encode_image(processor(Image.open(srcpath)).unsqueeze(0).to(device)).squeeze(0)  
des_feature = clip_model.encode_image(processor(Image.open(despath)).unsqueeze(0).to(device)).squeeze(0)sim_1 = 1 - cosine(src_feature.cpu(), des_feature.cpu())sim_2 = torch.cosine_similarity(src_feature, des_feature, dim=0).item()sim_3 = F.cosine_similarity(src_feature, des_feature, dim=0).item()cos = torch.nn.CosineSimilarity(dim=0)
sim_4 = cos(src_feature, des_feature).item()sim_5 = torch.div(torch.sum(src_feature * des_feature,0),torch.sqrt(torch.sum(torch.pow(src_feature,2),0))* torch.sqrt(torch.sum(torch.pow(des_feature,2),0))).item()print(sim_1, sim_2, sim_3, sim_4, sim_5)
# 0.5302734375 0.0 0.0 0.0 0.53076171875

发现,上述代码在CPU和GPU上运行结果不一致:

上述代码在 CPU 上的运行结果为:

0.5301393270492554
0.5301393270492554
0.5301393270492554
0.5301393270492554
0.5301393270492554

上述代码在 GPU 上的运行结果为:

0.5302734375
0.0
0.0
0.0
0.53076171875

这是一个很有意思的现象,在CPU上计算出的5种相似度结果惊人一致,而在GPU上计算出的5种相似度结果中,中间三种基于torch函数调用的方式计算结果均为0,而第一种首先将特征搬运到CPU上然后使用scipy.spatial.distance.cosine()函数的计算结果和第五种直接在GPU上使用基于余弦相似度公式的torch代码的计算结果又与CPU上计算出的结果各有不同。


然后,我把由 CLIP 预训练模型提取的两张图像的特征(维度为768维) src_featuredes_feature 换成两个随机初始化的张量 ,其余代码不变:


# import clip
# from PIL import Image
# clip_model, processor = clip.load("ViT-L/14", device=device)
# srcpath = '/newdata/SD/DEFAKE/data_test/9709_glide.png'
# despath = '/newdata/SD/outputs/0_9_sd1.5.png'
# src_feature = clip_model.encode_image(processor(Image.open(srcpath)).unsqueeze(0).to(device)).squeeze(0)  
# des_feature = clip_model.encode_image(processor(Image.open(despath)).unsqueeze(0).to(device)).squeeze(0)# 将上面代码注释掉,换为:src_featre = torch.tensor([1.0, 2.0, 3.0])
des_feature = torch.tensor([4.0, 5.0, 6.0])

可见上述示例代码在CPU和GPU上运行结果是一致的:

上述代码在 CPU 上的运行结果为:

0.9746318459510803
0.9746317863464355
0.9746317863464355
0.9746317863464355
0.9746317863464355

上述代码在 GPU 上的运行结果为:

0.9746318459510803
0.9746317863464355
0.9746317863464355
0.9746317863464355
0.9746317863464355

由上述结果可以发现,第一种基于scipy.spatial.distance.cosine()函数的计算结果与其余四组基于torch的计算结果略有不同,说明后四种方法实现的底层逻辑应该是类似的,但由于给定特征的某些不可知原因,有时会出现中间三种基于torch函数调用的方法结果为0的情况,所以保险起见,如果要使用基于torch的计算方法,首选第5种相似度计算方法,当然,时间允许的情况下,直接在CPU上使用第一种方法无疑是精度最高的计算方法。


接下来放一个时间对比图(如下),可见在GPU上使用最后一种计算方法效率最高,在CPU上使用第一种方法效率最低。

time for spicy(gpu):  0.004714250564575195, [0.5361, 0.5303, 0.5220, 0.5059, 0.5430, 0.5469, 0.5078, 0.5293, 0.5283, 0.5337]
time for torch(gpu): 0.0005323886871337891, [0.5361, 0.5308, 0.5225, 0.5063, 0.5435, 0.5469, 0.5078, 0.5298, 0.5283, 0.5332]
time for spicy(cpu):  0.009323358535766602, [0.5358, 0.5301, 0.5222, 0.5060, 0.5426, 0.5467, 0.5077, 0.5298, 0.5279, 0.5333]
time for torch(cpu): 0.0025298595428466797, [0.5358, 0.5301, 0.5222, 0.5060, 0.5426, 0.5467, 0.5077, 0.5298, 0.5279, 0.5333]

PS:鉴于第一种方法无法进行余弦相似度的批量计算(1vN计算),追求速度的话,还是选择第五种方法吧~👀 附赠批量计算方法

src_feature = clip_model.encode_image(processor(Image.open(srcpath)).unsqueeze(0).to(device))  # [1,768]
des_features = torch.stack([clip_model.encode_image(processor(Image.open(path)).unsqueeze(0).to(device)) for path in despaths]).squeeze(1)  # [N,768]
sims = torch.div(torch.sum(src_feature * des_features,1),torch.sqrt(torch.sum(torch.pow(src_feature,2),1))* torch.sqrt(torch.sum(torch.pow(des_features,2),1)))  # 长度为N的张量
sims = sims.cpu().detach().numpy().tolist()  # 转化为列表,方便计算

参考资料

  1. GPU和CPU计算上的精度差异_cpu和gpu训练结果不同-CSDN博客

这篇关于「torch.cosine_smilarity() = 0」引发的关于cpu与gpu精度问题的探讨的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/407758

相关文章

怎样通过分析GC日志来定位Java进程的内存问题

《怎样通过分析GC日志来定位Java进程的内存问题》:本文主要介绍怎样通过分析GC日志来定位Java进程的内存问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、GC 日志基础配置1. 启用详细 GC 日志2. 不同收集器的日志格式二、关键指标与分析维度1.

Java 线程安全与 volatile与单例模式问题及解决方案

《Java线程安全与volatile与单例模式问题及解决方案》文章主要讲解线程安全问题的五个成因(调度随机、变量修改、非原子操作、内存可见性、指令重排序)及解决方案,强调使用volatile关键字... 目录什么是线程安全线程安全问题的产生与解决方案线程的调度是随机的多个线程对同一个变量进行修改线程的修改操

Redis出现中文乱码的问题及解决

《Redis出现中文乱码的问题及解决》:本文主要介绍Redis出现中文乱码的问题及解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1. 问题的产生2China编程. 问题的解决redihttp://www.chinasem.cns数据进制问题的解决中文乱码问题解决总结

全面解析MySQL索引长度限制问题与解决方案

《全面解析MySQL索引长度限制问题与解决方案》MySQL对索引长度设限是为了保持高效的数据检索性能,这个限制不是MySQL的缺陷,而是数据库设计中的权衡结果,下面我们就来看看如何解决这一问题吧... 目录引言:为什么会有索引键长度问题?一、问题根源深度解析mysql索引长度限制原理实际场景示例二、五大解决

Springboot如何正确使用AOP问题

《Springboot如何正确使用AOP问题》:本文主要介绍Springboot如何正确使用AOP问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录​一、AOP概念二、切点表达式​execution表达式案例三、AOP通知四、springboot中使用AOP导出

Python中Tensorflow无法调用GPU问题的解决方法

《Python中Tensorflow无法调用GPU问题的解决方法》文章详解如何解决TensorFlow在Windows无法识别GPU的问题,需降级至2.10版本,安装匹配CUDA11.2和cuDNN... 当用以下代码查看GPU数量时,gpuspython返回的是一个空列表,说明tensorflow没有找到

解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘问题

《解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘问题》:本文主要介绍解决未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4... 目录未解析的依赖项:‘net.sf.json-lib:json-lib:jar:2.4‘打开pom.XM

IDEA Maven提示:未解析的依赖项的问题及解决

《IDEAMaven提示:未解析的依赖项的问题及解决》:本文主要介绍IDEAMaven提示:未解析的依赖项的问题及解决,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝... 目录IDEA Maven提示:未解析的依编程赖项例如总结IDEA Maven提示:未解析的依赖项例如

Redis分片集群、数据读写规则问题小结

《Redis分片集群、数据读写规则问题小结》本文介绍了Redis分片集群的原理,通过数据分片和哈希槽机制解决单机内存限制与写瓶颈问题,实现分布式存储和高并发处理,但存在通信开销大、维护复杂及对事务支持... 目录一、分片集群解android决的问题二、分片集群图解 分片集群特征如何解决的上述问题?(与哨兵模

SpringBoot+Redis防止接口重复提交问题

《SpringBoot+Redis防止接口重复提交问题》:本文主要介绍SpringBoot+Redis防止接口重复提交问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不... 目录前言实现思路代码示例测试总结前言在项目的使用使用过程中,经常会出现某些操作在短时间内频繁提交。例