pytorch 函数DataLoader

2024-04-02 13:08
文章标签 函数 pytorch dataloader

本文主要是介绍pytorch 函数DataLoader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Dataset(https://blog.csdn.net/TH_NUM/article/details/80877196)只负责数据的抽象,一次调用getitem只返回一个样本。前面提到过,在训练神经网络时,最好是对一个batch的数据进行操作,同时还需要对数据进行shuffle和并行加速等。对此,PyTorch提供了DataLoader帮助我们实现这些功能。

DataLoader的函数定义如下:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, 
num_workers=0, collate_fn=default_collate, pin_memory=False, 
drop_last=False)

dataset:加载的数据集(Dataset对象)
batch_size:batch size
shuffle::是否将数据打乱
sampler: 样本抽样,后续会详细介绍
num_workers:使用多进程加载的进程数,0代表不使用多进程
collate_fn: 如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms
from torch.utils.data import DataLoader#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([transforms.RandomSizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]normalize
])dataset=ImageFolder('data/dogcat_2/',transform=transform)#dataloader是一个可迭代的对象,意味着我们可以像使用迭代器一样使用它 或者 or batch_datas, batch_labels in dataloader:
dataloader = DataLoader(dataset, batch_size=3, shuffle=True, num_workers=0, drop_last=False)dataiter = iter(dataloader)
imgs, labels = next(dataiter)
print(imgs.size()) # batch_size, channel, height, weight
#输出 torch.Size([3, 3, 224, 224])

在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在_ getitem _函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。

'''
在数据处理中,有时会出现某个样本无法读取等问题,比如某张图片损坏。这时在__getitem__函数中将出现异常,此时最好的解决方案即是将出错的样本剔除。如果实在是遇到这种情况无法处理,则可以返回None对象,然后在Dataloader中实现自定义的collate_fn,将空对象过滤掉。但要注意,在这种情况下dataloader返回的batch数目会少于batch_size。
'''
from dataSet import *
import random
class NewDogCat(DogCat): # 继承前面实现的DogCat数据集def __getitem__(self, index):try:# 调用父类的获取函数,即 DogCat.__getitem__(self, index)return super(NewDogCat,self).__getitem__(index)except:#对于诸如样本损坏或数据集加载异常等情况,还可以通过其它方式解决。例如但凡遇到异常情况,就随机取一张图片代替:new_index = random.randint(0, len(self) - 1)return self[new_index]from torch.utils.data.dataloader import default_collate # 导入默认的拼接方式
from torch.utils.data import DataLoader
def my_collate_fn(batch):'''batch中每个元素形如(data, label)'''# 过滤为None的数据batch = list(filter(lambda x:x[0] is not None, batch))if len(batch) == 0: return torch.Tensor()return default_collate(batch) # 用默认方式拼接过滤后的batch数据transform=transforms.Compose([transforms.Resize(224), #缩放图片,保持长宽比不变,最短边的长为224像素,transforms.CenterCrop(224), #从中间切出 224*224的图片transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至[-1,1]
])dataset = NewDogCat(root='data/dogcat_wrong/', transform=transform)#print(dataSet[11])
dataloader = DataLoader(dataset, 2, collate_fn=my_collate_fn, num_workers=1,shuffle=True)
for batch_datas, batch_labels in dataloader:print(batch_datas.size(),batch_labels.size())

github 地址:https://github.com/WebLearning17/CommonTool

参考:https://github.com/chenyuntc/pytorch-book/blob/master/chapter5-%E5%B8%B8%E7%94%A8%E5%B7%A5%E5%85%B7/chapter5.ipynb

这篇关于pytorch 函数DataLoader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/870066

相关文章

MySQL count()聚合函数详解

《MySQLcount()聚合函数详解》MySQL中的COUNT()函数,它是SQL中最常用的聚合函数之一,用于计算表中符合特定条件的行数,本文给大家介绍MySQLcount()聚合函数,感兴趣的朋... 目录核心功能语法形式重要特性与行为如何选择使用哪种形式?总结深入剖析一下 mysql 中的 COUNT

MySQL 中 ROW_NUMBER() 函数最佳实践

《MySQL中ROW_NUMBER()函数最佳实践》MySQL中ROW_NUMBER()函数,作为窗口函数为每行分配唯一连续序号,区别于RANK()和DENSE_RANK(),特别适合分页、去重... 目录mysql 中 ROW_NUMBER() 函数详解一、基础语法二、核心特点三、典型应用场景1. 数据分

MySQL数据库的内嵌函数和联合查询实例代码

《MySQL数据库的内嵌函数和联合查询实例代码》联合查询是一种将多个查询结果组合在一起的方法,通常使用UNION、UNIONALL、INTERSECT和EXCEPT关键字,下面:本文主要介绍MyS... 目录一.数据库的内嵌函数1.1聚合函数COUNT([DISTINCT] expr)SUM([DISTIN

Python get()函数用法案例详解

《Pythonget()函数用法案例详解》在Python中,get()是字典(dict)类型的内置方法,用于安全地获取字典中指定键对应的值,它的核心作用是避免因访问不存在的键而引发KeyError错... 目录简介基本语法一、用法二、案例:安全访问未知键三、案例:配置参数默认值简介python是一种高级编

python 常见数学公式函数使用详解(最新推荐)

《python常见数学公式函数使用详解(最新推荐)》文章介绍了Python的数学计算工具,涵盖内置函数、math/cmath标准库及numpy/scipy/sympy第三方库,支持从基础算术到复杂数... 目录python 数学公式与函数大全1. 基本数学运算1.1 算术运算1.2 分数与小数2. 数学函数

Python中help()和dir()函数的使用

《Python中help()和dir()函数的使用》我们经常需要查看某个对象(如模块、类、函数等)的属性和方法,Python提供了两个内置函数help()和dir(),它们可以帮助我们快速了解代... 目录1. 引言2. help() 函数2.1 作用2.2 使用方法2.3 示例(1) 查看内置函数的帮助(

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

Python中bisect_left 函数实现高效插入与有序列表管理

《Python中bisect_left函数实现高效插入与有序列表管理》Python的bisect_left函数通过二分查找高效定位有序列表插入位置,与bisect_right的区别在于处理重复元素时... 目录一、bisect_left 基本介绍1.1 函数定义1.2 核心功能二、bisect_left 与

java中BigDecimal里面的subtract函数介绍及实现方法

《java中BigDecimal里面的subtract函数介绍及实现方法》在Java中实现减法操作需要根据数据类型选择不同方法,主要分为数值型减法和字符串减法两种场景,本文给大家介绍java中BigD... 目录Java中BigDecimal里面的subtract函数的意思?一、数值型减法(高精度计算)1.

C++/类与对象/默认成员函数@构造函数的用法

《C++/类与对象/默认成员函数@构造函数的用法》:本文主要介绍C++/类与对象/默认成员函数@构造函数的用法,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录名词概念默认成员函数构造函数概念函数特征显示构造函数隐式构造函数总结名词概念默认构造函数:不用传参就可以