pytoch如何加载批量数据--Dataset,DataLoader

2024-03-11 11:12

本文主要是介绍pytoch如何加载批量数据--Dataset,DataLoader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Dataset

Dataset 是一个抽象类。我们可以定义一个类继承这个类,从而加载数据,构造数据集(索引)

DataLoader

DataLoader是一个帮助我们在Pytorch中加载数据的类,在训练测试时加载数据,获取mini-batch
使用说明:¶
epoch:One forward pass and one backward pass of all the training examples. 训练次数
batch-size:The number of tarining examples in one forward backward pass. 每次用的样本数量
iterations:Number of passes,each pass using [batch size] number of examples. 迭代的数量 样本数量/batch
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
准备数据: 实现三个魔法方法
# 继承抽象类Dataset,实现三个方法
class DiabetesDataset(Dataset):# def __init__(self,filepath):#     data =  np.loadtxt(filepath,delimiter=',',dtype=np.float32)#     self.len = data.shape[0]  # (d0,d1,d2..) d0#     self.X = torch.from_numpy(data[:,:-1])#     self.Y = torch.from_numpy(data[:,[-1]])# 1.所有数据一次性加载# 2.每次只加载batch的数据def __init__(self,data,label):self.len = data.shape[0]self.X = torch.from_numpy(data)self.Y = torch.from_numpy(label)# 实例化类后,使对象支持索引操作 The expression,dataset[index],will call this magic functiondef __getitem__(self,index):return self.X[index], self.Y[index] # 返回元组 # this magic function returns length of datasetdef __len__(self):return self.len# 划分训练集和测试集    
data =  np.loadtxt('diabetes.csv',delimiter=',',dtype=np.float32)
X = data[:,:-1]
Y = data[:,[-1]]
# train_data,test_data = train_test_split(data,test_size=0.2,random_state=42)
Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,Y,test_size=0.2)# 实例化dataset类
dataset = DiabetesDataset(Xtrain,Ytrain)
# DataLoader:加载器 Initialize loader with batch-size,shuffle,process number
train_loader = DataLoader(dataset=dataset,batch_size=15,shuffle=True# ,num_workers=2)
# num_workers 线程,并发数,用了发现超conda虚拟内存

diabetes.csv糖尿病数据集

        链接:https://pan.baidu.com/s/1a-6ToVlXr7QfYAnHIpWnNg?pwd=1234 
        提取码:1234 

设计模型:
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()

构造损失函数和优化器
criterion = torch.nn.BCELoss(size_average=True,reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

训练
# if name==__main__():
for epoch in range(10000):for i,data in enumerate(train_loader,0):# for i,(x,y) in enumerate(train_loader,0):inputs,labels = datay_pred = model(inputs)loss = criterion(y_pred,labels)# print(epoch,i,loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

测试
x_test = torch.from_numpy(Xtest)
y_test = torch.from_numpy(Ytest)
y_pred = model(x_test)
y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))
acc = torch.eq(y_pred_label, y_test).sum().item()/y_test.size(0)
print("acc = ",acc)

这篇关于pytoch如何加载批量数据--Dataset,DataLoader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/797639

相关文章

MyBatis延迟加载与多级缓存全解析

《MyBatis延迟加载与多级缓存全解析》文章介绍MyBatis的延迟加载与多级缓存机制,延迟加载按需加载关联数据提升性能,一级缓存会话级默认开启,二级缓存工厂级支持跨会话共享,增删改操作会清空对应缓... 目录MyBATis延迟加载策略一对多示例一对多示例MyBatis框架的缓存一级缓存二级缓存MyBat

Linux下利用select实现串口数据读取过程

《Linux下利用select实现串口数据读取过程》文章介绍Linux中使用select、poll或epoll实现串口数据读取,通过I/O多路复用机制在数据到达时触发读取,避免持续轮询,示例代码展示设... 目录示例代码(使用select实现)代码解释总结在 linux 系统里,我们可以借助 select、

Ubuntu向多台主机批量传输文件的流程步骤

《Ubuntu向多台主机批量传输文件的流程步骤》:本文主要介绍在Ubuntu中批量传输文件到多台主机的方法,需确保主机互通、用户名密码统一及端口开放,通过安装sshpass工具,准备包含目标主机信... 目录Ubuntu 向多台主机批量传输文件1.安装 sshpass2.准备主机列表文件3.创建一个批处理脚

C#使用iText获取PDF的trailer数据的代码示例

《C#使用iText获取PDF的trailer数据的代码示例》开发程序debug的时候,看到了PDF有个trailer数据,挺有意思,于是考虑用代码把它读出来,那么就用到我们常用的iText框架了,所... 目录引言iText 核心概念C# 代码示例步骤 1: 确保已安装 iText步骤 2: C# 代码程

Pandas处理缺失数据的方式汇总

《Pandas处理缺失数据的方式汇总》许多教程中的数据与现实世界中的数据有很大不同,现实世界中的数据很少是干净且同质的,本文我们将讨论处理缺失数据的一些常规注意事项,了解Pandas如何表示缺失数据,... 目录缺失数据约定的权衡Pandas 中的缺失数据None 作为哨兵值NaN:缺失的数值数据Panda

C++中处理文本数据char与string的终极对比指南

《C++中处理文本数据char与string的终极对比指南》在C++编程中char和string是两种用于处理字符数据的类型,但它们在使用方式和功能上有显著的不同,:本文主要介绍C++中处理文本数... 目录1. 基本定义与本质2. 内存管理3. 操作与功能4. 性能特点5. 使用场景6. 相互转换核心区别

MySQL批量替换数据库字符集的实用方法(附详细代码)

《MySQL批量替换数据库字符集的实用方法(附详细代码)》当需要修改数据库编码和字符集时,通常需要对其下属的所有表及表中所有字段进行修改,下面:本文主要介绍MySQL批量替换数据库字符集的实用方法... 目录前言为什么要批量修改字符集?整体脚本脚本逻辑解析1. 设置目标参数2. 生成修改表默认字符集的语句3

python库pydantic数据验证和设置管理库的用途

《python库pydantic数据验证和设置管理库的用途》pydantic是一个用于数据验证和设置管理的Python库,它主要利用Python类型注解来定义数据模型的结构和验证规则,本文给大家介绍p... 目录主要特点和用途:Field数值验证参数总结pydantic 是一个让你能够 confidentl

JAVA实现亿级千万级数据顺序导出的示例代码

《JAVA实现亿级千万级数据顺序导出的示例代码》本文主要介绍了JAVA实现亿级千万级数据顺序导出的示例代码,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面... 前提:主要考虑控制内存占用空间,避免出现同时导出,导致主程序OOM问题。实现思路:A.启用线程池

SpringBoot分段处理List集合多线程批量插入数据方式

《SpringBoot分段处理List集合多线程批量插入数据方式》文章介绍如何处理大数据量List批量插入数据库的优化方案:通过拆分List并分配独立线程处理,结合Spring线程池与异步方法提升效率... 目录项目场景解决方案1.实体类2.Mapper3.spring容器注入线程池bejsan对象4.创建