动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet

本文主要是介绍动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

23卷积神经网络LeNet

在这里插入图片描述

import torch
from torch import nn
import liliPytorch as lp
import matplotlib.pyplot as plt# 定义一个卷积神经网络
net = nn.Sequential(nn.Conv2d(1, 6,  kernel_size=5, padding=2), # 卷积层1:输入通道数1,输出通道数6,卷积核大小5x5,填充2nn.ReLU(), # 激活函数nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层1:池化窗口大小2x2,步幅2nn.Conv2d(6, 16, kernel_size=5), # 卷积层2:输入通道数6,输出通道数16,卷积核大小5x5nn.ReLU(), nn.AvgPool2d(kernel_size=2, stride=2), # 平均池化层2:池化窗口大小2x2,步幅2nn.Flatten(), # 展平层:将多维输入展平为1维nn.Linear(16 * 5 * 5, 120), # 全连接层1:输入节点数16*5*5,输出节点数120nn.ReLU(),nn.Linear(120, 84), # 全连接层2:输入节点数120,输出节点数84nn.ReLU(), nn.Linear(84, 10) # 全连接层3:输入节点数84,输出节点数10(对应10个分类)
)# 通过在每一层打印输出的形状,我们可以检查模型
X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32) # 随机生成一个形状为(1,1,28,28)的张量,作为输入
for layer in net:X = layer(X) # 将输入依次通过每一层print(layer.__class__.__name__, 'output shape: \t', X.shape) # 打印每一层的输出形状
"""
Conv2d output shape:     torch.Size([1, 6, 28, 28])
ReLU output shape:       torch.Size([1, 6, 28, 28])
AvgPool2d output shape:          torch.Size([1, 6, 14, 14])
Conv2d output shape:     torch.Size([1, 16, 10, 10])
ReLU output shape:       torch.Size([1, 16, 10, 10])
AvgPool2d output shape:          torch.Size([1, 16, 5, 5])
Flatten output shape:    torch.Size([1, 400])
Linear output shape:     torch.Size([1, 120])
ReLU output shape:       torch.Size([1, 120])
Linear output shape:     torch.Size([1, 84])
ReLU output shape:       torch.Size([1, 84])
Linear output shape:     torch.Size([1, 10])
"""
# 模型训练
batch_size = 256
train_iter, test_iter = lp.loda_data_fashion_mnist(batch_size) # 加载Fashion-MNIST数据集#分类精度
def accuracy(y_hat,y): #@save"""计算预测正确的数量"""#判断y_hat.shape是否为二维以上的矩阵#并且列数大于1if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:#axis = 1 表示按照每一行#argmax(axis = 1)得到每行最大值的下标y_hat = y_hat.argmax(axis = 1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())def evaluate_accuracy_gpu(net, data_iter, device=None):"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval() # 将模型设置为评估模式metric = lp.Accumulator(2) # 正确预测数、预测总数with torch.no_grad(): # 禁用梯度计算for X, y in data_iter:if isinstance(X, list):X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(accuracy(net(X), y), y.numel()) # 累加正确预测数和样本总数return metric[0] / metric[1] # 返回精度def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight) # 初始化权重net.apply(init_weights) # 对网络应用权重初始化print('training on', device)net.to(device) # 将模型加载到设备上optimizer = torch.optim.SGD(net.parameters(), lr=lr) # 使用随机梯度下降优化器loss = nn.CrossEntropyLoss() # 定义交叉熵损失函数animator = lp.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc']) # 动画工具,绘制训练曲线timer, num_batches = lp.Timer(), len(train_iter) # 计时器和批次数for epoch in range(num_epochs):metric = lp.Accumulator(3) # 训练损失之和,训练准确率之和,样本数net.train() # 训练模式for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad() # 梯度清零X, y = X.to(device), y.to(device) # 将数据加载到设备上y_hat = net(X) # 前向传播l = loss(y_hat, y) # 计算损失l.backward() # 反向传播optimizer.step() # 更新参数with torch.no_grad(): # 禁用梯度计算metric.add(l * X.shape[0], lp.accuracy(y_hat, y), X.shape[0]) # 累加损失、准确率和样本数timer.stop()train_l = metric[0] / metric[2] # 计算平均训练损失train_acc = metric[1] / metric[2] # 计算平均训练准确率if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None)) # 更新动画test_acc = evaluate_accuracy_gpu(net, test_iter, device) # 计算测试集上的准确率animator.add(epoch + 1, (None, None, test_acc)) # 更新动画print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')lr, num_epochs = 0.5, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, lp.try_gpu()) # 训练模型
# d2l.plt.show() # 显示训练曲线
plt.show() # 显示训练曲线# lr = 0.9,Sigmoid()
# loss 0.466, train acc 0.825, test acc 0.808# lr = 0.1,Sigmoid()
# loss 1.277, train acc 0.551, test acc 0.568# lr = 0.1,ReLU()
# loss 0.339, train acc 0.874, test acc 0.803# lr = 0.5,ReLU()
# loss 0.302, train acc 0.887, test acc 0.857# lr = 0.6,ReLU()
# loss 0.316, train acc 0.878, test acc 0.861

