动手学深度学习(pytorch)学习记录21-读写文件(模型与参数)[学习记录]

2024-08-30 07:28

本文主要是介绍动手学深度学习(pytorch)学习记录21-读写文件(模型与参数)[学习记录],希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

目录

  • 加载和保存张量
  • 加载和保存模型参数

保存模型的好处众多,涵盖了从开发到部署的整个机器学习生命周期。

  • 节省资源:训练模型可能需要大量的时间和计算资源。保存模型可以避免重复训练,从而节省时间和计算资源。
  • 快速部署:一旦模型被训练并保存,它可以迅速部署到生产环境中,加速产品上市时间。
  • 版本控制:保存不同版本的模型有助于跟踪模型的迭代过程,便于比较和回滚到之前的版本。
  • 离线使用:保存的模型可以在没有网络连接的情况下使用,这对于需要在本地设备上运行模型的应用程序非常有用。
  • 模型共享:研究人员和开发者可以共享他们的模型,促进合作和知识传播。
  • 模型评估:保存的模型可以在不同的数据集上进行评估,帮助验证模型的泛化能力和性能。
  • 实验复现:保存模型的状态使得其他研究者可以复现实验结果,增加研究的可验证性。
  • 业务连续性:在系统升级或迁移过程中,保存的模型可以确保业务的连续性,减少停机时间。
  • 法律合规:在某些行业,如医疗和金融,保存模型可能是必须的,以满足法律和合规要求。
  • 模型优化:保存的模型可以用于进一步的优化,如模型压缩、加速等,以适应不同的部署环境。
  • 模型监控:在模型部署后,保存的模型可以用于监控和比较,以检测模型性能随时间的变化。
  • 用户信任:提供透明的模型保存信息可以增加用户对模型决策的信任。
  • 教育和研究:保存的模型可以作为教育材料,帮助学生和研究人员学习模型的工作原理。
  • 灾难恢复:在发生系统故障时,保存的模型可以作为备份,快速恢复服务。
  • 长期维护:随着时间的推移,保存的模型可以用于维护和更新,以适应新的数据和需求。

加载和保存张量

# 保存张量
import torch
from torch import nn
from torch.nn import functional as Fx = torch.arange(4)
torch.save(x, 'x-file')

将存储在文件中的数据读回内存。

x2 = torch.load('x-file')
x2
tensor([0, 1, 2, 3])

存储一个张量列表,然后把它们读回内存。

y = torch.zeros(4)
torch.save([x, y],'x-files')
x2, y2 = torch.load('x-files')
(x2, y2)
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

可以写入或读取从字符串映射到张量的字典。 当我们要读取或写入模型中的所有权重时,这很方便。

mydict = {'x': x, 'y': y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

加载和保存模型参数

class MLP(nn.Module):def __init__(self):super().__init__()self.hidden = nn.Linear(20, 256)self.output = nn.Linear(256, 10)def forward(self, x):return self.output(F.relu(self.hidden(x)))net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

将模型的参数存储在一个叫做“mlp.params”的文件中

torch.save(net.state_dict(), 'mlp.params')

为恢复模型,需实例化原始多层感知机模型的一个备份, 直接读取文件中存储的参数作为初始参数。

clone = MLP()
clone.load_state_dict(torch.load('mlp.params'))
clone.eval()
MLP((hidden): Linear(in_features=20, out_features=256, bias=True)(output): Linear(in_features=256, out_features=10, bias=True)
)

由于两个实例具有相同的模型参数,在输入相同的X时, 两个实例的计算结果应该相同。

Y_clone = clone(X)
Y_clone == Y
tensor([[True, True, True, True, True, True, True, True, True, True],[True, True, True, True, True, True, True, True, True, True]])

保存整个模型

torch.save(net, 'net.pt')
net1 = torch.load('net.pt')
net1.eval()
MLP((hidden): Linear(in_features=20, out_features=256, bias=True)(output): Linear(in_features=256, out_features=10, bias=True)
)

原模型和新加载的模型参数应该是相同的。

net.state_dict()['hidden.weight'].data == net1.state_dict()['hidden.weight'].data
tensor([[True, True, True,  ..., True, True, True],[True, True, True,  ..., True, True, True],[True, True, True,  ..., True, True, True],...,[True, True, True,  ..., True, True, True],[True, True, True,  ..., True, True, True],[True, True, True,  ..., True, True, True]])

封面图片来源

欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

这篇关于动手学深度学习(pytorch)学习记录21-读写文件(模型与参数)[学习记录]的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1120168

相关文章

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

Java学习手册之Filter和Listener使用方法

《Java学习手册之Filter和Listener使用方法》:本文主要介绍Java学习手册之Filter和Listener使用方法的相关资料,Filter是一种拦截器,可以在请求到达Servl... 目录一、Filter(过滤器)1. Filter 的工作原理2. Filter 的配置与使用二、Listen

SpringBoot请求参数接收控制指南分享

《SpringBoot请求参数接收控制指南分享》:本文主要介绍SpringBoot请求参数接收控制指南,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录Spring Boot 请求参数接收控制指南1. 概述2. 有注解时参数接收方式对比3. 无注解时接收参数默认位置

Python使用getopt处理命令行参数示例解析(最佳实践)

《Python使用getopt处理命令行参数示例解析(最佳实践)》getopt模块是Python标准库中一个简单但强大的命令行参数处理工具,它特别适合那些需要快速实现基本命令行参数解析的场景,或者需要... 目录为什么需要处理命令行参数?getopt模块基础实际应用示例与其他参数处理方式的比较常见问http

Python中__init__方法使用的深度解析

《Python中__init__方法使用的深度解析》在Python的面向对象编程(OOP)体系中,__init__方法如同建造房屋时的奠基仪式——它定义了对象诞生时的初始状态,下面我们就来深入了解下_... 目录一、__init__的基因图谱二、初始化过程的魔法时刻继承链中的初始化顺序self参数的奥秘默认

Java使用SLF4J记录不同级别日志的示例详解

《Java使用SLF4J记录不同级别日志的示例详解》SLF4J是一个简单的日志门面,它允许在运行时选择不同的日志实现,这篇文章主要为大家详细介绍了如何使用SLF4J记录不同级别日志,感兴趣的可以了解下... 目录一、SLF4J简介二、添加依赖三、配置Logback四、记录不同级别的日志五、总结一、SLF4J

Linux内核参数配置与验证详细指南

《Linux内核参数配置与验证详细指南》在Linux系统运维和性能优化中,内核参数(sysctl)的配置至关重要,本文主要来聊聊如何配置与验证这些Linux内核参数,希望对大家有一定的帮助... 目录1. 引言2. 内核参数的作用3. 如何设置内核参数3.1 临时设置(重启失效)3.2 永久设置(重启仍生效

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在Spring Boot中浅尝内存泄漏的实战记录

《在SpringBoot中浅尝内存泄漏的实战记录》本文给大家分享在SpringBoot中浅尝内存泄漏的实战记录,结合实例代码给大家介绍的非常详细,感兴趣的朋友一起看看吧... 目录使用静态集合持有对象引用,阻止GC回收关键点:可执行代码:验证:1,运行程序(启动时添加JVM参数限制堆大小):2,访问 htt

SpringMVC获取请求参数的方法

《SpringMVC获取请求参数的方法》:本文主要介绍SpringMVC获取请求参数的方法,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友可以参考下... 目录1、通过ServletAPI获取2、通过控制器方法的形参获取请求参数3、@RequestParam4、@