脑PET图像分析和疾病预测挑战赛---CNN进阶版

2023-10-15 03:50

本文主要是介绍脑PET图像分析和疾病预测挑战赛---CNN进阶版,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

前情提要:通过使用CNN成功将F1-score刷到0.74214,详情见:

脑PET图像分析和疾病预测挑战赛---CNN_Vector Jason的博客-CSDN博客本次使用卷积神经网络(CNN)进行医学图像预测,其能够自动学习和提取图像中的特征,相比于基于sklearn的逻辑回归统计算法,最终F1-score提高了0.4https://blog.csdn.net/Vectorln/article/details/132383243

目录

数据预处理

数据增强

迁移学习

优化器

保存最佳的参数

集成学习

保存参数

循环载入并预测 

结果

总结

疑问


        本次重点对CNN代码进行优化,分别从数据预处理、数据增强、迁移学习、优化器与集成学习五个方向进行扩展。

数据预处理

        经过查阅相关资料,部分网友提出nii数据具有header属性,再进一步查找,发现有个名为‘db_name’的字符,可能用来表示患者ID、检测设备ID等等,此处不妨假设表示患者ID,这样将会大大提高结果的预测精度---相同患者的诊断结果一定相同。

        在数据预处理阶段初期,根据经验我先确定测试集中每个样本的db_name(以下称为患者ID),然后在训练集当中进行遍历,看看有没有可能寻找到相同的患者ID,然而最终发现具有戏剧性的一幕是:测试集的部分样本出现在训练集当中!

        接下来问题就变得简单起来了,因为训练集的label是真实的,那么我们只需要保证:

               在测试集的预测结果中,已经在训练集出现过的样本,其label依然一致即可

        具体流程如下:

1.在测试集中统计所有nii文件的db_name

2.在训练集中寻找与测试集相同db_name的nii文件

3.将训练集中的labe赋值给相同db_name的测试集文件

4.整理到CSV文件中

定义的函数:

def invert_dict_with_duplicates(original_dict):inverted_dict = {}for key, value in original_dict.items():if value in inverted_dict:inverted_dict[value].append(key)# key value互换elseelse:inverted_dict[value] = [key]return inverted_dict

 正式处理:

import nibabel as nb
import pandas as pd
import glob # 获取文件路径
import os
import numpy as np
import nibabel as nib train_path = glob.glob('./BrainPET/Train/*/*')
test_path = glob.glob('./BrainPET/Test/*')test_dict = {}
for path in test_path:img = nb.load(path)header = img.headername = path.split(os .sep)[-1][:-4]test_dict[name] = header['db_name'].item().decode('utf-8')#翻转字典
invers_dict = invert_dict_with_duplicates(test_dict)
res = []   #debug用
labels = []
test_names = [] 
for path in train_path:img = nb.load(path)header = img.headerdb_name = header['db_name'].item().decode('utf-8')if db_name in invers_dict:if len(invers_dict[db_name]) == 1:test_name = invers_dict[db_name][0]else:test_name = invers_dict[db_name]label = path.split(os.sep)[-2]train_name = path.split(os.sep)[-1][:-4]res.append([train_name, label, test_name])labels.append(label)test_names.append(int(test_name))
submit = pd.DataFrame({    'uuid': test_names,'label': labels})
submit = submit.sort_values(by='uuid')
submit.to_csv('submit2.csv', index=None)
print('done')

数据增强

        对train_loader, val_loader和 test_loader的batch_size修改为4,在val_loader和 test_loader中增加了 RandomContrast(p=0.5),经过测试,发现 RandomBrightnessContrast(p=0.5) 的处理会很大程度影响预测结果,因此未在验证集与测试集中进行运用。

