CLIP算法的Loss详解 和 交叉熵CrossEntropy实现

2024-01-08 07:40

本文主要是介绍CLIP算法的Loss详解 和 交叉熵CrossEntropy实现,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

CLIP:Contrastive Language–Image Pre-training(可对比语言-图像预训练算法)是OpenAI提出的多模态预训练的算法,在各种各样的**样本对(图像、文本)**上训练的神经网络。

具体参考:CLIP、OpenCLIP

image-20220601180224080

其中,流程:

image-20220601180639145

loss_iloss_t的具体源码如下,参考 model.py:

    def forward(self, image, text):image_features = self.encode_image(image)text_features = self.encode_text(text)# normalized featuresimage_features = image_features / image_features.norm(dim=1, keepdim=True)text_features = text_features / text_features.norm(dim=1, keepdim=True)# cosine similarity as logitslogit_scale = self.logit_scale.exp()logits_per_image = logit_scale * image_features @ text_features.t()logits_per_text = logits_per_image.t()# shape = [global_batch_size, global_batch_size]return logits_per_image, logits_per_text

其中,labels是torch.arange(batch_size, device=device).long(),参考train.py,具体如下

        with torch.no_grad():for i, batch in enumerate(dataloader):images, texts = batchimages = images.to(device=device, non_blocking=True)texts = texts.to(device=device, non_blocking=True)with autocast():image_features, text_features, logit_scale = model(images, texts)# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly# however, system RAM is easily exceeded and compute time becomes problematicall_image_features.append(image_features.cpu())all_text_features.append(text_features.cpu())logit_scale = logit_scale.mean()logits_per_image = logit_scale * image_features @ text_features.t()logits_per_text = logits_per_image.t()batch_size = images.shape[0]labels = torch.arange(batch_size, device=device).long()total_loss = (F.cross_entropy(logits_per_image, labels) +F.cross_entropy(logits_per_text, labels)) / 2

交叉熵函数:y就是label,x_softmax[i][y[i]],表示在x_softmax中筛选第i个sample的第y[i]个值,作为log的输入,全部log负向求和,再求均值。

  • y所对应的就是CLIP的np.arange(n),也就是依次是第0个位置~第n-1个位置,计算log。
# 定义softmax函数
def softmax(x):return np.exp(x) / np.sum(np.exp(x))# 利用numpy计算
def cross_entropy_np(x, y):x_softmax = [softmax(x[i]) for i in range(len(x))]x_log = [np.log(x_softmax[i][y[i]]) for i in range(len(y))]loss = - np.sum(x_log) / len(y)return loss# 测试逻辑
x = [[1.9269, 1.4873, 0.9007, -2.1055]]
y = [[2]]
v1 = cross_entropy_np(x, y)
print(f"v1: {v1}")x = torch.unsqueeze(torch.Tensor(x), dim=0)
x = x.transpose(1, 2)  # CrossEntropy输入期望: Class放在第2维,Batch放在第1维y = torch.Tensor(y)
y = y.to(torch.long)  # label的类型为longv2 = F.cross_entropy(x, y, reduction="none")
print(f"v2: {v2}")

输出:

v1: 1.729491540989093
v2: tensor([[1.7295]])

参考:

  • arxiv文章下载很慢怎么办?
  • CLIP-对比图文多模态预训练的读后感
  • CrossEntropy的numpy实现和Pytorch调用

这篇关于CLIP算法的Loss详解 和 交叉熵CrossEntropy实现的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/582742

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

MySQL中的分组和多表连接详解

《MySQL中的分组和多表连接详解》:本文主要介绍MySQL中的分组和多表连接的相关操作,本文通过实例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧... 目录mysql中的分组和多表连接一、MySQL的分组(group javascriptby )二、多表连接(表连接会产生大量的数据垃圾)MySQL中的

Java 实用工具类Spring 的 AnnotationUtils详解

《Java实用工具类Spring的AnnotationUtils详解》Spring框架提供了一个强大的注解工具类org.springframework.core.annotation.Annot... 目录前言一、AnnotationUtils 的常用方法二、常见应用场景三、与 JDK 原生注解 API 的

Python实现微信自动锁定工具

《Python实现微信自动锁定工具》在数字化办公时代,微信已成为职场沟通的重要工具,但临时离开时忘记锁屏可能导致敏感信息泄露,下面我们就来看看如何使用Python打造一个微信自动锁定工具吧... 目录引言:当微信隐私遇到自动化守护效果展示核心功能全景图技术亮点深度解析1. 无操作检测引擎2. 微信路径智能获

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Python中pywin32 常用窗口操作的实现

《Python中pywin32常用窗口操作的实现》本文主要介绍了Python中pywin32常用窗口操作的实现,pywin32主要的作用是供Python开发者快速调用WindowsAPI的一个... 目录获取窗口句柄获取最前端窗口句柄获取指定坐标处的窗口根据窗口的完整标题匹配获取句柄根据窗口的类别匹配获取句

在 Spring Boot 中实现异常处理最佳实践

《在SpringBoot中实现异常处理最佳实践》本文介绍如何在SpringBoot中实现异常处理,涵盖核心概念、实现方法、与先前查询的集成、性能分析、常见问题和最佳实践,感兴趣的朋友一起看看吧... 目录一、Spring Boot 异常处理的背景与核心概念1.1 为什么需要异常处理?1.2 Spring B

SpringBoot3.4配置校验新特性的用法详解

《SpringBoot3.4配置校验新特性的用法详解》SpringBoot3.4对配置校验支持进行了全面升级,这篇文章为大家详细介绍了一下它们的具体使用,文中的示例代码讲解详细,感兴趣的小伙伴可以参考... 目录基本用法示例定义配置类配置 application.yml注入使用嵌套对象与集合元素深度校验开发

Python中的Walrus运算符分析示例详解

《Python中的Walrus运算符分析示例详解》Python中的Walrus运算符(:=)是Python3.8引入的一个新特性,允许在表达式中同时赋值和返回值,它的核心作用是减少重复计算,提升代码简... 目录1. 在循环中避免重复计算2. 在条件判断中同时赋值变量3. 在列表推导式或字典推导式中简化逻辑

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1