攻克LlavaOneVision微调:Labels张量的正确生成方法
2025-03-25 23:12:21
搞定 LlavaOneVision 微调:如何正确处理 Labels?
在尝试微调像 llava-hf/llava-onevision-qwen2-0.5b-ov-hf
这样的多模态模型时,一个常见的卡壳点是如何正确地准备 labels
张量。你可能已经搞定了图像和文本的预处理,也写好了 collate_fn
来将数据打包成批次,甚至模型的前向传播和训练步骤也基本就绪,但唯独 labels
这一环让人头疼。不少人尝试过直接克隆 input_ids
作为 labels
,或者单独对答案进行分词,结果往往不尽人意,不是报错就是模型学不到东西。这篇就来聊聊怎么搞定这个 labels
。
问题来了:labels
到底该是啥?
我们先看看典型的 collate_fn
里的问题点:
# (部分 collate_fn 代码)
texts = []
answers = [] # 单独存储答案似乎没直接用上
for example in batch:
question, answer, rgb_image_np = example
images.append(rgb_image_np)
# answers.append(answer) # 这一行在当前逻辑下可能不需要单独存
# 构建对话历史
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer},
],
}
]
# 应用聊天模板,生成完整的输入文本
text_prompt = self.processor.apply_chat_template(conversation, tokenize=False) # 注意这里先不 tokenize
texts.append(text_prompt)
# 使用 processor 处理图像和文本,获得 tokenized input_ids 和 attention_mask
model_inputs = self.processor(
images=images,
text=texts,
return_tensors="pt",
padding=True # 注意 padding=True 很重要
).to(torch.float16) # 确保数据类型匹配
# --- 问题核心:如何生成 labels? ---
labels = ???
return {
"pixel_values": model_inputs["pixel_values"], # 别忘了图像数据也要返回
"input_ids": model_inputs["input_ids"],
"attention_mask": model_inputs["attention_mask"], # attention_mask 也通常需要
"labels": labels
}
# (省略 LightningModule 定义和 training_step)
# forward 函数需要 pixel_values, input_ids, attention_mask, labels
def forward(self, pixel_values, input_ids, attention_mask, labels):
outputs = self.model(
pixel_values=pixel_values.to(self.device),
input_ids=input_ids.to(self.device),
attention_mask=attention_mask.to(self.device),
labels=labels.to(self.device)
)
return outputs
# training_step 需要解包所有输入
def training_step(self, batch, batch_idx):
outputs = self(
pixel_values=batch['pixel_values'],
input_ids=batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['labels']
)
loss = outputs.loss
self.log('train_loss', loss, prog_bar=True)
return loss
从代码里能看到,processor
将图像和应用了聊天模板的文本(包含了问题 <image>
占位符和答案)一起处理,生成了 model_inputs
,其中含有 input_ids
。这个 input_ids
包含了整个对话序列的 token ID。
模型训练时,尤其是对于自回归模型(比如 LlavaOneVision 底层的语言模型 Qwen2),目标是预测序列中的下一个 token。损失函数(通常是交叉熵损失)需要将模型的预测 logits 与真实的下一个 token(即 labels
)进行比较。
关键点在于:模型只需要在预测“答案”部分时计算损失 ,而不应该在预测“问题”或提示(prompt)部分时计算损失。
为啥会卡住?几种错误尝试剖析
-
直接克隆
input_ids
:labels = model_inputs["input_ids"].clone()
- 问题 :这样做会让模型在预测输入提示(问题部分)时也计算损失。模型的目标应该是根据提示生成答案,而不是预测提示本身。这会导致模型学习目标混乱,效果很差。
-
单独 Tokenize 答案 :
label_tokens = self.processor.tokenizer(text=answers, return_tensors="pt", padding=True)
labels = label_tokens["input_ids"]
- 问题 :这种方法生成的
labels
长度通常与model_inputs["input_ids"]
的长度不匹配。model_inputs["input_ids"]
包含了完整的对话序列(问题+答案+特殊token),而label_tokens
只包含答案。Hugging Face 的LlavaOnevisionForConditionalGeneration
或类似模型在内部计算损失时,期望input_ids
和labels
有相同的序列长度。长度不一致会直接导致运行时错误。
-
关于右移 (Right Shifting) :
- 你可能听说过需要将
labels
相对于input_ids
向右移动一位。 - 解释 :这个概念是正确的,因为模型在位置
i
的输入是input_ids[i]
,它应该预测的下一个 token 是原始序列中的input_ids[i+1]
。所以理论上,labels[i]
应该等于input_ids[i+1]
。 - 好消息 :大多数 Hugging Face 的 Causal LM(包括 LlavaOneVision 依赖的底层模型)在
forward
函数内部自动处理了这种移位 。当你提供labels
参数时,模型内部会将其与logits
对齐进行损失计算,通常是比较logits[:, :-1, :]
和labels[:, 1:]
。所以,你不需要手动进行右移操作 。
- 你可能听说过需要将
动手搞定 Labels:正确的处理姿势
核心思路是:labels
张量应该与 input_ids
形状完全相同,但其中对应于“输入提示”(问题、图像占位符、用户角色标识等)部分的 token ID 需要被替换为一个特殊的忽略值(通常是 -100
),而对应于“答案”部分的 token ID 则保持不变。Padding 部分的 token ID 也需要被替换为 -100
。
-100
这个值是 PyTorch 交叉熵损失函数默认的 ignore_index
。设置了这个值的 labels
位置不会对最终的损失值产生贡献。
下面是改进后的 collate_fn
,重点看如何构造 labels
:
import torch
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
import pytorch_lightning as pl # 假设你用了 Lightning
# 假设 self.processor 是已经加载好的 AutoProcessor
# 假设 self.processor.tokenizer.pad_token_id 存在
class YourDataCollator: # 可以将 collate_fn 封装在类里
def __init__(self, processor):
self.processor = processor
# 确定 padding token id,如果 tokenizer 没有 pad_token,通常用 eos_token
self.pad_token_id = processor.tokenizer.pad_token_id
if self.pad_token_id is None:
self.pad_token_id = processor.tokenizer.eos_token_id
print(f"Warning: pad_token_id not set, using eos_token_id ({self.pad_token_id}) as pad token.")
# 检查 processor 是否有 apply_chat_template 方法
if not hasattr(self.processor, 'apply_chat_template'):
raise AttributeError("The processor does not have the 'apply_chat_template' method. "
"Make sure you are using a processor compatible with chat templates.")
def __call__(self, batch):
images = []
texts_full = [] # 存储完整的对话文本
texts_prompt_only = [] # 存储仅包含用户提示的文本,用于定位答案起点
for example in batch:
question, answer, rgb_image_np = example
images.append(rgb_image_np)
# 构建用户提示部分对话 (不包含答案)
conversation_prompt = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image"},
],
},
# 注意:这里故意只到 assistant role,但不加 content
{
"role": "assistant",
"content": [], # 或者省略 "content" 键,取决于模板处理方式
}
]
# 应用模板得到只有用户部分的文本,末尾包含助手开始标记
# 注意:tokenize=False 得到的是字符串
# Add_generation_prompt=True 确保像 "ASSISTANT:" 这样的触发词被加上
# 具体参数可能需要根据你的 processor 和模型微调
try:
# 尝试标准用法
text_prompt_part = self.processor.apply_chat_template(
conversation_prompt,
add_generation_prompt=True, # 确保 assistant 角色后的分隔符被添加
tokenize=False
)
except Exception as e:
# 有些模型或模板需要一个空的 assistant content
conversation_prompt[-1]['content'] = [{"type": "text", "text": ""}]
text_prompt_part = self.processor.apply_chat_template(
conversation_prompt,
add_generation_prompt=True, # 理论上这参数在有空content时可能不需要,但最好保留
tokenize=False
)
texts_prompt_only.append(text_prompt_part)
# 构建完整对话 (包含答案)
conversation_full = [
{
"role": "user",
"content": [
{"type": "text", "text": question},
{"type": "image"},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer},
],
}
]
# 应用模板得到完整对话文本
text_full_part = self.processor.apply_chat_template(
conversation_full,
add_generation_prompt=False, # 因为已有答案,不再需要 generation prompt
tokenize=False
)
texts_full.append(text_full_part)
# 使用 processor 处理图像和完整文本
# Padding='longest' 或 'max_length' 都可以,这里用 longest
model_inputs = self.processor(
images=images,
text=texts_full,
return_tensors="pt",
padding="longest", # 使用 padding='longest' 使得批内序列等长
truncation=True, # 如果需要截断,加上 truncation=True 和 max_length
# max_length=YOUR_MAX_LENGTH, # 例如 2048
)#.to(torch.float16) # 数据类型转换可以放到 LightningModule 里
input_ids = model_inputs["input_ids"]
labels = input_ids.clone() # 先克隆 input_ids
# ---- 核心处理逻辑:识别并 Mask 掉 Prompt 部分 ----
# 对只有 prompt 部分的文本进行 tokenize,以确定答案开始的位置
# 这里 tokenize 时不要 padding,我们需要每个 prompt 的实际长度
prompt_tokens_outputs = self.processor.tokenizer(
texts_prompt_only,
return_tensors="pt",
padding=False, # 重要:这里不能 padding
truncation=True, # 如果 prompt 也可能超长,也要截断
# max_length=YOUR_MAX_LENGTH, # 同上
add_special_tokens=False # 模板应用时已处理特殊token,这里通常不需要再加
)
# 遍历批次中的每个样本
for i in range(len(texts_full)):
prompt_len = prompt_tokens_outputs.input_ids[i].ne(self.processor.tokenizer.pad_token_id).sum().item()
# 另一种获取 prompt 长度的方法:基于 tokenizer 处理结果
# 有些 processor 的 apply_chat_template 返回的是 token_ids,可以简化
# 如果你的 processor/tokenizer 可以直接获取提示部分长度,用那个更佳
# 检查 prompt_len 是否合理,防止边界情况
if prompt_len >= labels.shape[1]: # 如果提示长度等于或超过了总长度(可能因为截断或空答案)
prompt_len = labels.shape[1] - 1 # 至少保留最后一个 token 不被 mask,虽然理想情况不应发生
# 将 labels 中对应 prompt 部分的 token 设置为 -100
# 注意:是从 0 到 prompt_len - 1
labels[i, :prompt_len] = -100
# ---- Mask 掉 Padding 部分 ----
# 找出 input_ids 中是 pad_token 的位置
padding_mask = input_ids.eq(self.pad_token_id)
# 将 labels 中对应 padding 的位置也设置为 -100
labels[padding_mask] = -100
# --- 检查点:确保 labels 不是全为 -100 ---
# 如果某个样本的 labels 全是 -100,说明 prompt_len 计算可能有误,或者答案为空/被截断
# 可以加一个断言或警告
# assert not torch.all(labels[i] == -100), f"Sample {i} has all labels masked!"
# 确保返回的数据在需要时移到 GPU(LightningModule 通常会处理)
return {
"pixel_values": model_inputs["pixel_values"],
"input_ids": input_ids,
"attention_mask": model_inputs["attention_mask"],
"labels": labels
}
# --- 对应的 LightningModule 部分需要接收所有数据 ---
# class LlavaOnevisionModule(pl.LightningModule):
# def __init__(self, model_name, processor, learning_rate=2e-5):
# super().__init__()
# # ... (之前的初始化代码) ...
# # self.pad_token_id 可以在这里也存一份,或从 processor 动态获取
# def forward(self, pixel_values, input_ids, attention_mask, labels):
# # 将所有输入移动到正确的 device
# outputs = self.model(
# pixel_values=pixel_values.to(self.device),
# input_ids=input_ids.to(self.device),
# attention_mask=attention_mask.to(self.device),
# labels=labels.to(self.device)
# )
# return outputs
# def training_step(self, batch, batch_idx):
# # 直接解包 batch 字典
# outputs = self(**batch) # 使用 ** batch 传递所有需要的参数
# loss = outputs.loss
# self.log('train_loss', loss, prog_bar=True) # prog_bar=True 使其显示在进度条上
# return loss
# # 需要配置 optimizer
# def configure_optimizers(self):
# optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
# return optimizer
原理解释
- 完整序列
input_ids
:通过processor(images=..., text=texts_full, ...)
得到,包含了图像标记、问题文本、角色转换标记和答案文本的所有 token ID。还包括了 Attention Mask 和 padding。 - 克隆
labels
:labels = input_ids.clone()
创建一个与input_ids
完全相同的副本,作为labels
的基础。 - 定位答案起点 :
- 我们利用
processor.apply_chat_template
生成只包含用户提示和助手起始标记 的文本texts_prompt_only
。 - 然后单独对
texts_prompt_only
进行 tokenize(注意:这里不加 padding ),得到每个样本提示部分的实际 token 数量prompt_len
。这个长度就是答案在完整input_ids
中开始的位置索引(从0开始计数的话,答案第一个 token 在索引prompt_len
处)。 - 重要细节 :
apply_chat_template
的行为可能依赖于模型配置和 Transformers 版本。参数add_generation_prompt=True
的作用是确保模板在用户对话后添加必要的、引导模型生成回答的分隔符或角色标识(如ASSISTANT:
)。你需要根据你使用的具体模型和processor
确认模板的行为,可能需要查看processor.chat_template
或相关文档。有时,即使没有助手内容,也需要一个空的{"role": "assistant", "content": [...]}
条目才能正确触发add_generation_prompt
。
- 我们利用
- Mask 提示部分 :
labels[i, :prompt_len] = -100
将labels
张量中,从开头到答案开始前(不包括答案第一个 token)的所有 token ID 都设置为-100
。 - Mask Padding 部分 :
labels[input_ids == self.pad_token_id] = -100
确保所有因 padding 添加的 token ID 在labels
中也被设置为-100
。这是因为模型也不应该在 padding 位置计算损失。
这样处理后,labels
张量就满足了要求:与 input_ids
形状一致,且只有答案部分的 token ID 用于损失计算,提示和 padding 部分都被忽略了。模型在 forward
时接收这个 labels
,内部会自动处理移位并计算正确的损失。
额外的安全建议和技巧
- Tokenizer 行为确认 :不同模型的 Tokenizer 和
apply_chat_template
行为可能略有差异。务必打印并检查texts_prompt_only
、tokenize 后的prompt_tokens_outputs.input_ids
以及最终的input_ids
和labels
,确保 tokenization 和 mask 操作符合预期。特别留意特殊 token(如<s>
,</s>
,<|im_start|>
,<|im_end|>
等)是否被正确处理和计数。 - 处理空答案或截断 :如果数据中存在空答案,或者序列因为
max_length
被截断导致答案部分丢失,对应的labels
可能会全变成-100
。训练循环应该能处理这种情况(损失为 0),但大量此类样本可能影响训练效率。可以考虑在数据预处理阶段过滤掉或修正这些样本。上面代码中增加了一个简单的检查,防止prompt_len
越界。 - 数据类型 :
labels
的数据类型应为torch.long
。input_ids
通常也是long
。pixel_values
通常是float
(如torch.float16
或torch.float32
)。确保传递给模型的数据类型符合预期。LightningModule 会自动处理设备(CPU/GPU)转移,但初始类型要对。 - 调试 :遇到问题时,可以尝试用一个非常小的 batch(比如
batch_size=1
)进行调试,逐步打印中间结果 (text_prompt_part
,text_full_part
,input_ids.shape
,prompt_len
,labels
),看看哪里出了偏差。
现在,你应该能比较顺利地为你的 LlavaOneVision 模型准备好正确的 labels
,让微调过程跑起来了。