DataLoader基础用法

2024-06-09 19:36
文章标签 基础 用法 dataloader

本文主要是介绍DataLoader基础用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DataLoader 是 PyTorch 中一个非常有用的工具,用于将数据集进行批处理,并提供一个迭代器来简化模型训练和评估过程。以下是 DataLoader 的常见用法和功能介绍:

基本用法

  1. 创建数据集
    首先,需要一个数据集。数据集可以是 PyTorch 提供的内置数据集,也可以是自定义的数据集。数据集需要继承 torch.utils.data.Dataset 并实现 __len____getitem__ 方法。

    import torch
    import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
    
  2. 创建 DataLoader
    DataLoader 用于将数据集封装成批次,并提供一个迭代器来进行数据的加载。常见的参数包括数据集、批量大小、是否打乱数据、使用的进程数等。

    enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
    
  3. 迭代数据
    使用 DataLoader 的迭代器来访问批次数据。

    for batch in loader:enc_batch, dec_batch, output_batch = batchprint(enc_batch)print(dec_batch)print(output_batch)
    

常见参数

  1. dataset

    • 数据集对象,必须继承 torch.utils.data.Dataset 类。
  2. batch_size

    • 每个批次的大小,默认为 1。
  3. shuffle

    • 是否在每个 epoch 开始时打乱数据,默认为 False
  4. num_workers

    • 使用多少个子进程来加载数据。0 表示数据将在主进程中加载。对于大型数据集,增加 num_workers 可以加快数据加载速度。
  5. drop_last

    • 如果设置为 True,则丢弃不能整除 batch_size 的最后一个不完整的批次。
  6. pin_memory

    • 如果设置为 True,DataLoader 将在返回前将张量复制到 CUDA 固定内存中。这对 GPU 训练有所帮助。

进阶用法

  1. 自定义 collate_fn

    • collate_fn 用于指定如何将多个样本合并成一个批次。默认情况下,DataLoader 将使用 default_collate,它会将相同类型的数据合并在一起。例如,所有张量数据将合并成一个张量。
    def my_collate_fn(batch):enc_inputs, dec_inputs, dec_outputs = zip(*batch)return torch.stack(enc_inputs, 0), torch.stack(dec_inputs, 0), torch.stack(dec_outputs, 0)loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
    
  2. 使用 Sampler

    • Sampler 用于指定如何抽样数据。PyTorch 提供了一些内置的采样器,如 RandomSamplerSequentialSampler
    from torch.utils.data.sampler import RandomSamplersampler = RandomSampler(dataset)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, sampler=sampler)
    

完整示例

import torch
import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)for batch in loader:enc_batch, dec_batch, output_batch = batchprint("Encoder batch:", enc_batch)print("Decoder batch:", dec_batch)print("Output batch:", output_batch)

通过使用 DataLoader,我们可以轻松地处理和批量化我们的数据,这对于大型数据集和深度学习模型的训练是非常重要的。

这篇关于DataLoader基础用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1046105

相关文章

Android协程高级用法大全

《Android协程高级用法大全》这篇文章给大家介绍Android协程高级用法大全,本文结合实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友跟随小编一起学习吧... 目录1️⃣ 协程作用域(CoroutineScope)与生命周期绑定Activity/Fragment 中手

Python异步编程之await与asyncio基本用法详解

《Python异步编程之await与asyncio基本用法详解》在Python中,await和asyncio是异步编程的核心工具,用于高效处理I/O密集型任务(如网络请求、文件读写、数据库操作等),接... 目录一、核心概念二、使用场景三、基本用法1. 定义协程2. 运行协程3. 并发执行多个任务四、关键

从基础到进阶详解Python条件判断的实用指南

《从基础到进阶详解Python条件判断的实用指南》本文将通过15个实战案例,带你大家掌握条件判断的核心技巧,并从基础语法到高级应用一网打尽,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一... 目录​引言:条件判断为何如此重要一、基础语法:三行代码构建决策系统二、多条件分支:elif的魔法三、

Python WebSockets 库从基础到实战使用举例

《PythonWebSockets库从基础到实战使用举例》WebSocket是一种全双工、持久化的网络通信协议,适用于需要低延迟的应用,如实时聊天、股票行情推送、在线协作、多人游戏等,本文给大家介... 目录1. 引言2. 为什么使用 WebSocket?3. 安装 WebSockets 库4. 使用 We

Python中yield的用法和实际应用示例

《Python中yield的用法和实际应用示例》在Python中,yield关键字主要用于生成器函数(generatorfunctions)中,其目的是使函数能够像迭代器一样工作,即可以被遍历,但不会... 目录python中yield的用法详解一、引言二、yield的基本用法1、yield与生成器2、yi

深度解析Python yfinance的核心功能和高级用法

《深度解析Pythonyfinance的核心功能和高级用法》yfinance是一个功能强大且易于使用的Python库,用于从YahooFinance获取金融数据,本教程将深入探讨yfinance的核... 目录yfinance 深度解析教程 (python)1. 简介与安装1.1 什么是 yfinance?

Python库 Django 的简介、安装、用法入门教程

《Python库Django的简介、安装、用法入门教程》Django是Python最流行的Web框架之一,它帮助开发者快速、高效地构建功能强大的Web应用程序,接下来我们将从简介、安装到用法详解,... 目录一、Django 简介 二、Django 的安装教程 1. 创建虚拟环境2. 安装Django三、创

python中update()函数的用法和一些例子

《python中update()函数的用法和一些例子》update()方法是字典对象的方法,用于将一个字典中的键值对更新到另一个字典中,:本文主要介绍python中update()函数的用法和一些... 目录前言用法注意事项示例示例 1: 使用另一个字典来更新示例 2: 使用可迭代对象来更新示例 3: 使用

python连接sqlite3简单用法完整例子

《python连接sqlite3简单用法完整例子》SQLite3是一个内置的Python模块,可以通过Python的标准库轻松地使用,无需进行额外安装和配置,:本文主要介绍python连接sqli... 目录1. 连接到数据库2. 创建游标对象3. 创建表4. 插入数据5. 查询数据6. 更新数据7. 删除

Python中的sort()和sorted()用法示例解析

《Python中的sort()和sorted()用法示例解析》本文给大家介绍Python中list.sort()和sorted()的使用区别,详细介绍其参数功能及Timsort排序算法特性,涵盖自适应... 目录一、list.sort()参数说明常用内置函数基本用法示例自定义函数示例lambda表达式示例o