Skip to content

Attention 机制深度解析

Self-Attention 直觉理解

Attention 机制让模型在处理每个 token 时,能"关注"序列中其他所有 token,动态计算相关性权重。

python
import torch
import torch.nn.functional as F
import math

def attention_demo():
    """用简单例子理解 Attention"""
    # 假设句子: "招商银行 发布 年报"
    # 每个词用 4 维向量表示
    seq_len, d_model = 3, 4
    
    # 模拟 Q, K, V(实际由线性层生成)
    Q = torch.randn(seq_len, d_model)  # 查询:我想找什么
    K = torch.randn(seq_len, d_model)  # 键:我有什么
    V = torch.randn(seq_len, d_model)  # 值:实际内容
    
    # 计算注意力分数
    scores = torch.matmul(Q, K.T) / math.sqrt(d_model)
    # scores[i][j] = token_i 对 token_j 的关注程度
    
    # Softmax 归一化为概率
    attn_weights = F.softmax(scores, dim=-1)
    print("注意力权重矩阵:")
    print(attn_weights)
    
    # 加权求和
    output = torch.matmul(attn_weights, V)
    return output

attention_demo()

因果掩码(Causal Mask)

GPT 类模型使用因果掩码,确保每个 token 只能看到之前的 token:

python
def create_causal_mask(seq_len: int) -> torch.Tensor:
    """创建下三角掩码"""
    mask = torch.tril(torch.ones(seq_len, seq_len))
    # 将上三角设为 -inf,softmax 后变为 0
    mask = mask.masked_fill(mask == 0, float('-inf'))
    mask = mask.masked_fill(mask == 1, 0.0)
    return mask

mask = create_causal_mask(5)
print(mask)
# tensor([[0., -inf, -inf, -inf, -inf],
#         [0.,   0., -inf, -inf, -inf],
#         [0.,   0.,   0., -inf, -inf],
#         [0.,   0.,   0.,   0., -inf],
#         [0.,   0.,   0.,   0.,   0.]])

Flash Attention

Flash Attention 通过分块计算优化显存访问,速度提升 2-4x:

python
# 使用 Flash Attention(需要安装 flash-attn)
# pip install flash-attn --no-build-isolation

from flash_attn import flash_attn_qkvpacked_func

# 或在 HuggingFace 中启用
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-7B-Instruct",
    attn_implementation="flash_attention_2",  # 启用 Flash Attention 2
    torch_dtype=torch.float16,
    device_map="auto"
)

GQA(分组查询注意力)

现代 LLM(Llama 3、千问2)使用 GQA 减少 KV Cache 显存:

MHA: Q头数 = K头数 = V头数 = 32
GQA: Q头数 = 32, K头数 = K头数 = 8(每组 4 个 Q 共享一对 KV)
MQA: Q头数 = 32, K头数 = K头数 = 1(极端情况)
python
# GQA 实现示意
class GroupedQueryAttention(torch.nn.Module):
    def __init__(self, d_model, num_q_heads, num_kv_heads):
        super().__init__()
        self.num_q_heads = num_q_heads
        self.num_kv_heads = num_kv_heads
        self.num_groups = num_q_heads // num_kv_heads
        self.head_dim = d_model // num_q_heads
        
        self.q_proj = torch.nn.Linear(d_model, d_model)
        self.k_proj = torch.nn.Linear(d_model, num_kv_heads * self.head_dim)
        self.v_proj = torch.nn.Linear(d_model, num_kv_heads * self.head_dim)
        self.o_proj = torch.nn.Linear(d_model, d_model)

KV Cache 原理

python
# 推理时 KV Cache 避免重复计算
# 第一次生成 token_1:
#   计算 K1, V1 并缓存
# 第二次生成 token_2:
#   只计算 K2, V2,与缓存的 K1,V1 拼接
#   不需要重新计算 K1, V1

# HuggingFace 自动管理 KV Cache
outputs = model.generate(
    input_ids,
    max_new_tokens=100,
    use_cache=True  # 默认开启
)

本站内容由 褚成志 整理编写,仅供学习参考