BERT 模型输入与推理:填坑指南及代码示例
2025-03-13 03:15:21
BERT 模型输入与推理:填坑指南
最近我搞定了一个基于 BERT 的模型训练。训练完后,模型文件夹里出现了几个文件:pytorch_model.bin
、training_args.bin
、merges.txt
、vocab.json
。现在我想测试一下模型,给它个输入,看看输出是啥。但我完全不知道该怎么下手,犯难了!
我上网搜了搜,有人建议用 Gradio,但实际操作时还是有不少疑问。这篇博客就来记录一下我踩过的坑,以及最终解决问题的方法。
一、 问题出在哪?
搞不清怎么给 BERT 模型输入,其实是因为不清楚 BERT 模型接收什么样的输入,以及如何将原始文本转换成这种输入。我们平时看到的文本是字符串,而模型处理的是数字化的张量(tensor)。 关键在于找到模型使用的分词器(Tokenizer),并正确地进行文本预处理。
另外,网上有很多例子,各种方法都有,让人眼花缭乱。 需要一个简单、直接、有效的方式来做模型推理。
二、 解决方案
下面我总结了三种方法,分别用transformers
库直接进行推理、构建简易Pipeline,以及使用Gradio
搭建一个可视化交互界面。 难度由简入深,大家各取所需。
2.1 方法一:transformers
库直接推理 (最简)
这是最基础,也是最直接的方式。 你只需要加载模型和分词器,然后用分词器处理输入文本,再将处理后的结果输入模型即可。
原理:
transformers
库提供了一个高层次的 API,屏蔽了很多底层细节。 我们主要用到了:
AutoTokenizer
: 自动根据模型配置文件加载对应的分词器。AutoModelForXXX
: 自动根据任务类型(例如问答、序列分类等)加载相应的模型。- 分词器的
encode
或encode_plus
方法:将文本转换为模型需要的输入格式(通常是 input_ids, attention_mask, token_type_ids)。 - 模型的
forward
方法(直接调用模型实例即可):接收处理后的输入,返回模型的输出。
代码示例(以文本分类为例):
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# 1. 加载模型和分词器
model_path = "./your_model_directory" # 替换成你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# 2. 准备输入文本
text = "这是一段测试文本。"
# 3. 使用分词器处理文本
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True, # 添加 [CLS] 和 [SEP] 标记
return_tensors="pt" # 返回 PyTorch tensors
)
# 4. 模型推理
with torch.no_grad(): # 推理时不需要计算梯度
outputs = model(**inputs)
# 5. 获取预测结果
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
print(f"预测的类别: {predictions.item()}") # 假设是二分类,输出0或1
安全建议:
- 如果你处理的是敏感数据,请确保你的模型部署环境安全可靠。
进阶技巧:
如果你是多条句子需要处理:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# 1. 加载模型和分词器
model_path = "./your_model_directory" # 替换成你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# 2. 准备输入文本
texts = ["这是一段测试文本。", "这是第二段测试文本"]
# 3. 使用分词器处理文本, 并注意参数的不同.
inputs = tokenizer(
texts,
add_special_tokens=True,
padding=True, # 填充到最大长度
truncation=True, # 截断到最大长度
return_tensors="pt"
)
# 4. 模型推理
with torch.no_grad():
outputs = model(**inputs)
# 5. 获取预测结果
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
print(predictions) # 注意这时候预测的是一个列表了.
2.2 方法二:构建简易 Pipeline
transformers
库的 pipeline
模块进一步简化了模型的使用。你可以用几行代码就创建一个用于特定任务的 pipeline,无需手动处理分词和模型调用。
原理:
pipeline
模块将模型的加载、预处理、推理、后处理等步骤封装在一起,提供了一个更高级别的接口。你只需要指定任务类型和模型路径,就可以直接输入文本并获得结果。
代码示例(以情感分析为例):
from transformers import pipeline
# 1. 创建 pipeline
model_path = "./your_model_directory" # 替换成你的模型路径
sentiment_pipeline = pipeline("sentiment-analysis", model=model_path, tokenizer=model_path)
# 2. 使用 pipeline 进行预测
text = "这个电影真棒!"
result = sentiment_pipeline(text)
# 3. 打印结果
print(result) # 输出:[{'label': 'POSITIVE', 'score': 0.999}] (示例)
安全建议:
- 与方法一相同,注意数据安全。
2.3 方法三:Gradio 可视化交互 (最炫)
如果你想让模型更易用,甚至分享给其他人体验,Gradio 是个好东西。它可以快速创建一个 Web 界面,让你通过浏览器与模型交互。
原理:
Gradio 基于 Web 框架,将 Python 函数映射为网页上的交互组件。 你需要定义一个函数,该函数接收输入,调用模型进行推理,并返回输出。Gradio 会自动处理界面的生成和事件处理。
代码示例:
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# 1. 加载模型和分词器
model_path = "./your_model_directory" # 替换成你的模型路径
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
# 2. 定义推理函数
def predict_sentiment(text):
inputs = tokenizer.encode_plus(
text,
add_special_tokens=True,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# 这里我们假设是二分类,并且 labels = ['negative', 'positive']。 如果有变化,需要根据真实情况修改.
labels = ['negative', 'positive']
return labels[predictions.item()]
# 3. 创建 Gradio 界面
iface = gr.Interface(
fn=predict_sentiment,
inputs="text",
outputs="text",
title="情感分析 Demo",
description="输入一段文本,模型会判断其情感倾向。"
)
# 4. 启动界面
iface.launch()
# 如果你需要共享给别人用, 把下面这行取消注释.
#iface.launch(share=True)
操作步骤:
- 安装 Gradio:
pip install gradio
- 运行上面的代码。
- 在浏览器中打开显示的 URL(通常是
http://127.0.0.1:7860
)。 - 在文本框中输入文本,点击 "Submit" 按钮,查看模型的预测结果。
安全建议:
- 如果将 Gradio 应用部署到公网,务必采取安全措施,例如身份验证、限制访问等。
- Gradio 的
share=True
功能会创建一个临时公开链接,使用后及时关闭,避免泄露。
总结
搞清楚怎么给 BERT 输入, 最核心是弄明白你的模型配套的 Tokenizer 是什么。 拿到 Tokenizer 以后, 你可以通过transformers
直接写逻辑调用, 也可以利用现成的pipeline
直接跑, 或者也可以通过 Gradio
部署一个页面. 这三种办法是比较常见的了, 你可以选择任何一个合适你自己的办法.