在pytorch模型中如何获得BatchNorm2d层的各个mean和var(平均值和方差)

2024-03-10 09:08

本文主要是介绍在pytorch模型中如何获得BatchNorm2d层的各个mean和var(平均值和方差),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

这个内容是将随便做了一个网络结构,然后简单的训练几次,生成模型,并且存储起来,主要是为了学习获得pytorch中的BatchNorm2d层的各个特征图的平均值和方差。代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets,transforms
from torch.optim import lr_scheduler
import torch.optim as optimclass VGG(nn.Module):def __init__(self):super(VGG,self).__init__()self.conv1 = nn.Conv2d(3,64,3,padding=(1,1))self.bn1 = nn.BatchNorm2d(64)self.maxpool1 = nn.MaxPool2d((2,2))self.conv2 = nn.Conv2d(64,128,3,padding=(1,1))# self.bn2 = nn.BatchNorm2d(128)self.maxpool2 = nn.MaxPool2d((2,2))self.conv3 = nn.Conv2d(128,256,3,padding=(1,1))# self.bn3 = nn.BatchNorm2d(256)self.maxpool3 = nn.MaxPool2d((2,2))self.fc1 = nn.Linear(256*16*8,4096)self.fc2 = nn.Linear(4096,1000)self.fc3 = nn.Linear(1000,10)def forward(self,x):in_size = x.size(0)out = self.conv1(x)out = self.bn1(out)out = F.relu(out)out = self.maxpool1(out)out = self.conv2(out)out = F.relu(out)out = self.maxpool2(out)out = self.conv3(out)out = F.relu(out)out = self.maxpool3(out)out = out.view(out.size(0),-1)out = self.fc1(out)out = F.relu(out)out = self.fc2(out)out = F.relu(out)out = self.fc3(out)return outtransform_train_list = transforms.Compose([transforms.Resize( (256,128),interpolation=3 ),transforms.RandomCrop((128,64)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])train_dataset = datasets.ImageFolder('./train',transform_train_list)
dataloaders = torch.utils.data.DataLoader(train_dataset,batch_size=2,num_workers=0)dataset_size = len(train_dataset)
class_names = train_dataset.classesprint(dataset_size)
print(class_names)
net=VGG()
optimizer = optim.SGD(net.parameters(),lr=0.001,momentum=0.9)
criterion = nn.CrossEntropyLoss()for epoch in range(2):#训练模型print(epoch)net.train(True)running_loss = 0.0running_corrects = 0.0for data in dataloaders:inputs,labels = datanow_batch,c,h,w = inputs.shapeoptimizer.zero_grad()outputs = net(inputs)# print(outputs)_,preds = torch.max(outputs.data,1)loss = criterion(outputs,labels)loss.backward()optimizer.step()running_loss = running_loss + loss.item() * now_batchrunning_corrects += float( torch.sum( preds == labels.data ) )epoch_loss = running_loss/dataset_sizeepoch_acc = running_corrects/dataset_sizeprint(epoch_loss)print(epoch_acc)torch.save(net.cpu().state_dict(),'first.pth')  ##将训练好的模型保存起来net = VGG()
net.load_state_dict( torch.load('first.pth') )
net.eval()  #产生一个模型并且加载已经训练好的模型的参数# for data in dataloaders:
#     inputs,labels = data
#     # print(inputs)
#     print(labels)
#     outputs = net(inputs)
#     print(outputs)
#     breakm = VGG()
# m.eval()
m.load_state_dict( torch.load('second.pth') )
print(m.bn1.running_mean.size()) ##获得一共有多少个mean  要是想获得var只要将mean改为var即可
print(m.bn1.running_mean.data[0])
print(m.bn1.running_mean.data[1])
print(m.bn1.running_mean.data[2])print(m.bn1.running_var.data[0])
print(type(m.bn1.running_mean.data[0]))m.bn1.running_mean.data[0] = m.bn1.running_mean.data[2]  ##可以对模型参数进行更改,然后保存更改后的模型
m.bn1.running_mean.data[1] = m.bn1.running_mean.data[2]torch.save(m.cpu().state_dict(),'second.pth')

对于输入到BatchNorm2d层的数据格式为(batch_size,channels_size,h,w),channels_size为多少,就会生成多少个mean和var。

举个例子,如果输入的数据是batch_size=16,channels_size=64,h=32,w=16,则每对mean和var都是 16张某一个特征图中的所有数据的mean和var

这篇关于在pytorch模型中如何获得BatchNorm2d层的各个mean和var(平均值和方差)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/793745

相关文章

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

详解如何使用Python从零开始构建文本统计模型

《详解如何使用Python从零开始构建文本统计模型》在自然语言处理领域,词汇表构建是文本预处理的关键环节,本文通过Python代码实践,演示如何从原始文本中提取多尺度特征,并通过动态调整机制构建更精确... 目录一、项目背景与核心思想二、核心代码解析1. 数据加载与预处理2. 多尺度字符统计3. 统计结果可

SpringBoot整合Sa-Token实现RBAC权限模型的过程解析

《SpringBoot整合Sa-Token实现RBAC权限模型的过程解析》:本文主要介绍SpringBoot整合Sa-Token实现RBAC权限模型的过程解析,本文给大家介绍的非常详细,对大家的学... 目录前言一、基础概念1.1 RBAC模型核心概念1.2 Sa-Token核心功能1.3 环境准备二、表结

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

PyTorch高级特性与性能优化方式

《PyTorch高级特性与性能优化方式》:本文主要介绍PyTorch高级特性与性能优化方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、自动化机制1.自动微分机制2.动态计算图二、性能优化1.内存管理2.GPU加速3.多GPU训练三、分布式训练1.分布式数据

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

pytorch自动求梯度autograd的实现

《pytorch自动求梯度autograd的实现》autograd是一个自动微分引擎,它可以自动计算张量的梯度,本文主要介绍了pytorch自动求梯度autograd的实现,具有一定的参考价值,感兴趣... autograd是pytorch构建神经网络的核心。在 PyTorch 中,结合以下代码例子,当你

在PyCharm中安装PyTorch、torchvision和OpenCV详解

《在PyCharm中安装PyTorch、torchvision和OpenCV详解》:本文主要介绍在PyCharm中安装PyTorch、torchvision和OpenCV方式,具有很好的参考价值,... 目录PyCharm安装PyTorch、torchvision和OpenCV安装python安装PyTor

pytorch之torch.flatten()和torch.nn.Flatten()的用法

《pytorch之torch.flatten()和torch.nn.Flatten()的用法》:本文主要介绍pytorch之torch.flatten()和torch.nn.Flatten()的用... 目录torch.flatten()和torch.nn.Flatten()的用法下面举例说明总结torch