返回

攻克LlavaOneVision微调:Labels张量的正确生成方法

Ai

搞定 LlavaOneVision 微调:如何正确处理 Labels?

在尝试微调像 llava-hf/llava-onevision-qwen2-0.5b-ov-hf 这样的多模态模型时,一个常见的卡壳点是如何正确地准备 labels 张量。你可能已经搞定了图像和文本的预处理,也写好了 collate_fn 来将数据打包成批次,甚至模型的前向传播和训练步骤也基本就绪,但唯独 labels 这一环让人头疼。不少人尝试过直接克隆 input_ids 作为 labels,或者单独对答案进行分词,结果往往不尽人意,不是报错就是模型学不到东西。这篇就来聊聊怎么搞定这个 labels

问题来了:labels 到底该是啥?

我们先看看典型的 collate_fn 里的问题点:

# (部分 collate_fn 代码)
    texts = []
    answers = [] # 单独存储答案似乎没直接用上
    for example in batch:
        question, answer, rgb_image_np = example
        images.append(rgb_image_np)
        # answers.append(answer) # 这一行在当前逻辑下可能不需要单独存

        # 构建对话历史
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": question},
                    {"type": "image"},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": answer},
                ],
            }
        ]
        # 应用聊天模板,生成完整的输入文本
        text_prompt = self.processor.apply_chat_template(conversation, tokenize=False) # 注意这里先不 tokenize
        texts.append(text_prompt)

    # 使用 processor 处理图像和文本,获得 tokenized input_ids 和 attention_mask
    model_inputs = self.processor(
            images=images,
            text=texts,
            return_tensors="pt",
            padding=True # 注意 padding=True 很重要
        ).to(torch.float16) # 确保数据类型匹配

    # --- 问题核心:如何生成 labels? ---
    labels = ???

    return {
            "pixel_values": model_inputs["pixel_values"], # 别忘了图像数据也要返回
            "input_ids": model_inputs["input_ids"],
            "attention_mask": model_inputs["attention_mask"], # attention_mask 也通常需要
            "labels": labels
        }

# (省略 LightningModule 定义和 training_step)
# forward 函数需要 pixel_values, input_ids, attention_mask, labels
def forward(self, pixel_values, input_ids, attention_mask, labels):
    outputs = self.model(
        pixel_values=pixel_values.to(self.device),
        input_ids=input_ids.to(self.device),
        attention_mask=attention_mask.to(self.device),
        labels=labels.to(self.device)
    )
    return outputs

# training_step 需要解包所有输入
def training_step(self, batch, batch_idx):
    outputs = self(
        pixel_values=batch['pixel_values'],
        input_ids=batch['input_ids'],
        attention_mask=batch['attention_mask'],
        labels=batch['labels']
    )
    loss = outputs.loss
    self.log('train_loss', loss, prog_bar=True)
    return loss

从代码里能看到,processor 将图像和应用了聊天模板的文本(包含了问题 <image> 占位符和答案)一起处理,生成了 model_inputs,其中含有 input_ids。这个 input_ids 包含了整个对话序列的 token ID。

模型训练时,尤其是对于自回归模型(比如 LlavaOneVision 底层的语言模型 Qwen2),目标是预测序列中的下一个 token。损失函数(通常是交叉熵损失)需要将模型的预测 logits 与真实的下一个 token(即 labels)进行比较。

关键点在于:模型只需要在预测“答案”部分时计算损失 ,而不应该在预测“问题”或提示(prompt)部分时计算损失。

