Transformer 架构原理
架构总览
Transformer 由 Vaswani 等人在 2017 年论文《Attention Is All You Need》中提出,彻底取代了 RNN/LSTM 成为 NLP 主流架构。
输入序列 [x1, x2, ..., xn]
│
▼
┌─────────────────────────────┐
│ Input Embedding │ 词向量映射
│ + Positional Encoding │ 位置编码
└─────────────────────────────┘
│
▼ × N 层
┌─────────────────────────────┐
│ Multi-Head Attention │ 多头自注意力
│ Add & Layer Norm │
│ Feed Forward Network │ 前馈网络
│ Add & Layer Norm │
└─────────────────────────────┘
│
▼
┌─────────────────────────────┐
│ Linear + Softmax │ 输出概率分布
└─────────────────────────────┘Self-Attention 核心计算
python
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q: (batch, heads, seq_len, d_k)
K: (batch, heads, seq_len, d_k)
V: (batch, heads, seq_len, d_v)
"""
d_k = Q.size(-1)
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# 可选:应用掩码(用于 Decoder 的因果掩码)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# Softmax 归一化
attn_weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights多头注意力(Multi-Head Attention)
python
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, num_heads=8):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# 线性投影层
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
x = x.view(batch_size, -1, self.num_heads, self.d_k)
return x.transpose(1, 2) # (batch, heads, seq, d_k)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
Q = self.split_heads(self.W_q(Q), batch_size)
K = self.split_heads(self.W_k(K), batch_size)
V = self.split_heads(self.W_v(V), batch_size)
attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
# 合并多头
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
return self.W_o(attn_output)位置编码(Positional Encoding)
Transformer 本身无序列顺序感知,通过位置编码注入位置信息:
python
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 构建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(
torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)前馈网络(FFN)
python
class FeedForward(nn.Module):
def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.linear2(
self.dropout(F.relu(self.linear1(x)))
)现代 LLM 的架构改进
| 改进点 | 原始 Transformer | 现代 LLM(如 Llama) |
|---|---|---|
| 归一化位置 | Post-LN | Pre-LN(更稳定) |
| 归一化方式 | Layer Norm | RMS Norm(更快) |
| 激活函数 | ReLU | SwiGLU / GeGLU |
| 位置编码 | 正弦绝对编码 | RoPE 旋转位置编码 |
| 注意力优化 | 标准 MHA | GQA(分组查询注意力) |
| KV Cache | 无 | 有(推理加速关键) |
RoPE 旋转位置编码
现代 LLM 普遍采用 RoPE,支持外推到更长上下文:
python
def apply_rotary_emb(xq, xk, freqs_cis):
"""
xq, xk: (batch, seq, heads, head_dim)
freqs_cis: 预计算的旋转频率
"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = freqs_cis[:xq_.shape[1]]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)关键参数规模
| 模型 | 参数量 | 层数 | 隐藏维度 | 注意力头数 |
|---|---|---|---|---|
| GPT-2 | 1.5B | 48 | 1600 | 25 |
| Llama 2 7B | 7B | 32 | 4096 | 32 |
| Llama 2 70B | 70B | 80 | 8192 | 64 |
| GPT-4 | ~1.8T(MoE) | - | - | - |
工程实践
理解 Transformer 架构有助于:
- 合理设置
max_tokens和上下文长度 - 理解为什么长文本推理更慢(O(n²) 注意力复杂度)
- 选择合适的微调策略(LoRA 针对注意力层)