Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练

2024-03-23 06:32

本文主要是介绍Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

pytorch进行CIFAR-10分类(4)训练

我的系列博文:

Pytorch打怪路(一)pytorch进行CIFAR-10分类(1)CIFAR-10数据加载和处理

Pytorch打怪路(一)pytorch进行CIFAR-10分类(2)定义卷积神经网络

Pytorch打怪路(一)pytorch进行CIFAR-10分类(3)定义损失函数和优化器

Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练本文

Pytorch打怪路(一)pytorch进行CIFAR-10分类(5)测试

1、简述

经过前面的数据加载和网络定义后,就可以开始训练了,这里会看到前面遇到的一些东西究竟在后面会有什么用,所以这一步希望各位也能仔细研究一下

2、代码

for epoch in range(2):  # loop over the dataset multiple times 指定训练一共要循环几个epochrunning_loss = 0.0  #定义一个变量方便我们对loss进行输出for i, data in enumerate(trainloader, 0): # 这里我们遇到了第一步中出现的trailoader,代码传入数据# enumerate是python的内置函数,既获得索引也获得数据,详见下文# get the inputsinputs, labels = data   # data是从enumerate返回的data,包含数据和标签信息,分别赋值给inputs和labels# wrap them in Variableinputs, labels = Variable(inputs), Variable(labels) # 将数据转换成Variable,第二步里面我们已经引入这个模块# 所以这段程序里面就直接使用了,下文会分析# zero the parameter gradientsoptimizer.zero_grad()                # 要把梯度重新归零,因为反向传播过程中梯度会累加上一次循环的梯度# forward + backward + optimize      outputs = net(inputs)                # 把数据输进网络net,这个net()在第二步的代码最后一行我们已经定义了loss = criterion(outputs, labels)    # 计算损失值,criterion我们在第三步里面定义了loss.backward()                      # loss进行反向传播,下文详解optimizer.step()                     # 当执行反向传播之后,把优化器的参数进行更新,以便进行下一轮# print statistics                   # 这几行代码不是必须的,为了打印出loss方便我们看而已,不影响训练过程running_loss += loss.data[0]         # 从下面一行代码可以看出它是每循环0-1999共两千次才打印一次if i % 2000 == 1999:    # print every 2000 mini-batches   所以每个2000次之类先用running_loss进行累加print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))  # 然后再除以2000,就得到这两千次的平均损失值running_loss = 0.0               # 这一个2000次结束后,就把running_loss归零,下一个2000次继续使用print('Finished Training')
 

3、分析

①autograd


在第二步中我们定义网络时定义了前向传播函数,但是并没有定义反向传播函数,可是深度学习是需要反向传播求导的,
Pytorch其实利用的是Autograd模块来进行自动求导,反向传播
Autograd中最核心的类就是Variable了,它封装了Tensor,并几乎支持所有Tensor的操作,这里可以参考官方给的详细解释:
http://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py
以上链接详细讲述了variable究竟是怎么能够实现自动求导的,怎么用它来实现反向传播的。
这里涉及到计算图的相关概念,这里我不详细讲,后面会写相关博文来讨论这个东西,暂时不会对我们理解这个程序造成影响
只说一句, 想要计算各个variable的梯度,只需调用根节点的backward方法,Autograd就会自动沿着整个计算图进行反向计算
而在此例子中,根节点就是我们的loss,所以:

程序中的loss.backward()代码就是在实现反向传播,自动计算所有的梯度。

所以训练部分的代码其实比较简单:
running_loss和后面负责打印损失值的那部分并不是必须的,所以关键行不多,总得来说分成三小节
第一节:把最开始放在trainloader里面的数据给转换成variable,然后指定为网络的输入;
第二节:每次循环新开始的时候,要确保梯度归零
第三节:forward+backward,就是调用我们在第三步里面实例化的net()实现前传,loss.backward()实现后传
每结束一次循环,要确保梯度更新

