【图像分类】ConViT从入门到实战——使用ConViT实现植物幼苗的分类(pytorch)

2024-03-22 16:10

本文主要是介绍【图像分类】ConViT从入门到实战——使用ConViT实现植物幼苗的分类(pytorch),希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 摘要
  • 导入项目使用的库
  • 设置全局参数
  • 图像预处理
  • 读取数据
  • 设置模型
  • 定义训练和验证函数
  • 测试
  • 完整的训练代码

摘要

来自 Facebook 的研究者提出了一种名为 ConViT 的新计算机视觉模型,它结合了两种广泛使用的 AI 架构——卷积神经网络 (CNN) 和 Transformer,该模型取长补短,克服了 CNN 和 Transformer 本身的一些局限性。同时,借助这两种架构的优势,这种基于视觉 Transformer 的模型可以胜过现有架构,尤其是在小数据的情况下,同时在大数据的情况下也能实现类似的优秀性能。

• 论文地址:https://arxiv.org/pdf/2103.10697.pdf
• 论文翻译:ConViT:使用软卷积归纳偏置改进视觉变换器
• GitHub 地址:https://github.com/facebookresearch/convit
• 本文用的数据集
链接:https://pan.baidu.com/s/1gYb-3XCZBhBoEFyj6d_kdw
提取码:q060
其工作原理
ConViT 在 vision Transformer 的基础上进行了调整,以利用 soft 卷积归纳偏置,从而激励网络进行卷积操作。同时最重要的是,ConViT 允许模型自行决定是否要保持卷积。为了利用这种 soft 归纳偏置,研究者引入了一种称为「门控位置自注意力(gated positional self-attention,GPSA)」的位置自注意力形式,其模型学习门控参数 lambda,该参数用于平衡基于内容的自注意力和卷积初始化位置自注意力。
在这里插入图片描述
如上图所示,ConViT(左)在 ViT 的基础上,将一些自注意力(SA)层用门控位置自注意力层(GPSA,右)替代。因为 GPSA 层涉及位置信息,因此在最后一个 GPSA层之后,类 token 会与隐藏表征联系到一起。
有了 GPSA 层加持,ConViT 的性能优于 Facebook 去年提出的 DeiT 模型。例如,ConViT-S+ 性能略优于 DeiT-B(对比结果为 82.2% vs. 81.8%),而 ConViT-S + 使用的参数量只有 DeiT-B 的一半左右 (48M vs 86M)。而 ConViT 最大的改进是在有限的数据范围内,soft 卷积归纳偏置发挥了重要作用。例如,仅使用 5% 的训练数据时,ConViT 的性能明显优于 DeiT(对比结果为 47.8% vs. 34.8%)。

此外,ConViT 在样本效率和参数效率方面也都优于 DeiT。如上图所示,左图为 ConViT-S 与 DeiT-S 的样本效率对比结果,这两个模型是在相同的超参数,且都是在 ImageNet-1k 的子集上训练完成的。图中绿色折线是 ConViT 相对于 DeiT 的提升。研究者还在 ImageNet-1k 上比较了 ConViT 模型与其他 ViT 以及 CNN 的 top-1 准确率,如上右图所示。
除了 ConViT 的性能优势外,门控参数提供了一种简单的方法来理解模型训练后每一层的卷积程度。查看所有层,研究者发现 ConViT 在训练过程中对卷积位置注意力的关注逐渐减少。对于靠后的层,门控参数最终会收敛到接近 0,这表明卷积归纳偏置实际上被忽略了。然而,对于起始层来说,许多注意力头保持较高的门控值,这表明该网络利用早期层的卷积归纳偏置来辅助训练。

在这里插入图片描述
上图展示了 DeiT (b) 及 ConViT © 注意力图的几个例子。σ(λ) 表示可学习的门控参数。接近 1 的值表示使用了卷积初始化,而接近 0 的值表示只使用了基于内容的注意力。注意,早期的 ConViT 层部分地维护了卷积初始化,而后面的层则完全基于内容。

导入项目使用的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from models import convit_tiny
from torch.autograd import Variable
from mydatasets import SeedlingData
import os

设置全局参数

设置使用哪块GPU,0指的是第一块GPU,设置学习率、BatchSize、epoch等参数

os.environ['CUDA_VISIBLE_DEVICE'] = '0'torch.cuda.set_device(0)
# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 100
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

