python使用DataLoader对数据集进行批处理

2024-02-10 04:58

本文主要是介绍python使用DataLoader对数据集进行批处理,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

使用DataLoader对数据集进行批处理,转自

https://www.cnblogs.com/JeasonIsCoding/p/10168753.html

第一步:创建torch能够识别的数据集类型

首先建立两个向量X和Y,一个作为输入的数据,一个作为正确的结果:

import torch
import torch.utils.data as DataBATCH_SIZE = 3 		# 批训练的数据个数x = torch.linspace(1,9,9)  # x data (torch tensor)
y = torch.linspace(9,1,9)  # y data (torch tensor)

随后把X和Y组成一个完整的数据集,并转化为pytorch能识别的数据集类型:

# 先转换成 torch 能够识别的 Dataset
torch_dataset = Data.TensorDataset( x, y )

现在来看一下这些数据的数据类型:

In [1]:  type(torch_dataset)
out[1]:  torch.utils.data.dataset.TensorDatasetIn [2]:  type(x)
out[2]:  torch.Tensor

可以看出X和Y通过Data.TensorDataset() 这个函数拼装成了一个数据集,数据集的类型是TensorDataset

第二步:把上一步的数据集放入Data.DataLoader中,生成一个迭代器,从而方便进行批处理

# 把 dataset 放入 Dataloader
loader = Data.DataLoader(dataset = torch_dataset,# torch TensorDataset formatbatch_size = BATCH_SIZE,#mini batch sizeshuffle = True, # 是否打乱数据num_workers = 2, # 多线程来读数据
)

DataLoader中也有很多其他参数:

dataset:		Dataset类型,从其中加载数据 
batch_size:	int,可选。每个batch加载多少样本 
shuffle:		bool,可选。为True时表示每个epoch都对数据进行洗牌 
sampler:		Sampler,可选。从数据集中采样样本的方法。 
num_workers:	int,可选。加载数据时使用多少子进程。默认值为0,表示在主进程中加载数据。 
collate_fn:	callable,可选。 
pin_memory:	bool,可选 
drop_last:		bool,可选。True表示如果最后剩下不完全的batch,丢弃。False表示不丢弃。

第三步:用上面定义好的迭代器进行训练

这里利用print来模拟训练过程:

for epoch in range(5): 		# 训练所有数据5次i = 0for batch_x,batch_y in loader:i = i+1print('Epoch:{}|num:{}|batch_x:{}|batch_y:{}'.format(epoch,i,batch_x,batch_y))

为了便于观察分批的结果,这里设置:

shuffle = False, # 是否打乱数据

即:

# 把 dataset 放入 Dataloader
loader = Data.DataLoader(dataset = torch_dataset,# torch TensorDataset formatbatch_size = BATCH_SIZE,#mini batch sizeshuffle = False, # 是否打乱数据num_workers = 2, # 多线程来读数据
)

输出结果是:

Epoch:0|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:0|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:0|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:1|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:1|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:1|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:2|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:2|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:2|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:3|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:3|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:3|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])
Epoch:4|num:1|batch_x:tensor([1., 2., 3.])|batch_y:tensor([9., 8., 7.])
Epoch:4|num:2|batch_x:tensor([4., 5., 6.])|batch_y:tensor([6., 5., 4.])
Epoch:4|num:3|batch_x:tensor([7., 8., 9.])|batch_y:tensor([3., 2., 1.])

可以看到,所有数据一共训练了5次。数据中一共9组,设置的mini-batch是3,即每一次训练网络的时候送入3组数据。

此外,还可以利用python中的enumerate(),是对所有可以迭代的数据类型(含有很多东西的list等等)进行取操作的函数,用法如下:

for epoch in range(5): 		# 训练所有数据5次i = 0for step,(batch_x,batch_y) in enumerate(loader):# 假设这里在进行训练i = i+1# 打印一些数据print('Epoch:{}|num:{}|batch_x:{}|batch_y:{}'.format(epoch,i,batch_x,batch_y))

这篇关于python使用DataLoader对数据集进行批处理的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/696177

相关文章

使用animation.css库快速实现CSS3旋转动画效果

《使用animation.css库快速实现CSS3旋转动画效果》随着Web技术的不断发展,动画效果已经成为了网页设计中不可或缺的一部分,本文将深入探讨animation.css的工作原理,如何使用以及... 目录1. css3动画技术简介2. animation.css库介绍2.1 animation.cs

Java进行日期解析与格式化的实现代码

《Java进行日期解析与格式化的实现代码》使用Java搭配ApacheCommonsLang3和Natty库,可以实现灵活高效的日期解析与格式化,本文将通过相关示例为大家讲讲具体的实践操作,需要的可以... 目录一、背景二、依赖介绍1. Apache Commons Lang32. Natty三、核心实现代

使用雪花算法产生id导致前端精度缺失问题解决方案

《使用雪花算法产生id导致前端精度缺失问题解决方案》雪花算法由Twitter提出,设计目的是生成唯一的、递增的ID,下面:本文主要介绍使用雪花算法产生id导致前端精度缺失问题的解决方案,文中通过代... 目录一、问题根源二、解决方案1. 全局配置Jackson序列化规则2. 实体类必须使用Long封装类3.

Python文件操作与IO流的使用方式

《Python文件操作与IO流的使用方式》:本文主要介绍Python文件操作与IO流的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、python文件操作基础1. 打开文件2. 关闭文件二、文件读写操作1.www.chinasem.cn 读取文件2. 写

SpringBoot实现接口数据加解密的三种实战方案

《SpringBoot实现接口数据加解密的三种实战方案》在金融支付、用户隐私信息传输等场景中,接口数据若以明文传输,极易被中间人攻击窃取,SpringBoot提供了多种优雅的加解密实现方案,本文将从原... 目录一、为什么需要接口数据加解密?二、核心加解密算法选择1. 对称加密(AES)2. 非对称加密(R

详解如何在SpringBoot控制器中处理用户数据

《详解如何在SpringBoot控制器中处理用户数据》在SpringBoot应用开发中,控制器(Controller)扮演着至关重要的角色,它负责接收用户请求、处理数据并返回响应,本文将深入浅出地讲解... 目录一、获取请求参数1.1 获取查询参数1.2 获取路径参数二、处理表单提交2.1 处理表单数据三、

PyQt6中QMainWindow组件的使用详解

《PyQt6中QMainWindow组件的使用详解》QMainWindow是PyQt6中用于构建桌面应用程序的基础组件,本文主要介绍了PyQt6中QMainWindow组件的使用,具有一定的参考价值,... 目录1. QMainWindow 组php件概述2. 使用 QMainWindow3. QMainW

使用Python自动化生成PPT并结合LLM生成内容的代码解析

《使用Python自动化生成PPT并结合LLM生成内容的代码解析》PowerPoint是常用的文档工具,但手动设计和排版耗时耗力,本文将展示如何通过Python自动化提取PPT样式并生成新PPT,同时... 目录核心代码解析1. 提取 PPT 样式到 jsON关键步骤:代码片段:2. 应用 JSON 样式到

python通过curl实现访问deepseek的API

《python通过curl实现访问deepseek的API》这篇文章主要为大家详细介绍了python如何通过curl实现访问deepseek的API,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编... API申请和充值下面是deepeek的API网站https://platform.deepsee

java变量内存中存储的使用方式

《java变量内存中存储的使用方式》:本文主要介绍java变量内存中存储的使用方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、介绍2、变量的定义3、 变量的类型4、 变量的作用域5、 内存中的存储方式总结1、介绍在 Java 中,变量是用于存储程序中数据