pytoch如何加载批量数据--Dataset,DataLoader

2024-03-11 11:12

本文主要是介绍pytoch如何加载批量数据--Dataset,DataLoader,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

Dataset

Dataset 是一个抽象类。我们可以定义一个类继承这个类,从而加载数据,构造数据集(索引)

DataLoader

DataLoader是一个帮助我们在Pytorch中加载数据的类,在训练测试时加载数据,获取mini-batch
使用说明:¶
epoch:One forward pass and one backward pass of all the training examples. 训练次数
batch-size:The number of tarining examples in one forward backward pass. 每次用的样本数量
iterations:Number of passes,each pass using [batch size] number of examples. 迭代的数量 样本数量/batch
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
准备数据: 实现三个魔法方法
# 继承抽象类Dataset,实现三个方法
class DiabetesDataset(Dataset):# def __init__(self,filepath):#     data =  np.loadtxt(filepath,delimiter=',',dtype=np.float32)#     self.len = data.shape[0]  # (d0,d1,d2..) d0#     self.X = torch.from_numpy(data[:,:-1])#     self.Y = torch.from_numpy(data[:,[-1]])# 1.所有数据一次性加载# 2.每次只加载batch的数据def __init__(self,data,label):self.len = data.shape[0]self.X = torch.from_numpy(data)self.Y = torch.from_numpy(label)# 实例化类后,使对象支持索引操作 The expression,dataset[index],will call this magic functiondef __getitem__(self,index):return self.X[index], self.Y[index] # 返回元组 # this magic function returns length of datasetdef __len__(self):return self.len# 划分训练集和测试集    
data =  np.loadtxt('diabetes.csv',delimiter=',',dtype=np.float32)
X = data[:,:-1]
Y = data[:,[-1]]
# train_data,test_data = train_test_split(data,test_size=0.2,random_state=42)
Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,Y,test_size=0.2)# 实例化dataset类
dataset = DiabetesDataset(Xtrain,Ytrain)
# DataLoader:加载器 Initialize loader with batch-size,shuffle,process number
train_loader = DataLoader(dataset=dataset,batch_size=15,shuffle=True# ,num_workers=2)
# num_workers 线程,并发数,用了发现超conda虚拟内存

diabetes.csv糖尿病数据集

        链接:https://pan.baidu.com/s/1a-6ToVlXr7QfYAnHIpWnNg?pwd=1234 
        提取码:1234 

设计模型:
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6)self.linear2 = torch.nn.Linear(6,4)self.linear3 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid()def forward(self,x):x = self.sigmoid(self.linear1(x))x = self.sigmoid(self.linear2(x))x = self.sigmoid(self.linear3(x))return xmodel = Model()

构造损失函数和优化器
criterion = torch.nn.BCELoss(size_average=True,reduction='mean')
optimizer = torch.optim.SGD(model.parameters(),lr=0.01)

训练
# if name==__main__():
for epoch in range(10000):for i,data in enumerate(train_loader,0):# for i,(x,y) in enumerate(train_loader,0):inputs,labels = datay_pred = model(inputs)loss = criterion(y_pred,labels)# print(epoch,i,loss.item())optimizer.zero_grad()loss.backward()optimizer.step()

测试
x_test = torch.from_numpy(Xtest)
y_test = torch.from_numpy(Ytest)
y_pred = model(x_test)
y_pred_label = torch.where(y_pred>=0.5,torch.tensor([1.0]),torch.tensor([0.0]))
acc = torch.eq(y_pred_label, y_test).sum().item()/y_test.size(0)
print("acc = ",acc)

这篇关于pytoch如何加载批量数据--Dataset,DataLoader的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/797639

相关文章

MyBatisPlus如何优化千万级数据的CRUD

《MyBatisPlus如何优化千万级数据的CRUD》最近负责的一个项目,数据库表量级破千万,每次执行CRUD都像走钢丝,稍有不慎就引起数据库报警,本文就结合这个项目的实战经验,聊聊MyBatisPl... 目录背景一、MyBATis Plus 简介二、千万级数据的挑战三、优化 CRUD 的关键策略1. 查

