返回

PyTorch GPU显存持续增长? 解决DataLoader与.cuda()陷阱

python

PyTorch 训练 GPU 显存持续增长?揪出 DataLoader.cuda() 的隐藏陷阱

跑 PyTorch 训练时,遇上 CUDA out of memory 错误是家常便饭。但有时候,明明代码里规规矩矩地 del 了不再需要的张量,甚至调用了 torch.cuda.empty_cache(),GPU 显存(VRAM)还是像脱缰的野马一样持续上涨,几个 epoch 下来就爆了。这是咋回事?

最近就有朋友遇到了这个问题,他的代码结构大致如下 (已简化):

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os # 模拟文件创建

# --- 准备模拟数据 ---
os.makedirs('dummy_data', exist_ok=True)
for i in range(10000):
    np.save(f'dummy_data/img_{i}.npy', np.random.rand(3, 224, 224).astype(np.float32))
    np.save(f'dummy_data/label_{i}.npy', np.random.randint(0, 10, size=()).astype(np.int64))
# --- 模拟数据准备完毕 ---


class CustomDataset(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        # 从磁盘加载数据 (模拟)
        image_path = self.data_paths[idx]['image']
        label_path = self.data_paths[idx]['label']

        # 假设数据是以 .npy 格式存储
        # 注意:实际应用中, I/O 操作可能更复杂
        try:
            image = np.load(image_path).astype(np.float32)
            label = np.load(label_path).astype(np.int64)
        except FileNotFoundError:
            print(f"Warning: File not found for index {idx}. Skipping.")
            # 返回None或者占位符, 需要在collate_fn中处理
            # 这里简单返回零值张量示意
            image = np.zeros((3, 224, 224), dtype=np.float32)
            label = np.int64(0)


        # !!! 问题关键点:在 Dataset 中就把张量放到了 GPU 上 !!!
        image = torch.tensor(image).cuda()
        label = torch.tensor(label).cuda()

        return image, label

# 生成数据路径列表
data_paths = [{'image': f'dummy_data/img_{i}.npy', 'label': f'dummy_data/label_{i}.npy'} for i in range(10000)]

# 创建 Dataset 和 DataLoader
dataset = CustomDataset(data_paths)
# 注意:num_workers > 0 是触发问题的关键之一
dataloader = DataLoader(dataset, batch_size=32, num_workers=4, pin_memory=True)

# 模拟训练循环
print("Starting simulated training loop...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 后面会用到

# 手动模拟模型放置到 GPU (实际训练会有)
# dummy_model = torch.nn.Module().to(device) # 例子中没用到,但实际情况会有

for epoch in range(5): # 减少 epoch 数量便于观察
    print(f"--- Epoch {epoch+1} ---")
    for i, batch in enumerate(dataloader):
        if batch is None: # 处理 __getitem__ 可能返回 None 的情况 (虽然例子里没返回None)
             print(f"Skipping batch {i} due to loading error.")
             continue

        # 数据已经直接是 GPU 张量了
        images, labels = batch
        print(f"Batch {i}: images on {images.device}, labels on {labels.device}")

        # 模拟一些简单的计算
        # output = images.mean() # 简化计算, 原代码的这步可能引入其他问题, 但非核心
        # loss = output.sum() # 同样简化
        try:
             # 模拟操作,例如与模型交互 (如果模型在GPU上)
             # output = dummy_model(images) # 假如有模型
             # loss = some_loss_function(output, labels)
             # loss.backward() # 模拟反向传播

             # 为了演示显存占用,做个简单的张量操作
             output = images + 1
             loss = output.sum() # 创建一个计算图节点

        except Exception as e:
            print(f"Error during batch {i} processing: {e}")
            # 清理以防万一,尽管问题根源不在此
            if 'images' in locals(): del images
            if 'labels' in locals(): del labels
            if 'loss' in locals(): del loss
            if 'output' in locals(): del output
            torch.cuda.empty_cache()
            continue # 继续下一个 batch

        print(f"Batch {i} processed. VRAM usage might be increasing.")
        # 打印显存占用信息(需要安装 pynvml:pip install nvidia-ml-py3)
        # 注意: 这只是粗略估计,实际占用受 PyTorch caching 影响
        # import pynvml
        # pynvml.nvmlInit()
        # handle = pynvml.nvmlDeviceGetHandleByIndex(0) # 假设使用 GPU 0
        # info = pynvml.nvmlDeviceGetMemoryInfo(handle)
        # print(f"Batch {i}: Used VRAM: {info.used // 1024**2} MB")
        # pynvml.nvmlShutdown()


        # !!! 尝试手动删除,但效果不佳 !!!
        del images
        del labels
        del loss
        del output # 删除计算结果

        # !!! 调用 empty_cache() 也不能阻止显存持续增长 !!!
        torch.cuda.empty_cache()

    print(f"--- Epoch {epoch+1} finished ---")
    # 在 Epoch 结束时也检查显存
    # import pynvml
    # pynvml.nvmlInit()
    # handle = pynvml.nvmlDeviceGetHandleByIndex(0)
    # info = pynvml.nvmlDeviceGetMemoryInfo(handle)
    # print(f"End of Epoch {epoch+1}: Used VRAM: {info.used // 1024**2} MB")
    # pynvml.nvmlShutdown()


print("Simulated training loop finished.")
# 清理模拟数据
import shutil
shutil.rmtree('dummy_data')
print("Cleaned up dummy data.")

即使在循环的最后用 del 删除所有张量,并且调用 torch.cuda.empty_cache(),用 nvidia-smi 观察,显存占用依然是步步高升,直到 OOM 崩溃。把 num_workers 设置成 0 似乎能缓解,但并不是根本解决办法,而且失去了多进程加载数据的优势。

问题到底出在哪?

一、揪出幕后黑手:DataLoader、多进程和 cuda()

问题的根源在于 DataLoader 的工作机制与在 Dataset.__getitem__ 方法中调用 .cuda() 的方式发生了冲突,尤其是当 num_workers > 0 时。

核心原因:CUDA 上下文 (Context) 与进程

  1. 多进程工作: 当你设置 DataLoadernum_workers 大于 0 时,PyTorch 会启动多个子进程来并行加载数据。每个子进程负责调用 Dataset.__getitem__ 方法获取单个样本。
  2. 独立的 CUDA 上下文: 每个进程(主进程和所有 worker 子进程)在使用 GPU 时,都会初始化自己的 CUDA 上下文。CUDA 上下文可以看作是 GPU 上的一个独立工作空间,包含了当前进程分配的显存、加载的 CUDA 核函数等状态。
  3. __getitem__ 中的 .cuda(): 当你在 __getitem__ 中调用 torch.tensor(image).cuda() 时,这个张量是在 某个特定的 worker 子进程 的 CUDA 上下文中创建并分配显存的。
  4. 数据传输回主进程: DataLoader 会从 worker 子进程收集这些已经在 GPU 上的张量,并将它们传递给主进程(也就是你的训练循环所在的进程)。
  5. 隐式的上下文切换和保留: 这个从 worker 进程到主进程的 GPU 张量传递过程,背后涉及一些复杂的 CUDA 操作。关键在于,即使张量被传递到了主进程,它可能仍然隐式地与创建它的那个 worker 子进程的 CUDA 上下文有关联。更糟糕的是,由于多进程共享和通信的机制,这些来自 worker 进程的 GPU 张量引用可能没有被完全释放,即使你在主进程中 del 了对应的 Python 变量。 worker 进程自身退出时也许会清理,但 DataLoader 的 worker 可能被复用,导致生命周期管理复杂化。
  6. empty_cache() 的局限: torch.cuda.empty_cache() 只能释放 PyTorch 缓存的、当前未被任何活动张量引用 的显存块。它无法释放那些虽然在主进程的 Python 代码里看似已经 del 了,但由于上述跨进程引用问题、或是计算图未释放等原因,在 CUDA 层面仍然被认为"在使用中"的显存。

简单来说: 你在子进程 A 的 GPU 工作区创建了一个东西,把它传给主进程 B。主进程 B 用完说“我不要了 (del)”,但因为这东西的"户口"还在子进程 A 那里,或者交接手续没彻底办完,导致 CUDA 系统觉得这块显存还在用,empty_cache() 这个“清洁工”也无权清理。随着一批批数据从不同的 worker 进程传过来,越来越多这样的"孤儿"显存被占用,最终爆掉。

num_workers=0 时,所有数据加载都在主进程完成,不存在跨进程 CUDA 上下文的问题,所以显存通常能被正确释放。但你就损失了并行加载的性能。

二、釜底抽薪:正确的解决方案

知道了原因,解决起来就思路清晰了。核心原则是:数据准备(CPU 操作)和模型计算(GPU 操作)分离,只在主训练循环中将数据移动到目标设备。

方案一:坚守阵地,数据只在主循环上 GPU(推荐)

这是最标准、最推荐的做法。让 Dataset 只负责加载数据、进行必要的 CPU 预处理(如 NumPy 操作、基本的数据增强),并返回 CPU 张量。然后在主训练循环中,获取到一个 batch 的数据后,再将其 .to(device).cuda()

1. 修改 CustomDataset:

import torch
from torch.utils.data import Dataset
import numpy as np
import os # 需要os来处理文件路径等

# 确保模拟数据存在 (如果运行环境清理了之前的数据)
os.makedirs('dummy_data', exist_ok=True)
if not os.path.exists('dummy_data/img_0.npy'): # 简单检查
    print("Recreating dummy data...")
    for i in range(10000): # 根据你的实际数据量调整
        np.save(f'dummy_data/img_{i}.npy', np.random.rand(3, 224, 224).astype(np.float32))
        np.save(f'dummy_data/label_{i}.npy', np.random.randint(0, 10, size=()).astype(np.int64))
    print("Dummy data recreated.")


class CustomDatasetFixed(Dataset):
    def __init__(self, data_paths):
        self.data_paths = data_paths

    def __len__(self):
        return len(self.data_paths)

    def __getitem__(self, idx):
        image_path = self.data_paths[idx]['image']
        label_path = self.data_paths[idx]['label']

        try:
            # 加载数据为 NumPy 数组
            image = np.load(image_path).astype(np.float32)
            label = np.load(label_path).astype(np.int64)

            # 将 NumPy 数组转换为 PyTorch CPU 张量
            image = torch.tensor(image)
            label = torch.tensor(label)

            # 不在这里调用 .cuda() 或 .to(device) !
            return image, label

        except FileNotFoundError:
             # 更健壮的错误处理:可以选择记录日志、返回特定标记等
             print(f"Warning: Data file not found for index {idx}, path: {image_path} or {label_path}. Returning zeros.")
             # 返回一个有效的零张量,确保后续collate_fn或循环能处理
             # 尺寸应与正常数据匹配
             return torch.zeros((3, 224, 224), dtype=torch.float32), torch.tensor(0, dtype=torch.int64)
        except Exception as e:
            # 捕获其他可能的加载错误
            print(f"Error loading data for index {idx}: {e}. Returning zeros.")
            return torch.zeros((3, 224, 224), dtype=torch.float32), torch.tensor(0, dtype=torch.int64)

2. 修改训练循环:

import torch
from torch.utils.data import DataLoader
import time # 用于简单计时

# --- 复用上面的 CustomDatasetFixed 类 ---

# 设置目标设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 生成数据路径列表
data_paths = [{'image': f'dummy_data/img_{i}.npy', 'label': f'dummy_data/label_{i}.npy'} for i in range(10000)] # 假设数据已生成

# 创建 Dataset 和 DataLoader
dataset = CustomDatasetFixed(data_paths) # 使用修改后的 Dataset
dataloader = DataLoader(dataset,
                          batch_size=32,
                          num_workers=4,     # 可以安全使用多进程
                          pin_memory=True,    # pin_memory 配合后续 .to(device, non_blocking=True) 效果更佳
                          persistent_workers=True if torch.__version__ >= "1.8.0" and os.name != 'nt' else False # 建议开启以减少 worker 重启开销 (需要 PyTorch 1.8+ 且非 Windows)
                         )

print("Starting fixed training loop...")

# 模拟训练循环
for epoch in range(5):
    print(f"--- Epoch {epoch+1} ---")
    epoch_start_time = time.time()
    for i, batch in enumerate(dataloader):
        batch_start_time = time.time()

        # 从 DataLoader 获取的是 CPU 张量 batch
        images_cpu, labels_cpu = batch

        # !!! 在主进程中,将数据移动到目标设备 !!!
        images = images_cpu.to(device, non_blocking=True) # 使用 non_blocking=True 可以与数据拷贝和后续计算轻微重叠
        labels = labels_cpu.to(device, non_blocking=True)

        # 模拟计算... (这里省略了模型调用、损失计算、反向传播等)
        # 示例:简单的操作
        output = images + 1
        loss = output.mean() # 保持一个标量损失用于示例

        # backward() 等操作应在此处进行 (如果模拟完整训练)
        # loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()

        batch_end_time = time.time()
        if i % 50 == 0: # 每 50 个 batch 打印一次信息
            print(f"Epoch {epoch+1}, Batch {i}: Processing time: {batch_end_time - batch_start_time:.4f}s. Images shape: {images.shape}, Device: {images.device}")
            # 这里可以再次加入显存监控代码,会发现稳定多了
            # import pynvml ... (略)

        # 变量会在下次循环开始时被覆盖,Python 会自动处理引用计数
        # 通常不需要显式 `del`,除非确实需要立即释放大块内存且后续代码还很长
        # del images, labels, output, loss # 通常不需要

    epoch_end_time = time.time()
    print(f"--- Epoch {epoch+1} finished in {epoch_end_time - epoch_start_time:.2f} seconds ---")
    # 调用 empty_cache() 仍然可以做,主要用于清理 PyTorch 的内部缓存碎片,而不是解决泄漏问题
    torch.cuda.empty_cache()


print("Fixed training loop finished.")
# 清理模拟数据
import shutil
shutil.rmtree('dummy_data', ignore_errors=True) # 忽略错误以防万一
print("Cleaned up dummy data.")

原理: 这样做,所有 .to(device) 操作都在主进程的单一 CUDA 上下文中执行。数据从 CPU (可能经过 pin_memory 优化) 传输到 GPU。当主循环中的 imageslabels 变量在下一次迭代被重新赋值时,旧的 GPU 张量如果没有其他引用(比如没有保存在列表里),其引用计数会变为 0,PyTorch 就能知道这块显存可以被回收了。即使 PyTorch 的缓存机制暂时保留了这块内存,它也是标记为“可重用”的,不会导致显存无限增长。pin_memory=True.to(device, non_blocking=True) 结合,可以在 CPU 到 GPU 的数据传输发生时,让 CPU 继续执行一小部分后续代码,略微提升效率,但这与显存泄漏问题本身无关。

方案二:精细管理变量生命周期(辅助手段)

即使采用了方案一,有时候显存问题依然存在,这通常是由于其他原因导致的张量引用未被释放。

1. 检查张量累积:

是不是在循环中不小心把每一批的 output 或者 loss (或者其他包含计算图的张量) 添加到了一个 Python 列表或字典里,并且没有 .detach()

# 错误示例:保存了带有计算图的张量历史
loss_history = []
for batch in dataloader:
    images, labels = batch
    images = images.to(device)
    labels = labels.to(device)
    # ... model forward ...
    output = model(images)
    loss = criterion(output, labels)
    loss.backward()
    # ... optimizer step ...

    # !!! 错误:将带有计算图的 loss 直接存入列表 !!!
    # 这会导致每一轮的计算图(包括中间变量占用的显存)都无法释放
    loss_history.append(loss)

# 正确做法:只保存数值,切断计算图
loss_history = []
for batch in dataloader:
    # ... (同上) ...
    loss_history.append(loss.item()) # .item() 获取 Python 标量,不含计算图

# 如果需要保存张量本身(比如用于后续分析),但不需要梯度信息:
tensor_history = []
for batch in dataloader:
    # ...
    output = model(images)
    # ...
    tensor_history.append(output.detach().cpu()) # .detach()阻断梯度,.cpu()移回CPU节省显存

2. 合理使用 deltorch.cuda.empty_cache():

  • del variable: 主要作用是减少 Python 对象的引用计数。当引用计数为 0 时,Python 会回收该对象。对于 PyTorch 张量,这 可能 会触发 GPU 显存的释放,但这依赖于 PyTorch 的内存管理机制以及是否有其他隐式引用。在循环结束或者明确知道某个大张量不再需要时使用 del,可以提示系统回收,但不是万能药。
  • torch.cuda.empty_cache(): 前面说过,它清理的是 PyTorch 缓存的空闲显存块,不能释放仍被引用的张量占用的显存。它对于解决碎片化问题有一定帮助,比如在一系列不同大小的张量分配和释放后,调用它可以整合碎片化的空闲显存。但是,频繁调用它会带来额外的同步开销,拖慢训练速度。通常建议只在关键节点(如 epoch 结束、显存压力确实很大时)或者调试时使用。

简单建议: 优先确保没有意外的张量累积。让 Python 的作用域管理自动处理大部分变量的释放。只有在遇到确实难以解决的显存问题,或者需要精确控制大块显存释放时,才考虑使用 del,并把 empty_cache() 作为最后的手段或者用于缓解碎片化。

三、进阶排查:GPU 显存分析工具

如果上述方法都试了,显存还是有问题,可能就需要更专业的工具来剖析显存占用了。

  • torch.cuda.memory_summary()torch.cuda.memory_stats(): PyTorch 内置的工具,可以在代码中打印出当前 GPU 的详细显存分配情况、峰值占用、缓存使用等信息。

    # 在你怀疑显存异常的地方插入
    print(torch.cuda.memory_summary(device=device, abbreviated=False))
    # 或者打印统计数据
    # print(torch.cuda.memory_stats(device=device))
    

    分析这个输出,可以看到哪些尺寸的张量块被分配了,有多少活动内存,有多少缓存等,有助于定位是哪个环节分配的显存没有释放。

  • PyTorch Profiler: 功能更强大,可以记录算子执行时间、CPU/GPU 活动以及显存事件。

    from torch.profiler import profile, record_function, ProfilerActivity
    
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True) as prof:
        with record_function("model_training_step"): # 给代码块命名
            # 把你的一个训练迭代代码放在这里
            # images, labels = next(iter(dataloader))
            # images = images.to(device)
            # ... forward, backward, step ...
            pass # 示例占位
    
    print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))
    # 保存结果供 TensorBoard 查看
    # prof.export_chrome_trace("memory_trace.json")
    

    Profiler 能关联显存分配/释放事件到具体的代码操作,是精确定位内存问题的利器。

  • 外部工具: 如 NVIDIA 的 Nsight Systems / Nsight Compute,提供更底层的系统级和 CUDA 核函数级别的性能和显存分析。

四、总结一下

PyTorch 训练中遇到显存持续增长,即使 delempty_cache() 都用了,多半是 DataLoadernum_workers > 0 和在 Dataset.__getitem__ 中提前 .cuda() 造成的。

记住最佳实践:

  1. 保持 Dataset 干净: 只负责加载数据到 CPU 张量。
  2. 主循环负责上设备: 在训练循环内部,拿到 batch 数据后再 .to(device)
  3. 警惕张量累积: 检查列表、字典等是否保存了带计算图的张量。用 .item().detach().cpu() 保存历史记录。
  4. 明智使用 delempty_cache() del 辅助,empty_cache() 主要处理缓存碎片,不能解决根本的引用泄漏。
  5. 善用分析工具: memory_summary, PyTorch Profiler 是深入排查的好帮手。

遵循这些原则,就能大大降低踩中显存陷阱的概率,让你的 GPU 安心炼丹。