【深度学习|Pytorch】torchvision.datasets.ImageFolder详解

2024-04-04 07:36

本文主要是介绍【深度学习|Pytorch】torchvision.datasets.ImageFolder详解,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

ImageFolder详解

  • 1、数据准备
  • 2、ImageFolder类的定义
    • transforms.ToTensor()解析
  • 3、ImageFolder返回对象

1、数据准备

创建一个文件夹,比如叫dataset,将cat和dog文件夹都放在dataset文件夹路径下:
在这里插入图片描述

2、ImageFolder类的定义

class ImageFolder(DatasetFolder):def __init__(self,root: str,transform: Optional[Callable] = None,target_transform: Optional[Callable] = None,loader: Callable[[str], Any] = default_loader,is_valid_file: Optional[Callable[[str], bool]] = None,):

可以看到,ImageFolder类有这几个参数:
root:图片存储的根目录,即存放不同类别图片文件夹的前一个路径。
transform:即对加载的这些图片进行的前处理的方式,这里可以传入一个实例化的torchvision.Compose()对象,里面包含了各种预处理的操作。
target_transform:对图片类别进行预处理,通常来说不会用到这一步,因此可以直接不传入参数,默认图像标签没有变换,如果需要进行标签的处理,同样可以传入一个实例化的torchvision.Compose()对象。
loader:表示图像数据加载的方式,通常采用默认的加载方式,ImageFolder加载图像的方式为调用PIL库,因此图像的通道顺序是RGB而非opencv的BGR
is_valid_file:获取图像文件路径的函数,并且可以检查是否有损坏的文件。
示例代码:

ROOT_TEST = 'dataset' #dataset/cat, dataset/dog
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),normalize
])# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)

transforms.ToTensor()解析

这里需要特别说一下ToTensor()这个函数的作用,刚接触深度学习的我那时以为只是单纯的将图像的ndarray和PIL格式转成Tensor格式,后来查看了一下源码之后发现,事情并没有这么简答!

   """Convert a PIL Image or ndarray to tensor and scale the values accordingly.This transform does not support torchscript.Converts a PIL Image or numpy.ndarray (H x W x C) in the range[0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)or if the numpy.ndarray has dtype = np.uint8In the other cases, tensors are returned without scaling... note::Because the input image is scaled to [0.0, 1.0], this transformation should not be used whentransforming target image masks. See the `references`_ for implementing the transforms for image masks... _references: https://github.com/pytorch/vision/tree/main/references/segmentation"""

这是关于ToTensor()函数的注解,这里明确指出了ToTensor()可以将PIL和ndarray格式的图像数据转成Tensor并缩放它们的值,这里的缩放他们的值的意思在下面也指出了,即将[0, 255]的像素值域归一化[0, 1.0],并且图像转换成Tensor格式之后,维度的顺序也会发生一点变化,从一开始的HWC变成了CHW的排列方式。

3、ImageFolder返回对象

以第一部分为例,我们用一个val_dataset接收了ImageFolder的返回值,那么这个Val_dataset对象里面包含了什么呢:
val_dataset.classes:存放着根目录下的子文件夹的名称(类别名称)的列表。
val_dataset.class_to_idx:存放着类别名称和各自的索引,字典类型。
val_dataset.extensions:存放着ImageFolder可以读取的图像格式名称,元组类型。
val_dataset.targets:存放着根目录下每一张图的类别索引。
val_dataset.transform:我们提供的transform的方式。
val_dataset.imgs:存放着根目录下每一张图的路径和类别索引。元组列表类型。
以上是关于这个ImageFolder返回的对象的属性的解析。

此外,我们可以通过一个for循环来遍历整个val_dataset的所有图像数据,其中val_dataset[i]是一个元组类型的数据,val_dataset[i][0]代表了前处理后的图像数据,类型为tensor,以AlexNet为例,此时的tensor应该是3 * 224 * 224的维度。val_dataset[i][1]代表了图像的类别索引。
完整示例代码:

