首页
/ 基于BERT模型的20Newsgroups文本分类实战教程

基于BERT模型的20Newsgroups文本分类实战教程

2026-02-04 04:38:58作者:薛曦旖Francesca

前言

在自然语言处理(NLP)领域,文本分类是一项基础且重要的任务。随着深度学习技术的发展,预训练语言模型如BERT已经显著提升了文本分类的性能。本文将详细介绍如何使用BERT模型对20Newsgroups数据集进行微调(fine-tuning),实现高效的文本分类。

环境准备

首先需要安装必要的Python库:

!pip install --upgrade --user transformers

然后导入所需的库:

import torch
from transformers import BertTokenizerFast, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
import numpy as np
import random
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split

设置随机种子

为了保证实验的可重复性,我们需要设置随机种子:

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
set_seed(1)

模型和参数配置

我们使用BERT的基础版本(uncased)作为预训练模型:

model_name = "bert-base-uncased"
max_length = 512  # BERT的最大输入长度

数据准备

加载20Newsgroups数据集

20Newsgroups是一个经典的文本分类数据集,包含20个不同主题的新闻组文档:

def read_20newsgroups(test_size=0.2):
    dataset = fetch_20newsgroups(subset="all", shuffle=True, remove=("headers", "footers", "quotes"))
    documents = dataset.data
    labels = dataset.target
    return train_test_split(documents, labels, test_size=test_size), dataset.target_names

(train_texts, valid_texts, train_labels, valid_labels), target_names = read_20newsgroups()

数据预处理

使用BERT的分词器对文本进行编码:

tokenizer = BertTokenizerFast.from_pretrained(model_name, do_lower_case=True)

train_encodings = tokenizer(train_texts, truncation=True, padding=True, max_length=max_length)
valid_encodings = tokenizer(valid_texts, truncation=True, padding=True, max_length=max_length)

创建PyTorch数据集

将编码后的数据转换为PyTorch Dataset格式:

class NewsGroupsDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor([self.labels[idx]])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = NewsGroupsDataset(train_encodings, train_labels)
valid_dataset = NewsGroupsDataset(valid_encodings, valid_labels)

模型训练

加载预训练模型

model = BertForSequenceClassification.from_pretrained(model_name, num_labels=len(target_names)).to("cuda")

定义评估指标

from sklearn.metrics import accuracy_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc}

设置训练参数

training_args = TrainingArguments(
    output_dir='./results',          # 输出目录
    num_train_epochs=3,              # 训练轮数
    per_device_train_batch_size=8,   # 训练批次大小
    per_device_eval_batch_size=20,   # 评估批次大小
    warmup_steps=500,                # 学习率预热步数
    weight_decay=0.01,               # 权重衰减
    logging_dir='./logs',            # 日志目录
    load_best_model_at_end=True,     # 训练结束时加载最佳模型
    logging_steps=400,               # 日志记录间隔
    save_steps=400,                  # 模型保存间隔
    evaluation_strategy="steps",     # 评估策略
)

创建Trainer并开始训练

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

模型评估与保存

评估模型性能

trainer.evaluate()

保存微调后的模型

model_path = "20newsgroups-bert-base-uncased"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

模型应用

加载已保存的模型

model = BertForSequenceClassification.from_pretrained(model_path, num_labels=len(target_names)).to("cuda")
tokenizer = BertTokenizerFast.from_pretrained(model_path)

预测函数

def get_prediction(text):
    inputs = tokenizer(text, padding=True, truncation=True, max_length=max_length, return_tensors="pt").to("cuda")
    outputs = model(**inputs)
    probs = outputs[0].softmax(1)
    return target_names[probs.argmax()]

示例预测

# 示例1:科技类
text = """With the pace of smartphone evolution moving so fast..."""
print(get_prediction(text))

# 示例2:科学类
text = """A black hole is a place in space where gravity pulls so much..."""
print(get_prediction(text))

# 示例3:医学类
text = """Respiratory illness is a common health condition..."""
print(get_prediction(text))

总结

本教程详细介绍了如何使用BERT模型对20Newsgroups数据集进行文本分类,包括:

  1. 环境准备和数据集加载
  2. 数据预处理和编码
  3. 模型配置和训练参数设置
  4. 模型训练、评估和保存
  5. 实际应用示例

通过这个流程,读者可以掌握BERT模型在文本分类任务中的完整应用方法,并可以将其扩展到其他类似的NLP任务中。BERT的强大表示能力结合适当的微调策略,可以在各种文本分类任务上取得优异的表现。

登录后查看全文
热门项目推荐
相关项目推荐