Transformer实战-系列教程8:SwinTransformer 源码解读1(项目配置/SwinTransformer类)

本文主要是介绍Transformer实战-系列教程8:SwinTransformer 源码解读1(项目配置/SwinTransformer类),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
https://download.csdn.net/download/weixin_50592077/88809977?spm=1001.2014.3001.5501

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

1、项目配置

本项目来自SwinTransformer 的GitHub官方源码:

Image Classification: Included in this repo. See get_started.md for a quick start.
Object Detection and Instance Segmentation: See Swin Transformer for Object Detection.
Semantic Segmentation: See Swin Transformer for Semantic Segmentation.
Video Action Recognition: See Video Swin Transformer.
Semi-Supervised Object Detection: See Soft Teacher.
SSL: Contrasitive Learning: See Transformer-SSL.
SSL: Masked Image Modeling: See get_started.md#simmim-support.
Mixture-of-Experts: See get_started for more instructions.
Feature-Distillation: See Feature-Distillation.

此处包含多个版本(分类、检测、分割、视频 ),但是仅仅学习算法建议选择第一个图像分类的基础版本就可以了

安装需求:

pip install timm==0.4.12
pip install yacs==0.1.8
pip install termcolor==1.1.0
pytorch
opencv
Apex(linux版本)

原本的数据是imagenet,这个数据太多了,有很多开源的imagenet小版本,本文配套的资源就是已经配好的imagenet小版本,目录信息、数据标注、数据划分都已经做好了

本项目的执行文件就是main.py,源码我已经修改了部分

配置参数:

--cfg configs/swin_tiny_patch4_window7_224.yaml
--data-path imagenet
--local_rank 0
--batch-size 4

–local rank 0这个参数表示的是分布式训练,直接用当前的这个就好

2、SwinTransformer类

打开models有两个构建模型的源码:
build.py
swin_transformer.py

构建模型的部分主要就在swin_transformer.py,一共有600多行代码

首先看SwinTransformer类的前向传播函数:

class SwinTransformer(nn.Module):def forward(self, x):x = self.forward_features(x)x = self.head(x)return x

打印这个过程的shape值:

  1. 原始输入x: torch.Size([4, 3, 224, 224]),原始输入是一张彩色图像,4是batch,3是通道数,图像是224*244的长宽
  2. self.forward_features(x):torch.Size([4, 768]),经过forward_features函数后,变成了768维的向量
  3. self.head(x):torch.Size([4, 1000]),head是一个全连接层,很显然这个1000是最后的分类数

所以整个体征提取的过程都在self.forward_features()函数中:

    def forward_features(self, x):x = self.patch_embed(x)if self.ape:x = x + self.absolute_pos_embedx = self.pos_drop(x)for layer in self.layers:x = layer(x)x = self.norm(x)  # B L Cx = self.avgpool(x.transpose(1, 2))  # B C 1x = torch.flatten(x, 1)return x
  1. 原始输入x: torch.Size([4, 3, 224, 224]),原始输入是一张彩色图像

  2. patch_embed: torch.Size([4, 3136, 96]),图像经过patch_embbeding变成一个Transformer需要的序列,相当于序列是3136个向量,每个向量维度是96。这个过程通常包括将图像分割成多个patches,然后将每个patch线性投影到一个指定的维度。

  3. if self.ape: x = x + self.absolute_pos_embed,如果模型配置了绝对位置编码(self.ape为真),这行代码会将绝对位置嵌入加到patch的嵌入上。绝对位置嵌入提供了每个patch在图像中位置的信息,帮助模型理解图像中不同部分的空间关系, 不改变维度

  4. pos_drop: torch.Size([4, 3136, 96]),一层Dropout

  5. layer: torch.Size([4, 784, 192]),for循环主要是Swin Transformer Block的堆叠

  6. layer: torch.Size([4, 196, 384]),4次循环,序列长度减小

  7. layer: torch.Size([4, 49, 768]),特征图个数增多,即向量维度变大

  8. layer: torch.Size([4, 49, 768]),最后一次维度不变

  9. norm: torch.Size([4, 49, 768]),层归一化,维度不变

  10. avgpool: torch.Size([4, 768, 1]),平均池化

  11. flatten: torch.Size([4, 768]),拉平操作,去掉多余的维度

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

