(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别

2023-11-24 15:15

本文主要是介绍(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

      • 1. 导入相关库
      • 2. 加载数据集
      • 3. 整理数据集
      • 4. 图像增广
      • 5. 读取数据
      • 6. 微调预训练模型
      • 7. 定义损失函数和评价损失函数
      • 9. 训练模型

1. 导入相关库

import os
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

2. 加载数据集

- 该数据集是完整数据集的小规模样本
# 下载数据集
d2l.DATA_HUB['dog_tiny'] = (d2l.DATA_URL + 'kaggle_dog_tiny.zip','0cb91d09b814ecdc07b50f31f8dcad3e81d6a86d')# 如果使用Kaggle比赛的完整数据集,请将下面的变量更改为False
demo = True
if demo:data_dir = d2l.download_extract('dog_tiny')
else:data_dir = os.path.join('..', 'data', 'dog-breed-identification')

3. 整理数据集

def reorg_dog_data(data_dir, valid_ratio):labels = d2l.read_csv_labels(os.path.join(data_dir, 'labels.csv'))d2l.reorg_train_valid(data_dir, labels, valid_ratio)d2l.reorg_test(data_dir)batch_size = 32 if demo else 128
valid_ratio = 0.1
reorg_dog_data(data_dir, valid_ratio)

4. 图像增广

transform_train = torchvision.transforms.Compose([torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3.0/4.0,4.0/3.0)),torchvision.transforms.RandomHorizontalFlip(),torchvision.transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
transform_test = torchvision.transforms.Compose([torchvision.transforms.Resize(256),torchvision.transforms.CenterCrop(224),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

5. 读取数据

train_ds, train_valid_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_train) for folder in ['train', 'train_valid']
]
valid_ds, test_ds = [torchvision.datasets.ImageFolder(os.path.join(data_dir, 'train_valid_test', folder),transform=transform_test) for folder in ['valid', 'test']
]
train_iter, train_valid_iter = [torch.utils.data.DataLoader(dataset, batch_size, shuffle=True, drop_last=True) for dataset in (train_ds, train_valid_ds)
]
valid_iter = torch.utils.data.DataLoader(valid_ds, batch_size, shuffle=False, drop_last=True
)
test_iter = torch.utils.data.DataLoader(test_ds, batch_size, shuffle=False, drop_last=True
)

6. 微调预训练模型

def get_net(devices):finetune_net = nn.Sequential()finetune_net.features = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)# 定义一个新的输出网络,共有120个输出类别finetune_net.output_new = nn.Sequential(nn.Linear(1000, 256),nn.ReLU(),nn.Linear(256, 120))finetune_net = finetune_net.to(devices[0])# 冻结参数for param in finetune_net.features.parameters():param.requires_grad = Falsereturn finetune_net
# 查看网络模型
get_net(devices=d2l.try_all_gpus())

在这里插入图片描述

7. 定义损失函数和评价损失函数

# 定义损失函数
loss = nn.CrossEntropyLoss(reduction='none')def evaluate_loss(data_iter, net, device):l_sum, n = 0.0, 0for features, labels in data_iter:features, labels = features.to(device[0]), labels.to(device[0])outputs = net(features)l = loss(outputs, labels)l_sum += l.sum()n += labels.numel()return (l_sum / n).to('cpu')
  1. 定义训练函数
def train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay):# 只训练小型定义输出网络net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.SGD((param for param in net.parameters() if param.requires_grad),lr=lr, momentum=0.9, weight_decay=wd)scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_period, lr_decay)num_batches, timer = len(train_iter), d2l.Timer()legend = ['train loss']if valid_iter is not None:legend.append('valid loss')animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], legend=legend)for epoch in range(num_epochs):metric = d2l.Accumulator(2)for i, (features, labels) in enumerate(train_iter):timer.start()features, labels = features.to(devices[0]), labels.to(devices[0])trainer.zero_grad()output = net(features)l = loss(output, labels).sum()l.backward()trainer.step()metric.add(l, labels.shape[0])timer.stop()if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches, (metric[0] / metric[1], None))measures = f'train loss {metric[0] / metric[1]:.3f}'if valid_iter is not None :valid_loss = evaluate_loss(valid_iter, net, devices)animator.add(epoch + 1, (None, valid_loss.detach().cpu()))scheduler.step()if valid_iter is not None:measures += f', valid loss {valid_loss:.3f}'print(measures + f'\n{metric[1] * num_epochs / timer.sum():.1f}'f'examples/sec on {str(devices)}')

9. 训练模型

