返回

CycleGAN单图推理模糊?3招搞定预处理

Ai

搞定 CycleGAN 单张图片推理:为啥我的图这么糊?

咱们用 pytorch-CycleGAN-and-pix2pix 这个库训练完 CycleGAN 模型,效果不错。用官方提供的数据加载器(DataLoader)跑测试集,生成的图片也挺好。就像这样:

# 假设 opt 已经配置好 TestOptions
# opt = TestOptions().parse() # 通常是这样获取,但这里简化
# opt.num_threads = 0   # test code only supports num_threads = 0
# opt.batch_size = 1    # test code only supports batch_size = 1
# opt.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
# opt.no_flip = True    # no flip; comment this line if results on flipped images are needed.
# opt.display_id = -1   # no visdom display; the test code saves the results to a HTML file.

# --- 这是能正常工作的代码片段 ---
dataset = create_dataset(opt)  # 使用官方 API 创建数据集
model = create_model(opt)      # 创建模型
model.setup(opt)               # 设置模型 (加载权重等)
model.eval()                   # 切换到评估模式

# 遍历数据集进行推理 (这里只取一个样本演示)
data_iter = iter(dataset.dataloader)
data_dict = next(data_iter)

input_image_tensor = data_dict['A']
data = {'A': input_image_tensor, 'A_paths': data_dict['A_paths']} # 注意路径也需要传入
model.set_input(data)          # 设置模型输入
model.test()                   # 执行推理
visuals = model.get_current_visuals() # 获取结果
output_image = visuals['fake']  # 'fake_B''fake_A' 取决于你的模型和方向

# 后处理显示图片
output_image_np = output_image.squeeze().cpu().numpy()
# 将 Tensor [-1, 1] 转换为 [0, 255] 的 NumPy 数组 (CHW -> HWC)
output_image_np = (output_image_np.transpose(1, 2, 0) + 1) / 2.0 * 255.0
output_image_np = output_image_np.astype(np.uint8)

# 如果需要从 RGB 转回 BGR (cv2 默认)
# output_image_np = cv2.cvtColor(output_image_np, cv2.COLOR_RGB2BGR)

# 使用 Colab 的显示函数或 cv2.imwrite 保存
# from google.colab.patches import cv2_imshow
# cv2_imshow(output_image_np)
# cv2.imwrite('output_image.png', output_image_np)
# --- 正常工作代码片段结束 ---

上面的代码跑起来没毛病,生成的图片质量符合预期。

问题来了:我现在不想用 create_dataset 加载一整个文件夹,就想对单张 图片进行推理,咋办?

很自然地,咱会想到模仿 create_dataset 内部的图像预处理逻辑。于是乎,写了下面这样的 preprocess 函数:

import cv2
import torch
from torchvision import transforms
import numpy as np
from PIL import Image # 推荐使用 PIL 处理,与 torchvision 配合更好

# --- 这是尝试手动预处理的代码片段 (有问题) ---
def preprocess_buggy(image_path):
    # 用 OpenCV 读取
    image = cv2.imread(image_path)
    # 颜色通道处理 (这里可能存在隐患)
    if image.ndim == 2 or image.shape[2] == 1:
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif image.shape[2] == 4:
        # 注意:cv2 读取 BGRA,转换目标应该是 BGR 或 RGB
        # 如果后续用 PIL 和 torchvision,转 RGB 更合适
        image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGB) # 改为转 RGB
    elif image.shape[2] == 3:
        # OpenCV 默认读取 BGR,需要转 RGB 给 PIL/torchvision
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 修正:只需一次转换

    # 转为 PIL Image,这是 torchvision transforms 期望的输入格式
    # 注意:上面 cv2 转的 RGB 是 NumPy HWC 格式
    pil_image = Image.fromarray(image) # 从 NumPy 数组创建 PIL Image

    # 定义变换流程 (尝试模仿)
    # 参数应该从 opt 获取,这里先写死,与默认值接近
    load_size = 286
    crop_size = 256
    transform_pipeline = transforms.Compose([
        transforms.Resize([load_size, load_size], interpolation=transforms.InterpolationMode.BICUBIC), # 明确指定插值方法
        transforms.CenterCrop(crop_size),
        transforms.ToTensor(), # 转换成 Tensor,并自动将 [0, 255] 归一化到 [0, 1]
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化到 [-1, 1]
    ])

    image_tensor = transform_pipeline(pil_image)
    # 添加 Batch 维度
    image_tensor = image_tensor.unsqueeze(0) # [C, H, W] -> [1, C, H, W]

    return image_tensor

