Pytorch使用DataLoader, num_workers!=0时的内存泄露

2023-10-08 05:30

本文主要是介绍Pytorch使用DataLoader, num_workers!=0时的内存泄露,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!

  • 描述一下背景,和遇到的问题:

我在做一个超大数据集的多分类,设备Ubuntu 22.04+i9 13900K+Nvidia 4090+64GB RAM,第一次的训练的训练集有700万张,训练成功。后面收集到更多数据集,数据增强后达到了1000万张。但第二次训练4个小时后,就被系统杀掉进程了,原因是Out of Memory。找了很久的原因,发现内存随着训练step的增加而线性增加,猜测是内存泄露,最后定位到了DataLoader的num_workers参数(只要num_workers=0就没有问题)。

  • 真正原因:

Python(Pytorch)中的list转换成tensor时,会发生内存泄漏,要避免list的使用,可以通过使用np.array来代替list。

  • 解决办法:

自定义DataLoader中的Dataset类,然后Dataset类中的list全部用np.array来代替。这样的话,DataLoader将np.array转换成Tensor的过程就不会发生内存泄露。

  • 下面给两个错误的示例代码和一个正确的代码:(都是我自己犯过的错误)

1.错误的DataLoader加载数据集方法1

# 加载数据
train_data = datasets.ImageFolder(root=TRAIN_DIR_ARG, transform=transform)
valid_data = datasets.ImageFolder(root=VALIDATION_DIR, transform=transform)
test_data = datasets.ImageFolder(root=TEST_DIR, transform=transform)train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

2.错误的DataLoader加载数据集方法2(重写了Dataset方法)


class CustomDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.image_paths = []self.labels = []# 遍历数据目录并收集图像文件路径和对应的标签classes = os.listdir(data_dir)for i, class_name in enumerate(classes):class_dir = os.path.join(data_dir, class_name)if os.path.isdir(class_dir):for image_name in os.listdir(class_dir):image_path = os.path.join(class_dir, image_name)self.image_paths.append(image_path)self.labels.append(i)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]label = self.labels[idx]# # 在需要时加载图像image = Image.open(image_path)if self.transform:image = self.transform(image)return image, labeltrain_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=18)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=False)

3.重写Dataset的正确方法(重写了Dataset方法,list全部转成np.array)

class CustomDataset(Dataset):def __init__(self, data_dir, transform=None):self.data_dir = data_dirself.transform = transformself.image_paths = []  # 使用Python列表self.labels = []  # 使用Python列表# 遍历数据目录并收集图像文件路径和对应的标签classes = os.listdir(data_dir)for i, class_name in enumerate(classes):class_dir = os.path.join(data_dir, class_name)if os.path.isdir(class_dir):for image_name in os.listdir(class_dir):image_path = os.path.join(class_dir, image_name)self.image_paths.append(image_path)  # 添加到Python列表self.labels.append(i)  # 添加到Python列表# 转换为NumPy数组,这里就是解决内存泄露的关键代码self.image_paths = np.array(self.image_paths)self.labels = np.array(self.labels)def __len__(self):return len(self.image_paths)def __getitem__(self, idx):image_path = self.image_paths[idx]label = self.labels[idx]# 在需要时加载图像image = Image.open(image_path)if self.transform:image = self.transform(image)# 将图像数据转换为NumPy数组image = np.array(image)return image, labeltrain_data = CustomDataset(data_dir=TRAIN_DIR_ARG, transform=transform)
valid_data = CustomDataset(data_dir=VALIDATION_DIR, transform=transform)
test_data = CustomDataset(data_dir=TEST_DIR, transform=transform)train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=18)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=8, pin_memory=False)

这篇关于Pytorch使用DataLoader, num_workers!=0时的内存泄露的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!



http://www.chinasem.cn/article/163008

相关文章

使用Python实现IP地址和端口状态检测与监控

《使用Python实现IP地址和端口状态检测与监控》在网络运维和服务器管理中,IP地址和端口的可用性监控是保障业务连续性的基础需求,本文将带你用Python从零打造一个高可用IP监控系统,感兴趣的小伙... 目录概述:为什么需要IP监控系统使用步骤说明1. 环境准备2. 系统部署3. 核心功能配置系统效果展

使用Java将各种数据写入Excel表格的操作示例

《使用Java将各种数据写入Excel表格的操作示例》在数据处理与管理领域,Excel凭借其强大的功能和广泛的应用,成为了数据存储与展示的重要工具,在Java开发过程中,常常需要将不同类型的数据,本文... 目录前言安装免费Java库1. 写入文本、或数值到 Excel单元格2. 写入数组到 Excel表格

redis中使用lua脚本的原理与基本使用详解

《redis中使用lua脚本的原理与基本使用详解》在Redis中使用Lua脚本可以实现原子性操作、减少网络开销以及提高执行效率,下面小编就来和大家详细介绍一下在redis中使用lua脚本的原理... 目录Redis 执行 Lua 脚本的原理基本使用方法使用EVAL命令执行 Lua 脚本使用EVALSHA命令

Java 中的 @SneakyThrows 注解使用方法(简化异常处理的利与弊)

《Java中的@SneakyThrows注解使用方法(简化异常处理的利与弊)》为了简化异常处理,Lombok提供了一个强大的注解@SneakyThrows,本文将详细介绍@SneakyThro... 目录1. @SneakyThrows 简介 1.1 什么是 Lombok?2. @SneakyThrows

判断PyTorch是GPU版还是CPU版的方法小结

《判断PyTorch是GPU版还是CPU版的方法小结》PyTorch作为当前最流行的深度学习框架之一,支持在CPU和GPU(NVIDIACUDA)上运行,所以对于深度学习开发者来说,正确识别PyTor... 目录前言为什么需要区分GPU和CPU版本?性能差异硬件要求如何检查PyTorch版本?方法1:使用命

使用Python和Pyecharts创建交互式地图

《使用Python和Pyecharts创建交互式地图》在数据可视化领域,创建交互式地图是一种强大的方式,可以使受众能够以引人入胜且信息丰富的方式探索地理数据,下面我们看看如何使用Python和Pyec... 目录简介Pyecharts 简介创建上海地图代码说明运行结果总结简介在数据可视化领域,创建交互式地

Java Stream流使用案例深入详解

《JavaStream流使用案例深入详解》:本文主要介绍JavaStream流使用案例详解,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价值,需要的朋友参考下吧... 目录前言1. Lambda1.1 语法1.2 没参数只有一条语句或者多条语句1.3 一个参数只有一条语句或者多

Java Spring 中 @PostConstruct 注解使用原理及常见场景

《JavaSpring中@PostConstruct注解使用原理及常见场景》在JavaSpring中,@PostConstruct注解是一个非常实用的功能,它允许开发者在Spring容器完全初... 目录一、@PostConstruct 注解概述二、@PostConstruct 注解的基本使用2.1 基本代

C#使用StackExchange.Redis实现分布式锁的两种方式介绍

《C#使用StackExchange.Redis实现分布式锁的两种方式介绍》分布式锁在集群的架构中发挥着重要的作用,:本文主要介绍C#使用StackExchange.Redis实现分布式锁的... 目录自定义分布式锁获取锁释放锁自动续期StackExchange.Redis分布式锁获取锁释放锁自动续期分布式

springboot使用Scheduling实现动态增删启停定时任务教程

《springboot使用Scheduling实现动态增删启停定时任务教程》:本文主要介绍springboot使用Scheduling实现动态增删启停定时任务教程,具有很好的参考价值,希望对大家有... 目录1、配置定时任务需要的线程池2、创建ScheduledFuture的包装类3、注册定时任务,增加、删