基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码

本文主要是介绍基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 第一步:准备数据

6种蝴蝶数据:self.class_indict = ["曙凤蝶", "麝凤蝶", "多姿麝凤蝶", "旖凤蝶", "红珠凤蝶", "热斑凤蝶"],总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个Vision Transformer网络,其原理介绍如下:

Vision Transformer(ViT)是一种基于Transformer架构的深度学习模型,用于图像识别和计算机视觉任务。与传统的卷积神经网络(CNN)不同,ViT直接将图像视为一个序列化的输入,并利用自注意力机制来处理图像中的像素关系。

ViT通过将图像分成一系列的图块(patches),并将每个图块转换为向量表示作为输入序列。然后,这些向量将通过多层的Transformer编码器进行处理,其中包含了自注意力机制和前馈神经网络层。这样可以捕捉到图像中不同位置的上下文依赖关系。最后,通过对Transformer编码器输出进行分类或回归,可以完成特定的视觉任务。

Vit model结构图
Vit的模型结构如下图所示。vit是将图像块应用于transformer。CNN是以滑窗的思想用卷积核在图像上进行卷积得到特征图。为了可以使图像仿照NLP的输入序列,我们可以先将图像分成块(patch),再将这些图像块进行平铺后输入到网络中(这样就变成了图像序列),然后通过transformer进行特征提取,最后再通过MLP对这些特征进行分类【其实就可以理解为在以往的CNN分类任务中,将backbone替换为transformer】。

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from vit_model import vit_base_patch16_224_in21k as create_model
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")if os.path.exists("./weights") is False:os.makedirs("./weights")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes, has_logits=False).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)# 删除不需要的权重del_keys = ['head.weight', 'head.bias'] if model.has_logits \else ['pre_logits.fc.weight', 'pre_logits.fc.bias', 'head.weight', 'head.bias']for k in del_keys:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head, pre_logits外,其他权重全部冻结if "head" not in name and "pre_logits" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=5E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=6)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.001)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\Butterfly20")parser.add_argument('--model-name', default='', help='create model name')# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='./vit_base_patch16_224_in21k.pth',help='initial weights path')# 是否冻结权重parser.add_argument('--freeze-layers', type=bool, default=True)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码

有问题可以私信或者留言,有问必答

这篇关于基于Pytorch框架的深度学习Vision Transformer神经网络蝴蝶分类识别系统源码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1083385

相关文章

Go学习记录之runtime包深入解析

《Go学习记录之runtime包深入解析》Go语言runtime包管理运行时环境,涵盖goroutine调度、内存分配、垃圾回收、类型信息等核心功能,:本文主要介绍Go学习记录之runtime包的... 目录前言:一、runtime包内容学习1、作用:① Goroutine和并发控制:② 垃圾回收:③ 栈和

Python中文件读取操作漏洞深度解析与防护指南

《Python中文件读取操作漏洞深度解析与防护指南》在Web应用开发中,文件操作是最基础也最危险的功能之一,这篇文章将全面剖析Python环境中常见的文件读取漏洞类型,成因及防护方案,感兴趣的小伙伴可... 目录引言一、静态资源处理中的路径穿越漏洞1.1 典型漏洞场景1.2 os.path.join()的陷

Android学习总结之Java和kotlin区别超详细分析

《Android学习总结之Java和kotlin区别超详细分析》Java和Kotlin都是用于Android开发的编程语言,它们各自具有独特的特点和优势,:本文主要介绍Android学习总结之Ja... 目录一、空安全机制真题 1:Kotlin 如何解决 Java 的 NullPointerExceptio

Pytorch介绍与安装过程

《Pytorch介绍与安装过程》PyTorch因其直观的设计、卓越的灵活性以及强大的动态计算图功能,迅速在学术界和工业界获得了广泛认可,成为当前深度学习研究和开发的主流工具之一,本文给大家介绍Pyto... 目录1、Pytorch介绍1.1、核心理念1.2、核心组件与功能1.3、适用场景与优势总结1.4、优

conda安装GPU版pytorch默认却是cpu版本

《conda安装GPU版pytorch默认却是cpu版本》本文主要介绍了遇到Conda安装PyTorchGPU版本却默认安装CPU的问题,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的... 目录一、问题描述二、网上解决方案罗列【此节为反面方案罗列!!!】三、发现的根本原因[独家]3.1 p

C++ HTTP框架推荐(特点及优势)

《C++HTTP框架推荐(特点及优势)》:本文主要介绍C++HTTP框架推荐的相关资料,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Crow2. Drogon3. Pistache4. cpp-httplib5. Beast (Boos

8种快速易用的Python Matplotlib数据可视化方法汇总(附源码)

《8种快速易用的PythonMatplotlib数据可视化方法汇总(附源码)》你是否曾经面对一堆复杂的数据,却不知道如何让它们变得直观易懂?别慌,Python的Matplotlib库是你数据可视化的... 目录引言1. 折线图(Line Plot)——趋势分析2. 柱状图(Bar Chart)——对比分析3

PyTorch中cdist和sum函数使用示例详解

《PyTorch中cdist和sum函数使用示例详解》torch.cdist是PyTorch中用于计算**两个张量之间的成对距离(pairwisedistance)**的函数,常用于点云处理、图神经网... 目录基本语法输出示例1. 简单的 2D 欧几里得距离2. 批量形式(3D Tensor)3. 使用不

SpringBoot基础框架详解

《SpringBoot基础框架详解》SpringBoot开发目的是为了简化Spring应用的创建、运行、调试和部署等,使用SpringBoot可以不用或者只需要很少的Spring配置就可以让企业项目快... 目录SpringBoot基础 – 框架介绍1.SpringBoot介绍1.1 概述1.2 核心功能2

Spring Boot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)

《SpringBoot拦截器Interceptor与过滤器Filter深度解析(区别、实现与实战指南)》:本文主要介绍SpringBoot拦截器Interceptor与过滤器Filter深度解析... 目录Spring Boot拦截器(Interceptor)与过滤器(Filter)深度解析:区别、实现与实