ResNet 皮肤癌分类tricks总结

2024-02-27 13:50

本文主要是介绍ResNet 皮肤癌分类tricks总结,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

project introduction

project使用的数据为皮肤癌的图片数据,分为了训练和测试集,两个数据集内图片没有重合,均为彩色图像,因为为良恶性皮肤癌的二分类任务,所以相对来讲比较简单。对于网络选择我个人不是很赞成使用算力很大的网络来提升参数指标,毕竟大算力可能会造成落地困难并且较大的算力并不亲民。

Approach

与分割任务不同,分类任务对应的label不会改变所以可以随心的添加许许多多的变换操作。对于皮肤癌的分类数据并没有做过多的预处理操作,仅仅使用的pytorch中自带的一些数据增强操作比如随机上下反转,随机左右翻转,随机旋转,说实话由于皮肤癌的图片大多都是居于中央的所以这些操作对于数据的增强效果并不显著。使用随机擦除以及随机仿射变换的操作就需要相应的提升训练的epoch才会达到较好的收敛效果。

在网络的选择上我选用的是ResNet,具体的搭建步骤参考的是B站一位大佬的视频代码搭建的,更改一下目录就可以用了。探讨了18层,34层,50层的网络的分类效果,因为数据量也没有很大所以50层已经可以做到较全面的提取图像特征。过深的网络反而可能会造成过拟合。

在这里插入图片描述

图1 ResNet的网络架构图

使用一些加速、以及提升算力的技巧来提升网络的性能,从而在有限的算力下获得更好的结果。

Tricks

使用了一系列的训练技巧的来提升网络性能,大部分都可以在pytorch的官网找到相应的调用代码

迁移学习

主要是将预训练好的模型权重加载进来

https://pytorch.org/hub/research-models

# 加载预训练预训练模型
model_weight_path = "./resnet34_pre.pth"
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
net.load_state_dict(torch.load(model_weight_path, map_location=device))

Auto Mix Precision

使用16位与32位存储混合精度训练,增加计算速度,但不会影响结果的准确度。

https://pytorch.org/docs/stable/amp.html

在这里插入图片描述

图2 在不同的网络中使用相同的训练超参数均没有出现准确率的下降

from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()                for step, data in enumerate(train_bar):images, labels = datawith autocast():logits = net(images.to(device))loss = loss_function(logits, labels.to(device))/ accumulation_steps            scaler.scale(loss).backward()if((step+1) % accumulation_steps)==0:scaler.step(optimizer)scaler.update()optimizer.zero_grad()

梯度累计

内存不够,梯度累及来凑,计算多个轮次再更新一次权重。这里的accumulation_step用来决定多少个iteration更新一次权重。

loss = loss_function(logits, labels.to(device))
loss = loss / accumulation_steps
#   梯度累计训练
if((step+1) % accumulation_steps)==0:optimizer.step()optimizer.zero_grad()

预处理

添加预处理操作做数据增强

https://pytorch.org/vision/stable/transforms.html

from torchvision import transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

网络层数

调用一下18层、34层、还有50层,实在不行自己造一个10层,不过34层和18层我可以训练到0.9的准确率,auc可以达到0.97

动态学习率

https://pytorch.org/docs/stable/optim.html

optimizer = optim.Adam(params, lr=0.0001) #lr =0.0001
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)
# 添加在epoch for循环的最后面lr_rate.append(optimizer.state_dict()['param_groups'][0]['lr'])scheduler.step()

Metrics

准确率(Accuracy): A c c = T P + T N T P + T N + F P + F N Acc = \frac{TP+TN}{TP+TN+FP+FN} Acc=TP+TN+FP+FNTP+TN

AUC_ROC

在这里插入图片描述

confusion matrix

在这里插入图片描述

使用的评价指标包括了准确率,AUC曲线,以及混淆矩阵

result

整体上分类的准确率可以达到0.9左右,使用梯度累计法可以有效的提高计算精度

使用AMP没有导致计算准确率的下降