为啥会卡住?几种错误尝试剖析

  1. 直接克隆 input_ids

    • labels = model_inputs["input_ids"].clone()
    • 问题 :这样做会让模型在预测输入提示(问题部分)时也计算损失。模型的目标应该是根据提示生成答案,而不是预测提示本身。这会导致模型学习目标混乱,效果很差。
  2. 单独 Tokenize 答案

    • label_tokens = self.processor.tokenizer(text=answers, return_tensors="pt", padding=True)
    • labels = label_tokens["input_ids"]
    • 问题 :这种方法生成的 labels 长度通常与 model_inputs["input_ids"] 的长度不匹配。model_inputs["input_ids"] 包含了完整的对话序列(问题+答案+特殊token),而 label_tokens 只包含答案。Hugging Face 的 LlavaOnevisionForConditionalGeneration 或类似模型在内部计算损失时,期望 input_idslabels 有相同的序列长度。长度不一致会直接导致运行时错误。
  3. 关于右移 (Right Shifting)

    • 你可能听说过需要将 labels 相对于 input_ids 向右移动一位。
    • 解释 :这个概念是正确的,因为模型在位置 i 的输入是 input_ids[i],它应该预测的下一个 token 是原始序列中的 input_ids[i+1]。所以理论上,labels[i] 应该等于 input_ids[i+1]
    • 好消息 :大多数 Hugging Face 的 Causal LM(包括 LlavaOneVision 依赖的底层模型)在 forward 函数内部自动处理了这种移位 。当你提供 labels 参数时,模型内部会将其与 logits 对齐进行损失计算,通常是比较 logits[:, :-1, :]labels[:, 1:]所以,你不需要手动进行右移操作

动手搞定 Labels:正确的处理姿势

核心思路是:labels 张量应该与 input_ids 形状完全相同,但其中对应于“输入提示”(问题、图像占位符、用户角色标识等)部分的 token ID 需要被替换为一个特殊的忽略值(通常是 -100),而对应于“答案”部分的 token ID 则保持不变。Padding 部分的 token ID 也需要被替换为 -100

-100 这个值是 PyTorch 交叉熵损失函数默认的 ignore_index。设置了这个值的 labels 位置不会对最终的损失值产生贡献。

下面是改进后的 collate_fn,重点看如何构造 labels

import torch
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
import pytorch_lightning as pl # 假设你用了 Lightning

# 假设 self.processor 是已经加载好的 AutoProcessor
# 假设 self.processor.tokenizer.pad_token_id 存在

