小白的Pytorch读取Cifar10数据集-------学习笔记

2024-01-20 06:18

本文主要是介绍小白的Pytorch读取Cifar10数据集-------学习笔记,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

                                  机器学习
  • Cifar10数据集
    Cifar10是一个由彩色图像组成的分类的数据集,其中包含了飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车10个类别,如下图所示,且每个类中包含了1000张图片。整个数据集中包含了60 000张32×32的彩色图片。该数据集被分成50 000和10 000两部分
    50 000是training set,用来做训练;
    10 000是test set,用来做验证。 在这里插入图片描述
    下面让我们来使用Pytorch来读取Cifar10数据集,代码如下:
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dsets
import matplotlib.pyplot as plt
batch_size = 100
#Cifar10 dataset                    #选择数据的根目录   #选择训练集    #从网上下载图片
train_dataset = dsets.CIFAR10(root = '/ml/pycifar', train= True, download= True)#选择数据的根目录   #选择训练集    #从网上下载图片
test_dataset = dsets.CIFAR10(root = '/ml/pycifar', train= False, download= True)
#加载数据
#将数据打乱
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,batch_size = batch_size,shuffle= True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,batch_size = batch_size,shuffle= True)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
digit = train_loader.dataset.data[0]
plt.imshow(digit,cmap=plt.cm.binary)
plt.show()
print(classes[train_loader.dataset.targets[0]])

在这里插入图片描述
在这里插入图片描述
**小注意
digit = train_loader.dataset.train_data[0]
print(classes[train_loader.dataset.train_labels[0]])报错
应该是因为torch版本的问题,不同torch对应的后缀不同,将train_data变成data,把train_labels变成targets就可以了

对于代码的理解:

  1. from torch.utils.data.DataLoader import DataLoaded
    PyTorch中数据读取的一个重要接口是torch.utils.data.DataLoader,该接口定义在dataloader.py脚本中,只要是用PyTorch来训练模型基本都会用到该接口,该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入,因此该接口有点承上启下的作用,比较重要。这篇博客介绍该接口的源码,主要包含DataLoader和DataLoaderIter两个类。
    对上面的全面解释请详见这位博主 https://blog.csdn.net/u014380165/article/details/79058479?ops_request_misc=%25257B%252522request%25255Fid%252522%25253A%252522161123796816780299050394%252522%25252C%252522scm%252522%25253A%25252220140713.130102334.pc%25255Fall.%252522%25257D&request_id=161123796816780299050394&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2allfirst_rank_v2~hot_rank-1-79058479.first_rank_v2_pc_rank_v29&utm_term=%E8%A7%A3%E8%AF%BBtorch.utils.data.DataLoader

  2. batch_size = 100
    batch_size 是批大小, 通常是用在数据库的批量操作里面, 为了提高性能, 比如: batch_size = 1000,就是每次数据库交互, 处理1000条数据。Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。
    对上面的全面解释请详见这位博主
    https://blog.csdn.net/qq_42380515/article/details/87885996

                                               参考文献:深度学习与图像识别:原理与实践
    

这篇关于小白的Pytorch读取Cifar10数据集-------学习笔记的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/625053

相关文章

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

Spring Boot 整合 Redis 实现数据缓存案例详解

《SpringBoot整合Redis实现数据缓存案例详解》Springboot缓存,默认使用的是ConcurrentMap的方式来实现的,然而我们在项目中并不会这么使用,本文介绍SpringB... 目录1.添加 Maven 依赖2.配置Redis属性3.创建 redisCacheManager4.使用Sp

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

Python Pandas高效处理Excel数据完整指南

《PythonPandas高效处理Excel数据完整指南》在数据驱动的时代,Excel仍是大量企业存储核心数据的工具,Python的Pandas库凭借其向量化计算、内存优化和丰富的数据处理接口,成为... 目录一、环境搭建与数据读取1.1 基础环境配置1.2 数据高效载入技巧二、数据清洗核心战术2.1 缺失

Python处理超大规模数据的4大方法详解

《Python处理超大规模数据的4大方法详解》在数据的奇妙世界里,数据量就像滚雪球一样,越变越大,从最初的GB级别的小数据堆,逐渐演变成TB级别的数据大山,所以本文我们就来看看Python处理... 目录1. Mars:数据处理界的 “变形金刚”2. Dask:分布式计算的 “指挥家”3. CuPy:GPU

使用Vue-ECharts实现数据可视化图表功能

《使用Vue-ECharts实现数据可视化图表功能》在前端开发中,经常会遇到需要展示数据可视化的需求,比如柱状图、折线图、饼图等,这类需求不仅要求我们准确地将数据呈现出来,还需要兼顾美观与交互体验,所... 目录前言为什么选择 vue-ECharts?1. 基于 ECharts,功能强大2. 更符合 Vue

Java如何根据word模板导出数据

《Java如何根据word模板导出数据》这篇文章主要为大家详细介绍了Java如何实现根据word模板导出数据,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... pom.XML文件导入依赖 <dependency> <groupId>cn.afterturn</groupId>

Python实现获取带合并单元格的表格数据

《Python实现获取带合并单元格的表格数据》由于在日常运维中经常出现一些合并单元格的表格,如果要获取数据比较麻烦,所以本文我们就来聊聊如何使用Python实现获取带合并单元格的表格数据吧... 由于在日常运维中经常出现一些合并单元格的表格,如果要获取数据比较麻烦,现将将封装成类,并通过调用list_exc

Mysql数据库中数据的操作CRUD详解

《Mysql数据库中数据的操作CRUD详解》:本文主要介绍Mysql数据库中数据的操作(CRUD),详细描述对Mysql数据库中数据的操作(CRUD),包括插入、修改、删除数据,还有查询数据,包括... 目录一、插入数据(insert)1.插入数据的语法2.注意事项二、修改数据(update)1.语法2.有

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据