基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

本文主要是介绍基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

 第一步:准备数据

5种中草药数据:self.class_indict = ["百合", "党参", "山魈", "枸杞", "槐花", "金银花"]

,总共有900张图片,每个文件夹单独放一种数据

第二步:搭建模型

本文选择一个EfficientNetV2网络,其原理介绍如下:

        该网络主要使用训练感知神经结构搜索缩放的组合;在EfficientNetV1的基础上,引入了Fused-MBConv到搜索空间中;引入渐进式学习策略自适应正则强度调整机制使得训练更快;进一步关注模型的推理速度训练速度

与EfficientV1相比,主要有以下不同:

  1. V2中除了使用MBConv模块外,还使用了Fused-MBConv模块
  2. V2中会使用较小的expansion ratio,在V1中基本都是6。这样的好处是能够减少内存访问开销
  3. V2中更偏向使用更小的kernel_size(3 x 3),在V1中很多5 x 5。优于3 x 3的感受野是比5 x 5小的,所以需要堆叠更多的层结构以增加感受野
  4. 移除了V1中最优一个步距为1的stage

第三步:训练代码

1)损失函数为:交叉熵损失函数

2)训练代码:

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = {"s": [300, 384],  # train_size, val_size"m": [384, 480],"l": [384, 480]}num_model = "s"data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),"val": transforms.Compose([transforms.Resize(img_size[num_model][1]),transforms.CenterCrop(img_size[num_model][1]),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=100)parser.add_argument('--batch-size', type=int, default=4)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lrf', type=float, default=0.01)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default=r"G:\demo\data\ChineseMedicine")# download model weights# 链接: https://pan.baidu.com/s/1uZX36rvrfEss-JGj4yfzbQ  密码: 5gu1parser.add_argument('--weights', type=str, default='./pre_efficientnetv2-s.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=True)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

第四步:统计正确率

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码

代码的下载路径(新窗口打开链接):基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码

有问题可以私信或者留言,有问必答

这篇关于基于Pytorch框架的深度学习EfficientNetV2神经网络中草药识别分类系统源码的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1018199

相关文章

C++ 右值引用(rvalue references)与移动语义(move semantics)深度解析

《C++右值引用(rvaluereferences)与移动语义(movesemantics)深度解析》文章主要介绍了C++右值引用和移动语义的设计动机、基本概念、实现方式以及在实际编程中的应用,... 目录一、右值引用(rvalue references)与移动语义(move semantics)设计动机1

Java使用Spire.Barcode for Java实现条形码生成与识别

《Java使用Spire.BarcodeforJava实现条形码生成与识别》在现代商业和技术领域,条形码无处不在,本教程将引导您深入了解如何在您的Java项目中利用Spire.Barcodefor... 目录1. Spire.Barcode for Java 简介与环境配置2. 使用 Spire.Barco

SQL 注入攻击(SQL Injection)原理、利用方式与防御策略深度解析

《SQL注入攻击(SQLInjection)原理、利用方式与防御策略深度解析》本文将从SQL注入的基本原理、攻击方式、常见利用手法,到企业级防御方案进行全面讲解,以帮助开发者和安全人员更系统地理解... 目录一、前言二、SQL 注入攻击的基本概念三、SQL 注入常见类型分析1. 基于错误回显的注入(Erro

C++简单日志系统实现代码示例

《C++简单日志系统实现代码示例》日志系统是成熟软件中的一个重要组成部分,其记录软件的使用和运行行为,方便事后进行故障分析、数据统计等,:本文主要介绍C++简单日志系统实现的相关资料,文中通过代码... 目录前言Util.hppLevel.hppLogMsg.hppFormat.hppSink.hppBuf

Java枚举类型深度详解

《Java枚举类型深度详解》Java的枚举类型(enum)是一种强大的工具,它不仅可以让你的代码更简洁、可读,而且通过类型安全、常量集合、方法重写和接口实现等特性,使得枚举在很多场景下都非常有用,本文... 目录前言1. enum关键字的使用:定义枚举类型什么是枚举类型?如何定义枚举类型?使用枚举类型:2.

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

linux系统中java的cacerts的优先级详解

《linux系统中java的cacerts的优先级详解》文章讲解了Java信任库(cacerts)的优先级与管理方式,指出JDK自带的cacerts默认优先级更高,系统级cacerts需手动同步或显式... 目录Java 默认使用哪个?如何检查当前使用的信任库?简要了解Java的信任库总结了解 Java 信