图像预处理

数据处理比较简单,没有做复杂的尝试,有兴趣的可以加入一些处理。

# 数据预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

读取数据

将数据集解压后放到data文件夹下面,如图:

在这里插入图片描述

然后我们在dataset文件夹下面新建 init.py和dataset.py,在mydatasets.py文件夹写入下面的代码:

说一下代码的核心逻辑。

第一步 建立字典,定义类别对应的ID,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_splitLabels = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}class SeedlingData (data.Dataset):def __init__(self, root, transforms=None, train=True, test=False):"""主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据"""self.test = testself.transforms = transformsif self.test:imgs = [os.path.join(root, img) for img in os.listdir(root)]self.imgs = imgselse:imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]imgs = []for imglable in imgs_labels:for imgname in os.listdir(imglable):imgpath = os.path.join(imglable, imgname)imgs.append(imgpath)trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)if train:self.imgs = trainval_fileselse:self.imgs = val_filesdef __getitem__(self, index):"""一次返回一张图片的数据"""img_path = self.imgs[index]img_path=img_path.replace("\\",'/')if self.test:label = -1else:labelname = img_path.split('/')[-2]label = Labels[labelname]data = Image.open(img_path).convert('RGB')data = self.transforms(data)return data, labeldef __len__(self):return len(self.imgs)

然后我们在train.py调用SeedlingData读取数据 ,记着导入刚才写的dataset.py(from mydatasets import SeedlingData)

dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型

  • 设置loss函数为nn.CrossEntropyLoss()。
  • 设置模型为convit_tiny,num_classes设置为12,embed_dim是编码的个数,设置为48,dropout设置为0.5。
  • 优化器设置为adam。
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = convit_tiny(pretrained=False, num_classes=12, embed_dim=48, drop_rate=0.5)model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)def adjust_learning_rate(optimizer, epoch):"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""modellrnew = modellr * (0.1 ** (epoch // 50))print("lr:", modellrnew)for param_group in optimizer.param_groups:param_group['lr'] = modellrnew

定义训练和验证函数

# 定义训练过程def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程
def val(model, device, test_loader):model.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))# 训练for epoch in range(1, EPOCHS + 1):adjust_learning_rate(optimizer, epoch)train(model_ft, DEVICE, train_loader, optimizer, epoch)val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')

测试

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:

测试集存放的目录如下图:
在这里插入图片描述

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

第三步 加载model,并将模型放在DEVICE里,

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed','Common wheat','Fat Hen', 'Loose Silky-bent','Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)path='data/test/'
testList=os.listdir(path)
for file in testList:img=Image.open(path+file)img=transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out=model(img)# Predict_, pred = torch.max(out.data, 1)print('Image Name:{},predict:{}'.format(file,classes[pred.data.item()]))

第二种 使用自定义的Dataset读取图片

import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variableclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed','Common wheat','Fat Hen', 'Loose Silky-bent','Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)dataset_test =SeedlingData('data/test/', transform_test,test=True)
print(len(dataset_test))
# 对应文件夹的labelfor index in range(len(dataset_test)):item = dataset_test[index]img, label = itemimg.unsqueeze_(0)data = Variable(img).to(DEVICE)output = model(data)_, pred = torch.max(output.data, 1)print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()]))index += 1

完整的训练代码

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from models import convit_tiny
from torch.autograd import Variable
from mydatasets import SeedlingData
import os# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 32
EPOCHS = 10
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 数据预处理transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = convit_tiny(pretrained=False, num_classes=12, embed_dim=48, drop_rate=0.5)model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)def adjust_learning_rate(optimizer, epoch):"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""modellrnew = modellr * (0.1 ** (epoch // 50))print("lr:", modellrnew)for param_group in optimizer.param_groups:param_group['lr'] = modellrnew# 定义训练过程def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)optimizer.zero_grad()loss.backward()optimizer.step()print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item()))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))# 验证过程
def val(model, device, test_loader):model.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))# 训练for epoch in range(1, EPOCHS + 1):adjust_learning_rate(optimizer, epoch)train(model_ft, DEVICE, train_loader, optimizer, epoch)val(model_ft, DEVICE, test_loader)
torch.save(model_ft, 'model.pth')

代码链接:https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/20593432?spm=1001.2014.3001.5501

