DataLoader 的 collate_fn 解释与示例教程

2024-04-09 04:28

本文主要是介绍DataLoader 的 collate_fn 解释与示例教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 导包
    • 数据
    • Dataloader
    • collate_fn

导包

import torch
from torch.utils.data import Dataset
from typing import Any

数据

class CustomDataset(Dataset):def __init__(self, length) -> None:super().__init__()self.length = lengthdef __getitem__(self, index=None):w1 = 3.14w2 = 4.27w = torch.tensor([w1, w2])feature = torch.rand(2) * 10noise = torch.randn_like(feature) * 0.01label = torch.matmul(w, feature.t())feature += noise# return feature, label.view(1)return feature, labeldef __len__(self):return self.lengthdataset = CustomDataset(4)

Dataloader

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, )for feature, label in dataloader:print(feature.shape, label.shape)

下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature),构建出 batch 数据;

torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])

常量直接拼接;
向量则会在前面添加一个 batch 纬度;

collate_fn

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;

如上述默认的输出结果所示:label.shape 为 torch.Size([2]),笔者想通过 collate_fn 修改 label.shapetorch.Size([2, 1]),下述代码实现这个功能:

def collate_fn(item):feature, label = zip(*item)feature = torch.stack(feature)label = torch.stack(label)label = label.view(-1, 1)return feature, labeldataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)for feature, label in dataloader:print(feature.shape, label.shape)

输出如下:

torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])

collate_fn(item),传入的item的数据为:

[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]

feature, label = zip(*item) 故通过zip(*item)的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack方法将其拼接出 batch 形状的数据。

这篇关于DataLoader 的 collate_fn 解释与示例教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/887174

相关文章

OpenCV实现实时颜色检测的示例

《OpenCV实现实时颜色检测的示例》本文主要介绍了OpenCV实现实时颜色检测的示例,通过HSV色彩空间转换和色调范围判断实现红黄绿蓝颜色检测,包含视频捕捉、区域标记、颜色分析等功能,具有一定的参考... 目录一、引言二、系统概述三、代码解析1. 导入库2. 颜色识别函数3. 主程序循环四、HSV色彩空间

C++ 函数 strftime 和时间格式示例详解

《C++函数strftime和时间格式示例详解》strftime是C/C++标准库中用于格式化日期和时间的函数,定义在ctime头文件中,它将tm结构体中的时间信息转换为指定格式的字符串,是处理... 目录C++ 函数 strftipythonme 详解一、函数原型二、功能描述三、格式字符串说明四、返回值五

LiteFlow轻量级工作流引擎使用示例详解

《LiteFlow轻量级工作流引擎使用示例详解》:本文主要介绍LiteFlow是一个灵活、简洁且轻量的工作流引擎,适合用于中小型项目和微服务架构中的流程编排,本文给大家介绍LiteFlow轻量级工... 目录1. LiteFlow 主要特点2. 工作流定义方式3. LiteFlow 流程示例4. LiteF

Nexus安装和启动的实现教程

《Nexus安装和启动的实现教程》:本文主要介绍Nexus安装和启动的实现教程,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、Nexus下载二、Nexus安装和启动三、关闭Nexus总结一、Nexus下载官方下载链接:DownloadWindows系统根

MyBatis ResultMap 的基本用法示例详解

《MyBatisResultMap的基本用法示例详解》在MyBatis中,resultMap用于定义数据库查询结果到Java对象属性的映射关系,本文给大家介绍MyBatisResultMap的基本... 目录MyBATis 中的 resultMap1. resultMap 的基本语法2. 简单的 resul

Mybatis Plus Join使用方法示例详解

《MybatisPlusJoin使用方法示例详解》:本文主要介绍MybatisPlusJoin使用方法示例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,... 目录1、pom文件2、yaml配置文件3、分页插件4、示例代码:5、测试代码6、和PageHelper结合6

MySQL JSON 查询中的对象与数组技巧及查询示例

《MySQLJSON查询中的对象与数组技巧及查询示例》MySQL中JSON对象和JSON数组查询的详细介绍及带有WHERE条件的查询示例,本文给大家介绍的非常详细,mysqljson查询示例相关知... 目录jsON 对象查询1. JSON_CONTAINS2. JSON_EXTRACT3. JSON_TA

使用SpringBoot整合Sharding Sphere实现数据脱敏的示例

《使用SpringBoot整合ShardingSphere实现数据脱敏的示例》ApacheShardingSphere数据脱敏模块,通过SQL拦截与改写实现敏感信息加密存储,解决手动处理繁琐及系统改... 目录痛点一:痛点二:脱敏配置Quick Start——Spring 显示配置:1.引入依赖2.创建脱敏

SpringBoot 中 CommandLineRunner的作用示例详解

《SpringBoot中CommandLineRunner的作用示例详解》SpringBoot提供的一种简单的实现方案就是添加一个model并实现CommandLineRunner接口,实现功能的... 目录1、CommandLineRunnerSpringBoot中CommandLineRunner的作用

Java死锁问题解决方案及示例详解

《Java死锁问题解决方案及示例详解》死锁是指两个或多个线程因争夺资源而相互等待,导致所有线程都无法继续执行的一种状态,本文给大家详细介绍了Java死锁问题解决方案详解及实践样例,需要的朋友可以参考下... 目录1、简述死锁的四个必要条件:2、死锁示例代码3、如何检测死锁?3.1 使用 jstack3.2