# 使用手动预处理函数
input_image_path = '/content/drive/MyDrive/dataset/testA/image_1.jpg'
input_image_tensor = preprocess_buggy(input_image_path) # 使用我们写的函数

# 后续推理步骤 (与之前类似)
data = {'A': input_image_tensor, 'A_paths': ['manual_path']} # A_paths 可以随便给个字符串
model.set_input(data)
model.test()
visuals = model.get_current_visuals()
output_image = visuals['fake']

# 后处理显示 (同上)
output_image_np = output_image.squeeze().cpu().numpy()
output_image_np = (output_image_np.transpose(1, 2, 0) + 1) / 2.0 * 255.0
output_image_np = output_image_np.astype(np.uint8)
# cv2_imshow(output_image_np)
# --- 有问题的代码片段结束 ---

结果呢?生成的图片变得非常模糊,完全没有之前用 DataLoader 得到的好效果。这到底是哪儿出了问题?

刨根问底:为啥手动处理就不行?

问题的核心在于:手动实现的预处理步骤和 pytorch-CycleGAN-and-pix2pix 库内部 create_dataset 使用的预处理步骤不完全一致。

GAN(生成对抗网络),特别是像 CycleGAN 这样的图像转换模型,对输入数据的分布非常敏感。训练时模型看到的是经过特定预处理流程的数据,推理时如果输入数据的预处理方式稍有偏差(比如缩放方式不同、裁剪位置不对、归一化参数错误、颜色通道顺序混乱等),模型就可能产生奇怪或质量低劣的结果。

我们上面那个 preprocess_buggy 函数看起来似乎做了缩放、裁剪、转 Tensor、归一化,但魔鬼藏在细节里:

  1. 图像读取库和颜色通道: cv2.imread 默认读取的是 BGR 格式,而 torchvision.transforms 通常期望 PIL Image 的 RGB 格式。虽然代码里做了 cv2.cvtColor,但这个转换过程是否完全等价于库内使用的 PIL 读取和处理,需要打个问号。尤其是原始代码片段里 cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 写了两次,第二次实际是在 RGB 上又做了一次 BGR<->RGB 转换,变成了 BGR,这肯定不对。
  2. 缩放 (Resize) 参数:
    • transforms.Resize([load_size, load_size]): 这会将图像调整到 load_size x load_size。但原库根据 opt.preprocess 参数可能有不同的行为,比如 resize_and_crop 模式。
    • 插值方法 (interpolation): transforms.Resize 默认使用 BILINEAR(双线性插值)。而 pytorch-CycleGAN-and-pix2pix 库里默认可能用的是 BICUBIC(双三次插值),这在 data/base_dataset.pyget_transform 函数里可以看到。不同的插值方法会产生像素级别的差异。我们上面的修正尝试加了 BICUBIC
  3. 裁剪 (Crop) 类型: 手动代码用了 transforms.CenterCrop,这在测试时通常是正确的。但也需要确认训练和测试的 opt 配置是否一致,原库也支持 RandomCrop 等。
  4. ToTensor 的细节: transforms.ToTensor() 不仅将 PIL Image 或 NumPy 数组转为 Tensor,还会自动将像素值从 [0, 255] 的范围缩放到 [0.0, 1.0]。这个行为是标准的。
  5. Normalize 参数: transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))[0.0, 1.0] 的数据变换到 [-1.0, 1.0]。这个参数对于 CycleGAN 很常用,但最好还是确认和你训练时用的 opt 设置一致。
  6. opt 参数的影响: create_dataset 函数会接收一个 opt 对象,里面包含了像 load_size, crop_size, preprocess, no_flip, input_nc 等等一大堆参数。这些参数共同决定了最终的预处理流程。手动复刻时,很容易忽略掉某个关键参数。例如,opt.preprocess 可以是 resize_and_crop, scale_width, scale_width_and_crop 等,每种策略的缩放和裁剪逻辑都不同。

总而言之,手动模仿很容易因为细节疏漏,导致预处理结果与训练时或官方加载器产生的数据分布不匹配,进而造成推理效果变差,图片模糊不清。

对症下药:几种靠谱的解决方案

