李宏毅 2022机器学习 HW3 boss baseline 上分记录

2023-10-09 02:52

本文主要是介绍李宏毅 2022机器学习 HW3 boss baseline 上分记录,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

作业数据是所有数据都有标签的版本。

李宏毅 2022机器学习 HW3 boss baseline 上分记录

    • 1. 训练数据增强, private 0.76056
    • 2. cross validation&ensemble, private 0.81647
    • 3. test dataset augmentation, private 0.82458
    • 4. resnet, private 0.86555
    • 5. Image Normalization, private 0.87494
    • 6. 减小batch_size, private 0.895

1. 训练数据增强, private 0.76056

结论:训练数据增强、更长时间的训练、dropout都证明很有效果,实验效果提升至接近strong baseline

增强1:crop + geometry
增强2:crop + geometry + gray
另外epochs数目增加到100,patience增加到10个epochs,FC层增加 dropout(0.3)

增强代码如下

#训练数据增强代码train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)# transforms.Resize((128, 128)),transforms.RandomResizedCrop(size=(128, 128), scale=(0.8, 1)),# 几何变换transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=180),transforms.RandomAffine(degrees=30),#像素变换transforms.RandomGrayscale(p=0.2), # You may add some transforms here.# ToTensor() should be the last one of the transforms.transforms.ToTensor(),
])

具体实验结果如下:
在这里插入图片描述

2. cross validation&ensemble, private 0.81647

使用5-fold cross validation,划分的时候使用分层抽样,
2.1)epochs=100, patience=10
训练时发现通常在60 epochs左右就early stop了,最终public score不如之前,但private score有提升,说明cross validation在过拟合上还是有效果的。
在这里插入图片描述
2.2)epochs=100, patience=16,再看看效果
patience增大后,效果有了一个非常明显的提升,超过strong baseline。具体看实验过程,会发现之前patience=10的时候,基本60epochs就停了,而现在patience=100的时候,early stop没有起作用,都是训练满100个epochs。猜测应该是使用5-fold的cross validation时,对比默认的train/valid,一方面训练数据更多,另一方面valid数据变少波动性更大,所以应该给更多的时间训练。
在这里插入图片描述

3. test dataset augmentation, private 0.82458

结论:此方式有效,分数进一步提升
在这里插入图片描述
测试数据的具体增强方式如下:
在步骤2的基础上,对test数据集使用了train数据集的数据增强方式,生成5张图片预测,对预测结果值平均,然后再用这个结果与原预测结果平均。以下为作业PPT相关部分。
在这里插入图片描述

4. resnet, private 0.86555

使用torchvision自带的resnet模型(按照作业要求,pretrained=False),尝试了resnet18和resnet50,效果进一步有了明显提升。public榜上超过bossline,但是从private榜上,可以看出存在一定过拟合。 另外resnet50的效果并没有比resnet18好,可能是小数据集的原因。这里均使用epochs=200,patience=16, lr=0.0003, weight_decay=1e-5。
在这里插入图片描述
在这里插入图片描述

两个注意点:
1,图片size设成224x224(论文中使用的图片尺寸),对比了128和224,两者差别很大。
2,resnet中的全连接层需要从原来的1000改成此次任务预测的类别数目11,代码如下:

def model_resnet():resnet = resnet18(pretrained=False)resnet.fc = nn.Sequential(nn.Linear(resnet.fc.in_features, 512),nn.ReLU(),nn.Dropout(0.3),nn.Linear(512, 11))return resnet

5. Image Normalization, private 0.87494

尝试了image normalization,略微有一些提升,尤其是做了test augmentation的private榜,达到了目前最高分。
在这里插入图片描述
image normalization的代码如下,注意train和test需要加上同样的norm。