运行结果:
在这里插入图片描述

这篇关于动手学深度学习(Pytorch版)代码实践 -卷积神经网络-23卷积神经网络LeNet的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1086588

相关文章

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

JDK21对虚拟线程的几种用法实践指南

《JDK21对虚拟线程的几种用法实践指南》虚拟线程是Java中的一种轻量级线程,由JVM管理,特别适合于I/O密集型任务,:本文主要介绍JDK21对虚拟线程的几种用法,文中通过代码介绍的非常详细,... 目录一、参考官方文档二、什么是虚拟线程三、几种用法1、Thread.ofVirtual().start(

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

从基础到高级详解Go语言中错误处理的实践指南

《从基础到高级详解Go语言中错误处理的实践指南》Go语言采用了一种独特而明确的错误处理哲学,与其他主流编程语言形成鲜明对比,本文将为大家详细介绍Go语言中错误处理详细方法,希望对大家有所帮助... 目录1 Go 错误处理哲学与核心机制1.1 错误接口设计1.2 错误与异常的区别2 错误创建与检查2.1 基础

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Java集合之Iterator迭代器实现代码解析

《Java集合之Iterator迭代器实现代码解析》迭代器Iterator是Java集合框架中的一个核心接口,位于java.util包下,它定义了一种标准的元素访问机制,为各种集合类型提供了一种统一的... 目录一、什么是Iterator二、Iterator的核心方法三、基本使用示例四、Iterator的工

springboot依靠security实现digest认证的实践

《springboot依靠security实现digest认证的实践》HTTP摘要认证通过加密参数(如nonce、response)验证身份,避免明文传输,但存在密码存储风险,相比基本认证更安全,却因... 目录概述参数Demopom.XML依赖Digest1Application.JavaMyPasswo

Java 线程池+分布式实现代码

《Java线程池+分布式实现代码》在Java开发中,池通过预先创建并管理一定数量的资源,避免频繁创建和销毁资源带来的性能开销,从而提高系统效率,:本文主要介绍Java线程池+分布式实现代码,需要... 目录1. 线程池1.1 自定义线程池实现1.1.1 线程池核心1.1.2 代码示例1.2 总结流程2. J

JS纯前端实现浏览器语音播报、朗读功能的完整代码

《JS纯前端实现浏览器语音播报、朗读功能的完整代码》在现代互联网的发展中,语音技术正逐渐成为改变用户体验的重要一环,下面:本文主要介绍JS纯前端实现浏览器语音播报、朗读功能的相关资料,文中通过代码... 目录一、朗读单条文本:① 语音自选参数,按钮控制语音:② 效果图:二、朗读多条文本:① 语音有默认值:②