DataLoader 的 collate_fn 解释与示例教程

2024-04-09 04:28

本文主要是介绍DataLoader 的 collate_fn 解释与示例教程,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 导包
    • 数据
    • Dataloader
    • collate_fn

导包

import torch
from torch.utils.data import Dataset
from typing import Any

数据

class CustomDataset(Dataset):def __init__(self, length) -> None:super().__init__()self.length = lengthdef __getitem__(self, index=None):w1 = 3.14w2 = 4.27w = torch.tensor([w1, w2])feature = torch.rand(2) * 10noise = torch.randn_like(feature) * 0.01label = torch.matmul(w, feature.t())feature += noise# return feature, label.view(1)return feature, labeldef __len__(self):return self.lengthdataset = CustomDataset(4)

Dataloader

dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, )for feature, label in dataloader:print(feature.shape, label.shape)

下述展示了,默认的 Dataload 的处理结果:
通过 torch.stack(feature),构建出 batch 数据;

torch.Size([2, 2]) torch.Size([2])
torch.Size([2, 2]) torch.Size([2])

常量直接拼接;
向量则会在前面添加一个 batch 纬度;

collate_fn

collate_fn:返回值为最终构建的batch数据;在这一步中处理dataset的数据,将其调整成我们期望的数据格式;

如上述默认的输出结果所示:label.shape 为 torch.Size([2]),笔者想通过 collate_fn 修改 label.shapetorch.Size([2, 1]),下述代码实现这个功能:

def collate_fn(item):feature, label = zip(*item)feature = torch.stack(feature)label = torch.stack(label)label = label.view(-1, 1)return feature, labeldataloader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=collate_fn)for feature, label in dataloader:print(feature.shape, label.shape)

输出如下:

torch.Size([2, 2]) torch.Size([2, 1])
torch.Size([2, 2]) torch.Size([2, 1])

collate_fn(item),传入的item的数据为:

[(tensor([[6.9436, 7.2040]]), tensor([[52.6007]])), (tensor([[7.1495, 2.8882]]), tensor([[34.7427]]))]
[(tensor([[1.5311, 9.9278]]), tensor([[47.1995]])), (tensor([[4.9614, 8.6232]]), tensor([[52.3849]]))]

feature, label = zip(*item) 故通过zip(*item)的方式,拆分出 feature 和 label 各自的数据,再借助torch.stack方法将其拼接出 batch 形状的数据。

这篇关于DataLoader 的 collate_fn 解释与示例教程的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/887174

相关文章

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

Python中的Walrus运算符分析示例详解

《Python中的Walrus运算符分析示例详解》Python中的Walrus运算符(:=)是Python3.8引入的一个新特性,允许在表达式中同时赋值和返回值,它的核心作用是减少重复计算,提升代码简... 目录1. 在循环中避免重复计算2. 在条件判断中同时赋值变量3. 在列表推导式或字典推导式中简化逻辑

Python位移操作和位运算的实现示例

《Python位移操作和位运算的实现示例》本文主要介绍了Python位移操作和位运算的实现示例,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一... 目录1. 位移操作1.1 左移操作 (<<)1.2 右移操作 (>>)注意事项:2. 位运算2.1

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删

pandas中位数填充空值的实现示例

《pandas中位数填充空值的实现示例》中位数填充是一种简单而有效的方法,用于填充数据集中缺失的值,本文就来介绍一下pandas中位数填充空值的实现,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是中位数填充?为什么选择中位数填充?示例数据结果分析完整代码总结在数据分析和机器学习过程中,处理缺失数

Pandas统计每行数据中的空值的方法示例

《Pandas统计每行数据中的空值的方法示例》处理缺失数据(NaN值)是一个非常常见的问题,本文主要介绍了Pandas统计每行数据中的空值的方法示例,具有一定的参考价值,感兴趣的可以了解一下... 目录什么是空值?为什么要统计空值?准备工作创建示例数据统计每行空值数量进一步分析www.chinasem.cn处

利用Python调试串口的示例代码

《利用Python调试串口的示例代码》在嵌入式开发、物联网设备调试过程中,串口通信是最基础的调试手段本文将带你用Python+ttkbootstrap打造一款高颜值、多功能的串口调试助手,需要的可以了... 目录概述:为什么需要专业的串口调试工具项目架构设计1.1 技术栈选型1.2 关键类说明1.3 线程模

如何为Yarn配置国内源的详细教程

《如何为Yarn配置国内源的详细教程》在使用Yarn进行项目开发时,由于网络原因,直接使用官方源可能会导致下载速度慢或连接失败,配置国内源可以显著提高包的下载速度和稳定性,本文将详细介绍如何为Yarn... 目录一、查询当前使用的镜像源二、设置国内源1. 设置为淘宝镜像源2. 设置为其他国内源三、还原为官方

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Android实现在线预览office文档的示例详解

《Android实现在线预览office文档的示例详解》在移动端展示在线Office文档(如Word、Excel、PPT)是一项常见需求,这篇文章为大家重点介绍了两种方案的实现方法,希望对大家有一定的... 目录一、项目概述二、相关技术知识三、实现思路3.1 方案一:WebView + Office Onl