pytorch -- CIFAR10 完整的模型训练套路

2024-02-27 06:52

本文主要是介绍pytorch -- CIFAR10 完整的模型训练套路,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  1. 网络结构
    在这里插入图片描述
  2. 代码
# CIFAR 10
'''
完整的模型训练套路:'''
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为{}'.format(train_data_size))
print('测试数据集的长度为{}'.format(test_data_size))# 2 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 3 搭建神经网络
# 4 创建网络模型
tudui = Tudui()# 5 损失函数
loss_fn = nn.CrossEntropyLoss()# 6 优化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 7 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 #训练轮数
# 添加tensorboard
writer = SummaryWriter('logs_model')for i in range(epoch):print('-----------第{}轮训练开始-----------'.format(i+1))# 训练开始# 训练步骤开始 dropout batchNorm仅对某些层次有作用tudui.train()for data in train_dataloader:imgs, targets = dataoutput = tudui(imgs) #训练模型的预测输出loss = loss_fn(output,targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print('训练次数是{}时,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor变成了数字writer.add_scalar('train_loss',loss.item(),total_train_step)# 训练完一轮,看是否训练好,有没有达到想要的需求,测试数据集中跑一篇看准确率或者损失# 测试步骤开始tudui.eval()total_test_loss = 0total_accuracy = 0# 测试不需要对梯度进行调整with torch.no_grad():for data in test_dataloader:imgs,targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs,targets)total_test_loss += loss.item()# accuracy 正确预测的样本数量accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint('整体测试集上的loss是{}'.format(total_test_loss))print('整体测试集上的正确率是{}'.format(total_accuracy/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_accuracy', total_accuracy, total_test_step)total_test_step+=1torch.save(tudui,'tudui_{}.pth'.format(i))print('模型已保存')writer.close()
# model.py
import torch
from torch import nn# 3 搭建神经网络
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024,64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()# 验证一下输入输出尺寸input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)

运行结果:
在这里插入图片描述

这篇关于pytorch -- CIFAR10 完整的模型训练套路的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751544

相关文章

Kotlin Compose Button 实现长按监听并实现动画效果(完整代码)

《KotlinComposeButton实现长按监听并实现动画效果(完整代码)》想要实现长按按钮开始录音,松开发送的功能,因此为了实现这些功能就需要自己写一个Button来解决问题,下面小编给大... 目录Button 实现原理1. Surface 的作用(关键)2. InteractionSource3.

MySQL数据库实现批量表分区完整示例

《MySQL数据库实现批量表分区完整示例》通俗地讲表分区是将一大表,根据条件分割成若干个小表,:本文主要介绍MySQL数据库实现批量表分区的相关资料,文中通过代码介绍的非常详细,需要的朋友可以参考... 目录一、表分区条件二、常规表和分区表的区别三、表分区的创建四、将既有表转换分区表脚本五、批量转换表为分区

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

Python Pandas高效处理Excel数据完整指南

《PythonPandas高效处理Excel数据完整指南》在数据驱动的时代,Excel仍是大量企业存储核心数据的工具,Python的Pandas库凭借其向量化计算、内存优化和丰富的数据处理接口,成为... 目录一、环境搭建与数据读取1.1 基础环境配置1.2 数据高效载入技巧二、数据清洗核心战术2.1 缺失

使用Java将实体类转换为JSON并输出到控制台的完整过程

《使用Java将实体类转换为JSON并输出到控制台的完整过程》在软件开发的过程中,Java是一种广泛使用的编程语言,而在众多应用中,数据的传输和存储经常需要使用JSON格式,用Java将实体类转换为J... 在软件开发的过程中,Java是一种广泛使用的编程语言,而在众多应用中,数据的传输和存储经常需要使用j

Java实现视频格式转换的完整指南

《Java实现视频格式转换的完整指南》在Java中实现视频格式的转换,通常需要借助第三方工具或库,因为视频的编解码操作复杂且性能需求较高,以下是实现视频格式转换的常用方法和步骤,需要的朋友可以参考下... 目录核心思路方法一:通过调用 FFmpeg 命令步骤示例代码说明优点方法二:使用 Jaffree(FF

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

SpringBoot实现二维码生成的详细步骤与完整代码

《SpringBoot实现二维码生成的详细步骤与完整代码》如今,二维码的应用场景非常广泛,从支付到信息分享,二维码都扮演着重要角色,SpringBoot是一个非常流行的Java基于Spring框架的微... 目录一、环境搭建二、创建 Spring Boot 项目三、引入二维码生成依赖四、编写二维码生成代码五

Java对接Dify API接口的完整流程

《Java对接DifyAPI接口的完整流程》Dify是一款AI应用开发平台,提供多种自然语言处理能力,通过调用Dify开放API,开发者可以快速集成智能对话、文本生成等功能到自己的Java应用中,本... 目录Java对接Dify API接口完整指南一、Dify API简介二、准备工作三、基础对接实现1.

使用Python创建一个功能完整的Windows风格计算器程序

《使用Python创建一个功能完整的Windows风格计算器程序》:本文主要介绍如何使用Python和Tkinter创建一个功能完整的Windows风格计算器程序,包括基本运算、高级科学计算(如三... 目录python实现Windows系统计算器程序(含高级功能)1. 使用Tkinter实现基础计算器2.