PyTorch GPU显存持续增长? 解决DataLoader与.cuda()陷阱
2025-04-22 17:35:46
PyTorch 训练 GPU 显存持续增长?揪出 DataLoader
和 .cuda()
的隐藏陷阱
跑 PyTorch 训练时,遇上 CUDA out of memory
错误是家常便饭。但有时候,明明代码里规规矩矩地 del
了不再需要的张量,甚至调用了 torch.cuda.empty_cache()
,GPU 显存(VRAM)还是像脱缰的野马一样持续上涨,几个 epoch 下来就爆了。这是咋回事?
最近就有朋友遇到了这个问题,他的代码结构大致如下 (已简化):
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os # 模拟文件创建
# --- 准备模拟数据 ---
os.makedirs('dummy_data', exist_ok=True)
for i in range(10000):
np.save(f'dummy_data/img_{i}.npy', np.random.rand(3, 224, 224).astype(np.float32))
np.save(f'dummy_data/label_{i}.npy', np.random.randint(0, 10, size=()).astype(np.int64))
# --- 模拟数据准备完毕 ---
class CustomDataset(Dataset):
def __init__(self, data_paths):
self.data_paths = data_paths
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
# 从磁盘加载数据 (模拟)
image_path = self.data_paths[idx]['image']
label_path = self.data_paths[idx]['label']
# 假设数据是以 .npy 格式存储
# 注意:实际应用中, I/O 操作可能更复杂
try:
image = np.load(image_path).astype(np.float32)
label = np.load(label_path).astype(np.int64)
except FileNotFoundError:
print(f"Warning: File not found for index {idx}. Skipping.")
# 返回None或者占位符, 需要在collate_fn中处理
# 这里简单返回零值张量示意
image = np.zeros((3, 224, 224), dtype=np.float32)
label = np.int64(0)
# !!! 问题关键点:在 Dataset 中就把张量放到了 GPU 上 !!!
image = torch.tensor(image).cuda()
label = torch.tensor(label).cuda()
return image, label
# 生成数据路径列表
data_paths = [{'image': f'dummy_data/img_{i}.npy', 'label': f'dummy_data/label_{i}.npy'} for i in range(10000)]
# 创建 Dataset 和 DataLoader
dataset = CustomDataset(data_paths)
# 注意:num_workers > 0 是触发问题的关键之一
dataloader = DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)
# 模拟训练循环
print("Starting simulated training loop...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 后面会用到
# 手动模拟模型放置到 GPU (实际训练会有)
# dummy_model = torch.nn.Module().to(device) # 例子中没用到,但实际情况会有
for epoch in range(5): # 减少 epoch 数量便于观察
print(f"--- Epoch {epoch+1} ---")
for i, batch in enumerate(dataloader):
if batch is None: # 处理 __getitem__ 可能返回 None 的情况 (虽然例子里没返回None)
print(f"Skipping batch {i} due to loading error.")
continue
# 数据已经直接是 GPU 张量了
images, labels = batch
print(f"Batch {i}: images on {images.device}, labels on {labels.device}")
# 模拟一些简单的计算
# output = images.mean() # 简化计算, 原代码的这步可能引入其他问题, 但非核心
# loss = output.sum() # 同样简化
try:
# 模拟操作,例如与模型交互 (如果模型在GPU上)
# output = dummy_model(images) # 假如有模型
# loss = some_loss_function(output, labels)
# loss.backward() # 模拟反向传播
# 为了演示显存占用,做个简单的张量操作
output = images + 1
loss = output.sum() # 创建一个计算图节点
except Exception as e:
print(f"Error during batch {i} processing: {e}")
# 清理以防万一,尽管问题根源不在此
if 'images' in locals(): del images
if 'labels' in locals(): del labels
if 'loss' in locals(): del loss
if 'output' in locals(): del output
torch.cuda.empty_cache()
continue # 继续下一个 batch
print(f"Batch {i} processed. VRAM usage might be increasing.")
# 打印显存占用信息(需要安装 pynvml:pip install nvidia-ml-py3)
# 注意: 这只是粗略估计,实际占用受 PyTorch caching 影响
# import pynvml
# pynvml.nvmlInit()
# handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 假设使用 GPU 0
# info = pynvml.nvmlDeviceGetMemoryInfo(handle)
# print(f"Batch {i}: Used VRAM: {info.used // 1024**2} MB")
# pynvml.nvmlShutdown()
# !!! 尝试手动删除,但效果不佳 !!!
del images
del labels
del loss
del output # 删除计算结果
# !!! 调用 empty_cache() 也不能阻止显存持续增长 !!!
torch.cuda.empty_cache()
print(f"--- Epoch {epoch+1} finished ---")
# 在 Epoch 结束时也检查显存
# import pynvml
# pynvml.nvmlInit()
# handle = pynvml.nvmlDeviceGetHandleByIndex(0)
# info = pynvml.nvmlDeviceGetMemoryInfo(handle)
# print(f"End of Epoch {epoch+1}: Used VRAM: {info.used // 1024**2} MB")
# pynvml.nvmlShutdown()
print("Simulated training loop finished.")
# 清理模拟数据
import shutil
shutil.rmtree('dummy_data')
print("Cleaned up dummy data.")
即使在循环的最后用 del
删除所有张量,并且调用 torch.cuda.empty_cache()
,用 nvidia-smi
观察,显存占用依然是步步高升,直到 OOM 崩溃。把 num_workers
设置成 0 似乎能缓解,但并不是根本解决办法,而且失去了多进程加载数据的优势。
问题到底出在哪?
一、揪出幕后黑手:DataLoader、多进程和 cuda()
问题的根源在于 DataLoader
的工作机制与在 Dataset.__getitem__
方法中调用 .cuda()
的方式发生了冲突,尤其是当 num_workers > 0
时。
核心原因:CUDA 上下文 (Context) 与进程
- 多进程工作: 当你设置
DataLoader
的num_workers
大于 0 时,PyTorch 会启动多个子进程来并行加载数据。每个子进程负责调用Dataset.__getitem__
方法获取单个样本。 - 独立的 CUDA 上下文: 每个进程(主进程和所有 worker 子进程)在使用 GPU 时,都会初始化自己的 CUDA 上下文。CUDA 上下文可以看作是 GPU 上的一个独立工作空间,包含了当前进程分配的显存、加载的 CUDA 核函数等状态。
__getitem__
中的.cuda()
: 当你在__getitem__
中调用torch.tensor(image).cuda()
时,这个张量是在 某个特定的 worker 子进程 的 CUDA 上下文中创建并分配显存的。- 数据传输回主进程:
DataLoader
会从 worker 子进程收集这些已经在 GPU 上的张量,并将它们传递给主进程(也就是你的训练循环所在的进程)。 - 隐式的上下文切换和保留: 这个从 worker 进程到主进程的 GPU 张量传递过程,背后涉及一些复杂的 CUDA 操作。关键在于,即使张量被传递到了主进程,它可能仍然隐式地与创建它的那个 worker 子进程的 CUDA 上下文有关联。更糟糕的是,由于多进程共享和通信的机制,这些来自 worker 进程的 GPU 张量引用可能没有被完全释放,即使你在主进程中
del
了对应的 Python 变量。 worker 进程自身退出时也许会清理,但 DataLoader 的 worker 可能被复用,导致生命周期管理复杂化。 empty_cache()
的局限:torch.cuda.empty_cache()
只能释放 PyTorch 缓存的、当前未被任何活动张量引用 的显存块。它无法释放那些虽然在主进程的 Python 代码里看似已经del
了,但由于上述跨进程引用问题、或是计算图未释放等原因,在 CUDA 层面仍然被认为"在使用中"的显存。
简单来说: 你在子进程 A 的 GPU 工作区创建了一个东西,把它传给主进程 B。主进程 B 用完说“我不要了 (del
)”,但因为这东西的"户口"还在子进程 A 那里,或者交接手续没彻底办完,导致 CUDA 系统觉得这块显存还在用,empty_cache()
这个“清洁工”也无权清理。随着一批批数据从不同的 worker 进程传过来,越来越多这样的"孤儿"显存被占用,最终爆掉。
当 num_workers=0
时,所有数据加载都在主进程完成,不存在跨进程 CUDA 上下文的问题,所以显存通常能被正确释放。但你就损失了并行加载的性能。
二、釜底抽薪:正确的解决方案
知道了原因,解决起来就思路清晰了。核心原则是:数据准备(CPU 操作)和模型计算(GPU 操作)分离,只在主训练循环中将数据移动到目标设备。
方案一:坚守阵地,数据只在主循环上 GPU(推荐)
这是最标准、最推荐的做法。让 Dataset
只负责加载数据、进行必要的 CPU 预处理(如 NumPy 操作、基本的数据增强),并返回 CPU 张量。然后在主训练循环中,获取到一个 batch 的数据后,再将其 .to(device)
或 .cuda()
。
1. 修改 CustomDataset
:
import torch
from torch.utils.data import Dataset
import numpy as np
import os # 需要os来处理文件路径等
# 确保模拟数据存在 (如果运行环境清理了之前的数据)
os.makedirs('dummy_data', exist_ok=True)
if not os.path.exists('dummy_data/img_0.npy'): # 简单检查
print("Recreating dummy data...")
for i in range(10000): # 根据你的实际数据量调整
np.save(f'dummy_data/img_{i}.npy', np.random.rand(3, 224, 224).astype(np.float32))
np.save(f'dummy_data/label_{i}.npy', np.random.randint(0, 10, size=()).astype(np.int64))
print("Dummy data recreated.")
class CustomDatasetFixed(Dataset):
def __init__(self, data_paths):
self.data_paths = data_paths
def __len__(self):
return len(self.data_paths)
def __getitem__(self, idx):
image_path = self.data_paths[idx]['image']
label_path = self.data_paths[idx]['label']
try:
# 加载数据为 NumPy 数组
image = np.load(image_path).astype(np.float32)
label = np.load(label_path).astype(np.int64)
# 将 NumPy 数组转换为 PyTorch CPU 张量
image = torch.tensor(image)
label = torch.tensor(label)
# 不在这里调用 .cuda() 或 .to(device) !
return image, label
except FileNotFoundError:
# 更健壮的错误处理:可以选择记录日志、返回特定标记等
print(f"Warning: Data file not found for index {idx}, path: {image_path} or {label_path}. Returning zeros.")
# 返回一个有效的零张量,确保后续collate_fn或循环能处理
# 尺寸应与正常数据匹配
return torch.zeros((3, 224, 224), dtype=torch.float32), torch.tensor(0, dtype=torch.int64)
except Exception as e:
# 捕获其他可能的加载错误
print(f"Error loading data for index {idx}: {e}. Returning zeros.")
return torch.zeros((3, 224, 224), dtype=torch.float32), torch.tensor(0, dtype=torch.int64)
2. 修改训练循环:
import torch
from torch.utils.data import DataLoader
import time # 用于简单计时
# --- 复用上面的 CustomDatasetFixed 类 ---
# 设置目标设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 生成数据路径列表
data_paths = [{'image': f'dummy_data/img_{i}.npy', 'label': f'dummy_data/label_{i}.npy'} for i in range(10000)] # 假设数据已生成
# 创建 Dataset 和 DataLoader
dataset = CustomDatasetFixed(data_paths) # 使用修改后的 Dataset
dataloader = DataLoader(dataset,
batch_size=32,
num_workers=4, # 可以安全使用多进程
pin_memory=True, # pin_memory 配合后续 .to(device, non_blocking=True) 效果更佳
persistent_workers=True if torch.__version__ >= "1.8.0" and os.name != 'nt' else False # 建议开启以减少 worker 重启开销 (需要 PyTorch 1.8+ 且非 Windows)
)
print("Starting fixed training loop...")
# 模拟训练循环
for epoch in range(5):
print(f"--- Epoch {epoch+1} ---")
epoch_start_time = time.time()
for i, batch in enumerate(dataloader):
batch_start_time = time.time()
# 从 DataLoader 获取的是 CPU 张量 batch
images_cpu, labels_cpu = batch
# !!! 在主进程中,将数据移动到目标设备 !!!
images = images_cpu.to(device, non_blocking=True) # 使用 non_blocking=True 可以与数据拷贝和后续计算轻微重叠
labels = labels_cpu.to(device, non_blocking=True)
# 模拟计算... (这里省略了模型调用、损失计算、反向传播等)
# 示例:简单的操作
output = images + 1
loss = output.mean() # 保持一个标量损失用于示例
# backward() 等操作应在此处进行 (如果模拟完整训练)
# loss.backward()
# optimizer.step()
# optimizer.zero_grad()
batch_end_time = time.time()
if i % 50 == 0: # 每 50 个 batch 打印一次信息
print(f"Epoch {epoch+1}, Batch {i}: Processing time: {batch_end_time - batch_start_time:.4f}s. Images shape: {images.shape}, Device: {images.device}")
# 这里可以再次加入显存监控代码,会发现稳定多了
# import pynvml ... (略)
# 变量会在下次循环开始时被覆盖,Python 会自动处理引用计数
# 通常不需要显式 `del`,除非确实需要立即释放大块内存且后续代码还很长
# del images, labels, output, loss # 通常不需要
epoch_end_time = time.time()
print(f"--- Epoch {epoch+1} finished in {epoch_end_time - epoch_start_time:.2f} seconds ---")
# 调用 empty_cache() 仍然可以做,主要用于清理 PyTorch 的内部缓存碎片,而不是解决泄漏问题
torch.cuda.empty_cache()
print("Fixed training loop finished.")
# 清理模拟数据
import shutil
shutil.rmtree('dummy_data', ignore_errors=True) # 忽略错误以防万一
print("Cleaned up dummy data.")
原理: 这样做,所有 .to(device)
操作都在主进程的单一 CUDA 上下文中执行。数据从 CPU (可能经过 pin_memory
优化) 传输到 GPU。当主循环中的 images
和 labels
变量在下一次迭代被重新赋值时,旧的 GPU 张量如果没有其他引用(比如没有保存在列表里),其引用计数会变为 0,PyTorch 就能知道这块显存可以被回收了。即使 PyTorch 的缓存机制暂时保留了这块内存,它也是标记为“可重用”的,不会导致显存无限增长。pin_memory=True
和 .to(device, non_blocking=True)
结合,可以在 CPU 到 GPU 的数据传输发生时,让 CPU 继续执行一小部分后续代码,略微提升效率,但这与显存泄漏问题本身无关。
方案二:精细管理变量生命周期(辅助手段)
即使采用了方案一,有时候显存问题依然存在,这通常是由于其他原因导致的张量引用未被释放。
1. 检查张量累积:
是不是在循环中不小心把每一批的 output
或者 loss
(或者其他包含计算图的张量) 添加到了一个 Python 列表或字典里,并且没有 .detach()
?
# 错误示例:保存了带有计算图的张量历史
loss_history = []
for batch in dataloader:
images, labels = batch
images = images.to(device)
labels = labels.to(device)
# ... model forward ...
output = model(images)
loss = criterion(output, labels)
loss.backward()
# ... optimizer step ...
# !!! 错误:将带有计算图的 loss 直接存入列表 !!!
# 这会导致每一轮的计算图(包括中间变量占用的显存)都无法释放
loss_history.append(loss)
# 正确做法:只保存数值,切断计算图
loss_history = []
for batch in dataloader:
# ... (同上) ...
loss_history.append(loss.item()) # .item() 获取 Python 标量,不含计算图
# 如果需要保存张量本身(比如用于后续分析),但不需要梯度信息:
tensor_history = []
for batch in dataloader:
# ...
output = model(images)
# ...
tensor_history.append(output.detach().cpu()) # .detach()阻断梯度,.cpu()移回CPU节省显存
2. 合理使用 del
和 torch.cuda.empty_cache()
:
del variable
: 主要作用是减少 Python 对象的引用计数。当引用计数为 0 时,Python 会回收该对象。对于 PyTorch 张量,这 可能 会触发 GPU 显存的释放,但这依赖于 PyTorch 的内存管理机制以及是否有其他隐式引用。在循环结束或者明确知道某个大张量不再需要时使用del
,可以提示系统回收,但不是万能药。torch.cuda.empty_cache()
: 前面说过,它清理的是 PyTorch 缓存的空闲显存块,不能释放仍被引用的张量占用的显存。它对于解决碎片化问题有一定帮助,比如在一系列不同大小的张量分配和释放后,调用它可以整合碎片化的空闲显存。但是,频繁调用它会带来额外的同步开销,拖慢训练速度。通常建议只在关键节点(如 epoch 结束、显存压力确实很大时)或者调试时使用。
简单建议: 优先确保没有意外的张量累积。让 Python 的作用域管理自动处理大部分变量的释放。只有在遇到确实难以解决的显存问题,或者需要精确控制大块显存释放时,才考虑使用 del
,并把 empty_cache()
作为最后的手段或者用于缓解碎片化。
三、进阶排查:GPU 显存分析工具
如果上述方法都试了,显存还是有问题,可能就需要更专业的工具来剖析显存占用了。
-
torch.cuda.memory_summary()
或torch.cuda.memory_stats()
: PyTorch 内置的工具,可以在代码中打印出当前 GPU 的详细显存分配情况、峰值占用、缓存使用等信息。# 在你怀疑显存异常的地方插入 print(torch.cuda.memory_summary(device=device, abbreviated=False)) # 或者打印统计数据 # print(torch.cuda.memory_stats(device=device))
分析这个输出,可以看到哪些尺寸的张量块被分配了,有多少活动内存,有多少缓存等,有助于定位是哪个环节分配的显存没有释放。
-
PyTorch Profiler: 功能更强大,可以记录算子执行时间、CPU/GPU 活动以及显存事件。
from torch.profiler import profile, record_function, ProfilerActivity with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof: with record_function("model_training_step"): # 给代码块命名 # 把你的一个训练迭代代码放在这里 # images, labels = next(iter(dataloader)) # images = images.to(device) # ... forward, backward, step ... pass # 示例占位 print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10)) # 保存结果供 TensorBoard 查看 # prof.export_chrome_trace("memory_trace.json")
Profiler 能关联显存分配/释放事件到具体的代码操作,是精确定位内存问题的利器。
-
外部工具: 如 NVIDIA 的 Nsight Systems / Nsight Compute,提供更底层的系统级和 CUDA 核函数级别的性能和显存分析。
四、总结一下
PyTorch 训练中遇到显存持续增长,即使 del
和 empty_cache()
都用了,多半是 DataLoader
的 num_workers > 0
和在 Dataset.__getitem__
中提前 .cuda()
造成的。
记住最佳实践:
- 保持
Dataset
干净: 只负责加载数据到 CPU 张量。 - 主循环负责上设备: 在训练循环内部,拿到 batch 数据后再
.to(device)
。 - 警惕张量累积: 检查列表、字典等是否保存了带计算图的张量。用
.item()
或.detach().cpu()
保存历史记录。 - 明智使用
del
和empty_cache()
:del
辅助,empty_cache()
主要处理缓存碎片,不能解决根本的引用泄漏。 - 善用分析工具:
memory_summary
, PyTorch Profiler 是深入排查的好帮手。
遵循这些原则,就能大大降低踩中显存陷阱的概率,让你的 GPU 安心炼丹。