python实现对数据公钥加密与私钥解密

《python实现对数据公钥加密与私钥解密》这篇文章主要为大家详细介绍了如何使用python实现对数据公钥加密与私钥解密,文中的示例代码讲解详细,感兴趣的小伙伴可以跟随小编一起学习一下... 目录公钥私钥的生成使用公钥加密使用私钥解密公钥私钥的生成这一部分,使用python生成公钥与私钥,然后保存在两个文

mysql中的数据目录用法及说明

《mysql中的数据目录用法及说明》:本文主要介绍mysql中的数据目录用法及说明,具有很好的参考价值,希望对大家有所帮助,如有错误或未考虑完全的地方,望不吝赐教... 目录1、背景2、版本3、数据目录4、总结1、背景安装mysql之后,在安装目录下会有一个data目录,我们创建的数据库、创建的表、插入的

Navicat数据表的数据添加,删除及使用sql完成数据的添加过程

《Navicat数据表的数据添加,删除及使用sql完成数据的添加过程》:本文主要介绍Navicat数据表的数据添加,删除及使用sql完成数据的添加过程,具有很好的参考价值,希望对大家有所帮助,如有... 目录Navicat数据表数据添加,删除及使用sql完成数据添加选中操作的表则出现如下界面,查看左下角从左

SpringBoot中4种数据水平分片策略

《SpringBoot中4种数据水平分片策略》数据水平分片作为一种水平扩展策略,通过将数据分散到多个物理节点上,有效解决了存储容量和性能瓶颈问题,下面小编就来和大家分享4种数据分片策略吧... 目录一、前言二、哈希分片2.1 原理2.2 SpringBoot实现2.3 优缺点分析2.4 适用场景三、范围分片

利用Python脚本实现批量将图片转换为WebP格式

《利用Python脚本实现批量将图片转换为WebP格式》Python语言的简洁语法和库支持使其成为图像处理的理想选择,本文将介绍如何利用Python实现批量将图片转换为WebP格式的脚本,WebP作为... 目录简介1. python在图像处理中的应用2. WebP格式的原理和优势2.1 WebP格式与传统

Spring如何使用注解@DependsOn控制Bean加载顺序

《Spring如何使用注解@DependsOn控制Bean加载顺序》:本文主要介绍Spring如何使用注解@DependsOn控制Bean加载顺序,具有很好的参考价值,希望对大家有所帮助,如有错误... 目录1.javascript 前言2. 代码实现总结1. 前言默认情况下,Spring加载Bean的顺

Redis分片集群、数据读写规则问题小结

《Redis分片集群、数据读写规则问题小结》本文介绍了Redis分片集群的原理,通过数据分片和哈希槽机制解决单机内存限制与写瓶颈问题,实现分布式存储和高并发处理,但存在通信开销大、维护复杂及对事务支持... 目录一、分片集群解android决的问题二、分片集群图解 分片集群特征如何解决的上述问题?(与哨兵模

浅析如何保证MySQL与Redis数据一致性

《浅析如何保证MySQL与Redis数据一致性》在互联网应用中,MySQL作为持久化存储引擎,Redis作为高性能缓存层,两者的组合能有效提升系统性能,下面我们来看看如何保证两者的数据一致性吧... 目录一、数据不一致性的根源1.1 典型不一致场景1.2 关键矛盾点二、一致性保障策略2.1 基础策略:更新数

Oracle 数据库数据操作如何精通 INSERT, UPDATE, DELETE

《Oracle数据库数据操作如何精通INSERT,UPDATE,DELETE》在Oracle数据库中,对表内数据进行增加、修改和删除操作是通过数据操作语言来完成的,下面给大家介绍Oracle数... 目录思维导图一、插入数据 (INSERT)1.1 插入单行数据,指定所有列的值语法:1.2 插入单行数据,指