AI-For-Beginners语言建模:CBOW与Skip-gram实战
2026-02-04 05:15:14作者:钟日瑜
引言:为什么需要词嵌入?
在传统的自然语言处理中,我们通常使用词袋模型(Bag-of-Words)或TF-IDF来表示文本。这些方法虽然简单有效,但存在两个主要问题:
- 维度灾难:词汇表大小可能达到数万甚至数十万,导致特征向量维度极高
- 语义缺失:one-hot编码无法表达词语之间的语义相似性
词嵌入(Word Embedding)技术通过将词语映射到低维稠密向量空间,完美解决了这两个问题。Word2Vec作为最经典的词嵌入算法,包含两种核心架构:CBOW(Continuous Bag-of-Words)和Skip-gram。
CBOW与Skip-gram原理对比
CBOW(连续词袋模型)
CBOW通过上下文词语预测中心词,训练目标是最大化给定上下文时中心词的条件概率。
flowchart TD
A[输入层: 上下文词语] --> B[嵌入层: 词向量查找]
B --> C[隐藏层: 向量平均或求和]
C --> D[输出层: Softmax分类]
D --> E[预测中心词]
Skip-gram(跳字模型)
Skip-gram与CBOW相反,通过中心词预测上下文词语,训练目标是最大化给定中心词时上下文词语的条件概率。
flowchart TD
A[输入层: 中心词语] --> B[嵌入层: 词向量查找]
B --> C[隐藏层: 共享权重矩阵]
C --> D[输出层: Softmax分类]
D --> E[预测上下文词语]
性能对比表
| 特性 | CBOW | Skip-gram |
|---|---|---|
| 训练速度 | 较快 | 较慢 |
| 低频词处理 | 一般 | 优秀 |
| 语义捕捉 | 整体语义 | 细粒度语义 |
| 适用场景 | 大规模语料 | 小规模语料 |
实战环境搭建
首先确保安装必要的依赖库:
# PyTorch版本
import torch
import torchtext
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
# 数据处理
import collections
import random
import numpy as np
CBOW模型实现详解
1. 数据预处理与词汇表构建
def build_vocabulary(text_corpus, vocab_size=5000):
"""构建词汇表"""
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')
counter = collections.Counter()
for text in text_corpus:
tokens = tokenizer(text)
counter.update(tokens)
# 选择前vocab_size个最频繁的词语
vocab = torchtext.vocab.Vocab(
collections.Counter(dict(counter.most_common(vocab_size))),
min_freq=1
)
return vocab, tokenizer
def text_to_indices(text, vocab, tokenizer):
"""将文本转换为索引序列"""
tokens = tokenizer(text)
return [vocab[token] for token in tokens if token in vocab.stoi]
2. CBOW训练数据生成
def generate_cbow_pairs(sentence_indices, window_size=2):
"""生成CBOW训练样本对"""
pairs = []
n = len(sentence_indices)
for center_pos in range(n):
# 获取上下文窗口
context_start = max(0, center_pos - window_size)
context_end = min(n, center_pos + window_size + 1)
# 排除中心词本身
context_indices = [
sentence_indices[i]
for i in range(context_start, context_end)
if i != center_pos
]
if context_indices:
pairs.append((context_indices, sentence_indices[center_pos]))
return pairs
3. CBOW模型架构
class CBOWModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(CBOWModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, context_words):
# 上下文词嵌入并求平均
embedded = self.embedding(context_words) # [batch_size, context_size, embedding_dim]
embedded_mean = torch.mean(embedded, dim=1) # [batch_size, embedding_dim]
# 预测中心词
output = self.linear(embedded_mean) # [batch_size, vocab_size]
return output
4. 训练循环
def train_cbow(model, train_data, vocab_size, epochs=10, learning_rate=0.01):
"""训练CBOW模型"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(epochs):
total_loss = 0
for context_words, center_word in train_data:
optimizer.zero_grad()
# 前向传播
output = model(context_words)
loss = criterion(output, center_word.unsqueeze(0))
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_data):.4f}')
Skip-gram模型实现
1. Skip-gram训练数据生成
def generate_skipgram_pairs(sentence_indices, window_size=2):
"""生成Skip-gram训练样本对"""
pairs = []
n = len(sentence_indices)
for center_pos in range(n):
center_word = sentence_indices[center_pos]
# 获取上下文窗口
context_start = max(0, center_pos - window_size)
context_end = min(n, center_pos + window_size + 1)
for context_pos in range(context_start, context_end):
if context_pos != center_pos:
context_word = sentence_indices[context_pos]
pairs.append((center_word, context_word))
return pairs
2. Skip-gram模型架构
class SkipGramModel(nn.Module):
def __init__(self, vocab_size, embedding_dim):
super(SkipGramModel, self).__init__()
self.center_embedding = nn.Embedding(vocab_size, embedding_dim)
self.context_embedding = nn.Embedding(vocab_size, embedding_dim)
def forward(self, center_words, context_words):
# 中心词和上下文词嵌入
center_embedded = self.center_embedding(center_words) # [batch_size, embedding_dim]
context_embedded = self.context_embedding(context_words) # [batch_size, embedding_dim]
# 计算相似度得分
scores = torch.matmul(center_embedded, context_embedded.t()) # [batch_size, batch_size]
return scores
3. 负采样训练
class NegativeSamplingLoss(nn.Module):
def __init__(self):
super(NegativeSamplingLoss, self).__init__()
def forward(self, center_vectors, context_vectors, negative_vectors):
# 正样本得分
positive_score = torch.sum(center_vectors * context_vectors, dim=1)
positive_loss = -F.logsigmoid(positive_score).mean()
# 负样本得分
negative_score = torch.sum(center_vectors.unsqueeze(1) * negative_vectors, dim=2)
negative_loss = -F.logsigmoid(-negative_score).mean()
return positive_loss + negative_loss
实战案例:AG新闻数据词嵌入
数据加载与预处理
def load_ag_news_dataset(sample_size=10000):
"""加载AG新闻数据集"""
train_dataset, test_dataset = torchtext.datasets.AG_NEWS(root='./data')
texts = []
for i, (label, text) in enumerate(train_dataset):
if i >= sample_size:
break
texts.append(text)
return texts
# 加载数据并构建词汇表
news_texts = load_ag_news_dataset(10000)
vocab, tokenizer = build_vocabulary(news_texts, vocab_size=5000)
# 生成训练数据
all_cbow_pairs = []
for text in news_texts:
indices = text_to_indices(text, vocab, tokenizer)
cbow_pairs = generate_cbow_pairs(indices, window_size=2)
all_cbow_pairs.extend(cbow_pairs)
模型训练与评估
# 初始化模型
embedding_dim = 100
vocab_size = len(vocab)
cbow_model = CBOWModel(vocab_size, embedding_dim)
# 训练CBOW模型
train_cbow(cbow_model, all_cbow_pairs, vocab_size, epochs=20, learning_rate=0.01)
# 获取词向量
word_vectors = cbow_model.embedding.weight.data
语义相似度查询
def find_similar_words(query_word, word_vectors, vocab, top_n=5):
"""查找语义相似的词语"""
if query_word not in vocab.stoi:
return f"'{query_word}' not in vocabulary"
query_idx = vocab.stoi[query_word]
query_vector = word_vectors[query_idx]
# 计算余弦相似度
similarities = torch.nn.functional.cosine_similarity(
word_vectors, query_vector.unsqueeze(0), dim=1
)
# 获取最相似的词语(排除查询词本身)
top_indices = similarities.argsort(descending=True)[1:top_n+1]
similar_words = [vocab.itos[idx] for idx in top_indices]
return similar_words
# 测试语义相似度
print("与'microsoft'相似的词语:", find_similar_words('microsoft', word_vectors, vocab))
print("与'basketball'相似的词语:", find_similar_words('basketball', word_vectors, vocab))
性能优化技巧
1. 负采样加速训练
def negative_sampling(vocab, num_negatives=5):
"""负采样实现"""
word_freq = np.array([vocab.freqs[word] for word in vocab.itos])
word_probs = word_freq ** 0.75 # 平滑处理
word_probs /= word_probs.sum()
negative_samples = np.random.choice(
len(vocab), size=num_negatives, p=word_probs, replace=False
)
return torch.tensor(negative_samples)
2. 分层Softmax
class HierarchicalSoftmax(nn.Module):
"""分层Softmax实现"""
def __init__(self, vocab_size, embedding_dim):
super(HierarchicalSoftmax, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
# 构建霍夫曼树等实现...
3. 批量训练优化
class CBOWDataset(Dataset):
"""自定义数据集类"""
def __init__(self, pairs):
self.pairs = pairs
def __len__(self):
return len(self.pairs)
def __getitem__(self, idx):
context, center = self.pairs[idx]
return torch.tensor(context), torch.tensor(center)
# 使用DataLoader进行批量训练
dataset = CBOWDataset(all_cbow_pairs)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
实际应用场景
1. 文本分类增强
class TextClassifierWithEmbeddings(nn.Module):
"""使用预训练词嵌入的文本分类器"""
def __init__(self, vocab_size, embedding_dim, num_classes, pretrained_embeddings=None):
super(TextClassifierWithEmbeddings, self).__init__()
if pretrained_embeddings is not None:
self.embedding = nn.Embedding.from_pretrained(pretrained_embeddings)
else:
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, 128, batch_first=True, bidirectional=True)
self.classifier = nn.Linear(256, num_classes)
def forward(self, x):
embedded = self.embedding(x)
lstm_out, _ = self.lstm(embedded)
pooled = torch.mean(lstm_out, dim=1)
return self.classifier(pooled)
2. 推荐系统语义匹配
def semantic_similarity(item1_text, item2_text, word_vectors, vocab):
"""计算两个文本的语义相似度"""
def text_to_vector(text):
tokens = tokenizer(text)
indices = [vocab[token] for token in tokens if token in vocab.stoi]
if not indices:
return None
vectors = word_vectors[indices]
return torch.mean(vectors, dim=0)
vec1 = text_to_vector(item1_text)
vec2 = text_to_vector(item2_text)
if vec1 is None or vec2 is None:
return 0.0
return F.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0)).item()
常见问题与解决方案
1. 内存不足问题
# 使用稀疏更新
optimizer = optim.SparseAdam(model.parameters(), lr=0.001)
# 梯度累积
def train_with_gradient_accumulation(model, dataloader, accumulation_steps=4):
optimizer.zero_grad()
for i, (context, center) in enumerate(dataloader):
output = model(context)
loss = criterion(output, center)
loss = loss / accumulation_steps
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
2. 训练不稳定问题
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# 学习率调度
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=2
)
3. 低频词处理
# 子词嵌入
import fasttext
# 或者使用字符级CNN
class CharCNNEmbedding(nn.Module):
def __init__(self, char_vocab_size, char_embedding_dim, word_embedding_dim):
super(CharCNNEmbedding, self).__init__()
self.char_embedding = nn.Embedding(char_vocab_size, char_embedding_dim)
self.conv = nn.Conv1d(char_embedding_dim, word_embedding_dim, kernel_size=3)
def forward(self, char_indices):
embedded = self.char_embedding(char_indices)
conv_out = self.conv(embedded.permute(0, 2, 1))
return torch.max(conv_out, dim=2)[0]
总结与展望
通过本教程,我们深入探讨了CBOW和Skip-gram两种Word2Vec架构的原理和实现。这两种方法虽然简单,但为现代NLP奠定了坚实基础。
关键收获:
- CBOW适合快速训练:上下文预测中心词,训练效率高
- Skip-gram处理低频词更优:中心词预测上下文,对罕见词更友好
- 负采样大幅提升性能:通过采样负样本加速训练过程
- 词嵌入是NLP基础:为下游任务提供高质量的语义表示
进阶方向:
- 探索GloVe、fastText等改进算法
- 尝试BERT、GPT等预训练语言模型
- 应用于具体业务场景的定制化词嵌入
词嵌入技术仍在不断发展,但CBOW和Skip-gram作为经典方法,仍然是理解现代NLP的重要基础。掌握这些基础技术,将为你深入学习更复杂的语言模型打下坚实基础。
登录后查看全文
热门项目推荐
相关项目推荐
Kimi-K2.5Kimi K2.5 是一款开源的原生多模态智能体模型,它在 Kimi-K2-Base 的基础上,通过对约 15 万亿混合视觉和文本 tokens 进行持续预训练构建而成。该模型将视觉与语言理解、高级智能体能力、即时模式与思考模式,以及对话式与智能体范式无缝融合。Python00- QQwen3-Coder-Next2026年2月4日,正式发布的Qwen3-Coder-Next,一款专为编码智能体和本地开发场景设计的开源语言模型。Python00
xw-cli实现国产算力大模型零门槛部署,一键跑通 Qwen、GLM-4.7、Minimax-2.1、DeepSeek-OCR 等模型Go06
PaddleOCR-VL-1.5PaddleOCR-VL-1.5 是 PaddleOCR-VL 的新一代进阶模型,在 OmniDocBench v1.5 上实现了 94.5% 的全新 state-of-the-art 准确率。 为了严格评估模型在真实物理畸变下的鲁棒性——包括扫描伪影、倾斜、扭曲、屏幕拍摄和光照变化——我们提出了 Real5-OmniDocBench 基准测试集。实验结果表明,该增强模型在新构建的基准测试集上达到了 SOTA 性能。此外,我们通过整合印章识别和文本检测识别(text spotting)任务扩展了模型的能力,同时保持 0.9B 的超紧凑 VLM 规模,具备高效率特性。Python00
KuiklyUI基于KMP技术的高性能、全平台开发框架,具备统一代码库、极致易用性和动态灵活性。 Provide a high-performance, full-platform development framework with unified codebase, ultimate ease of use, and dynamic flexibility. 注意:本仓库为Github仓库镜像,PR或Issue请移步至Github发起,感谢支持!Kotlin07
VLOOKVLOOK™ 是优雅好用的 Typora/Markdown 主题包和增强插件。 VLOOK™ is an elegant and practical THEME PACKAGE × ENHANCEMENT PLUGIN for Typora/Markdown.Less00
热门内容推荐
最新内容推荐
5分钟掌握ImageSharp色彩矩阵变换:图像色调调整的终极指南3分钟解决Cursor试用限制:go-cursor-help工具全攻略Transmission数据库迁移工具:转移种子状态到新设备如何在VMware上安装macOS?解锁神器Unlocker完整使用指南如何为so-vits-svc项目贡献代码:从提交Issue到创建PR的完整指南Label Studio数据处理管道设计:ETL流程与标注前预处理终极指南突破拖拽限制:React Draggable社区扩展与实战指南如何快速安装 JSON Formatter:让 JSON 数据阅读更轻松的终极指南Element UI表格数据地图:Table地理数据可视化Formily DevTools:让表单开发调试效率提升10倍的神器
项目优选
收起
deepin linux kernel
C
27
11
OpenHarmony documentation | OpenHarmony开发者文档
Dockerfile
525
3.73 K
Ascend Extension for PyTorch
Python
332
396
暂无简介
Dart
766
189
本项目是CANN提供的数学类基础计算算子库,实现网络在NPU上加速计算。
C++
878
586
openEuler内核是openEuler操作系统的核心,既是系统性能与稳定性的基石,也是连接处理器、设备与服务的桥梁。
C
336
166
React Native鸿蒙化仓库
JavaScript
302
352
Nop Platform 2.0是基于可逆计算理论实现的采用面向语言编程范式的新一代低代码开发平台,包含基于全新原理从零开始研发的GraphQL引擎、ORM引擎、工作流引擎、报表引擎、规则引擎、批处理引引擎等完整设计。nop-entropy是它的后端部分,采用java语言实现,可选择集成Spring框架或者Quarkus框架。中小企业可以免费商用
Java
12
1
🎉 (RuoYi)官方仓库 基于SpringBoot,Spring Security,JWT,Vue3 & Vite、Element Plus 的前后端分离权限管理系统
Vue
1.33 K
749
openJiuwen agent-studio提供零码、低码可视化开发和工作流编排,模型、知识库、插件等各资源管理能力
TSX
985
246