解决 torch.cat(): input types can‘t be cast to the desired output type Byte

2024-01-25 18:20

本文主要是介绍解决 torch.cat(): input types can‘t be cast to the desired output type Byte,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

最近使用 U2Net 训练模型的时候,遇到了下面的错误:


RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte

错误堆栈信息如下

Original Traceback (most recent call last): File "/datapython3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) ^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch return self.collate_fn(data) ^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 268, in default_collate return collate(batch, collate_fn_map=default_collate_fn_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in collate return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
File "/datapython3.11/site-packages/torch/utils/data/_utils/collate.py", line 165, in collate_tensor_fn return torch.stack(batch, 0, out=out) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 
RuntimeError: torch.cat(): input types can't be cast to the desired output type Byte

原因:个人猜想是多个 worker  在一起工作引起的并发问题。

解决方法一:

在构建 DataLoader 实例的时候,把 workers 设置为 0 即可。

缺点:会导致训练速度变慢
 

train_dataset = SalObjDataset(img_name_list=train_img_name_list,lbl_name_list=train_label_name_list,transform=transforms.Compose([RescaleT(320),# RandomCrop(288),ToTensorLab(flag=0)]))
train_dataset.auto_collation = True
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)

解决方法二:

修改 pytorch 里面的代码。

在上面的堆栈中,显示 /datapython3.11/site-packages/torch/utils/data/_utils/collate.py:165 报错了

我们打开 collate.py 文件,找到 collate_tensor_fn 这个函数,把 if 语句的内容注释掉就可以了

缺点:需要修改pytorch的代码,会增加多一次内存拷贝 

def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):elem = batch[0]out = None# if torch.utils.data.get_worker_info() is not None:#     # If we're in a background process, concatenate directly into a#     # shared memory tensor to avoid an extra copy#     numel = sum(x.numel() for x in batch)#     storage = elem._typed_storage()._new_shared(numel, device=elem.device)#     out = elem.new(storage).resize_(len(batch), *list(elem.size()))return torch.stack(batch, 0, out=out)

.

.

这篇关于解决 torch.cat(): input types can‘t be cast to the desired output type Byte的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/644175

相关文章

Java NoClassDefFoundError运行时错误分析解决

《JavaNoClassDefFoundError运行时错误分析解决》在Java开发中,NoClassDefFoundError是一种常见的运行时错误,它通常表明Java虚拟机在尝试加载一个类时未能... 目录前言一、问题分析二、报错原因三、解决思路检查类路径配置检查依赖库检查类文件调试类加载器问题四、常见

解决IDEA报错:编码GBK的不可映射字符问题

《解决IDEA报错:编码GBK的不可映射字符问题》:本文主要介绍解决IDEA报错:编码GBK的不可映射字符问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录IDEA报错:编码GBK的不可映射字符终端软件问题描述原因分析解决方案方法1:将命令改为方法2:右下jav

MyBatis模糊查询报错:ParserException: not supported.pos 问题解决

《MyBatis模糊查询报错:ParserException:notsupported.pos问题解决》本文主要介绍了MyBatis模糊查询报错:ParserException:notsuppo... 目录问题描述问题根源错误SQL解析逻辑深层原因分析三种解决方案方案一:使用CONCAT函数(推荐)方案二:

IntelliJ IDEA 中配置 Spring MVC 环境的详细步骤及问题解决

《IntelliJIDEA中配置SpringMVC环境的详细步骤及问题解决》:本文主要介绍IntelliJIDEA中配置SpringMVC环境的详细步骤及问题解决,本文分步骤结合实例给大... 目录步骤 1:创建 Maven Web 项目步骤 2:添加 Spring MVC 依赖1、保存后执行2、将新的依赖

Spring 中的循环引用问题解决方法

《Spring中的循环引用问题解决方法》:本文主要介绍Spring中的循环引用问题解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录什么是循环引用?循环依赖三级缓存解决循环依赖二级缓存三级缓存本章来聊聊Spring 中的循环引用问题该如何解决。这里聊

关于MongoDB图片URL存储异常问题以及解决

《关于MongoDB图片URL存储异常问题以及解决》:本文主要介绍关于MongoDB图片URL存储异常问题以及解决方案,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐... 目录MongoDB图片URL存储异常问题项目场景问题描述原因分析解决方案预防措施js总结MongoDB图

SpringBoot项目中报错The field screenShot exceeds its maximum permitted size of 1048576 bytes.的问题及解决

《SpringBoot项目中报错ThefieldscreenShotexceedsitsmaximumpermittedsizeof1048576bytes.的问题及解决》这篇文章... 目录项目场景问题描述原因分析解决方案总结项目场景javascript提示:项目相关背景:项目场景:基于Spring

解决Maven项目idea找不到本地仓库jar包问题以及使用mvn install:install-file

《解决Maven项目idea找不到本地仓库jar包问题以及使用mvninstall:install-file》:本文主要介绍解决Maven项目idea找不到本地仓库jar包问题以及使用mvnin... 目录Maven项目idea找不到本地仓库jar包以及使用mvn install:install-file基

最详细安装 PostgreSQL方法及常见问题解决

《最详细安装PostgreSQL方法及常见问题解决》:本文主要介绍最详细安装PostgreSQL方法及常见问题解决,介绍了在Windows系统上安装PostgreSQL及Linux系统上安装Po... 目录一、在 Windows 系统上安装 PostgreSQL1. 下载 PostgreSQL 安装包2.

Mysql如何解决死锁问题

《Mysql如何解决死锁问题》:本文主要介绍Mysql如何解决死锁问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录【一】mysql中锁分类和加锁情况【1】按锁的粒度分类全局锁表级锁行级锁【2】按锁的模式分类【二】加锁方式的影响因素【三】Mysql的死锁情况【1