PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录

本文主要是介绍PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前言

在这里插入图片描述

在PyTorch框架下使用F.cross_entropy()函数时,偶尔会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed

错误信息类似下面打印信息:

/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [0,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [1,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [2,0,0] Assertion `t >= 0 && t < n_classes` failed.
/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/ClassNLLCriterion.cu:52: void cunn_ClassNLLCriterion_updateOutput_kernel(Dtype *, Dtype *, Dtype *, long *, Dtype *, int, int, int, int) [with Dtype = float, Acctype = float]: block: [0,0,0], thread: [3,0,0] Assertion `t >= 0 && t < n_classes` failed.
THCudaCheck FAIL file=/py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu line=83 error=59 : device-side assert triggered
Traceback (most recent call last):File "tutorial.py", line 100, in <module>model = train_model(model, criterion, optim_scheduler_ft, num_epochs=25)File "tutorial.py", line 80, in train_modelloss = criterion(outputs, labels)File "python3.7/site-packages/torch/nn/modules/module.py", line 206, in __call__result = self.forward(*input, **kwargs)File "python3.7/site-packages/torch/nn/modules/loss.py", line 313, in forwardself.weight, self.size_average)File "python3.7/site-packages/torch/nn/functional.py", line 509, in cross_entropyreturn nll_loss(log_softmax(input), target, weight, size_average)File "python3.7/site-packages/torch/nn/functional.py", line 477, in nll_lossreturn f(input, target)File "python3.7/site-packages/torch/nn/_functions/thnn/auto.py", line 41, in forwardoutput, *self.additional_args)
RuntimeError: cuda runtime error (59) : device-side assert triggered at /py/conda-bld/pytorch_1490981920203/work/torch/lib/THCUNN/generic/ClassNLLCriterion.cu:83

通常情况下,这是由于求交叉熵函数在计算时遇到了类别错误的问题,即不满足t >= 0 && t < n_classes条件。

t >= 0 && t < n_classes条件

在分类任务中,需要调用torch.nn.functional.cross_entropy()函数求交叉熵,从PyTorch官网可以看到该函数定义:
在这里插入图片描述

torch.nn.functional.cross_entropy(input, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')

可以注意到有一个key-value是ignore_index=-100。这是在交叉熵计算时被跳过的部分。通常是在数据增强中的填充值。

而在代码运行中报错ClassNLLCriterion Assertion `t >= 0 && t < n_classes ` failed,大部分都是由于没有正确处理好label(ground truth)导致的。例如在数据增强中,填充数据使用了负数,或者使用了某大正数(如255),而在调用torch.nn.functional.cross_entropy()方法时却没有传入正确的ignore_index。这就会导致运行过程中的Assertion Error。

在这里插入图片描述

代码示例

数据增强部分

import torchvision.transforms.functional as tftf.pad(cropped_img, padding_tuple, padding_mode="reflect"),
tf.affine(mask, translate=(-x_offset, -y_offset), scale=1.0, angle=0.0, shear=0.0,fillcolor=250,)

求交叉熵部分

import torch
import torch.nn.functional as F
import torch.nn as nndef cross_entropy2d(input, target, weight=None, reduction='none'):n, c, h, w = input.size()nt, ht, wt = target.size()if h != ht or w != wt:input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True)input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c)target = target.view(-1)loss = F.cross_entropy(input, target, weight=weight, reduction=reduction, ignore_index=255)return loss

分析

可以看到在数据增强时的填充值为250(fillcolor=250),但在求交叉熵时却传入了ignore_index=255。因此在代码运行时,F.cross_entropy部分便会报错ClassNLLCriterion ··· Assertion `t >= 0 && t < n_classes ` failed。只需要统一好label部分填充数据和计算交叉熵时需要忽略的class就可以避免出现这一问题。

其他

在PyTorch框架下,使用无用label值进行填充和处理时,要注意在使用scatter_函数时也需要注意对无用label进行提前处理,否则在使用data.scatter_()时同样也会报类似类别index错误。

labels = labels[:, :, :].view(size[0], 1, size[1], size[2])
oneHot_size = (size[0], classes, size[1], size[2])
labels_real = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
# ignore_index=255
# labels[labels.data[::] == ignore_index] = 0
labels_real = labels_real.scatter_(1, labels.data.long().cuda(), 1.0)

在这里插入图片描述

参考资料

