第十二章 迁移学习-实战宝可梦精灵

2024-08-22 08:52

本文主要是介绍第十二章 迁移学习-实战宝可梦精灵,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

文章目录

  • 一、Pokemon数据集
    • 1.1 数据集收集
    • 1.2 数据集划分
    • 1.3 数据集加载
    • 1.4 数据预处理
    • 1.5 pytorch自定义数据库实现
  • 二、ResNet网络搭建
  • 三、训练与测试
  • 四、迁移学习
    • 4.1 pytorch实现迁移学习

一、Pokemon数据集

1.1 数据集收集

在这里插入图片描述

# git下载
git lfs install
git clone https://www.modelscope.cn/datasets/ModelBulider/pokemon.git

1.2 数据集划分

在这里插入图片描述


1.3 数据集加载

在这里插入图片描述

  • 加载数据
    ① 继承 torch.utils.data.Dataset
    ② 实现 __len__ 函数,其返回数据集的数量(整型数字)
    ③ 实现 __getitem__函数,根据索引值返回一个数据
    在这里插入图片描述

举例:
在这里插入图片描述


1.4 数据预处理

将尺寸大小不一致的数据(图片)预处理为大小一致的1数据
② 数据增强(旋转、裁剪等)
③ 归一化(均值、方差)
④ 转换为 Tensor 数据类型
在这里插入图片描述


