DataLoader基础用法

2024-06-09 19:36
文章标签 基础 用法 dataloader

本文主要是介绍DataLoader基础用法,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

DataLoader 是 PyTorch 中一个非常有用的工具,用于将数据集进行批处理,并提供一个迭代器来简化模型训练和评估过程。以下是 DataLoader 的常见用法和功能介绍:

基本用法

  1. 创建数据集
    首先,需要一个数据集。数据集可以是 PyTorch 提供的内置数据集,也可以是自定义的数据集。数据集需要继承 torch.utils.data.Dataset 并实现 __len____getitem__ 方法。

    import torch
    import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]
    
  2. 创建 DataLoader
    DataLoader 用于将数据集封装成批次,并提供一个迭代器来进行数据的加载。常见的参数包括数据集、批量大小、是否打乱数据、使用的进程数等。

    enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
    dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)
    
  3. 迭代数据
    使用 DataLoader 的迭代器来访问批次数据。

    for batch in loader:enc_batch, dec_batch, output_batch = batchprint(enc_batch)print(dec_batch)print(output_batch)
    

常见参数

  1. dataset

    • 数据集对象,必须继承 torch.utils.data.Dataset 类。
  2. batch_size

    • 每个批次的大小,默认为 1。
  3. shuffle

    • 是否在每个 epoch 开始时打乱数据,默认为 False
  4. num_workers

    • 使用多少个子进程来加载数据。0 表示数据将在主进程中加载。对于大型数据集,增加 num_workers 可以加快数据加载速度。
  5. drop_last

    • 如果设置为 True,则丢弃不能整除 batch_size 的最后一个不完整的批次。
  6. pin_memory

    • 如果设置为 True,DataLoader 将在返回前将张量复制到 CUDA 固定内存中。这对 GPU 训练有所帮助。

进阶用法

  1. 自定义 collate_fn

    • collate_fn 用于指定如何将多个样本合并成一个批次。默认情况下,DataLoader 将使用 default_collate,它会将相同类型的数据合并在一起。例如,所有张量数据将合并成一个张量。
    def my_collate_fn(batch):enc_inputs, dec_inputs, dec_outputs = zip(*batch)return torch.stack(enc_inputs, 0), torch.stack(dec_inputs, 0), torch.stack(dec_outputs, 0)loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=my_collate_fn)
    
  2. 使用 Sampler

    • Sampler 用于指定如何抽样数据。PyTorch 提供了一些内置的采样器,如 RandomSamplerSequentialSampler
    from torch.utils.data.sampler import RandomSamplersampler = RandomSampler(dataset)
    loader = Data.DataLoader(dataset=dataset, batch_size=2, sampler=sampler)
    

完整示例

import torch
import torch.utils.data as Dataclass MyDataSet(Data.Dataset):def __init__(self, enc_inputs, dec_inputs, dec_outputs):self.enc_inputs = enc_inputsself.dec_inputs = dec_inputsself.dec_outputs = dec_outputsdef __len__(self):return len(self.enc_inputs)def __getitem__(self, idx):return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]enc_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_inputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
dec_outputs = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])dataset = MyDataSet(enc_inputs, dec_inputs, dec_outputs)
loader = Data.DataLoader(dataset=dataset, batch_size=2, shuffle=True)for batch in loader:enc_batch, dec_batch, output_batch = batchprint("Encoder batch:", enc_batch)print("Decoder batch:", dec_batch)print("Output batch:", output_batch)

通过使用 DataLoader,我们可以轻松地处理和批量化我们的数据,这对于大型数据集和深度学习模型的训练是非常重要的。

这篇关于DataLoader基础用法的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1046105

相关文章

JDK21对虚拟线程的几种用法实践指南

《JDK21对虚拟线程的几种用法实践指南》虚拟线程是Java中的一种轻量级线程,由JVM管理,特别适合于I/O密集型任务,:本文主要介绍JDK21对虚拟线程的几种用法,文中通过代码介绍的非常详细,... 目录一、参考官方文档二、什么是虚拟线程三、几种用法1、Thread.ofVirtual().start(

从基础到高级详解Go语言中错误处理的实践指南

《从基础到高级详解Go语言中错误处理的实践指南》Go语言采用了一种独特而明确的错误处理哲学,与其他主流编程语言形成鲜明对比,本文将为大家详细介绍Go语言中错误处理详细方法,希望对大家有所帮助... 目录1 Go 错误处理哲学与核心机制1.1 错误接口设计1.2 错误与异常的区别2 错误创建与检查2.1 基础

Java8 Collectors.toMap() 的两种用法

《Java8Collectors.toMap()的两种用法》Collectors.toMap():JDK8中提供,用于将Stream流转换为Map,本文给大家介绍Java8Collector... 目录一、简单介绍用法1:根据某一属性,对对象的实例或属性做映射用法2:根据某一属性,对对象集合进行去重二、Du

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

Python中的sort方法、sorted函数与lambda表达式及用法详解

《Python中的sort方法、sorted函数与lambda表达式及用法详解》文章对比了Python中list.sort()与sorted()函数的区别,指出sort()原地排序返回None,sor... 目录1. sort()方法1.1 sort()方法1.2 基本语法和参数A. reverse参数B.

Spring的基础事务注解@Transactional作用解读

《Spring的基础事务注解@Transactional作用解读》文章介绍了Spring框架中的事务管理,核心注解@Transactional用于声明事务,支持传播机制、隔离级别等配置,结合@Tran... 目录一、事务管理基础1.1 Spring事务的核心注解1.2 注解属性详解1.3 实现原理二、事务事

vue监听属性watch的用法及使用场景详解

《vue监听属性watch的用法及使用场景详解》watch是vue中常用的监听器,它主要用于侦听数据的变化,在数据发生变化的时候执行一些操作,:本文主要介绍vue监听属性watch的用法及使用场景... 目录1. 监听属性 watch2. 常规用法3. 监听对象和route变化4. 使用场景附Watch 的

Java Instrumentation从概念到基本用法详解

《JavaInstrumentation从概念到基本用法详解》JavaInstrumentation是java.lang.instrument包提供的API,允许开发者在类被JVM加载时对其进行修改... 目录一、什么是 Java Instrumentation主要用途二、核心概念1. Java Agent

Java 中 Optional 的用法及最佳实践

《Java中Optional的用法及最佳实践》在Java开发中,空指针异常(NullPointerException)是开发者最常遇到的问题之一,本篇文章将详细讲解Optional的用法、常用方... 目录前言1. 什么是 Optional?主要特性:2. Optional 的基本用法2.1 创建 Opti

Java中最全最基础的IO流概述和简介案例分析

《Java中最全最基础的IO流概述和简介案例分析》JavaIO流用于程序与外部设备的数据交互,分为字节流(InputStream/OutputStream)和字符流(Reader/Writer),处理... 目录IO流简介IO是什么应用场景IO流的分类流的超类类型字节文件流应用简介核心API文件输出流应用文