Skip to content

Transformer 架构原理

架构总览

Transformer 由 Vaswani 等人在 2017 年论文《Attention Is All You Need》中提出,彻底取代了 RNN/LSTM 成为 NLP 主流架构。

输入序列 [x1, x2, ..., xn]


┌─────────────────────────────┐
│      Input Embedding        │  词向量映射
│   + Positional Encoding     │  位置编码
└─────────────────────────────┘

    ▼  × N 层
┌─────────────────────────────┐
│     Multi-Head Attention    │  多头自注意力
│     Add & Layer Norm        │
│     Feed Forward Network    │  前馈网络
│     Add & Layer Norm        │
└─────────────────────────────┘


┌─────────────────────────────┐
│      Linear + Softmax       │  输出概率分布
└─────────────────────────────┘

Self-Attention 核心计算

python
import torch
import torch.nn.functional as F
import math

def scaled_dot_product_attention(Q, K, V, mask=None):
    """
    Q: (batch, heads, seq_len, d_k)
    K: (batch, heads, seq_len, d_k)
    V: (batch, heads, seq_len, d_v)
    """
    d_k = Q.size(-1)
    
    # 计算注意力分数
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # 可选:应用掩码(用于 Decoder 的因果掩码)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    # Softmax 归一化
    attn_weights = F.softmax(scores, dim=-1)
    
    # 加权求和
    output = torch.matmul(attn_weights, V)
    return output, attn_weights

多头注意力(Multi-Head Attention)

python
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, num_heads=8):
        super().__init__()
        assert d_model % num_heads == 0
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # 线性投影层
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
    
    def split_heads(self, x, batch_size):
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2)  # (batch, heads, seq, d_k)
    
    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)
        
        Q = self.split_heads(self.W_q(Q), batch_size)
        K = self.split_heads(self.W_k(K), batch_size)
        V = self.split_heads(self.W_v(V), batch_size)
        
        attn_output, _ = scaled_dot_product_attention(Q, K, V, mask)
        
        # 合并多头
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)
        
        return self.W_o(attn_output)

位置编码(Positional Encoding)

Transformer 本身无序列顺序感知,通过位置编码注入位置信息:

python
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # 构建位置编码矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * 
            (-math.log(10000.0) / d_model)
        )
        
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度
        
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

前馈网络(FFN)

python
class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_ff=2048, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        return self.linear2(
            self.dropout(F.relu(self.linear1(x)))
        )

现代 LLM 的架构改进

改进点原始 Transformer现代 LLM(如 Llama)
归一化位置Post-LNPre-LN(更稳定)
归一化方式Layer NormRMS Norm(更快)
激活函数ReLUSwiGLU / GeGLU
位置编码正弦绝对编码RoPE 旋转位置编码
注意力优化标准 MHAGQA(分组查询注意力)
KV Cache有(推理加速关键)

RoPE 旋转位置编码

现代 LLM 普遍采用 RoPE,支持外推到更长上下文:

python
def apply_rotary_emb(xq, xk, freqs_cis):
    """
    xq, xk: (batch, seq, heads, head_dim)
    freqs_cis: 预计算的旋转频率
    """
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    
    freqs_cis = freqs_cis[:xq_.shape[1]]
    
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)

关键参数规模

模型参数量层数隐藏维度注意力头数
GPT-21.5B48160025
Llama 2 7B7B32409632
Llama 2 70B70B80819264
GPT-4~1.8T(MoE)---

工程实践

理解 Transformer 架构有助于:

  1. 合理设置 max_tokens 和上下文长度
  2. 理解为什么长文本推理更慢(O(n²) 注意力复杂度)
  3. 选择合适的微调策略(LoRA 针对注意力层)

本站内容由 褚成志 整理编写,仅供学习参考