import albumentations as A
train_loader = torch.utils.data.DataLoader(XunFeiDataset(train_path[:-10],A.Compose([A.RandomRotate90(),A.RandomCrop(120, 120),A.HorizontalFlip(p=0.5),A.RandomContrast(p=0.5),A.RandomBrightnessContrast(p=0.5),])), batch_size=4, shuffle=True, num_workers=1, pin_memory=False
)val_loader = torch.utils.data.DataLoader(XunFeiDataset(train_path[-10:],A.Compose([#A.RandomCrop(120, 120),#A.HorizontalFlip(p=0.5)A.RandomCrop(120, 120),A.HorizontalFlip(p=0.5),A.RandomContrast(p=0.5),])), batch_size=4, shuffle=False, num_workers=1, pin_memory=False
)test_loader = torch.utils.data.DataLoader(XunFeiDataset(test_path,A.Compose([A.RandomCrop(128, 128),A.HorizontalFlip(p=0.5),A.RandomContrast(p=0.5),])), batch_size=4, shuffle=False, num_workers=1, pin_memory=False
)

        在通道选择方面,经过分析数据集,发现一些医学影像的图片切片数不均匀,此处统一选择设置为32(炼丹经验来源于对batch_size的认识) 。

idx = np.random.choice(range(img.shape[-1]), 32) # 将图片的切片数统一修改为32

迁移学习

        可以直接设置 model = resnet50 进行训练,其中True or False 取决于当前任务下的数据集与Imagenet的数据集是否相似,例如本次数据集是医学影像数据集,严格意义上来说与Imagenet数据集差距很大,因此选择 False ,表示不使用预训练所得到的参数。

        但通过实际测试,发现resnet18的网络规模已经足够,反而如果使用resnet50 训练所得到的效果并不如resnet18 ,究其原因认为是数据集太小,导致越靠近最终全连接分类层的网络几乎没有学习到一些有用的参数。

# 采用迁移学习,使用resnet50的架构进行训练class XunFeiNet(nn.Module):def __init__(self):super(XunFeiNet, self).__init__()model = models.resnet50(False)model.conv1 = torch.nn.Conv2d(32, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)model.avgpool = nn.AdaptiveAvgPool2d(1)model.fc = nn.Linear(512, 2)self.resnet = modeldef forward(self, img):out = self.resnet(img)return out

优化器

        结合所学的深度学习相关知识,SGD所取得的最优效果往往要比Adam好(详见李宏毅老师的机器学习在线课程),所以此处优化器选择SGD。此外,可能存在初期学习率较低,从而最终收敛域一个局部最优,因此导入 lr_scheduler 对学习率进行调节,具体而言表示每隔5个epoch,learning rate 就变为原来的0.1,因此初始学习率设置为0.1,并初始化一个较大的训练次数即可。

# 导入训练所需要的包
from torchvision import models
import torch.optim as optim
from torch.optim import lr_scheduleroptimizer = torch.optim.SGD(model.parameters(), 0.1)
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1) # 采用阶梯变化的学习率

保存最佳的参数

        在上一次的CNN---Baseline中,发现最终测试所用的参数源于训练过程中最后一次的参数,显然这种做法是存在问题的,因此对此部分进行修改:

#if val_acc > best_test_accuracy:
#    # 删除旧的最佳模型文件(如有)
#    old_best_checkpoint_path = 'checkpoints/best{:.3f}.pth'.format(best_test_accuracy)
#    if os.path.exists(old_best_checkpoint_path):
#        os.remove(old_best_checkpoint_path)# 保存新的最佳模型文件
#    new_best_checkpoint_path = 'checkpoints/best-{:.3f}.pth'.format(val_acc)
#    torch.save(model, new_best_checkpoint_path)
#    print('保存新的最佳模型', 'checkpoints/best-{:.3f}.pth'.format(best_test_accuracy))
#    best_test_accuracy = val_acc#model = torch.load('checkpoints/best-{:.3f}.pth'.format(best_test_accuracy))#pred = None
#for _ in range(3):
#    if pred is None:
#        pred = predict(test_loader, model, criterion)
#    else:
#        pred += predict(test_loader, model, criterion)

集成学习

        在对测试集进行预测的过程中,我们可以采取一种集成学习的做法:

1. 使用不同模型进行多次预测,并取出现次数最多的结果作为最终结果

2. 使用相同模型,但加载不同参数进行多次预测,并取出现次数最多的结果作为最终结果

        经过考虑,此处我首先选择了第二种做法(第一种做法还在尝试实现):

        首先保存训练过程中 val_acc > 0.5 的参数文件,然后在预测过程中循环载入这些文件进行预测,最后选择出现次数最多的标签作为最终结果即可。

保存参数

        此处与“保存最佳的参数”步骤发生冲突,只能在训练过程中选择其一进行训练。

 if val_acc > 0.5:new_best_checkpoint_path = 'checkpoints/best-{:.3f}.pth'.format(val_acc)torch.save(model, new_best_checkpoint_path)

循环载入并预测 

pred = Nonefor model_path in ['/content/checkpoints/best-0.600.pth', '/content/checkpoints/best-0.700.pth','/content/checkpoints/best-0.800.pth','/content/checkpoints/best-0.900.pth','/content/checkpoints/best-1.00.pth']:model = torch.load(model_path)for _ in range(5):if pred is None:pred = predict(test_loader, model, criterion)else:pred += predict(test_loader, model, criterion)submit = pd.DataFrame({'uuid': [int(x.split('/')[-1][:-4]) for x in test_path],'label': pred.argmax(1)
})

结果

        综合目前所有技巧,可得到的F1-score为:

 

总结

        本次基于CNN---Baseline在数据增强、迁移学习、优化器、集成学习四个方向进行了优化,相比之下提高了约0.04,但仍未解决主要问题。

        个人认为问题的关键仍然在于对数据集的预处理有失偏颇,因为通过提取各个医学影像图片,可以发现所拍摄的人脑大小不一,这就需要我们重新定义一个函数先对每张图片进行统一(该部分仍在尝试),这样在crop之后得到的每张图片才具有训练价值。

疑问

        竞赛过程中依然存在一些疑问,比如在最初的数据预处理部分,按道理来说医学影像图片为灰白图片,其通道数必定为1,但是为什么能够随机选择通道数(比如将其设置为50)呢?

这篇关于脑PET图像分析和疾病预测挑战赛---CNN进阶版的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!


原文地址:https://blog.csdn.net/Vectorln/article/details/132430252
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.chinasem.cn/article/215151

相关文章

Android kotlin中 Channel 和 Flow 的区别和选择使用场景分析

《Androidkotlin中Channel和Flow的区别和选择使用场景分析》Kotlin协程中,Flow是冷数据流,按需触发,适合响应式数据处理;Channel是热数据流,持续发送,支持... 目录一、基本概念界定FlowChannel二、核心特性对比数据生产触发条件生产与消费的关系背压处理机制生命周期

怎样通过分析GC日志来定位Java进程的内存问题

《怎样通过分析GC日志来定位Java进程的内存问题》:本文主要介绍怎样通过分析GC日志来定位Java进程的内存问题,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录一、GC 日志基础配置1. 启用详细 GC 日志2. 不同收集器的日志格式二、关键指标与分析维度1.

MySQL中的表连接原理分析

《MySQL中的表连接原理分析》:本文主要介绍MySQL中的表连接原理分析,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、环境3、表连接原理【1】驱动表和被驱动表【2】内连接【3】外连接【4编程】嵌套循环连接【5】join buffer4、总结1、背景

深度解析Python装饰器常见用法与进阶技巧

《深度解析Python装饰器常见用法与进阶技巧》Python装饰器(Decorator)是提升代码可读性与复用性的强大工具,本文将深入解析Python装饰器的原理,常见用法,进阶技巧与最佳实践,希望可... 目录装饰器的基本原理函数装饰器的常见用法带参数的装饰器类装饰器与方法装饰器装饰器的嵌套与组合进阶技巧

python中Hash使用场景分析

《python中Hash使用场景分析》Python的hash()函数用于获取对象哈希值,常用于字典和集合,不可变类型可哈希,可变类型不可,常见算法包括除法、乘法、平方取中和随机数哈希,各有优缺点,需根... 目录python中的 Hash除法哈希算法乘法哈希算法平方取中法随机数哈希算法小结在Python中,

Java Stream的distinct去重原理分析

《JavaStream的distinct去重原理分析》Javastream中的distinct方法用于去除流中的重复元素,它返回一个包含过滤后唯一元素的新流,该方法会根据元素的hashcode和eq... 目录一、distinct 的基础用法与核心特性二、distinct 的底层实现原理1. 顺序流中的去重

关于MyISAM和InnoDB对比分析

《关于MyISAM和InnoDB对比分析》:本文主要介绍关于MyISAM和InnoDB对比分析,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录开篇:从交通规则看存储引擎选择理解存储引擎的基本概念技术原理对比1. 事务支持:ACID的守护者2. 锁机制:并发控制的艺

MyBatis Plus 中 update_time 字段自动填充失效的原因分析及解决方案(最新整理)

《MyBatisPlus中update_time字段自动填充失效的原因分析及解决方案(最新整理)》在使用MyBatisPlus时,通常我们会在数据库表中设置create_time和update... 目录前言一、问题现象二、原因分析三、总结:常见原因与解决方法对照表四、推荐写法前言在使用 MyBATis

从基础到进阶详解Pandas时间数据处理指南

《从基础到进阶详解Pandas时间数据处理指南》Pandas构建了完整的时间数据处理生态,核心由四个基础类构成,Timestamp,DatetimeIndex,Period和Timedelta,下面我... 目录1. 时间数据类型与基础操作1.1 核心时间对象体系1.2 时间数据生成技巧2. 时间索引与数据

Python主动抛出异常的各种用法和场景分析

《Python主动抛出异常的各种用法和场景分析》在Python中,我们不仅可以捕获和处理异常,还可以主动抛出异常,也就是以类的方式自定义错误的类型和提示信息,这在编程中非常有用,下面我将详细解释主动抛... 目录一、为什么要主动抛出异常?二、基本语法:raise关键字基本示例三、raise的多种用法1. 抛