class YourDataCollator: # 可以将 collate_fn 封装在类里
    def __init__(self, processor):
        self.processor = processor
        # 确定 padding token id,如果 tokenizer 没有 pad_token,通常用 eos_token
        self.pad_token_id = processor.tokenizer.pad_token_id
        if self.pad_token_id is None:
            self.pad_token_id = processor.tokenizer.eos_token_id
            print(f"Warning: pad_token_id not set, using eos_token_id ({self.pad_token_id}) as pad token.")
        # 检查 processor 是否有 apply_chat_template 方法
        if not hasattr(self.processor, 'apply_chat_template'):
            raise AttributeError("The processor does not have the 'apply_chat_template' method. "
                                 "Make sure you are using a processor compatible with chat templates.")


    def __call__(self, batch):
        images = []
        texts_full = [] # 存储完整的对话文本
        texts_prompt_only = [] # 存储仅包含用户提示的文本,用于定位答案起点

        for example in batch:
            question, answer, rgb_image_np = example
            images.append(rgb_image_np)

            # 构建用户提示部分对话 (不包含答案)
            conversation_prompt = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": question},
                        {"type": "image"},
                    ],
                },
                # 注意:这里故意只到 assistant role,但不加 content
                {
                    "role": "assistant",
                    "content": [], # 或者省略 "content" 键,取决于模板处理方式
                }
            ]
            # 应用模板得到只有用户部分的文本,末尾包含助手开始标记
            # 注意:tokenize=False 得到的是字符串
            # Add_generation_prompt=True 确保像 "ASSISTANT:" 这样的触发词被加上
            # 具体参数可能需要根据你的 processor 和模型微调
            try:
                # 尝试标准用法
                text_prompt_part = self.processor.apply_chat_template(
                    conversation_prompt,
                    add_generation_prompt=True, # 确保 assistant 角色后的分隔符被添加
                    tokenize=False
                )
            except Exception as e:
                 # 有些模型或模板需要一个空的 assistant content
                conversation_prompt[-1]['content'] = [{"type": "text", "text": ""}]
                text_prompt_part = self.processor.apply_chat_template(
                    conversation_prompt,
                    add_generation_prompt=True, # 理论上这参数在有空content时可能不需要,但最好保留
                    tokenize=False
                )

            texts_prompt_only.append(text_prompt_part)


            # 构建完整对话 (包含答案)
            conversation_full = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": question},
                        {"type": "image"},
                    ],
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer},
                    ],
                }
            ]
             # 应用模板得到完整对话文本
            text_full_part = self.processor.apply_chat_template(
                 conversation_full,
                 add_generation_prompt=False, # 因为已有答案,不再需要 generation prompt
                 tokenize=False
            )
            texts_full.append(text_full_part)

        # 使用 processor 处理图像和完整文本
        # Padding='longest' 或 'max_length' 都可以,这里用 longest
        model_inputs = self.processor(
                images=images,
                text=texts_full,
                return_tensors="pt",
                padding="longest", # 使用 padding='longest' 使得批内序列等长
                truncation=True, # 如果需要截断,加上 truncation=True 和 max_length
                # max_length=YOUR_MAX_LENGTH, # 例如 2048
            )#.to(torch.float16) # 数据类型转换可以放到 LightningModule 里

        input_ids = model_inputs["input_ids"]
        labels = input_ids.clone() # 先克隆 input_ids

        # ---- 核心处理逻辑:识别并 Mask 掉 Prompt 部分 ----
        # 对只有 prompt 部分的文本进行 tokenize,以确定答案开始的位置
        # 这里 tokenize 时不要 padding,我们需要每个 prompt 的实际长度
        prompt_tokens_outputs = self.processor.tokenizer(
            texts_prompt_only,
            return_tensors="pt",
            padding=False, # 重要:这里不能 padding
            truncation=True, # 如果 prompt 也可能超长,也要截断
            # max_length=YOUR_MAX_LENGTH, # 同上
            add_special_tokens=False # 模板应用时已处理特殊token,这里通常不需要再加
        )

        # 遍历批次中的每个样本
        for i in range(len(texts_full)):
            prompt_len = prompt_tokens_outputs.input_ids[i].ne(self.processor.tokenizer.pad_token_id).sum().item()

            # 另一种获取 prompt 长度的方法:基于 tokenizer 处理结果
            # 有些 processor 的 apply_chat_template 返回的是 token_ids,可以简化
            # 如果你的 processor/tokenizer 可以直接获取提示部分长度,用那个更佳
            # 检查 prompt_len 是否合理,防止边界情况
            if prompt_len >= labels.shape[1]: # 如果提示长度等于或超过了总长度(可能因为截断或空答案)
                prompt_len = labels.shape[1] - 1 # 至少保留最后一个 token 不被 mask,虽然理想情况不应发生

            # 将 labels 中对应 prompt 部分的 token 设置为 -100
            # 注意:是从 0 到 prompt_len - 1
            labels[i, :prompt_len] = -100

        # ---- Mask 掉 Padding 部分 ----
        # 找出 input_ids 中是 pad_token 的位置
        padding_mask = input_ids.eq(self.pad_token_id)
        # 将 labels 中对应 padding 的位置也设置为 -100
        labels[padding_mask] = -100

        # --- 检查点:确保 labels 不是全为 -100 ---
        # 如果某个样本的 labels 全是 -100,说明 prompt_len 计算可能有误,或者答案为空/被截断
        # 可以加一个断言或警告
        # assert not torch.all(labels[i] == -100), f"Sample {i} has all labels masked!"

        # 确保返回的数据在需要时移到 GPU(LightningModule 通常会处理)
        return {
                "pixel_values": model_inputs["pixel_values"],
                "input_ids": input_ids,
                "attention_mask": model_inputs["attention_mask"],
                "labels": labels
            }

# --- 对应的 LightningModule 部分需要接收所有数据 ---
# class LlavaOnevisionModule(pl.LightningModule):
#     def __init__(self, model_name, processor, learning_rate=2e-5):
#         super().__init__()
#         # ... (之前的初始化代码) ...
#         # self.pad_token_id 可以在这里也存一份,或从 processor 动态获取