height, width = 224, 224# Normally, We don't need augmentations in testing and validation.
# All we need here is to resize the PIL image and transform it into Tensor.
test_tfm = transforms.Compose([# transforms.Resize((128, 128)),transforms.Resize((height, width)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])# However, it is also possible to use augmentation in the testing phase.
# You may use train_tfm to produce a variety of images and then test using ensemble methods
train_tfm = transforms.Compose([# Resize the image into a fixed shape (height = width = 128)# transforms.Resize((128, 128)),# transforms.RandomResizedCrop(size=(128, 128), scale=(0.8, 1)),transforms.RandomResizedCrop(size=(height, width), scale=(0.8, 1)),# 几何变换transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=180),transforms.RandomAffine(degrees=30),#像素变换transforms.RandomGrayscale(p=0.2), # You may add some transforms here.# ToTensor() should be the last one of the transforms.transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

6. 减小batch_size, private 0.895

将batch_size从64减小到16,模型效果进一步提升,加上image normalization后,private和public双双达到目前最高分。
在这里插入图片描述

这篇关于李宏毅 2022机器学习 HW3 boss baseline 上分记录的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/169842

相关文章

Unity新手入门学习殿堂级知识详细讲解(图文)

《Unity新手入门学习殿堂级知识详细讲解(图文)》Unity是一款跨平台游戏引擎,支持2D/3D及VR/AR开发,核心功能模块包括图形、音频、物理等,通过可视化编辑器与脚本扩展实现开发,项目结构含A... 目录入门概述什么是 UnityUnity引擎基础认知编辑器核心操作Unity 编辑器项目模式分类工程

Python学习笔记之getattr和hasattr用法示例详解

《Python学习笔记之getattr和hasattr用法示例详解》在Python中,hasattr()、getattr()和setattr()是一组内置函数,用于对对象的属性进行操作和查询,这篇文章... 目录1.getattr用法详解1.1 基本作用1.2 示例1.3 原理2.hasattr用法详解2.

基于Spring Boot 的小区人脸识别与出入记录管理系统功能

《基于SpringBoot的小区人脸识别与出入记录管理系统功能》文章介绍基于SpringBoot框架与百度AI人脸识别API的小区出入管理系统,实现自动识别、记录及查询功能,涵盖技术选型、数据模型... 目录系统功能概述技术栈选择核心依赖配置数据模型设计出入记录实体类出入记录查询表单出入记录 VO 类(用于

java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)

《java中pdf模版填充表单踩坑实战记录(itextPdf、openPdf、pdfbox)》:本文主要介绍java中pdf模版填充表单踩坑的相关资料,OpenPDF、iText、PDFBox是三... 目录准备Pdf模版方法1:itextpdf7填充表单(1)加入依赖(2)代码(3)遇到的问题方法2:pd

Zabbix在MySQL性能监控方面的运用及最佳实践记录

《Zabbix在MySQL性能监控方面的运用及最佳实践记录》Zabbix通过自定义脚本和内置模板监控MySQL核心指标(连接、查询、资源、复制),支持自动发现多实例及告警通知,结合可视化仪表盘,可有效... 目录一、核心监控指标及配置1. 关键监控指标示例2. 配置方法二、自动发现与多实例管理1. 实践步骤

Visual Studio 2022 编译C++20代码的图文步骤

《VisualStudio2022编译C++20代码的图文步骤》在VisualStudio中启用C++20import功能,需设置语言标准为ISOC++20,开启扫描源查找模块依赖及实验性标... 默认创建Visual Studio桌面控制台项目代码包含C++20的import方法。右键项目的属性:

在Spring Boot中集成RabbitMQ的实战记录

《在SpringBoot中集成RabbitMQ的实战记录》本文介绍SpringBoot集成RabbitMQ的步骤,涵盖配置连接、消息发送与接收,并对比两种定义Exchange与队列的方式:手动声明(... 目录前言准备工作1. 安装 RabbitMQ2. 消息发送者(Producer)配置1. 创建 Spr

k8s上运行的mysql、mariadb数据库的备份记录(支持x86和arm两种架构)

《k8s上运行的mysql、mariadb数据库的备份记录(支持x86和arm两种架构)》本文记录在K8s上运行的MySQL/MariaDB备份方案,通过工具容器执行mysqldump,结合定时任务实... 目录前言一、获取需要备份的数据库的信息二、备份步骤1.准备工作(X86)1.准备工作(arm)2.手

SpringBoot3应用中集成和使用Spring Retry的实践记录

《SpringBoot3应用中集成和使用SpringRetry的实践记录》SpringRetry为SpringBoot3提供重试机制,支持注解和编程式两种方式,可配置重试策略与监听器,适用于临时性故... 目录1. 简介2. 环境准备3. 使用方式3.1 注解方式 基础使用自定义重试策略失败恢复机制注意事项

Python UV安装、升级、卸载详细步骤记录

《PythonUV安装、升级、卸载详细步骤记录》:本文主要介绍PythonUV安装、升级、卸载的详细步骤,uv是Astral推出的下一代Python包与项目管理器,主打单一可执行文件、极致性能... 目录安装检查升级设置自动补全卸载UV 命令总结 官方文档详见:https://docs.astral.sh/