PyTorch中定义自己的数据集

2024-05-09 16:20
文章标签 数据 定义 pytorch

本文主要是介绍PyTorch中定义自己的数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

    • 1. 简介
    • 2. 查看PyTorch自带的数据集(可视化)
    • 3. 准备材料
      • 3.1 图片数据
      • 3.2 标签数据
    • 4. 方法

1. 简介

尽管PyTorch提供了许多自带的数据集,如MNIST、CIFAR-10、ImageNet等,但它们对于没有经验的用户来说,理解数据加载器的工作原理以及如何正确地配置数据加载器可能会有一定难度。 用户需要了解所使用的数据集,包括数据集的内容、结构、标签等信息。对于一些复杂的数据集,用户可能需要理解数据集的结构和标签的含义。通过定义自己的数据集类,您可以更好地控制数据的加载和处理过程,提高代码的灵活性、可读性和可维护性,同时更好地满足模型训练的需求。

2. 查看PyTorch自带的数据集(可视化)

为了更好的定义自己的数据集,我们首先查看PyTorch自带的数据集的内容,代码如下

# 导入所需的库
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于可视化
import torch  # 导入PyTorch库
from torchvision.datasets import MNIST  # 从torchvision中导入MNIST数据集
from torchvision import transforms  # 导入transforms模块,用于数据预处理
import numpy as np  # 导入NumPy库# 加载MNIST数据集
train_mnist_data = MNIST(root='./data',  # 数据集存储路径train=True,  # 加载训练集transform=transforms.Compose([transforms.Resize(size=(28, 28)), transforms.ToTensor()]),  # 数据预处理操作download=True)  # 如果数据集不存在,则自动下载# 设置要显示的样本数量
num_samples = 10# 创建包含多个子图的大图窗口
fig, axes = plt.subplots(1, num_samples, figsize=(10, 6))# 遍历选择要显示的样本
for i in range(num_samples):# 从数据集中获取图像数据和标签image, label = train_mnist_data[i]# 在子图中显示图像axes[i].imshow(image.squeeze().numpy(), cmap='gray')  # 使用imshow函数显示图像,将张量转换为NumPy数组axes[i].set_title(f"Label: {label}")  # 设置子图标题,显示图像对应的标签axes[i].axis('off')  # 关闭坐标轴显示# 将图像保存为PNG格式的图片文件,文件名以图像的标签命名plt.imsave(f"./data/mnist_images/{label}.png", image.squeeze().numpy(), cmap='gray')# 显示图形窗口
plt.show()

这里,我们使用MNIST类加载MNIST数据集。在加载数据集时,通过transform参数指定了数据预处理操作,包括将图像大小调整为28x28像素,并将图像转换为张量。train=True表示加载训练集,download=True表示如果数据集不存在则自动下载到指定的路径。

接下来,我们选择一些样本进行可视化。我们在一个子图中显示了10个样本,每个样本对应一个数字图像和其对应的标签。通过循环遍历这些样本,从数据集中获取图像数据和标签,并使用Matplotlib的imshow()函数将图像显示在子图中。
在这里插入图片描述

同时,使用imsave()函数将每个图像保存为PNG格式的图片文件,文件名以标签命名。最后,使用plt.show()显示图形窗口,显示图像的同时也会将图像保存到指定的路径中。这段代码的执行结果是显示10张MNIST数据集中的数字图像,并将这些图像保存到指定路径下。保存的图片如下所示

在这里插入图片描述

通过上面程序可以看到,数据集主要是由图片数据和对应的标签构成,那么我们就可以用这两个主要构成成分来构建自己的数据集。

3. 准备材料

3.1 图片数据

这里我们就用刚才保存的十张图片,即

在这里插入图片描述

当然,你也可以准备其它的图片,并给图片分别命名为“0.png, 1.png, …”。

这里,十张图片的相对路径为

imgs_path = "./data/mnist_images"

注:你们要根据自己存储的路径来给定。

3.2 标签数据

创建一个txt文件,为每一幅图片指定标签数据,如下所示

在这里插入图片描述

这里,txt文件的相对路径为

labels_path = "labels.txt"

4. 方法

在PyTorch中,您可以通过创建一个自定义的数据集类来定义自己的数据集。这个自定义类需要继承自torch.utils.data.Dataset类,并且实现两个主要的方法:__len____getitem____len__方法应该返回数据集的长度,而__getitem__方法则根据给定的索引返回数据集中的样本。

下面我们展示如何创建一个自定义的数据集类:

import os  # 导入os模块,用于操作文件路径
from PIL import Image  # 导入PIL库中的Image模块,用于图像处理
import torch  # 导入PyTorch库
from torch.utils.data import Dataset  # 从torch.utils.data模块导入Dataset类,用于定义自定义数据集
from torchvision import transforms  # 导入transforms模块,用于数据预处理
import numpy as np  # 导入NumPy库,用于数值处理
import matplotlib.pyplot as plt  # 导入Matplotlib库,用于可视化class CustomDataset(Dataset):def __init__(self, image_dir, label_file, transform=None):super().__init__()  # 调用父类的构造函数self.image_dir = image_dir  # 图像数据的路径self.label_file = label_file  # 标签文本的路径self.transform = transform  # 数据预处理操作self.samples = self._load_samples()  # 加载数据集样本信息def _load_samples(self):samples = []  # 存储样本信息的列表with open(self.label_file, 'r') as f:  # 打开标签文本文件for line in f:  # 逐行读取标签文本文件中的内容image_name, label = line.strip().split(',')  # 根据逗号分隔每行内容,获取图像文件名和标签image_path = os.path.join(self.image_dir, image_name)  # 拼接图像文件的完整路径samples.append((image_path, int(label)))  # 将图像路径和标签组成元组,加入样本列表return samples  # 返回样本列表def __len__(self):return len(self.samples)  # 返回数据集样本的数量def __getitem__(self, index):image_path, label = self.samples[index]  # 获取指定索引处的图像路径和标签image = Image.open(image_path).convert('L')  # 打开图像文件并将其转换为灰度图像if self.transform:  # 如果定义了数据预处理操作image = self.transform(image)  # 对图像进行预处理操作return image, label  # 返回预处理后的图像和标签# 设置图片数据路径和标签文本路径
image_dir = './data/mnist_images'  # 图像数据的路径
label_file = 'labels.txt'  # 标签文本的路径# 定义数据预处理操作,根据需要添加其他预处理操作
transform = transforms.Compose([transforms.Resize((28, 28)),  # 调整图像大小transforms.ToTensor(),  # 将图像转换为张量
])# 创建自定义数据集实例
custom_dataset = CustomDataset(image_dir, label_file, transform=transform)# 创建数据加载器
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=1, shuffle=False)# 遍历数据加载器中的每个批次数据
for batch_images, batch_labels in data_loader:# 使用squeeze()函数去除图像张量中的单维度,将图像数据转换为NumPy数组,并存储在变量image中image = batch_images.squeeze().numpy()# 使用imshow()函数显示图像,cmap='gray'指定使用灰度色彩映射plt.imshow(image, cmap='gray')# 设置图像标题,显示图像对应的标签,使用f-string格式化字符串,将batch_labels转换为Python标量并获取其值plt.title(f"Label: {batch_labels.item()}")# 关闭坐标轴显示,即不显示坐标轴plt.axis('off')# 显示图形窗口plt.show()

