手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构

2024-02-27 09:20

本文主要是介绍手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

一. 构建Dataset

构建Dataset无非就是创建一个类继承dataset并重写三个方法:

from torch.utils.data import Dataset
from PIL import Imageclass MyDataset(Dataset):def __init__(self):# 初始化数据集passdef __getitem__(self, index):# 根据索引获取数据样本passdef __len__(self):# 返回数据集大小pass# 创建自定义数据集实例
dataset = MyDataset()

1. __init__(self,params...) 方法

通过创建实例输入实例属性,如数据集文件夹地址;

需设置一些其他方法所需的属性,如self.data_path指明数据所在文件夹等一些属性。

def __init__(self, data_path):self.data_path = data_pathself.data_list = os.listdir(self.data_path)

2. __getitem__(self, index)方法

参数index一般用于data_list[index] 指定某个与index捆绑的实例。

该方法需返回所需的数据,返回的数据相当于Dataset的一个实例。

    def __getitem__(self, index):file_name = self.data_list[index] #用index指定文件data_label = file_name.split('.')[0] #将标签置于文件名data_path = os.path.join(self.data_path, file_name) #获取文件地址img = Image.open(data_path) #读取数据return img, data_label #返回实例数据

3. __len__(self)方法

返回数据集实例的数量,也就是数据集的大小。

    def __len__(self):return len(self.data_list)

二. 创建dataloader

dataset(数据集):需要提取数据的数据集,Dataset对象
batch_size(批大小):每一次装载样本的个数,int型
 shuffle(洗牌):进行新一轮epoch时是否要重新洗牌,Boolean型
num_workers:是否多进程读取机制
drop_last:当样本数不能被batchsize整除时, 是否舍弃最后一批数据

#导入dataloader的包
from torch.utils.data import DataLoader#读取文件夹下数据以创建数据集
test_Dataset = MyDataset(dir_address)#创建一个dataloader,设置批大小为4,每一个epoch重新洗牌,
#不进行多进程读取机制,不舍弃不能被整除的批次
dataloader = DataLoader(dataset=test_dataset,batch_size=4,shuffle=True,num_workers=0,drop_last=False)

三. 搭建模型基础架构(pytorch版)

1. 导包(torch,nn,tqdm几乎必须)

import torch
from torch import nn
from tqdm import tqdm

2. 搭建通用核心训练架构函数

def train_loop(model, loader, epochs, optim, device, display=False, store_path='model.pt'):mse = nn.MSELoss()best_loss = float('inf')for epoch in tqdm(range(epochs),desc=f"Training process", colour='#00f00'):epoch_loss = 0.0for step, batch in enumerate(tqdm(loader, leave=False, desc=f"Epoch {epoch + 1}/{epochs}", colour="#005500")):#loading data,对应了Dataset返回的实例数据x = batch[0].to(device)y = batch[1].to(device)n = len(x) #batchsizeloss = mse(x, y) # the average loss of one batchoptim.zero_grad()optim.step()epoch_loss += loss.item() * n / len(loader.dataset)#display results generated at this epochif display:passlog_string = f"Loss at epoch {epoch + 1}: {epoch_loss:.3f}"#save the modelif best_loss > epoch_loss:best_loss = epoch_losstorch.save(ddpm.state_dict(), store_path)log_string += " --> Best model ever(stored)"print(log_string)

这篇关于手撸AI-1:构建DatasetDataloader,搭建模型训练基础架构的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/751943

相关文章

Java中的StringBuilder之如何高效构建字符串

《Java中的StringBuilder之如何高效构建字符串》本文将深入浅出地介绍StringBuilder的使用方法、性能优势以及相关字符串处理技术,结合代码示例帮助读者更好地理解和应用,希望对大家... 目录关键点什么是 StringBuilder?为什么需要 StringBuilder?如何使用 St

使用Python构建一个Hexo博客发布工具

《使用Python构建一个Hexo博客发布工具》虽然Hexo的命令行工具非常强大,但对于日常的博客撰写和发布过程,我总觉得缺少一个直观的图形界面来简化操作,下面我们就来看看如何使用Python构建一个... 目录引言Hexo博客系统简介设计需求技术选择代码实现主框架界面设计核心功能实现1. 发布文章2. 加

利用Python快速搭建Markdown笔记发布系统

《利用Python快速搭建Markdown笔记发布系统》这篇文章主要为大家详细介绍了使用Python生态的成熟工具,在30分钟内搭建一个支持Markdown渲染、分类标签、全文搜索的私有化知识发布系统... 目录引言:为什么要自建知识博客一、技术选型:极简主义开发栈二、系统架构设计三、核心代码实现(分步解析

Spring Security基于数据库的ABAC属性权限模型实战开发教程

《SpringSecurity基于数据库的ABAC属性权限模型实战开发教程》:本文主要介绍SpringSecurity基于数据库的ABAC属性权限模型实战开发教程,本文给大家介绍的非常详细,对大... 目录1. 前言2. 权限决策依据RBACABAC综合对比3. 数据库表结构说明4. 实战开始5. MyBA

使用Python实现快速搭建本地HTTP服务器

《使用Python实现快速搭建本地HTTP服务器》:本文主要介绍如何使用Python快速搭建本地HTTP服务器,轻松实现一键HTTP文件共享,同时结合二维码技术,让访问更简单,感兴趣的小伙伴可以了... 目录1. 概述2. 快速搭建 HTTP 文件共享服务2.1 核心思路2.2 代码实现2.3 代码解读3.

MySQL双主搭建+keepalived高可用的实现

《MySQL双主搭建+keepalived高可用的实现》本文主要介绍了MySQL双主搭建+keepalived高可用的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,... 目录一、测试环境准备二、主从搭建1.创建复制用户2.创建复制关系3.开启复制,确认复制是否成功4.同

Java的IO模型、Netty原理解析

《Java的IO模型、Netty原理解析》Java的I/O是以流的方式进行数据输入输出的,Java的类库涉及很多领域的IO内容:标准的输入输出,文件的操作、网络上的数据传输流、字符串流、对象流等,这篇... 目录1.什么是IO2.同步与异步、阻塞与非阻塞3.三种IO模型BIO(blocking I/O)NI

一文详解如何从零构建Spring Boot Starter并实现整合

《一文详解如何从零构建SpringBootStarter并实现整合》SpringBoot是一个开源的Java基础框架,用于创建独立、生产级的基于Spring框架的应用程序,:本文主要介绍如何从... 目录一、Spring Boot Starter的核心价值二、Starter项目创建全流程2.1 项目初始化(

使用Java实现通用树形结构构建工具类

《使用Java实现通用树形结构构建工具类》这篇文章主要为大家详细介绍了如何使用Java实现通用树形结构构建工具类,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录完整代码一、设计思想与核心功能二、核心实现原理1. 数据结构准备阶段2. 循环依赖检测算法3. 树形结构构建4. 搜索子

基于Flask框架添加多个AI模型的API并进行交互

《基于Flask框架添加多个AI模型的API并进行交互》:本文主要介绍如何基于Flask框架开发AI模型API管理系统,允许用户添加、删除不同AI模型的API密钥,感兴趣的可以了解下... 目录1. 概述2. 后端代码说明2.1 依赖库导入2.2 应用初始化2.3 API 存储字典2.4 路由函数2.5 应