DataLoader 的 collate_fn 解释与示例教程

2024-04-09 04:28

本文主要是介绍DataLoader 的 collate_fn 解释与示例教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 导包
    • 数据
    • Dataloader
    • collate_fn

导包

import torch
from torch.utils.data import Dataset
from typing import Any

数据

class CustomDataset(Dataset):def __init__(self, length) -> None:super().__init__()self.length = lengthdef __getitem__(self, index=None):w1 = 3.14w2 = 4.27w = torch.tensor([w1, w2])feature = torch.rand(2) * 10noise = torch.randn_like(feature) * 0.01label = torch.matmul(w, feature.t())feature += noise# return feature, label.view(1)return feature, labeldef __len__(self):return self.lengthdataset = CustomDataset(4)

Dataloader

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, )for feature, label in dataloader:print(feature.shape, label.shape)

下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature),构建出 batch 数据;

torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])

常量直接拼接;
向量则会在前面添加一个 batch 纬度;

collate_fn

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;

如上述默认的输出结果所示:label.shape 为 torch.Size([2]),笔者想通过 collate_fn 修改 label.shapetorch.Size([2, 1]),下述代码实现这个功能:

def collate_fn(item):feature, label = zip(*item)feature = torch.stack(feature)label = torch.stack(label)label = label.view(-1, 1)return feature, labeldataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)for feature, label in dataloader:print(feature.shape, label.shape)

输出如下:

torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])

collate_fn(item),传入的item的数据为:

[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]

feature, label = zip(*item) 故通过zip(*item)的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack方法将其拼接出 batch 形状的数据。

这篇关于DataLoader 的 collate_fn 解释与示例教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/887174

相关文章

详解SpringBoot+Ehcache使用示例

《详解SpringBoot+Ehcache使用示例》本文介绍了SpringBoot中配置Ehcache、自定义get/set方式,并实际使用缓存的过程,文中通过示例代码介绍的非常详细,对大家的学习或者... 目录摘要概念内存与磁盘持久化存储:配置灵活性:编码示例引入依赖:配置ehcache.XML文件:配置

Java高效实现PowerPoint转PDF的示例详解

《Java高效实现PowerPoint转PDF的示例详解》在日常开发或办公场景中,经常需要将PowerPoint演示文稿(PPT/PPTX)转换为PDF,本文将介绍从基础转换到高级设置的多种用法,大家... 目录为什么要将 PowerPoint 转换为 PDF安装 Spire.Presentation fo

全网最全Tomcat完全卸载重装教程小结

《全网最全Tomcat完全卸载重装教程小结》windows系统卸载Tomcat重新通过ZIP方式安装Tomcat,优点是灵活可控,适合开发者自定义配置,手动配置环境变量后,可通过命令行快速启动和管理... 目录一、完全卸载Tomcat1. 停止Tomcat服务2. 通过控制面板卸载3. 手动删除残留文件4.

Python中isinstance()函数原理解释及详细用法示例

《Python中isinstance()函数原理解释及详细用法示例》isinstance()是Python内置的一个非常有用的函数,用于检查一个对象是否属于指定的类型或类型元组中的某一个类型,它是Py... 目录python中isinstance()函数原理解释及详细用法指南一、isinstance()函数

python中的高阶函数示例详解

《python中的高阶函数示例详解》在Python中,高阶函数是指接受函数作为参数或返回函数作为结果的函数,下面:本文主要介绍python中高阶函数的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录1.定义2.map函数3.filter函数4.reduce函数5.sorted函数6.自定义高阶函数

Python的pandas库基础知识超详细教程

《Python的pandas库基础知识超详细教程》Pandas是Python数据处理核心库,提供Series和DataFrame结构,支持CSV/Excel/SQL等数据源导入及清洗、合并、统计等功能... 目录一、配置环境二、序列和数据表2.1 初始化2.2  获取数值2.3 获取索引2.4 索引取内容2

Vue实现路由守卫的示例代码

《Vue实现路由守卫的示例代码》Vue路由守卫是控制页面导航的钩子函数,主要用于鉴权、数据预加载等场景,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着... 目录一、概念二、类型三、实战一、概念路由守卫(Navigation Guards)本质上就是 在路

JAVA实现Token自动续期机制的示例代码

《JAVA实现Token自动续期机制的示例代码》本文主要介绍了JAVA实现Token自动续期机制的示例代码,通过动态调整会话生命周期平衡安全性与用户体验,解决固定有效期Token带来的风险与不便,感兴... 目录1. 固定有效期Token的内在局限性2. 自动续期机制:兼顾安全与体验的解决方案3. 总结PS

python依赖管理工具UV的安装和使用教程

《python依赖管理工具UV的安装和使用教程》UV是一个用Rust编写的Python包安装和依赖管理工具,比传统工具(如pip)有着更快、更高效的体验,:本文主要介绍python依赖管理工具UV... 目录前言一、命令安装uv二、手动编译安装2.1在archlinux安装uv的依赖工具2.2从github

C#中通过Response.Headers设置自定义参数的代码示例

《C#中通过Response.Headers设置自定义参数的代码示例》:本文主要介绍C#中通过Response.Headers设置自定义响应头的方法,涵盖基础添加、安全校验、生产实践及调试技巧,强... 目录一、基础设置方法1. 直接添加自定义头2. 批量设置模式二、高级配置技巧1. 安全校验机制2. 类型