import torch
from AlexNet import AlexNet
from torch.autograd import Variable
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader# ROOT_TRAIN = 'D:/pycharm/AlexNet/data/train'
ROOT_TEST = 'dataset'# 将图像的像素值归一化到[-1,1]之间
normalize = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])val_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),normalize
])# 加载训练数据集
val_dataset = ImageFolder(ROOT_TEST, transform=val_transform)# 如果有NVIDA显卡,转到GPU训练,否则用CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'# 模型实例化,将模型转到device
model = AlexNet().to(device)# 加载train.py里训练好的模型
model.load_state_dict(torch.load(r'save_model/model_best.pth'))# 结果类型
classes = ["cat","dog"
]# 把Tensor转化为图片,方便可视化
show = ToPILImage()# 进入验证阶段
model.eval()
for i in range(10):x, y = val_dataset[i][0], val_dataset[i][1]# show():显示图片# show(x).show()# torch.unsqueeze(input, dim),input(Tensor):输入张量,dim (int):插入维度的索引,最终扩展张量维度为4维x = Variable(torch.unsqueeze(x, dim=0).float(), requires_grad=False).to(device)with torch.no_grad():pred = model(x)# argmax(input):返回指定维度最大值的序号# 得到预测类别中最高的那一类,再把最高的这一类对应classes中的那一类predicted, actual = classes[torch.argmax(pred[0])], classes[y]# 输出预测值与真实值print(f'predicted:"{predicted}", actual:"{actual}"')

这篇关于【深度学习|Pytorch】torchvision.datasets.ImageFolder详解的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/875182

相关文章

一文全面详解Python变量作用域

《一文全面详解Python变量作用域》变量作用域是Python中非常重要的概念,它决定了在哪里可以访问变量,下面我将用通俗易懂的方式,结合代码示例和图表,带你全面了解Python变量作用域,需要的朋友... 目录一、什么是变量作用域?二、python的四种作用域作用域查找顺序图示三、各作用域详解1. 局部作

Java SWT库详解与安装指南(最新推荐)

《JavaSWT库详解与安装指南(最新推荐)》:本文主要介绍JavaSWT库详解与安装指南,在本章中,我们介绍了如何下载、安装SWTJAR包,并详述了在Eclipse以及命令行环境中配置Java... 目录1. Java SWT类库概述2. SWT与AWT和Swing的区别2.1 历史背景与设计理念2.1.

C++作用域和标识符查找规则详解

《C++作用域和标识符查找规则详解》在C++中,作用域(Scope)和标识符查找(IdentifierLookup)是理解代码行为的重要概念,本文将详细介绍这些规则,并通过实例来说明它们的工作原理,需... 目录作用域标识符查找规则1. 普通查找(Ordinary Lookup)2. 限定查找(Qualif

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

SpringBoot 中 CommandLineRunner的作用示例详解

《SpringBoot中CommandLineRunner的作用示例详解》SpringBoot提供的一种简单的实现方案就是添加一个model并实现CommandLineRunner接口,实现功能的... 目录1、CommandLineRunnerSpringBoot中CommandLineRunner的作用

Java死锁问题解决方案及示例详解

《Java死锁问题解决方案及示例详解》死锁是指两个或多个线程因争夺资源而相互等待,导致所有线程都无法继续执行的一种状态,本文给大家详细介绍了Java死锁问题解决方案详解及实践样例,需要的朋友可以参考下... 目录1、简述死锁的四个必要条件:2、死锁示例代码3、如何检测死锁?3.1 使用 jstack3.2

详解Linux中常见环境变量的特点与设置

《详解Linux中常见环境变量的特点与设置》环境变量是操作系统和用户设置的一些动态键值对,为运行的程序提供配置信息,理解环境变量对于系统管理、软件开发都很重要,下面小编就为大家详细介绍一下吧... 目录前言一、环境变量的概念二、常见的环境变量三、环境变量特点及其相关指令3.1 环境变量的全局性3.2、环境变

Java日期类详解(最新推荐)

《Java日期类详解(最新推荐)》早期版本主要使用java.util.Date、java.util.Calendar等类,Java8及以后引入了新的日期和时间API(JSR310),包含在ja... 目录旧的日期时间API新的日期时间 API(Java 8+)获取时间戳时间计算与其他日期时间类型的转换Dur

Linux系统中的firewall-offline-cmd详解(收藏版)

《Linux系统中的firewall-offline-cmd详解(收藏版)》firewall-offline-cmd是firewalld的一个命令行工具,专门设计用于在没有运行firewalld服务的... 目录主要用途基本语法选项1. 状态管理2. 区域管理3. 服务管理4. 端口管理5. ICMP 阻断

Java中常见队列举例详解(非线程安全)

《Java中常见队列举例详解(非线程安全)》队列用于模拟队列这种数据结构,队列通常是指先进先出的容器,:本文主要介绍Java中常见队列(非线程安全)的相关资料,文中通过代码介绍的非常详细,需要的朋... 目录一.队列定义 二.常见接口 三.常见实现类3.1 ArrayDeque3.1.1 实现原理3.1.2