既然知道了问题所在,解决起来就有方向了。核心思路就是:确保单张图片推理时的预处理与原始训练/测试流程完全一致。

方案一:复刻官方数据加载逻辑 (推荐)

这是最稳妥、最不容易出错的方法。直接利用 pytorch-CycleGAN-and-pix2pix 库里提供的函数来获取正确的预处理变换流程。

原理:

库里的 data/base_dataset.py 文件中有一个 get_transform 函数,它会根据传入的 opt 参数和是否为训练阶段 (isTrain) 返回一个配置好的 transforms.Compose 对象。我们只需要拿到这个对象,然后用它来处理单张图片即可。

步骤:

  1. 导入必要模块:

    import torch
    from torchvision import transforms
    from data.base_dataset import get_transform # 关键!
    from PIL import Image
    import numpy as np
    # 可能还需要 opt 配置相关的类
    # from options.test_options import TestOptions # 如果需要完整 opt
    import argparse # 用于创建简单的命名空间对象模拟 opt
    
  2. 准备 opt 参数: 你不需要一个完整的 TestOptions 解析结果,只需要一个包含必要预处理参数的对象即可。可以用 argparse.Namespace 手动创建一个。关键参数包括:

    • preprocess: 和你训练、测试时用的模式一致,如 'resize_and_crop'
    • load_size: 和训练、测试时一致。
    • crop_size: 和训练、测试时一致。
    • no_flip: 测试时通常设置为 True
    • input_nc: 输入图像通道数 (如 3 代表 RGB)。
    • output_nc: 输出图像通道数。
    • grayscale: 是否是灰度图转换,影响 get_transform 中的 input_nc 判断。
    • 可能还有 load_features 等,取决于你是否用了特殊功能。
    # 创建一个简单的对象来模拟 opt
    # 这些值需要和你训练/测试时的设置完全匹配!
    opt_preprocess = argparse.Namespace(
        preprocess='resize_and_crop', # 例如, 确认你的模式
        load_size=286,             # 确认你的 load_size
        crop_size=256,             # 确认你的 crop_size
        no_flip=True,              # 测试时不翻转
        input_nc=3,                # 输入通道数
        output_nc=3,               # 输出通道数
        grayscale=False,           # 根据你的数据集设置
        phase='test'               # 明确是测试阶段
        # isTrain = False 会在 get_transform 内部根据 phase 设置
        # 可能还需要其他参数,具体看 get_transform 的实现
    )
    
  3. 获取官方的 transform:

    # isTrain=False 告诉 get_transform 我们正在进行测试
    # 如果你的 opt_preprocess 没有 isTrain 属性, get_transform 会根据 opt.phase 判断
    image_transform = get_transform(opt_preprocess, grayscale=(opt_preprocess.input_nc == 1))
    
  4. 加载并处理单张图片:

    def preprocess_official(image_path, transform):
        """使用官方获取的 transform 处理单张图片"""
        # 使用 PIL 加载图像,确保是 RGB
        try:
            img = Image.open(image_path).convert('RGB')
        except IOError as e:
            print(f"Error opening image {image_path}: {e}")
            return None
    
        # 应用官方的变换流程
        img_tensor = transform(img)
    
        # 添加 batch 维度
        img_tensor = img_tensor.unsqueeze(0) # [C, H, W] -> [1, C, H, W]
        return img_tensor
    
    # --- 使用官方 transform 进行推理 ---
    input_image_path = '/content/drive/MyDrive/dataset/testA/image_1.jpg'
    
    # 获取官方的 transform
    image_transform = get_transform(opt_preprocess, grayscale=(opt_preprocess.input_nc == 1))
    
    # 预处理图片
    input_image_tensor = preprocess_official(input_image_path, image_transform)
    
    if input_image_tensor is not None:
        # 后续推理步骤 (和之前一样)
        data = {'A': input_image_tensor, 'A_paths': [input_image_path]} # A_paths 给个路径
        # 假设 model 已经创建并 setup 好
        # model.setup(opt) # 如果没 setup 的话需要 setup
        model.eval()
        model.set_input(data)
        model.test()
        visuals = model.get_current_visuals()
        output_image = visuals['fake'] # 或 'fake_B' / 'fake_A'
    
        # 后处理显示 (同上)
        output_image_np = output_image.squeeze().cpu().numpy()
        output_image_np = (output_image_np.transpose(1, 2, 0) + 1) / 2.0 * 255.0
        output_image_np = output_image_np.astype(np.uint8)
    
        # 在 Colab 中显示或保存
        # from google.colab.patches import cv2_imshow
        # cv2_imshow(output_image_np)
        # cv2.imwrite('output_image_official.png', cv2.cvtColor(output_image_np, cv2.COLOR_RGB2BGR)) # 保存时注意颜色通道
    # --- 官方 transform 推理结束 ---
    

