pytorch -- CIFAR10 完整的模型训练套路

2024-02-27 06:52

本文主要是介绍pytorch -- CIFAR10 完整的模型训练套路,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  1. 网络结构
    在这里插入图片描述
  2. 代码
# CIFAR 10
'''
完整的模型训练套路:'''
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为{}'.format(train_data_size))
print('测试数据集的长度为{}'.format(test_data_size))# 2 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 3 搭建神经网络
# 4 创建网络模型
tudui = Tudui()# 5 损失函数
loss_fn = nn.CrossEntropyLoss()# 6 优化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 7 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 #训练轮数
# 添加tensorboard
writer = SummaryWriter('logs_model')for i in range(epoch):print('-----------第{}轮训练开始-----------'.format(i+1))# 训练开始# 训练步骤开始 dropout batchNorm仅对某些层次有作用tudui.train()for data in train_dataloader:imgs, targets = dataoutput = tudui(imgs) #训练模型的预测输出loss = loss_fn(output,targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print('训练次数是{}时,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor变成了数字writer.add_scalar('train_loss',loss.item(),total_train_step)# 训练完一轮,看是否训练好,有没有达到想要的需求,测试数据集中跑一篇看准确率或者损失# 测试步骤开始tudui.eval()total_test_loss = 0total_accuracy = 0# 测试不需要对梯度进行调整with torch.no_grad():for data in test_dataloader:imgs,targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs,targets)total_test_loss += loss.item()# accuracy 正确预测的样本数量accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint('整体测试集上的loss是{}'.format(total_test_loss))print('整体测试集上的正确率是{}'.format(total_accuracy/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_accuracy', total_accuracy, total_test_step)total_test_step+=1torch.save(tudui,'tudui_{}.pth'.format(i))print('模型已保存')writer.close()
# model.py
import torch
from torch import nn# 3 搭建神经网络
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024,64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()# 验证一下输入输出尺寸input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)

运行结果:
在这里插入图片描述

这篇关于pytorch -- CIFAR10 完整的模型训练套路的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751544

相关文章

SpringBoot基于注解实现数据库字段回填的完整方案

《SpringBoot基于注解实现数据库字段回填的完整方案》这篇文章主要为大家详细介绍了SpringBoot如何基于注解实现数据库字段回填的相关方法,文中的示例代码讲解详细,感兴趣的小伙伴可以了解... 目录数据库表pom.XMLRelationFieldRelationFieldMapping基础的一些代

Nginx搭建前端本地预览环境的完整步骤教学

《Nginx搭建前端本地预览环境的完整步骤教学》这篇文章主要为大家详细介绍了Nginx搭建前端本地预览环境的完整步骤教学,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录项目目录结构核心配置文件:nginx.conf脚本化操作:nginx.shnpm 脚本集成总结:对前端的意义很多

在Ubuntu上打不开GitHub的完整解决方法

《在Ubuntu上打不开GitHub的完整解决方法》当你满心欢喜打开Ubuntu准备推送代码时,突然发现终端里的gitpush卡成狗,浏览器里的GitHub页面直接变成Whoathere!警告页面... 目录一、那些年我们遇到的"红色惊叹号"二、三大症状快速诊断症状1:浏览器直接无法访问症状2:终端操作异常

Spring Boot分层架构详解之从Controller到Service再到Mapper的完整流程(用户管理系统为例)

《SpringBoot分层架构详解之从Controller到Service再到Mapper的完整流程(用户管理系统为例)》本文将以一个实际案例(用户管理系统)为例,详细解析SpringBoot中Co... 目录引言:为什么学习Spring Boot分层架构?第一部分:Spring Boot的整体架构1.1

mybatis直接执行完整sql及踩坑解决

《mybatis直接执行完整sql及踩坑解决》MyBatis可通过select标签执行动态SQL,DQL用ListLinkedHashMap接收结果,DML用int处理,注意防御SQL注入,优先使用#... 目录myBATiFBNZQs直接执行完整sql及踩坑select语句采用count、insert、u

JS纯前端实现浏览器语音播报、朗读功能的完整代码

《JS纯前端实现浏览器语音播报、朗读功能的完整代码》在现代互联网的发展中,语音技术正逐渐成为改变用户体验的重要一环,下面:本文主要介绍JS纯前端实现浏览器语音播报、朗读功能的相关资料,文中通过代码... 目录一、朗读单条文本:① 语音自选参数,按钮控制语音:② 效果图:二、朗读多条文本:① 语音有默认值:②

nodejs打包作为公共包使用的完整流程

《nodejs打包作为公共包使用的完整流程》在Node.js项目中,打包和部署是发布应用的关键步骤,:本文主要介绍nodejs打包作为公共包使用的相关资料,文中通过代码介绍的非常详细,需要的朋友可... 目录前言一、前置准备二、创建与编码三、一键构建四、本地“白嫖”测试(可选)五、发布公共包六、常见踩坑提醒

macOS彻底卸载Python的超完整指南(推荐!)

《macOS彻底卸载Python的超完整指南(推荐!)》随着python解释器的不断更新升级和项目开发需要,有时候会需要升级或者降级系统中的python的版本,系统中留存的Pytho版本如果没有卸载干... 目录MACOS 彻底卸载 python 的完整指南重要警告卸载前检查卸载方法(按安装方式)1. 卸载

C#实现高性能拍照与水印添加功能完整方案

《C#实现高性能拍照与水印添加功能完整方案》在工业检测、质量追溯等应用场景中,经常需要对产品进行拍照并添加相关信息水印,本文将详细介绍如何使用C#实现一个高性能的拍照和水印添加功能,包含完整的代码实现... 目录1. 概述2. 功能架构设计3. 核心代码实现python3.1 主拍照方法3.2 安全HBIT

C#实现SHP文件读取与地图显示的完整教程

《C#实现SHP文件读取与地图显示的完整教程》在地理信息系统(GIS)开发中,SHP文件是一种常见的矢量数据格式,本文将详细介绍如何使用C#读取SHP文件并实现地图显示功能,包括坐标转换、图形渲染、平... 目录概述功能特点核心代码解析1. 文件读取与初始化2. 坐标转换3. 图形绘制4. 地图交互功能缩放