TorchScript模型和普通PyTorch模型

2023-10-25 16:28

本文主要是介绍TorchScript模型和普通PyTorch模型,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • TorchScript模型和普通PyTorch模型
    • 普通的PyTorch模型(基于Python的序列化):
      • 例程:
    • TorchScript模型:
      • 例程:
      • 结论:

TorchScript模型和普通PyTorch模型

PyTorch提供了两种主要的模型保存和加载机制,一种是基于Python的序列化,另一种是TorchScript。

普通的PyTorch模型(基于Python的序列化):

  • 保存: 使用torch.save(model.state_dict(), 'model_path.pth'),它保存了模型的权重和参数,但不保存模型的结构。
  • 加载:
    1. 首先,您需要有模型的类定义。
    2. 创建该类的一个实例。
    3. 使用model.load_state_dict(torch.load('model_path.pth'))来加载权重。
  • 特点:
    • 需要Python环境和模型的原始代码来加载和运行模型。
    • 保存的文件是Python特定的,并且依赖于特定的类结构。
    • 主要用于继续训练或在Python环境中进行推断。

保存模型:

torch.save(model.state_dict(), 'model_weights.pth')

加载模型:

model = ModelClass() 
model.load_state_dict(torch.load('model_weights.pth'))
model.eval() # 设置为评估模式

例程:

import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 10)def forward(self, x):return self.fc(x)model = SimpleModel()
torch.save(model.state_dict(), 'simple_model_weights.pth')# 在其他地方或时间
loaded_model = SimpleModel()
loaded_model.load_state_dict(torch.load('simple_model_weights.pth'))
loaded_model.eval()

TorchScript模型:

  • TorchScript是PyTorch的一个子集,它创建了一个可以独立于Python运行的序列化模型。
  • 生成方法:
    1. Tracing: 使用torch.jit.trace方法。这涉及到通过模型运行一个输入示例,从而跟踪模型的执行路径。
    2. Scripting: 使用torch.jit.script方法。这转化Python代码到TorchScript,允许更复杂的模型和控制流。
  • 保存: 使用torch.jit.save(traced_model, 'model_path.pt')
  • 加载: 使用torch.jit.load('model_path.pt')。注意,加载不需要原始的模型类定义。
  • 特点:
    • 可以在没有Python运行时的环境中运行,如C++。
    • 提供了一种方法,将模型从Python转移到其他平台或部署环境。
    • 包含模型的完整定义,包括结构、权重和参数。

Tracing方法:

example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_model.pt')

Scripting方法:

scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_model.pt')

加载模型:

loaded_model = torch.jit.load('model_path.pt')

例程:

import torch
import torch.nn as nnclass SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 10)def forward(self, x):return self.fc(x)model = SimpleModel()# Tracing
example_input = torch.randn(1, 10)
traced_model = torch.jit.trace(model, example_input)
torch.jit.save(traced_model, 'traced_simple_model.pt')# Scripting
scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'scripted_simple_model.pt')# 加载模型
loaded_model = torch.jit.load('traced_simple_model.pt')

结论:

  • 普通PyTorch模型: 适用于Python环境,需要模型的原始定义来加载。它是基于Python的序列化。
  • TorchScript: 为了跨平台部署和运行,例如在C++环境中。它提供了一个独立于Python的序列化模型,可以在没有Python的环境中使用。

这篇关于TorchScript模型和普通PyTorch模型的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/283768

相关文章

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应

使用PyTorch实现手写数字识别功能

《使用PyTorch实现手写数字识别功能》在人工智能的世界里,计算机视觉是最具魅力的领域之一,通过PyTorch这一强大的深度学习框架,我们将在经典的MNIST数据集上,见证一个神经网络从零开始学会识... 目录当计算机学会“看”数字搭建开发环境MNIST数据集解析1. 认识手写数字数据库2. 数据预处理的

Pytorch微调BERT实现命名实体识别

《Pytorch微调BERT实现命名实体识别》命名实体识别(NER)是自然语言处理(NLP)中的一项关键任务,它涉及识别和分类文本中的关键实体,BERT是一种强大的语言表示模型,在各种NLP任务中显著... 目录环境准备加载预训练BERT模型准备数据集标记与对齐微调 BERT最后总结环境准备在继续之前,确