动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet

本文主要是介绍动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

23卷积神经网络LeNet

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt# 定义一个卷积神经网络
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)# 通过在每一层打印输出的形状,我们可以检查模型
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32) # 随机生成一个形状为(1,1,28,28)的张量,作为输入
for layer in net:X = layer(X) # 将输入依次通过每一层print(layer.__class__.__name__, 'output shape: \t', X.shape) # 打印每一层的输出形状
"""
Conv2d output shape:     torch.Size([1, 6, 28, 28])
ReLU output shape:       torch.Size([1, 6, 28, 28])
AvgPool2d output shape:          torch.Size([1, 6, 14, 14])
Conv2d output shape:     torch.Size([1, 16, 10, 10])
ReLU output shape:       torch.Size([1, 16, 10, 10])
AvgPool2d output shape:          torch.Size([1, 16, 5, 5])
Flatten output shape:    torch.Size([1, 400])
Linear output shape:     torch.Size([1, 120])
ReLU output shape:       torch.Size([1, 120])
Linear output shape:     torch.Size([1, 84])
ReLU output shape:       torch.Size([1, 84])
Linear output shape:     torch.Size([1, 10])
"""
# 模型训练
batch_size = 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size) # 加载Fashion-MNIST数据集#分类精度
def accuracy(y_hat,y): #@save"""计算预测正确的数量"""#判断y_hat.shape是否为二维以上的矩阵#并且列数大于1if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:#axis = 1 表示按照每一行#argmax(axis = 1)得到每行最大值的下标y_hat = y_hat.argmax(axis = 1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval() # 将模型设置为评估模式metric = lp.Accumulator(2) # 正确预测数、预测总数with torch.no_grad(): # 禁用梯度计算for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel()) # 累加正确预测数和样本总数return metric[0] / metric[1] # 返回精度def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) # 初始化权重net.apply(init_weights) # 对网络应用权重初始化print('training on', device)net.to(device) # 将模型加载到设备上optimizer = torch.optim.SGD(net.parameters(), lr=lr) # 使用随机梯度下降优化器loss = nn.CrossEntropyLoss() # 定义交叉熵损失函数animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc']) # 动画工具,绘制训练曲线timer, num_batches = lp.Timer(), len(train_iter) # 计时器和批次数for epoch in range(num_epochs):metric = lp.Accumulator(3) # 训练损失之和,训练准确率之和,样本数net.train() # 训练模式for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将数据加载到设备上y_hat = net(X) # 前向传播l = loss(y_hat, y) # 计算损失l.backward() # 反向传播optimizer.step() # 更新参数with torch.no_grad(): # 禁用梯度计算metric.add(l * X.shape[0], lp.accuracy(y_hat, y), X.shape[0]) # 累加损失、准确率和样本数timer.stop()train_l = metric[0] / metric[2] # 计算平均训练损失train_acc = metric[1] / metric[2] # 计算平均训练准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None)) # 更新动画test_acc = evaluate_accuracy_gpu(net, test_iter, device) # 计算测试集上的准确率animator.add(epoch + 1, (None, None, test_acc)) # 更新动画print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')lr, num_epochs = 0.5, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu()) # 训练模型
# d2l.plt.show() # 显示训练曲线
plt.show() # 显示训练曲线# lr = 0.9,Sigmoid()
# loss 0.466, train acc 0.825, test acc 0.808# lr = 0.1,Sigmoid()
# loss 1.277, train acc 0.551, test acc 0.568# lr = 0.1,ReLU()
# loss 0.339, train acc 0.874, test acc 0.803# lr = 0.5,ReLU()
# loss 0.302, train acc 0.887, test acc 0.857# lr = 0.6,ReLU()
# loss 0.316, train acc 0.878, test acc 0.861

运行结果:
在这里插入图片描述

这篇关于动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1086588

相关文章

Django开发时如何避免频繁发送短信验证码(python图文代码)

《Django开发时如何避免频繁发送短信验证码(python图文代码)》Django开发时,为防止频繁发送验证码,后端需用Redis限制请求频率,结合管道技术提升效率,通过生产者消费者模式解耦业务逻辑... 目录避免频繁发送 验证码1. www.chinasem.cn避免频繁发送 验证码逻辑分析2. 避免频繁

精选20个好玩又实用的的Python实战项目(有图文代码)

《精选20个好玩又实用的的Python实战项目(有图文代码)》文章介绍了20个实用Python项目,涵盖游戏开发、工具应用、图像处理、机器学习等,使用Tkinter、PIL、OpenCV、Kivy等库... 目录① 猜字游戏② 闹钟③ 骰子模拟器④ 二维码⑤ 语言检测⑥ 加密和解密⑦ URL缩短⑧ 音乐播放

Spring Boot集成/输出/日志级别控制/持久化开发实践

《SpringBoot集成/输出/日志级别控制/持久化开发实践》SpringBoot默认集成Logback,支持灵活日志级别配置(INFO/DEBUG等),输出包含时间戳、级别、类名等信息,并可通过... 目录一、日志概述1.1、Spring Boot日志简介1.2、日志框架与默认配置1.3、日志的核心作用

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

破茧 JDBC:MyBatis 在 Spring Boot 中的轻量实践指南

《破茧JDBC:MyBatis在SpringBoot中的轻量实践指南》MyBatis是持久层框架,简化JDBC开发,通过接口+XML/注解实现数据访问,动态代理生成实现类,支持增删改查及参数... 目录一、什么是 MyBATis二、 MyBatis 入门2.1、创建项目2.2、配置数据库连接字符串2.3、入

深度解析Spring Security 中的 SecurityFilterChain核心功能

《深度解析SpringSecurity中的SecurityFilterChain核心功能》SecurityFilterChain通过组件化配置、类型安全路径匹配、多链协同三大特性,重构了Spri... 目录Spring Security 中的SecurityFilterChain深度解析一、Security

Android Paging 分页加载库使用实践

《AndroidPaging分页加载库使用实践》AndroidPaging库是Jetpack组件的一部分,它提供了一套完整的解决方案来处理大型数据集的分页加载,本文将深入探讨Paging库... 目录前言一、Paging 库概述二、Paging 3 核心组件1. PagingSource2. Pager3.

深度解析Nginx日志分析与499状态码问题解决

《深度解析Nginx日志分析与499状态码问题解决》在Web服务器运维和性能优化过程中,Nginx日志是排查问题的重要依据,本文将围绕Nginx日志分析、499状态码的成因、排查方法及解决方案展开讨论... 目录前言1. Nginx日志基础1.1 Nginx日志存放位置1.2 Nginx日志格式2. 499

在Java中使用OpenCV实践

《在Java中使用OpenCV实践》用户分享了在Java项目中集成OpenCV4.10.0的实践经验,涵盖库简介、Windows安装、依赖配置及灰度图测试,强调其在图像处理领域的多功能性,并计划后续探... 目录前言一 、OpenCV1.简介2.下载与安装3.目录说明二、在Java项目中使用三 、测试1.测

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em