在这里插入图片描述

Summary

总的来讲实现的过程比较简单,实现的结果也比较初级,感谢大佬提供参考的代码,大佬的B站id是霹雳吧啦Wz

小白上路还有很多不足请大家多多指教!

文中使用的图片来源于ResNet原论文以及NVIDIA官方的文档,侵权即删

这篇关于ResNet 皮肤癌分类tricks总结的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/752609

相关文章

C# List.Sort四种重载总结

《C#List.Sort四种重载总结》本文详细分析了C#中List.Sort()方法的四种重载形式及其实现原理,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友... 目录1. Sort方法的四种重载2. 具体使用- List.Sort();- IComparable

SpringBoot项目整合Netty启动失败的常见错误总结

《SpringBoot项目整合Netty启动失败的常见错误总结》本文总结了SpringBoot集成Netty时常见的8类问题及解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参... 目录一、端口冲突问题1. Tomcat与Netty端口冲突二、主线程被阻塞问题1. Netty启动阻

SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)

《SpringBoot整合Kafka启动失败的常见错误问题总结(推荐)》本文总结了SpringBoot项目整合Kafka启动失败的常见错误,包括Kafka服务器连接问题、序列化配置错误、依赖配置问题、... 目录一、Kafka服务器连接问题1. Kafka服务器无法连接2. 开发环境与生产环境网络不通二、序

python3中正则表达式处理函数用法总结

《python3中正则表达式处理函数用法总结》Python中的正则表达式是一个强大的文本处理工具,用于匹配、查找、替换等操作,在Python中正则表达式的操作主要通过内置的re模块来实现,这篇文章主要... 目录前言re.match函数re.search方法re.match 与 re.search的区别检索

Python版本与package版本兼容性检查方法总结

《Python版本与package版本兼容性检查方法总结》:本文主要介绍Python版本与package版本兼容性检查方法的相关资料,文中提供四种检查方法,分别是pip查询、conda管理、PyP... 目录引言为什么会出现兼容性问题方法一:用 pip 官方命令查询可用版本方法二:conda 管理包环境方法

pycharm跑python项目易出错的问题总结

《pycharm跑python项目易出错的问题总结》:本文主要介绍pycharm跑python项目易出错问题的相关资料,当你在PyCharm中运行Python程序时遇到报错,可以按照以下步骤进行排... 1. 一定不要在pycharm终端里面创建环境安装别人的项目子模块等,有可能出现的问题就是你不报错都安装

Python中logging模块用法示例总结

《Python中logging模块用法示例总结》在Python中logging模块是一个强大的日志记录工具,它允许用户将程序运行期间产生的日志信息输出到控制台或者写入到文件中,:本文主要介绍Pyt... 目录前言一. 基本使用1. 五种日志等级2.  设置报告等级3. 自定义格式4. C语言风格的格式化方法

Spring 依赖注入与循环依赖总结

《Spring依赖注入与循环依赖总结》这篇文章给大家介绍Spring依赖注入与循环依赖总结篇,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录1. Spring 三级缓存解决循环依赖1. 创建UserService原始对象2. 将原始对象包装成工

MySQL中查询和展示LONGBLOB类型数据的技巧总结

《MySQL中查询和展示LONGBLOB类型数据的技巧总结》在MySQL中LONGBLOB是一种二进制大对象(BLOB)数据类型,用于存储大量的二进制数据,:本文主要介绍MySQL中查询和展示LO... 目录前言1. 查询 LONGBLOB 数据的大小2. 查询并展示 LONGBLOB 数据2.1 转换为十

在Java中实现线程之间的数据共享的几种方式总结

《在Java中实现线程之间的数据共享的几种方式总结》在Java中实现线程间数据共享是并发编程的核心需求,但需要谨慎处理同步问题以避免竞态条件,本文通过代码示例给大家介绍了几种主要实现方式及其最佳实践,... 目录1. 共享变量与同步机制2. 轻量级通信机制3. 线程安全容器4. 线程局部变量(ThreadL