这篇关于Transformer实战-系列教程8:SwinTransformer 源码解读1(项目配置/SwinTransformer类)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/693925

相关文章

Windows环境下解决Matplotlib中文字体显示问题的详细教程

《Windows环境下解决Matplotlib中文字体显示问题的详细教程》本文详细介绍了在Windows下解决Matplotlib中文显示问题的方法,包括安装字体、更新缓存、配置文件设置及编码調整,并... 目录引言问题分析解决方案详解1. 检查系统已安装字体2. 手动添加中文字体(以SimHei为例)步骤

IntelliJ IDEA2025创建SpringBoot项目的实现步骤

《IntelliJIDEA2025创建SpringBoot项目的实现步骤》本文主要介绍了IntelliJIDEA2025创建SpringBoot项目的实现步骤,文中通过示例代码介绍的非常详细,对大家... 目录一、创建 Spring Boot 项目1. 新建项目2. 基础配置3. 选择依赖4. 生成项目5.

nginx 负载均衡配置及如何解决重复登录问题

《nginx负载均衡配置及如何解决重复登录问题》文章详解Nginx源码安装与Docker部署,介绍四层/七层代理区别及负载均衡策略,通过ip_hash解决重复登录问题,对nginx负载均衡配置及如何... 目录一:源码安装:1.配置编译参数2.编译3.编译安装 二,四层代理和七层代理区别1.二者混合使用举例

Java JDK1.8 安装和环境配置教程详解

《JavaJDK1.8安装和环境配置教程详解》文章简要介绍了JDK1.8的安装流程,包括官网下载对应系统版本、安装时选择非系统盘路径、配置JAVA_HOME、CLASSPATH和Path环境变量,... 目录1.下载JDK2.安装JDK3.配置环境变量4.检验JDK官网下载地址:Java Downloads

Linux下进程的CPU配置与线程绑定过程

《Linux下进程的CPU配置与线程绑定过程》本文介绍Linux系统中基于进程和线程的CPU配置方法,通过taskset命令和pthread库调整亲和力,将进程/线程绑定到特定CPU核心以优化资源分配... 目录1 基于进程的CPU配置1.1 对CPU亲和力的配置1.2 绑定进程到指定CPU核上运行2 基于

MySQL 多列 IN 查询之语法、性能与实战技巧(最新整理)

《MySQL多列IN查询之语法、性能与实战技巧(最新整理)》本文详解MySQL多列IN查询,对比传统OR写法,强调其简洁高效,适合批量匹配复合键,通过联合索引、分批次优化提升性能,兼容多种数据库... 目录一、基础语法:多列 IN 的两种写法1. 直接值列表2. 子查询二、对比传统 OR 的写法三、性能分析

Spring Boot spring-boot-maven-plugin 参数配置详解(最新推荐)

《SpringBootspring-boot-maven-plugin参数配置详解(最新推荐)》文章介绍了SpringBootMaven插件的5个核心目标(repackage、run、start... 目录一 spring-boot-maven-plugin 插件的5个Goals二 应用场景1 重新打包应用

Java中读取YAML文件配置信息常见问题及解决方法

《Java中读取YAML文件配置信息常见问题及解决方法》:本文主要介绍Java中读取YAML文件配置信息常见问题及解决方法,本文给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要... 目录1 使用Spring Boot的@ConfigurationProperties2. 使用@Valu

Python办公自动化实战之打造智能邮件发送工具

《Python办公自动化实战之打造智能邮件发送工具》在数字化办公场景中,邮件自动化是提升工作效率的关键技能,本文将演示如何使用Python的smtplib和email库构建一个支持图文混排,多附件,多... 目录前言一、基础配置:搭建邮件发送框架1.1 邮箱服务准备1.2 核心库导入1.3 基础发送函数二、

Jenkins分布式集群配置方式

《Jenkins分布式集群配置方式》:本文主要介绍Jenkins分布式集群配置方式,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1.安装jenkins2.配置集群总结Jenkins是一个开源项目,它提供了一个容易使用的持续集成系统,并且提供了大量的plugin满