这篇关于【图像分类】ConViT从入门到实战——使用ConViT实现植物幼苗的分类(pytorch)的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/835639

相关文章

分布式锁在Spring Boot应用中的实现过程

《分布式锁在SpringBoot应用中的实现过程》文章介绍在SpringBoot中通过自定义Lock注解、LockAspect切面和RedisLockUtils工具类实现分布式锁,确保多实例并发操作... 目录Lock注解LockASPect切面RedisLockUtils工具类总结在现代微服务架构中,分布

Java使用Thumbnailator库实现图片处理与压缩功能

《Java使用Thumbnailator库实现图片处理与压缩功能》Thumbnailator是高性能Java图像处理库,支持缩放、旋转、水印添加、裁剪及格式转换,提供易用API和性能优化,适合Web应... 目录1. 图片处理库Thumbnailator介绍2. 基本和指定大小图片缩放功能2.1 图片缩放的

精选20个好玩又实用的的Python实战项目(有图文代码)

《精选20个好玩又实用的的Python实战项目(有图文代码)》文章介绍了20个实用Python项目,涵盖游戏开发、工具应用、图像处理、机器学习等,使用Tkinter、PIL、OpenCV、Kivy等库... 目录① 猜字游戏② 闹钟③ 骰子模拟器④ 二维码⑤ 语言检测⑥ 加密和解密⑦ URL缩短⑧ 音乐播放

Python使用Tenacity一行代码实现自动重试详解

《Python使用Tenacity一行代码实现自动重试详解》tenacity是一个专为Python设计的通用重试库,它的核心理念就是用简单、清晰的方式,为任何可能失败的操作添加重试能力,下面我们就来看... 目录一切始于一个简单的 API 调用Tenacity 入门:一行代码实现优雅重试精细控制:让重试按我

MySQL中EXISTS与IN用法使用与对比分析

《MySQL中EXISTS与IN用法使用与对比分析》在MySQL中,EXISTS和IN都用于子查询中根据另一个查询的结果来过滤主查询的记录,本文将基于工作原理、效率和应用场景进行全面对比... 目录一、基本用法详解1. IN 运算符2. EXISTS 运算符二、EXISTS 与 IN 的选择策略三、性能对比

Redis客户端连接机制的实现方案

《Redis客户端连接机制的实现方案》本文主要介绍了Redis客户端连接机制的实现方案,包括事件驱动模型、非阻塞I/O处理、连接池应用及配置优化,具有一定的参考价值,感兴趣的可以了解一下... 目录1. Redis连接模型概述2. 连接建立过程详解2.1 连php接初始化流程2.2 关键配置参数3. 最大连

Python实现网格交易策略的过程

《Python实现网格交易策略的过程》本文讲解Python网格交易策略,利用ccxt获取加密货币数据及backtrader回测,通过设定网格节点,低买高卖获利,适合震荡行情,下面跟我一起看看我们的第一... 网格交易是一种经典的量化交易策略,其核心思想是在价格上下预设多个“网格”,当价格触发特定网格时执行买

使用Python构建智能BAT文件生成器的完美解决方案

《使用Python构建智能BAT文件生成器的完美解决方案》这篇文章主要为大家详细介绍了如何使用wxPython构建一个智能的BAT文件生成器,它不仅能够为Python脚本生成启动脚本,还提供了完整的文... 目录引言运行效果图项目背景与需求分析核心需求技术选型核心功能实现1. 数据库设计2. 界面布局设计3

SQL Server跟踪自动统计信息更新实战指南

《SQLServer跟踪自动统计信息更新实战指南》本文详解SQLServer自动统计信息更新的跟踪方法,推荐使用扩展事件实时捕获更新操作及详细信息,同时结合系统视图快速检查统计信息状态,重点强调修... 目录SQL Server 如何跟踪自动统计信息更新:深入解析与实战指南 核心跟踪方法1️⃣ 利用系统目录

使用IDEA部署Docker应用指南分享

《使用IDEA部署Docker应用指南分享》本文介绍了使用IDEA部署Docker应用的四步流程:创建Dockerfile、配置IDEADocker连接、设置运行调试环境、构建运行镜像,并强调需准备本... 目录一、创建 dockerfile 配置文件二、配置 IDEA 的 Docker 连接三、配置 Do