#     def forward(self, pixel_values, input_ids, attention_mask, labels):
#         # 将所有输入移动到正确的 device
#         outputs = self.model(
#             pixel_values=pixel_values.to(self.device),
#             input_ids=input_ids.to(self.device),
#             attention_mask=attention_mask.to(self.device),
#             labels=labels.to(self.device)
#         )
#         return outputs

#     def training_step(self, batch, batch_idx):
#         # 直接解包 batch 字典
#         outputs = self(**batch) # 使用 ** batch 传递所有需要的参数
#         loss = outputs.loss
#         self.log('train_loss', loss, prog_bar=True) # prog_bar=True 使其显示在进度条上
#         return loss

#     # 需要配置 optimizer
#     def configure_optimizers(self):
#         optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
#         return optimizer

原理解释

  1. 完整序列 input_ids :通过 processor(images=..., text=texts_full, ...) 得到,包含了图像标记、问题文本、角色转换标记和答案文本的所有 token ID。还包括了 Attention Mask 和 padding。
  2. 克隆 labelslabels = input_ids.clone() 创建一个与 input_ids 完全相同的副本,作为 labels 的基础。
  3. 定位答案起点
    • 我们利用 processor.apply_chat_template 生成只包含用户提示和助手起始标记 的文本 texts_prompt_only
    • 然后单独对 texts_prompt_only 进行 tokenize(注意:这里不加 padding ),得到每个样本提示部分的实际 token 数量 prompt_len。这个长度就是答案在完整 input_ids 中开始的位置索引(从0开始计数的话,答案第一个 token 在索引 prompt_len 处)。
    • 重要细节apply_chat_template 的行为可能依赖于模型配置和 Transformers 版本。参数 add_generation_prompt=True 的作用是确保模板在用户对话后添加必要的、引导模型生成回答的分隔符或角色标识(如 ASSISTANT:)。你需要根据你使用的具体模型和 processor 确认模板的行为,可能需要查看 processor.chat_template 或相关文档。有时,即使没有助手内容,也需要一个空的 {"role": "assistant", "content": [...]} 条目才能正确触发 add_generation_prompt
  4. Mask 提示部分labels[i, :prompt_len] = -100labels 张量中,从开头到答案开始前(不包括答案第一个 token)的所有 token ID 都设置为 -100
  5. Mask Padding 部分labels[input_ids == self.pad_token_id] = -100 确保所有因 padding 添加的 token ID 在 labels 中也被设置为 -100。这是因为模型也不应该在 padding 位置计算损失。

这样处理后,labels 张量就满足了要求:与 input_ids 形状一致,且只有答案部分的 token ID 用于损失计算,提示和 padding 部分都被忽略了。模型在 forward 时接收这个 labels,内部会自动处理移位并计算正确的损失。

额外的安全建议和技巧

  • Tokenizer 行为确认 :不同模型的 Tokenizer 和 apply_chat_template 行为可能略有差异。务必打印并检查 texts_prompt_only、tokenize 后的 prompt_tokens_outputs.input_ids 以及最终的 input_idslabels,确保 tokenization 和 mask 操作符合预期。特别留意特殊 token(如 <s>, </s>, <|im_start|>, <|im_end|> 等)是否被正确处理和计数。
  • 处理空答案或截断 :如果数据中存在空答案,或者序列因为 max_length 被截断导致答案部分丢失,对应的 labels 可能会全变成 -100。训练循环应该能处理这种情况(损失为 0),但大量此类样本可能影响训练效率。可以考虑在数据预处理阶段过滤掉或修正这些样本。上面代码中增加了一个简单的检查,防止 prompt_len 越界。
  • 数据类型labels 的数据类型应为 torch.longinput_ids 通常也是 longpixel_values 通常是 float(如 torch.float16torch.float32)。确保传递给模型的数据类型符合预期。LightningModule 会自动处理设备(CPU/GPU)转移,但初始类型要对。
  • 调试 :遇到问题时,可以尝试用一个非常小的 batch(比如 batch_size=1)进行调试,逐步打印中间结果 (text_prompt_part, text_full_part, input_ids.shape, prompt_len, labels),看看哪里出了偏差。

现在,你应该能比较顺利地为你的 LlavaOneVision 模型准备好正确的 labels,让微调过程跑起来了。