这段代码实现了加载自定义数据集,并使用 PyTorch 的 DataLoader 将数据加载成批次,然后逐批次地展示图像。

这篇关于PyTorch中定义自己的数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/973886

相关文章

MyBatis-Plus通用中等、大量数据分批查询和处理方法

《MyBatis-Plus通用中等、大量数据分批查询和处理方法》文章介绍MyBatis-Plus分页查询处理,通过函数式接口与Lambda表达式实现通用逻辑,方法抽象但功能强大,建议扩展分批处理及流式... 目录函数式接口获取分页数据接口数据处理接口通用逻辑工具类使用方法简单查询自定义查询方法总结函数式接口

SQL中如何添加数据(常见方法及示例)

《SQL中如何添加数据(常见方法及示例)》SQL全称为StructuredQueryLanguage,是一种用于管理关系数据库的标准编程语言,下面给大家介绍SQL中如何添加数据,感兴趣的朋友一起看看吧... 目录在mysql中,有多种方法可以添加数据。以下是一些常见的方法及其示例。1. 使用INSERT I

Python使用vllm处理多模态数据的预处理技巧

《Python使用vllm处理多模态数据的预处理技巧》本文深入探讨了在Python环境下使用vLLM处理多模态数据的预处理技巧,我们将从基础概念出发,详细讲解文本、图像、音频等多模态数据的预处理方法,... 目录1. 背景介绍1.1 目的和范围1.2 预期读者1.3 文档结构概述1.4 术语表1.4.1 核

MySQL 删除数据详解(最新整理)

《MySQL删除数据详解(最新整理)》:本文主要介绍MySQL删除数据的相关知识,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录一、前言二、mysql 中的三种删除方式1.DELETE语句✅ 基本语法: 示例:2.TRUNCATE语句✅ 基本语

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

Navicat数据表的数据添加,删除及使用sql完成数据的添加过程

《Navicat数据表的数据添加,删除及使用sql完成数据的添加过程》:本文主要介绍Navicat数据表的数据添加,删除及使用sql完成数据的添加过程,具有很好的参考价值,希望对大家有所帮助,如有... 目录Navicat数据表数据添加,删除及使用sql完成数据添加选中操作的表则出现如下界面,查看左下角从左

SpringBoot中4种数据水平分片策略

《SpringBoot中4种数据水平分片策略》数据水平分片作为一种水平扩展策略,通过将数据分散到多个物理节点上,有效解决了存储容量和性能瓶颈问题,下面小编就来和大家分享4种数据分片策略吧... 目录一、前言二、哈希分片2.1 原理2.2 SpringBoot实现2.3 优缺点分析2.4 适用场景三、范围分片

Redis分片集群、数据读写规则问题小结

《Redis分片集群、数据读写规则问题小结》本文介绍了Redis分片集群的原理,通过数据分片和哈希槽机制解决单机内存限制与写瓶颈问题,实现分布式存储和高并发处理,但存在通信开销大、维护复杂及对事务支持... 目录一、分片集群解android决的问题二、分片集群图解 分片集群特征如何解决的上述问题?(与哨兵模