返回

BERT 模型输入与推理:填坑指南及代码示例

Ai

BERT 模型输入与推理:填坑指南

最近我搞定了一个基于 BERT 的模型训练。训练完后,模型文件夹里出现了几个文件:pytorch_model.bintraining_args.binmerges.txtvocab.json。现在我想测试一下模型,给它个输入,看看输出是啥。但我完全不知道该怎么下手,犯难了!

我上网搜了搜,有人建议用 Gradio,但实际操作时还是有不少疑问。这篇博客就来记录一下我踩过的坑,以及最终解决问题的方法。

一、 问题出在哪?

搞不清怎么给 BERT 模型输入,其实是因为不清楚 BERT 模型接收什么样的输入,以及如何将原始文本转换成这种输入。我们平时看到的文本是字符串,而模型处理的是数字化的张量(tensor)。 关键在于找到模型使用的分词器(Tokenizer),并正确地进行文本预处理。

另外,网上有很多例子,各种方法都有,让人眼花缭乱。 需要一个简单、直接、有效的方式来做模型推理。

二、 解决方案

下面我总结了三种方法,分别用transformers库直接进行推理、构建简易Pipeline,以及使用Gradio搭建一个可视化交互界面。 难度由简入深,大家各取所需。

2.1 方法一:transformers 库直接推理 (最简)

这是最基础,也是最直接的方式。 你只需要加载模型和分词器,然后用分词器处理输入文本,再将处理后的结果输入模型即可。

原理:

transformers 库提供了一个高层次的 API,屏蔽了很多底层细节。 我们主要用到了:

  • AutoTokenizer: 自动根据模型配置文件加载对应的分词器。
  • AutoModelForXXX: 自动根据任务类型(例如问答、序列分类等)加载相应的模型。
  • 分词器的 encodeencode_plus 方法:将文本转换为模型需要的输入格式(通常是 input_ids, attention_mask, token_type_ids)。
  • 模型的 forward 方法(直接调用模型实例即可):接收处理后的输入,返回模型的输出。

代码示例(以文本分类为例):

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# 1. 加载模型和分词器
model_path = "./your_model_directory"  # 替换成你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

# 2. 准备输入文本
text = "这是一段测试文本。"

# 3. 使用分词器处理文本
inputs = tokenizer.encode_plus(
    text,
    add_special_tokens=True,  # 添加 [CLS] 和 [SEP] 标记
    return_tensors="pt"       # 返回 PyTorch tensors
)

# 4. 模型推理
with torch.no_grad():  # 推理时不需要计算梯度
    outputs = model(**inputs)

# 5. 获取预测结果
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)

print(f"预测的类别: {predictions.item()}") # 假设是二分类,输出0或1

安全建议:

  • 如果你处理的是敏感数据,请确保你的模型部署环境安全可靠。

进阶技巧:
如果你是多条句子需要处理:

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# 1. 加载模型和分词器
model_path = "./your_model_directory"  # 替换成你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

# 2. 准备输入文本
texts = ["这是一段测试文本。", "这是第二段测试文本"]

# 3. 使用分词器处理文本, 并注意参数的不同.
inputs = tokenizer(
    texts,
    add_special_tokens=True,
    padding=True,  # 填充到最大长度
    truncation=True, # 截断到最大长度
    return_tensors="pt"
)

# 4. 模型推理
with torch.no_grad():
    outputs = model(**inputs)

# 5. 获取预测结果
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
print(predictions) # 注意这时候预测的是一个列表了.

2.2 方法二:构建简易 Pipeline

transformers 库的 pipeline 模块进一步简化了模型的使用。你可以用几行代码就创建一个用于特定任务的 pipeline,无需手动处理分词和模型调用。

原理:

pipeline 模块将模型的加载、预处理、推理、后处理等步骤封装在一起,提供了一个更高级别的接口。你只需要指定任务类型和模型路径,就可以直接输入文本并获得结果。

代码示例(以情感分析为例):

from transformers import pipeline

# 1. 创建 pipeline
model_path = "./your_model_directory"  # 替换成你的模型路径
sentiment_pipeline = pipeline("sentiment-analysis", model=model_path, tokenizer=model_path)

# 2. 使用 pipeline 进行预测
text = "这个电影真棒!"
result = sentiment_pipeline(text)

# 3. 打印结果
print(result)  # 输出:[{'label': 'POSITIVE', 'score': 0.999}] (示例)

安全建议:

  • 与方法一相同,注意数据安全。

2.3 方法三:Gradio 可视化交互 (最炫)

如果你想让模型更易用,甚至分享给其他人体验,Gradio 是个好东西。它可以快速创建一个 Web 界面,让你通过浏览器与模型交互。

原理:

Gradio 基于 Web 框架,将 Python 函数映射为网页上的交互组件。 你需要定义一个函数,该函数接收输入,调用模型进行推理,并返回输出。Gradio 会自动处理界面的生成和事件处理。

代码示例:

import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# 1. 加载模型和分词器
model_path = "./your_model_directory"  # 替换成你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

# 2. 定义推理函数
def predict_sentiment(text):
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        return_tensors="pt"
    )
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    # 这里我们假设是二分类,并且 labels = ['negative', 'positive']。 如果有变化,需要根据真实情况修改.
    labels = ['negative', 'positive']
    return labels[predictions.item()]

# 3. 创建 Gradio 界面
iface = gr.Interface(
    fn=predict_sentiment,
    inputs="text",
    outputs="text",
    title="情感分析 Demo",
    description="输入一段文本,模型会判断其情感倾向。"
)

# 4. 启动界面
iface.launch()
#   如果你需要共享给别人用, 把下面这行取消注释.
#iface.launch(share=True)

操作步骤:

  1. 安装 Gradio:pip install gradio
  2. 运行上面的代码。
  3. 在浏览器中打开显示的 URL(通常是 http://127.0.0.1:7860)。
  4. 在文本框中输入文本,点击 "Submit" 按钮,查看模型的预测结果。

安全建议:

  • 如果将 Gradio 应用部署到公网,务必采取安全措施,例如身份验证、限制访问等。
  • Gradio 的 share=True 功能会创建一个临时公开链接,使用后及时关闭,避免泄露。

总结

搞清楚怎么给 BERT 输入, 最核心是弄明白你的模型配套的 Tokenizer 是什么。 拿到 Tokenizer 以后, 你可以通过transformers直接写逻辑调用, 也可以利用现成的pipeline直接跑, 或者也可以通过 Gradio 部署一个页面. 这三种办法是比较常见的了, 你可以选择任何一个合适你自己的办法.