学习基于pytorch的VGG图像分类 day2

2024-04-10 14:36

本文主要是介绍学习基于pytorch的VGG图像分类 day2,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

注:本系列博客在于汇总CSDN的精华帖,类似自用笔记,不做学习交流,方便以后的复习回顾,博文中的引用都注明出处,并点赞收藏原博主.

目录

VGG网络搭建(模型文件)

        1.字典文件配置

         2.提取特征网络结构

        3. VGG类的定义

         4.VGG网络实例化


VGG网络搭建(模型文件)

        1.字典文件配置

#字典文件,对应各个配置,数字对应卷积核的个数,'M'对应最大液化(即maxpool)
cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

         2.提取特征网络结构

#提取特征网络结构
def make_features(cfg: list): #传入对应的列表layers = [] #定义一个空列表,存放每层的结果in_channels = 3 #输入为RGB彩色图片,输入通道为3for v in cfg: #通过for循环遍历列表if v == "M":                                                    #maxpool size = 2,stride = 2layers += [nn.MaxPool2d(kernel_size=2, stride=2)] #创建最大池化下载量程,池化核为2,布局也为2else:                                                           #conv padding = 1,stride = 1conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) #创建卷积操作(输入特征矩阵深度,输出特征矩阵深度(卷积核个数),卷积核为3,填充为1,stride默认为1(不用写))layers += [conv2d, nn.ReLU(True)] #使用ReLU激活函数in_channels = v #输出深度改变成vreturn nn.Sequential(*layers) #通过Sequential函数将列表以非关键字参数的形式传入(*代表非关键字传入)

        3. VGG类的定义

class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False): #(通过make_features生成的提取特征网络结构,分类的类别个数,是否对网络权重初始化)super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential( #生成分类网络nn.Linear(512*7*7, 4096), #全连接层上下的节点个数nn.ReLU(True),  #ReLU函数激活nn.Dropout(p=0.5), #Dropout函数减少过拟合,以50%的比例随机失活神经元nn.Linear(4096, 4096), #第一层和第二层nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, num_classes) #第二层和第三层,总计3层全连接层,最后连接到输出层,输出num_classes的所需个数)if init_weights: #初始化权重函数self._initialize_weights()def forward(self, x): #正向传播 x就是输入的图像数据 # N x 3 x 224 x 224x = self.features(x) #用features提取特征网络结构# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1) #对输出进行一个展平处理,(start_dim定义从哪个维度开始展平处理)# N x 512*7*7x = self.classifier(x) #输入到分类网络结构return xdef _initialize_weights(self):for m in self.modules(): #遍历网络的每一个子模块if isinstance(m, nn.Conv2d): #遍历到卷积层# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight) #使用xavier函数初始化,初始化卷积核的权重if m.bias is not None: #卷积核采用偏置nn.init.constant_(m.bias, 0) #将偏执初始化为0elif isinstance(m, nn.Linear): #遍历到全连接层,下面同理nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

         4.VGG网络实例化

#实例化VGG网络结构
def vgg(model_name="vgg16", **kwargs):assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)cfg = cfgs[model_name]model = VGG(make_features(cfg), **kwargs) #通过VGG这个类实现实例化网络,(**可变长度的字典变量)return model

 内容参考来源:

 ​​​​​​使用pytorch搭建VGG网络_哔哩哔哩_bilibili

这篇关于学习基于pytorch的VGG图像分类 day2的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/891335

相关文章

PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例

《PyTorch中的词嵌入层(nn.Embedding)详解与实战应用示例》词嵌入解决NLP维度灾难,捕捉语义关系,PyTorch的nn.Embedding模块提供灵活实现,支持参数配置、预训练及变长... 目录一、词嵌入(Word Embedding)简介为什么需要词嵌入?二、PyTorch中的nn.Em

基于Python开发一个图像水印批量添加工具

《基于Python开发一个图像水印批量添加工具》在当今数字化内容爆炸式增长的时代,图像版权保护已成为创作者和企业的核心需求,本方案将详细介绍一个基于PythonPIL库的工业级图像水印解决方案,有需要... 目录一、系统架构设计1.1 整体处理流程1.2 类结构设计(扩展版本)二、核心算法深入解析2.1 自

MySQL中的索引结构和分类实战案例详解

《MySQL中的索引结构和分类实战案例详解》本文详解MySQL索引结构与分类,涵盖B树、B+树、哈希及全文索引,分析其原理与优劣势,并结合实战案例探讨创建、管理及优化技巧,助力提升查询性能,感兴趣的朋... 目录一、索引概述1.1 索引的定义与作用1.2 索引的基本原理二、索引结构详解2.1 B树索引2.2

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

Python中OpenCV与Matplotlib的图像操作入门指南

《Python中OpenCV与Matplotlib的图像操作入门指南》:本文主要介绍Python中OpenCV与Matplotlib的图像操作指南,本文通过实例代码给大家介绍的非常详细,对大家的学... 目录一、环境准备二、图像的基本操作1. 图像读取、显示与保存 使用OpenCV操作2. 像素级操作3.

C/C++的OpenCV 进行图像梯度提取的几种实现

《C/C++的OpenCV进行图像梯度提取的几种实现》本文主要介绍了C/C++的OpenCV进行图像梯度提取的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的... 目录预www.chinasem.cn备知识1. 图像加载与预处理2. Sobel 算子计算 X 和 Y

c/c++的opencv图像金字塔缩放实现

《c/c++的opencv图像金字塔缩放实现》本文主要介绍了c/c++的opencv图像金字塔缩放实现,通过对原始图像进行连续的下采样或上采样操作,生成一系列不同分辨率的图像,具有一定的参考价值,感兴... 目录图像金字塔简介图像下采样 (cv::pyrDown)图像上采样 (cv::pyrUp)C++ O