【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑

本文主要是介绍【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

3、cnn基础

卷积神经网络

在这里插入图片描述

输入层 —输入图片矩阵

  • 输入层一般是 RGB 图像或单通道的灰度图像,图片像素值在[0,255],可以用矩阵表示图片
  • 在这里插入图片描述

卷积层 —特征提取

  • 人通过特征进行图像识别,根据左图直的笔画判断X,右图曲的笔画判断圆
    在这里插入图片描述
  • 卷积操作
    在这里插入图片描述
    在这里插入图片描述

激活层 —加强特征

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

池化层 —压缩数据

在这里插入图片描述
在这里插入图片描述

全连接层 —进行分类

输出层 —输出分类概率

在这里插入图片描述

4、基于LeNet实现cifar10数据集分类

1、数据集

在这里插入图片描述

  • DataSet
import torchvision
from torchvision import transforms
def get_train_dataset(data_root):transform = transforms.Compose([transforms.ToTensor(),#0~1transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])#-1,1#(0-0.5)/0.5=-1 (1-0.5)/0.5=1train_dataset = torchvision.datasets.CIFAR10(root=data_root, train=True,download=True, transform=None)return train_dataset
train_dataset = get_train_dataset("dataset")
image,label = train_dataset[1]
print(train_dataset.classes[label])
print(type(image))

transforms 设定处理图片的规则

# Composes several transforms together (authority)
# 把一系列图片操作【组合】起来
transform = transforms.Compose
# Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor
# 将PIL Image格式或者numpy.ndarray格式的数据格式化为可被pytorch快速处理的tensor类型
transforms.ToTensor() 
# Normalize a tensor image with mean and standard deviation
# output[channel] = (input[channel] - mean[channel]) / std[channel]
# 使用均值和方差对数据归一化,保证程序运行时收敛加快,训练次数少点
# 因为tensor是0-1,经过normalize可以变成(-1,1) 所以可以使用0.5,当然并不一定都是0.5
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

在这里插入图片描述

# root :Root directory of dataset where directory(数据集的根目录)
# train  If True, creates dataset from training set
# download If true, downloads the dataset from the internet
# transform A transform that takes in an PIL image and returns a transformed version
torchvision.datasets.CIFAR10(root=data_root, train=True,download=True, transform=transform)                              

验证检查

train_dataset = get_train_dataset("dataset")
image,label = train_dataset[1]
print(train_dataset.classes[label])
image.show()
  • DataLoader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=32,shuffle=True, num_workers=2)

说明

dataset:数据集
batch_size:how many samples per batch to load
shuffle:set to ``True`` to have the data reshuffled
num_workers:
how many subprocesses to use for data 加载数据集采用单进程还是多进程
0 means that the data will be loaded in the main process.数据在主进程加载(windows建议0

测试

val_dataset = get_val_dataset("dataset")
val_loader = get_val_loader(val_dataset)
for data in val_loader:images,labels = dataprint(images.shape)print(labels)

2、模型

  • 搭建网络

import torch
from torch import nnclass Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6,kernel_size=5,stride=1,padding=0)self.relu1 = nn.ReLU()self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(in_channels=6, out_channels=16,kernel_size=5,stride=1,padding=0)self.relu2 = nn.ReLU()self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)self.flatten = nn.Flatten()self.fc1 = nn.Linear(in_features=16 * 5 * 5, out_features=120)self.relu3 = nn.ReLU()self.fc2 = nn.Linear(in_features=120, out_features=84)self.relu4 = nn.ReLU()self.fc3 = nn.Linear(in_features=84, out_features=10)def forward(self, x):x = self.conv1(x)x = self.relu1(x)x = self.pool1(x)x = self.conv2(x)x = self.relu2(x)x = self.pool2(x)x = self.flatten(x)print(x.shape)x = self.fc1(x)x = self.relu3(x)x = self.fc2

这篇关于【超级干货】2天速成PyTorch深度学习入门教程,缓解研究生焦虑的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1147264

相关文章

深度解析Python中递归下降解析器的原理与实现

《深度解析Python中递归下降解析器的原理与实现》在编译器设计、配置文件处理和数据转换领域,递归下降解析器是最常用且最直观的解析技术,本文将详细介绍递归下降解析器的原理与实现,感兴趣的小伙伴可以跟随... 目录引言:解析器的核心价值一、递归下降解析器基础1.1 核心概念解析1.2 基本架构二、简单算术表达

深度解析Java @Serial 注解及常见错误案例

《深度解析Java@Serial注解及常见错误案例》Java14引入@Serial注解,用于编译时校验序列化成员,替代传统方式解决运行时错误,适用于Serializable类的方法/字段,需注意签... 目录Java @Serial 注解深度解析1. 注解本质2. 核心作用(1) 主要用途(2) 适用位置3

Java MCP 的鉴权深度解析

《JavaMCP的鉴权深度解析》文章介绍JavaMCP鉴权的实现方式,指出客户端可通过queryString、header或env传递鉴权信息,服务器端支持工具单独鉴权、过滤器集中鉴权及启动时鉴权... 目录一、MCP Client 侧(负责传递,比较简单)(1)常见的 mcpServers json 配置

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶

深度剖析SpringBoot日志性能提升的原因与解决

《深度剖析SpringBoot日志性能提升的原因与解决》日志记录本该是辅助工具,却为何成了性能瓶颈,SpringBoot如何用代码彻底破解日志导致的高延迟问题,感兴趣的小伙伴可以跟随小编一起学习一下... 目录前言第一章:日志性能陷阱的底层原理1.1 日志级别的“双刃剑”效应1.2 同步日志的“吞吐量杀手”

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

深度解析Python yfinance的核心功能和高级用法

《深度解析Pythonyfinance的核心功能和高级用法》yfinance是一个功能强大且易于使用的Python库,用于从YahooFinance获取金融数据,本教程将深入探讨yfinance的核... 目录yfinance 深度解析教程 (python)1. 简介与安装1.1 什么是 yfinance?

Python库 Django 的简介、安装、用法入门教程

《Python库Django的简介、安装、用法入门教程》Django是Python最流行的Web框架之一,它帮助开发者快速、高效地构建功能强大的Web应用程序,接下来我们将从简介、安装到用法详解,... 目录一、Django 简介 二、Django 的安装教程 1. 创建虚拟环境2. 安装Django三、创

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security