1.5 pytorch自定义数据库实现

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: code - pokemon.py
@author: yonghao
@Description: 
@since 2021/03/01 19:41
'''
from visdom import Visdom
import time
import torch
import os, glob
import random, csv
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoaderroot = 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon'class Pokemon(Dataset):
def __init__(self, root, resize, mode='train'):
'''
初始化数据集
:paramroot: 图片存储的位置
:paramresize: 重新编辑图片的尺寸
:parammode: 初始化图片的类型(可以是数据集中各中分类)
'''
super(Pokemon, self).__init__()
self.root = rootself.resize = resizeself.mode = modeself.name2label = {}
# 创建 类名-> label 的映射字典
# os.listdir()每次顺序都不一样,故使用sorted()排序,使 类名-> label 的映射字典固定
for name in sorted(os.listdir(os.path.join(root))):
# 只读取文件夹名
if not os.path.isdir(os.path.join(root, name)):
continue
self.name2label[name] = len(self.name2label)
self.images, self.labels = self.load_csv('images.csv')
# 根据mode设定数据集的比例
if mode == 'train': # 60%
self.images = self.images[:int(0.6 * len(self.images))]
self.labels = self.labels[:int(0.6 * len(self.labels))]
elif mode == 'val': # 20%
self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]
self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]
else: # 20%
self.images = self.images[int(0.8 * len(self.images)):]
self.labels = self.labels[int(0.8 * len(self.labels)):]def __len__(self):
return len(self.images)def __getitem__(self, item) -> tuple:
# item ~ [0,len(images)-1]
# self.images , self.labels
# image , label
img, label = self.images[item], self.labels[item]
tf = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path => image data
transforms.Resize((int(1.25 * self.resize), int(1.25 * self.resize))), # 调整尺寸
transforms.RandomRotation(15), # 旋转
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
# 注意transforms.Normalize() 应该在transforms.ToTensor() 后面
# 数据在通道层上归一化,会使变化图片的像素
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 返回由img,label 组成的Tensor 元组
img = tf(img)
label = torch.tensor(label)return img, labeldef denormalize(self, x_het):
'''
图像逆正则化显示
:paramx_het: 正则化后的数据
:return:
'''# x_het = (x - mean) / std
mean, std = torch.tensor([0.485, 0.456, 0.406]), torch.tensor([0.229, 0.224, 0.225])
# x = x_het * std + mean
# x:[channel , h , w] , mean:[3] -> [3,1,1] , std:[3] -> [3,1,1]
mean = mean.unsqueeze(dim=-1).unsqueeze(dim=-1)std = std.unsqueeze(dim=-1).unsqueeze(dim=-1)x = x_het * std + meanreturn xdef load_csv(self, filename):
'''
加载图片数据 与 其label数据
:paramfilename: 加载数据的文件名
:return:
'''
# 仅在第一次调用时创建csv文件,保存 图片路径——>label 的映射关系
if not os.path.exists(os.path.join(self.root, filename)):
images = []
for name in self.name2label.keys():
'''
python在模块glob中定义了glob()函数,实现了对目录内容进行匹配的功能,
glob.glob()函数接受通配模式作为输入,并返回所有匹配的文件名和路径名列表
与os.listdir类似
'''
images += glob.glob(os.path.join(self.root, name, '*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
# 1167 , 'D:\\个人\\学习资料\\学习视频\\深度学习与PyTorch入门实战教程\\12.迁移学习-实战宝可梦精灵\\project_code\\pokemon\\bulbasaur\\00000000.png'# 打乱的是图片的存储路径
random.shuffle(images)# 使用上下文管理,对文件进行操作
'''
with是从Python2.5引入的一个新的语法,它是一种上下文管理协议,目的在于从流程图中把try,except 和finally 关键字和资源分配释放相关代码统统去掉,简化try….except….finlally的处理流程。with通过__enter__方法初始化,然后在__exit__中做善后以及处理异常。所以使用with处理的对象必须有__enter__()和__exit__()这两个方法。其中__enter__()方法在语句体(with语句包裹起来的代码块)执行之前进入运行,__exit__()方法在语句体执行完毕退出后运行。with 语句适用于对资源进行访问的场合,确保不管使用过程中是否发生异常都会执行必要的“清理”操作,释放资源,比如文件使用后自动关闭、线程中锁的自动获取和释放等。紧跟with后面的语句会被求值,返回对象的__enter__()方法被调用,这个方法的返回值将被赋值给as关键字后面的变量,当with后面的代码块全部被执行完之后,将调用前面返回对象的__exit__()方法
'''
with open(os.path.join(self.root, filename), mode='w', newline='') as f:
writer = csv.writer(f)
for img in images:
# os.sep 为系统自动识别的文件路径分隔符
name = img.split(os.sep)[-2]
label = self.name2label[name]
writer.writerow([img, label])images, labels = [], []
with open(os.path.join(root, filename), mode='r') as f:
reader = csv.reader(f)
for row in reader:
img, label = rowimages.append(img)
labels.append(int(label))assert len(images) == len(labels)return images, labelsdef main():
vis = Visdom()
# 获取数据集(单个数据做返回)
db = Pokemon(root, 64, mode='train')
img, label = next(iter(db))
print('sample:', img.shape, label.shape)
vis.image(img, win='img_win_het', opts=dict(title='norm_img_show'))
vis.image(db.denormalize(img), win='img_win', opts=dict(title='img_show'))# 批量导出数据
loader = DataLoader(db, batch_size=32, shuffle=True)
for x, y in loader:
vis.images(db.denormalize(x), nrow=8, win='batch', opts=dict(title='batch'))
vis.text(str(y.numpy()), win='label', opts=dict(title='bacth-y'))
time.sleep(10)if __name__ == '__main__':
main()

二、ResNet网络搭建

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: 实战代码- resnet.py
@author: yonghao
@Description: 创建残差网络结构
@since 2021/03/01 17:51
'''
import torch
import torch.nn.functional as F
from torch import nn
import utilsclass ResBlk(nn.Module):
'''
创建ResBlock
'''def __init__(self, ch_in, ch_out, stride=1):
'''
创建ResBlock模块
:paramch_in: 输入的通道数
:paramch_out: 输出的通道数
:paramstride: 卷积步长
'''
super(ResBlk, self).__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
if ch_in == ch_out:
self.extra = nn.Sequential()
else:
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
nn.BatchNorm2d(ch_out)
)def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))out = out + self.extra(x)
out = F.relu(out)
return outclass ResNet18(nn.Module):def __init__(self, num_class):
'''
创建18层的ResNet
:paramnum_class:分类数量
'''
super(ResNet18, self).__init__()self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=3, padding=2),
nn.BatchNorm2d(16)
)# followed 4 blocks
# [b , 16 , h , w] => [b , 32 , h , w]
self.blk1 = ResBlk(16, 32, stride=3)
# [b , 32 , h , w] => [b , 64 , h , w]
self.blk2 = ResBlk(32, 64, stride=3)
# [b , 64 , h , w] => [b , 128 , h , w]
self.blk3 = ResBlk(64, 128, stride=2)
# [b , 128 , h , w] => [b , 256 , h , w]
self.blk4 = ResBlk(128, 256, stride=2)
# [b , 256 , h , 2] => [b , 256*h*w]
self.flat = utils.Flatten()
# [b , 256*h*w] => [b , num_class]
self.out_layer = nn.Linear(256 * 3 * 3, num_class)def forward(self, x):
x = F.relu(self.conv1(x), inplace=True)
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
# print(x.shape)
x = self.flat(x)
out = self.out_layer(x)
return outdef mian():
# 测试ResBlk,当ch_in==ch_out时正确
# 当ch_in==ch_out时报异常
blk = ResBlk(64, 128, stride=2)
tmp = torch.randn(2, 64, 64, 64)
out = blk(tmp)
print('block:', out.shape)model = ResNet18(5)
tmp = torch.randn(2, 3, 224, 224)
out = model(tmp)
print("resnet:", out.shape)
p = sum([i.numel() for i in model.parameters()])
print('parameters size:', p)if __name__ == '__main__':
mian()

