本文主要是介绍PyTorch核心方法之state_dict()、parameters()参数打印与应用案例,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
《PyTorch核心方法之state_dict()、parameters()参数打印与应用案例》PyTorch是一个流行的开源深度学习框架,提供了灵活且高效的方式来训练和部署神经网络,这篇文章主要介绍...
前言
本文以 LeNet-5 模型为案例,介绍了 PyTorch 中打印模型参数的相关方法。首先展示了 LeNet-5 模型的结构定义及打印结果;随后详细说明了三种获取模型参数的方式:
state_dict()方法返回有序字典形式的可学习参数,包含参数名称和对应张量;parameters()方法返回生成器,仅包含各层参数信息;named_parameters()方法返回生成器,包含模型名称和对应参数信息;
最后提供了利用named_parameters()进行模型结构冻结的示例,可打印确认冻结的网络名称。
模型案例
本文以LeNet-5为基础模型,快速验证模型参数打印过程。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import torch
import torch.nn.functional as F
import torch.nn as nn
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
# 1 input image channel, 6 output channels, 5x5 square convolution
# kernel
self.conv1 = nn.Conv2d(1, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
# an affine operation: y = Wx + b
self.fc1 = nn.Linear(16 * 5 * 5, 120) # 这里论文上写的是conv,官方教程用了线性层
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
# Max pooling over a (2, 2) window
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
# If the size is a square you can only specify a single number
python x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = selphpf.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] # all dimensions except the BATch dimension
num_features = 1
for s in size:
num_features *= s
return num_features
net = LeNet5()
print(net)
模型结构打印如下。

A. state_dict()方法验证
在 PyTorch 中,state_dict() 是核心方法之一,用于以有序字典(OrderedDict)的形式返回模型 / 优化器等实例的可学习参数(或状态),是模型保存、加载、迁移学习的基础。
state_dict() 本质是一个 python 字典(PyTorch 中为 OrderedDict),键为参数 / 状态的名称(字符串),值为对应的张量(torch.Tensor)。
print(type(net.state_dict())) # <class 'collections.OrderedDict'>
## 遍历打印
for model_key in net.state_dict(): # 【字典格式】的遍历,获取的是模型的名称
print(f"{model_key}: {net.state_dict()[model_key].size()}")
对于Lenet-5模型进行打印,可以看到state_dict()的类型为 <class 'collections.OrderedDict'>,各层名称及参数尺寸如下图所示。

B. parameters()
parameters()方法也可以获取到模型的参数。可以看出,parameters()获取到的是一个生成器,其中仅包含各层参数的信息。
params = net.parameters()
print(type(params)) # <class 'generator'> 生成器
for param in params:
print(param.size()) # 只包含参数信息:具体的参数尺寸
对Lenet-5进行模型参数打印。

如果也需要模型名称信息,可以使javascript用named_parameters()方法。该方法获取的也是一个生成器,其中返回的是一个元组,包括模型名称和对应的参数。
named_params = net.named_parameters()
print(type(named_params)) # <class 'generator'> 也是一个生成器
for name, param in named_params:
print(f"{name}: {param.sizChina编程e()}") # 同时获取网络名称和网络参数
对Lenet-5进行模型名称及参数尺寸信息打印:

C. 模型结构冻结示例
该方法可以在对模型结构冻结时使用,如下述示例对模型结构m的参数进行冻结,同时打印确认冻结包含哪些网络结构。
# 示例
for name, param in m.named_parameters():
param.requires_grad = False
print(f"Freezing layer {name}")
总结
到此这篇关于PyTorch核心方法之state_dict()、parameters()参数打印与应用案例的文章就介绍到这了,更多相关PyTorch state_dict()、parameters()参数打印内容请搜索China编程(www.chinasem.cn)以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程China编程(www.chinasem.cn)!
这篇关于PyTorch核心方法之state_dict()、parameters()参数打印与应用案例的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!