devices, num_epochs, lr, wd = d2l.try_all_gpus(), 10, 1e-4, 1e-4
lr_period, lr_decay, net, = 2, 0.9, get_net(devices)
import time# 在开头设置开始时间
start = time.perf_counter()  # start = time.clock() python3.8之前可以train(net, train_iter, valid_iter, num_epochs, lr, wd, devices, lr_period, lr_decay)# 在程序运行结束的位置添加结束时间
end = time.perf_counter()  # end = time.clock()  python3.8之前可以# 再将其进行打印,即可显示出程序完成的运行耗时
print(f'运行耗时{(end-start):.4f}')

在这里插入图片描述

这篇关于(动手学习深度学习)第13章 实战kaggle竞赛:狗的品种识别的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/422148

相关文章

Java中Redisson 的原理深度解析

《Java中Redisson的原理深度解析》Redisson是一个高性能的Redis客户端,它通过将Redis数据结构映射为Java对象和分布式对象,实现了在Java应用中方便地使用Redis,本文... 目录前言一、核心设计理念二、核心架构与通信层1. 基于 Netty 的异步非阻塞通信2. 编解码器三、

Java HashMap的底层实现原理深度解析

《JavaHashMap的底层实现原理深度解析》HashMap基于数组+链表+红黑树结构,通过哈希算法和扩容机制优化性能,负载因子与树化阈值平衡效率,是Java开发必备的高效数据结构,本文给大家介绍... 目录一、概述:HashMap的宏观结构二、核心数据结构解析1. 数组(桶数组)2. 链表节点(Node

Java 虚拟线程的创建与使用深度解析

《Java虚拟线程的创建与使用深度解析》虚拟线程是Java19中以预览特性形式引入,Java21起正式发布的轻量级线程,本文给大家介绍Java虚拟线程的创建与使用,感兴趣的朋友一起看看吧... 目录一、虚拟线程简介1.1 什么是虚拟线程?1.2 为什么需要虚拟线程?二、虚拟线程与平台线程对比代码对比示例:三

Python版本信息获取方法详解与实战

《Python版本信息获取方法详解与实战》在Python开发中,获取Python版本号是调试、兼容性检查和版本控制的重要基础操作,本文详细介绍了如何使用sys和platform模块获取Python的主... 目录1. python版本号获取基础2. 使用sys模块获取版本信息2.1 sys模块概述2.1.1

Python函数作用域与闭包举例深度解析

《Python函数作用域与闭包举例深度解析》Python函数的作用域规则和闭包是编程中的关键概念,它们决定了变量的访问和生命周期,:本文主要介绍Python函数作用域与闭包的相关资料,文中通过代码... 目录1. 基础作用域访问示例1:访问全局变量示例2:访问外层函数变量2. 闭包基础示例3:简单闭包示例4

Python爬虫HTTPS使用requests,httpx,aiohttp实战中的证书异步等问题

《Python爬虫HTTPS使用requests,httpx,aiohttp实战中的证书异步等问题》在爬虫工程里,“HTTPS”是绕不开的话题,HTTPS为传输加密提供保护,同时也给爬虫带来证书校验、... 目录一、核心问题与优先级检查(先问三件事)二、基础示例:requests 与证书处理三、高并发选型:

Oracle Scheduler任务故障诊断方法实战指南

《OracleScheduler任务故障诊断方法实战指南》Oracle数据库作为企业级应用中最常用的关系型数据库管理系统之一,偶尔会遇到各种故障和问题,:本文主要介绍OracleSchedul... 目录前言一、故障场景:当定时任务突然“消失”二、基础环境诊断:搭建“全局视角”1. 数据库实例与PDB状态2

如何正确识别一台POE交换机的好坏? 选购可靠的POE交换机注意事项

《如何正确识别一台POE交换机的好坏?选购可靠的POE交换机注意事项》POE技术已经历多年发展,广泛应用于安防监控和无线覆盖等领域,需求量大,但质量参差不齐,市场上POE交换机的品牌繁多,如何正确识... 目录生产标识1. 必须包含的信息2. 劣质设备的常见问题供电标准1. 正规的 POE 标准2. 劣质设

Git进行版本控制的实战指南

《Git进行版本控制的实战指南》Git是一种分布式版本控制系统,广泛应用于软件开发中,它可以记录和管理项目的历史修改,并支持多人协作开发,通过Git,开发者可以轻松地跟踪代码变更、合并分支、回退版本等... 目录一、Git核心概念解析二、环境搭建与配置1. 安装Git(Windows示例)2. 基础配置(必

MyBatis分页查询实战案例完整流程

《MyBatis分页查询实战案例完整流程》MyBatis是一个强大的Java持久层框架,支持自定义SQL和高级映射,本案例以员工工资信息管理为例,详细讲解如何在IDEA中使用MyBatis结合Page... 目录1. MyBATis框架简介2. 分页查询原理与应用场景2.1 分页查询的基本原理2.1.1 分