三、训练与测试

在这里插入图片描述

# -*- coding: UTF-8 -*-
'''
@version: 1.0
@PackageName: project_code - process.py
@author: yonghao
@Description: 实现训练过程 与 测试过程
@since 2021/03/02 18:54
'''
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from model.resnet import ResNet18
from pokemon import Pokemon# 批量数量
bacthsz = 32# 学习率
lr = 1e-3# 迭代次数
epochs = 10# device = torch.device('cpu')
# if torch.cuda.is_available():
#     device = torch.device('cuda')# 设置固定随机初始值
torch.manual_seed(1234)# 训练集
train_db = Pokemon('pokemon', 224, mode='train')
train_loader = DataLoader(train_db, batch_size=bacthsz, shuffle=True, num_workers=4)# 验证集
val_db = Pokemon('pokemon', 224, mode='val')
val_loader = DataLoader(val_db, batch_size=bacthsz, num_workers=2)# 测试集
test_db = Pokemon('pokemon', 224, mode='test')
test_loader = DataLoader(test_db, batch_size=bacthsz, num_workers=2)def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for x, y in loader:
# x, y = x.to(device), y.to(device)
# x:[b , c , h , w] , y:[b]
# out:[b,class_num]
with torch.no_grad():
out = model(x)
pred = out.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()return correct / totaldef main():
# model = ResNet18(5).to(device)
model = ResNet18(5)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()# 用于保存最高精度
best_acc = 0
best_epoch = 0
# 训练过程
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
# [b , c , h , w] , y[b]
# x, y = x.to(device), y.to(device)
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()# validation
if epoch % 2 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_epoch = epochbest_acc = val_acctorch.save(model.state_dict(), 'best.mdl')
print('best acc:', best_acc, "best epoch:", best_epoch)# 测试过程
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)if __name__ == '__main__':
'''
best acc: 0.8969957081545065 best epoch: 8
loaded from ckpt!
test acc: 0.8931623931623932
'''
main()

四、迁移学习

将处理相类似信号(特别是数据量较大)的神经网络嫁接过来,应用到本实验中
在这里插入图片描述

  • 具体的嫁接过程
    ① 尽量保留网络前、中部分
    ② 去除最后一层,根据自己的分类任务定制最后一层
    在这里插入图片描述

4.1 pytorch实现迁移学习

from torchvision.models import resnet18model = resnet18(pretrained=True)
# 17 layer out:[32, 512, 1, 1]
model = nn.Sequential(*list(model.children())[:-1],
utils.Flatten(),# 降维度
nn.Linear(512, 5))