这篇关于Pytorch打怪路(一)pytorch进行CIFAR-10分类(4)训练的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/837437

相关文章

Nginx中配置使用非默认80端口进行服务的完整指南

《Nginx中配置使用非默认80端口进行服务的完整指南》在实际生产环境中,我们经常需要将Nginx配置在其他端口上运行,本文将详细介绍如何在Nginx中配置使用非默认端口进行服务,希望对大家有所帮助... 目录一、为什么需要使用非默认端口二、配置Nginx使用非默认端口的基本方法2.1 修改listen指令

MySQL按时间维度对亿级数据表进行平滑分表

《MySQL按时间维度对亿级数据表进行平滑分表》本文将以一个真实的4亿数据表分表案例为基础,详细介绍如何在不影响线上业务的情况下,完成按时间维度分表的完整过程,感兴趣的小伙伴可以了解一下... 目录引言一、为什么我们需要分表1.1 单表数据量过大的问题1.2 分表方案选型二、分表前的准备工作2.1 数据评估

MySQL进行分片合并的实现步骤

《MySQL进行分片合并的实现步骤》分片合并是指在分布式数据库系统中,将不同分片上的查询结果进行整合,以获得完整的查询结果,下面就来具体介绍一下,感兴趣的可以了解一下... 目录环境准备项目依赖数据源配置分片上下文分片查询和合并代码实现1. 查询单条记录2. 跨分片查询和合并测试结论分片合并(Shardin

SpringBoot结合Knife4j进行API分组授权管理配置详解

《SpringBoot结合Knife4j进行API分组授权管理配置详解》在现代的微服务架构中,API文档和授权管理是不可或缺的一部分,本文将介绍如何在SpringBoot应用中集成Knife4j,并进... 目录环境准备配置 Swagger配置 Swagger OpenAPI自定义 Swagger UI 底

基于Python Playwright进行前端性能测试的脚本实现

《基于PythonPlaywright进行前端性能测试的脚本实现》在当今Web应用开发中,性能优化是提升用户体验的关键因素之一,本文将介绍如何使用Playwright构建一个自动化性能测试工具,希望... 目录引言工具概述整体架构核心实现解析1. 浏览器初始化2. 性能数据收集3. 资源分析4. 关键性能指

Nginx进行平滑升级的实战指南(不中断服务版本更新)

《Nginx进行平滑升级的实战指南(不中断服务版本更新)》Nginx的平滑升级(也称为热升级)是一种在不停止服务的情况下更新Nginx版本或添加模块的方法,这种升级方式确保了服务的高可用性,避免了因升... 目录一.下载并编译新版Nginx1.下载解压2.编译二.替换可执行文件,并平滑升级1.替换可执行文件

Python进行JSON和Excel文件转换处理指南

《Python进行JSON和Excel文件转换处理指南》在数据交换与系统集成中,JSON与Excel是两种极为常见的数据格式,本文将介绍如何使用Python实现将JSON转换为格式化的Excel文件,... 目录将 jsON 导入为格式化 Excel将 Excel 导出为结构化 JSON处理嵌套 JSON:

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

一文解密Python进行监控进程的黑科技

《一文解密Python进行监控进程的黑科技》在计算机系统管理和应用性能优化中,监控进程的CPU、内存和IO使用率是非常重要的任务,下面我们就来讲讲如何Python写一个简单使用的监控进程的工具吧... 目录准备工作监控CPU使用率监控内存使用率监控IO使用率小工具代码整合在计算机系统管理和应用性能优化中,监

如何使用Lombok进行spring 注入

《如何使用Lombok进行spring注入》本文介绍如何用Lombok简化Spring注入,推荐优先使用setter注入,通过注解自动生成getter/setter及构造器,减少冗余代码,提升开发效... Lombok为了开发环境简化代码,好处不用多说。spring 注入方式为2种,构造器注入和setter