[1] torch.nn.functional — PyTorch 1.8.0 documentation
[2] Pytorch里的CrossEntropyLoss详解 - marsggbo - 博客园
[3] RuntimeError: cuda runtime error (59) : device-side assert triggered when running transfer_learning_tutorial · Issue #1204 · pytorch/pytorch
[4] PyTorch 中,nn 与 nn.functional 有什么区别? - 知乎
[5] FaceParsing.PyTorch/augmentations.py at master · TracelessLe/FaceParsing.PyTorch

这篇关于PyTorch使用F.cross_entropy报错Assertion `t >= 0 t < n_classes` failed问题记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/922145

相关文章

使用Python和OpenCV库实现实时颜色识别系统

《使用Python和OpenCV库实现实时颜色识别系统》:本文主要介绍使用Python和OpenCV库实现的实时颜色识别系统,这个系统能够通过摄像头捕捉视频流,并在视频中指定区域内识别主要颜色(红... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间详解

Windows下C++使用SQLitede的操作过程

《Windows下C++使用SQLitede的操作过程》本文介绍了Windows下C++使用SQLite的安装配置、CppSQLite库封装优势、核心功能(如数据库连接、事务管理)、跨平台支持及性能优... 目录Windows下C++使用SQLite1、安装2、代码示例CppSQLite:C++轻松操作SQ

qt5cored.dll报错怎么解决? 电脑qt5cored.dll文件丢失修复技巧

《qt5cored.dll报错怎么解决?电脑qt5cored.dll文件丢失修复技巧》在进行软件安装或运行程序时,有时会遇到由于找不到qt5core.dll,无法继续执行代码,这个问题可能是由于该文... 遇到qt5cored.dll文件错误时,可能会导致基于 Qt 开发的应用程序无法正常运行或启动。这种错

Python常用命令提示符使用方法详解

《Python常用命令提示符使用方法详解》在学习python的过程中,我们需要用到命令提示符(CMD)进行环境的配置,:本文主要介绍Python常用命令提示符使用方法的相关资料,文中通过代码介绍的... 目录一、python环境基础命令【Windows】1、检查Python是否安装2、 查看Python的安

Python UV安装、升级、卸载详细步骤记录

《PythonUV安装、升级、卸载详细步骤记录》:本文主要介绍PythonUV安装、升级、卸载的详细步骤,uv是Astral推出的下一代Python包与项目管理器,主打单一可执行文件、极致性能... 目录安装检查升级设置自动补全卸载UV 命令总结 官方文档详见:https://docs.astral.sh/

Python并行处理实战之如何使用ProcessPoolExecutor加速计算

《Python并行处理实战之如何使用ProcessPoolExecutor加速计算》Python提供了多种并行处理的方式,其中concurrent.futures模块的ProcessPoolExecu... 目录简介完整代码示例代码解释1. 导入必要的模块2. 定义处理函数3. 主函数4. 生成数字列表5.

Python中help()和dir()函数的使用

《Python中help()和dir()函数的使用》我们经常需要查看某个对象(如模块、类、函数等)的属性和方法,Python提供了两个内置函数help()和dir(),它们可以帮助我们快速了解代... 目录1. 引言2. help() 函数2.1 作用2.2 使用方法2.3 示例(1) 查看内置函数的帮助(

Linux脚本(shell)的使用方式

《Linux脚本(shell)的使用方式》:本文主要介绍Linux脚本(shell)的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录概述语法详解数学运算表达式Shell变量变量分类环境变量Shell内部变量自定义变量:定义、赋值自定义变量:引用、修改、删

Java使用HttpClient实现图片下载与本地保存功能

《Java使用HttpClient实现图片下载与本地保存功能》在当今数字化时代,网络资源的获取与处理已成为软件开发中的常见需求,其中,图片作为网络上最常见的资源之一,其下载与保存功能在许多应用场景中都... 目录引言一、Apache HttpClient简介二、技术栈与环境准备三、实现图片下载与保存功能1.

Python中使用uv创建环境及原理举例详解

《Python中使用uv创建环境及原理举例详解》uv是Astral团队开发的高性能Python工具,整合包管理、虚拟环境、Python版本控制等功能,:本文主要介绍Python中使用uv创建环境及... 目录一、uv工具简介核心特点:二、安装uv1. 通过pip安装2. 通过脚本安装验证安装:配置镜像源(可