进阶使用技巧:

  • 最小化 opt 对象: 你不需要加载所有 TestOptions。仔细阅读 data/base_dataset.pyget_transform 函数的实现,只把你需要的参数(如 load_size, crop_size, preprocess, no_flip 等)放进 opt_preprocess 对象里就够了。
  • 处理不同 preprocess 模式: 如果你的 opt.preprocess 不是 'resize_and_crop',而是 'scale_width' 或其他,get_transform 函数内部的逻辑会不同。直接调用它能确保你使用的就是对应的正确逻辑,无需手动实现这些复杂的分支。
  • 代码健壮性: 这种方法直接依赖库本身,只要库不发生大的变动,这个方法就能一直正确工作,维护成本最低。

方案二:手动对齐预处理步骤 (需要细心)

如果你坚持要手动实现,或者想深入理解预处理细节,那么就需要非常仔细地比对你的实现和 get_transform 函数的源代码。

原理:

逐行阅读 get_transform 函数,根据你的 opt 参数(特别是 opt.preprocessopt.no_flip)找到对应的代码分支,然后用 torchvision.transforms 复现完全一样的操作序列和参数。

步骤:

  1. 打开 data/base_dataset.py 文件,找到 get_transform 函数。
  2. 根据你的 opt 值,确定执行路径。 例如,如果 opt.preprocess == 'resize_and_crop'isTrain == False (或 opt.phase == 'test'):
    • 它会添加 transforms.Resize([opt.load_size, opt.load_size], interpolation=InterpolationMode.BICUBIC)
    • 然后添加 transforms.CenterCrop(opt.crop_size)。(注意不是 RandomCrop,测试时通常是中心裁剪)。
    • 接着是 transforms.ToTensor()
    • 最后是 transforms.Normalize(...)
  3. 注意细节:
    • 图像读取: 官方常用 PIL.Image.open(path).convert('RGB')。尽量保持一致,避免 cv2 带来的 BGR/RGB 混淆。如果你非要用 cv2 读取,确保在送入 transforms 前正确转换成 RGB 格式的 PIL Image 或 NumPy 数组。
    • 插值方法: 确认 Resize 使用的插值方法 (BICUBIC 还是 BILINEAR 等)。
    • 参数值: 确保 load_size, crop_size 和归一化的均值、标准差都和你训练时完全一样。

修正后的手动预处理代码示例 (假设使用 BICUBIC 插值):

from PIL import Image
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode # 导入插值模式

def preprocess_manual_corrected(image_path, load_size=286, crop_size=256, input_nc=3):
    """手动实现的预处理,力求与官方对齐"""
    
    grayscale = (input_nc == 1)
    
    # 使用 PIL 读取并转换为 RGB (或 灰度)
    try:
        if grayscale:
            img = Image.open(image_path).convert('L')
        else:
            img = Image.open(image_path).convert('RGB')
    except IOError as e:
        print(f"Error opening image {image_path}: {e}")
        return None

    # 构建变换流程 (模拟 'resize_and_crop' 测试模式)
    transform_list = []
    # 1. Resize
    # 检查 opt.preprocess 逻辑, 这里假设是 resize_and_crop
    transform_list.append(transforms.Resize([load_size, load_size], interpolation=InterpolationMode.BICUBIC)) # 明确插值
    
    # 2. Crop (测试时用 CenterCrop)
    transform_list.append(transforms.CenterCrop(crop_size))
    
    # 3. ToTensor (自动归一化到 [0, 1])
    transform_list.append(transforms.ToTensor())
    
    # 4. Normalize (到 [-1, 1])
    if grayscale:
        transform_list.append(transforms.Normalize((0.5,), (0.5,)))
    else:
        transform_list.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
        
    transform_pipeline = transforms.Compose(transform_list)
    
    # 应用变换
    img_tensor = transform_pipeline(img)
    
    # 添加 batch 维度
    img_tensor = img_tensor.unsqueeze(0)
    
    return img_tensor