这篇关于第十二章 迁移学习-实战宝可梦精灵的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/1095808

相关文章

Python版本信息获取方法详解与实战

《Python版本信息获取方法详解与实战》在Python开发中,获取Python版本号是调试、兼容性检查和版本控制的重要基础操作,本文详细介绍了如何使用sys和platform模块获取Python的主... 目录1. python版本号获取基础2. 使用sys模块获取版本信息2.1 sys模块概述2.1.1

Python爬虫HTTPS使用requests,httpx,aiohttp实战中的证书异步等问题

《Python爬虫HTTPS使用requests,httpx,aiohttp实战中的证书异步等问题》在爬虫工程里,“HTTPS”是绕不开的话题,HTTPS为传输加密提供保护,同时也给爬虫带来证书校验、... 目录一、核心问题与优先级检查(先问三件事)二、基础示例:requests 与证书处理三、高并发选型:

Oracle Scheduler任务故障诊断方法实战指南

《OracleScheduler任务故障诊断方法实战指南》Oracle数据库作为企业级应用中最常用的关系型数据库管理系统之一,偶尔会遇到各种故障和问题,:本文主要介绍OracleSchedul... 目录前言一、故障场景:当定时任务突然“消失”二、基础环境诊断:搭建“全局视角”1. 数据库实例与PDB状态2

Git进行版本控制的实战指南

《Git进行版本控制的实战指南》Git是一种分布式版本控制系统,广泛应用于软件开发中,它可以记录和管理项目的历史修改,并支持多人协作开发,通过Git,开发者可以轻松地跟踪代码变更、合并分支、回退版本等... 目录一、Git核心概念解析二、环境搭建与配置1. 安装Git(Windows示例)2. 基础配置(必

MyBatis分页查询实战案例完整流程

《MyBatis分页查询实战案例完整流程》MyBatis是一个强大的Java持久层框架,支持自定义SQL和高级映射,本案例以员工工资信息管理为例,详细讲解如何在IDEA中使用MyBatis结合Page... 目录1. MyBATis框架简介2. 分页查询原理与应用场景2.1 分页查询的基本原理2.1.1 分

使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解

《使用Python批量将.ncm格式的音频文件转换为.mp3格式的实战详解》本文详细介绍了如何使用Python通过ncmdump工具批量将.ncm音频转换为.mp3的步骤,包括安装、配置ffmpeg环... 目录1. 前言2. 安装 ncmdump3. 实现 .ncm 转 .mp34. 执行过程5. 执行结

SpringBoot 多环境开发实战(从配置、管理与控制)

《SpringBoot多环境开发实战(从配置、管理与控制)》本文详解SpringBoot多环境配置,涵盖单文件YAML、多文件模式、MavenProfile分组及激活策略,通过优先级控制灵活切换环境... 目录一、多环境开发基础(单文件 YAML 版)(一)配置原理与优势(二)实操示例二、多环境开发多文件版

Three.js构建一个 3D 商品展示空间完整实战项目

《Three.js构建一个3D商品展示空间完整实战项目》Three.js是一个强大的JavaScript库,专用于在Web浏览器中创建3D图形,:本文主要介绍Three.js构建一个3D商品展... 目录引言项目核心技术1. 项目架构与资源组织2. 多模型切换、交互热点绑定3. 移动端适配与帧率优化4. 可

从原理到实战解析Java Stream 的并行流性能优化

《从原理到实战解析JavaStream的并行流性能优化》本文给大家介绍JavaStream的并行流性能优化:从原理到实战的全攻略,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的... 目录一、并行流的核心原理与适用场景二、性能优化的核心策略1. 合理设置并行度:打破默认阈值2. 避免装箱

Maven中生命周期深度解析与实战指南

《Maven中生命周期深度解析与实战指南》这篇文章主要为大家详细介绍了Maven生命周期实战指南,包含核心概念、阶段详解、SpringBoot特化场景及企业级实践建议,希望对大家有一定的帮助... 目录一、Maven 生命周期哲学二、default生命周期核心阶段详解(高频使用)三、clean生命周期核心阶