# --- 使用修正后的手动预处理 ---
input_image_path = '/content/drive/MyDrive/dataset/testA/image_1.jpg'
# 确认参数与 opt 一致
input_image_tensor = preprocess_manual_corrected(input_image_path, 
                                               load_size=opt_preprocess.load_size, 
                                               crop_size=opt_preprocess.crop_size,
                                               input_nc=opt_preprocess.input_nc)

# ... 后续推理和后处理步骤同上 ...
# --- 修正手动预处理结束 ---

注意事项:

  • 非常容易出错: 手动复刻非常依赖你对源码的理解和细心程度,一点小小的偏差就可能导致结果变差。
  • 维护成本高: 如果未来 pytorch-CycleGAN-and-pix2pix 库更新了 get_transform 的逻辑,你需要同步修改你的手动代码。
  • 颜色空间再三确认: cv2 (BGR) 和 PIL/torchvision (RGB) 的混用是常见错误源。强烈建议统一使用 PIL 加载图像。

方案三:利用测试脚本 test.py (便捷但不灵活)

如果你只是偶尔需要测下单张图片,又不想改代码,可以“曲线救国”。

原理:

pytorch-CycleGAN-and-pix2pix 自带了一个 test.py 脚本,它本身就是设计用来跑推理的。我们可以创建一个只包含单张图片的临时文件夹,让 test.py 去处理它。

步骤:

  1. 创建临时目录结构: 假设你的模型是将 A 域转到 B 域。

    • 创建一个主目录,比如 temp_inference
    • temp_inference 下创建一个子目录,名字对应输入域,比如 testA
    • 把你想要推理的那张图片复制到 temp_inference/testA/ 目录下。

    目录结构看起来像这样:

    temp_inference/
        └── testA/
            └── your_single_image.jpg
    
  2. 运行 test.py: 在命令行中,导航到 pytorch-CycleGAN-and-pix2pix 的根目录,然后执行 test.py,并设置合适的参数:

    • --dataroot ./temp_inference: 指定包含 testA 的父目录。
    • --name YourModelName: 你的模型名称 (训练时指定的 --name)。
    • --model cycle_gan: 指定模型类型 (或其他你使用的模型)。
    • --phase test: 明确是测试阶段。
    • --no_dropout: 测试时通常禁用 dropout。
    • --num_test 1: 告诉脚本只处理一张(或少量几张)图片。如果 testA 文件夹里只有一张图,它就只会处理那一张。
    • --load_size, --crop_size, --netG, --norm, --input_nc, --output_nc 等参数需要和你训练时的配置保持一致。
    • --results_dir ./results/single_test: 可以指定一个输出目录。

    示例命令:

    python test.py \
      --dataroot ./temp_inference \
      --name F2F \
      --model cycle_gan \
      --phase test \
      --no_dropout \
      --load_size 286 \
      --crop_size 256 \
      --netG resnet_9blocks \
      --norm instance \
      --input_nc 3 \
      --output_nc 3 \
      --num_test 1 \
      --results_dir ./results/single_inference \
      --gpu_ids -1  # 如果不用 GPU
    
  3. 查找结果: 生成的图片会保存在 --results_dir 指定的目录下(例如 ./results/single_inference/F2F/test_latest/images/),通常会包含原图、fake 图和可能的重建图。

进阶技巧/缺点:

  • 自动化: 如果你需要通过 Python 脚本触发这个过程,可以使用 subprocess 模块来调用 test.py 命令。
  • 开销: 这种方法需要启动一个完整的 Python 进程来运行 test.py,对于单张图片来说,初始化模型等操作的开销相对较大。
  • 灵活性差: 你得到的是保存在文件里的结果,不太方便直接在内存中进行后续处理。每次推理都需要操作文件系统。

总结一下, 处理 CycleGAN 单张图片推理模糊问题的关键是确保预处理步骤与模型训练或官方测试流程完全一致 。最推荐的方法是直接利用库提供的 get_transform 函数 (方案一),这样既准确又省心。如果想深入理解细节或者有特殊定制需求,可以尝试小心翼翼地手动对齐 (方案二),但要做好 Debug 和维护的准备。最后,利用 test.py 脚本 (方案三) 是个快速但不灵活的变通办法。根据你的具体场景选择合适的方式吧!