Attention 机制深度解析
Self-Attention 直觉理解
Attention 机制让模型在处理每个 token 时,能"关注"序列中其他所有 token,动态计算相关性权重。
python
import torch
import torch.nn.functional as F
import math
def attention_demo():
"""用简单例子理解 Attention"""
# 假设句子: "招商银行 发布 年报"
# 每个词用 4 维向量表示
seq_len, d_model = 3, 4
# 模拟 Q, K, V(实际由线性层生成)
Q = torch.randn(seq_len, d_model) # 查询:我想找什么
K = torch.randn(seq_len, d_model) # 键:我有什么
V = torch.randn(seq_len, d_model) # 值:实际内容
# 计算注意力分数
scores = torch.matmul(Q, K.T) / math.sqrt(d_model)
# scores[i][j] = token_i 对 token_j 的关注程度
# Softmax 归一化为概率
attn_weights = F.softmax(scores, dim=-1)
print("注意力权重矩阵:")
print(attn_weights)
# 加权求和
output = torch.matmul(attn_weights, V)
return output
attention_demo()因果掩码(Causal Mask)
GPT 类模型使用因果掩码,确保每个 token 只能看到之前的 token:
python
def create_causal_mask(seq_len: int) -> torch.Tensor:
"""创建下三角掩码"""
mask = torch.tril(torch.ones(seq_len, seq_len))
# 将上三角设为 -inf,softmax 后变为 0
mask = mask.masked_fill(mask == 0, float('-inf'))
mask = mask.masked_fill(mask == 1, 0.0)
return mask
mask = create_causal_mask(5)
print(mask)
# tensor([[0., -inf, -inf, -inf, -inf],
# [0., 0., -inf, -inf, -inf],
# [0., 0., 0., -inf, -inf],
# [0., 0., 0., 0., -inf],
# [0., 0., 0., 0., 0.]])Flash Attention
Flash Attention 通过分块计算优化显存访问,速度提升 2-4x:
python
# 使用 Flash Attention(需要安装 flash-attn)
# pip install flash-attn --no-build-isolation
from flash_attn import flash_attn_qkvpacked_func
# 或在 HuggingFace 中启用
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2-7B-Instruct",
attn_implementation="flash_attention_2", # 启用 Flash Attention 2
torch_dtype=torch.float16,
device_map="auto"
)GQA(分组查询注意力)
现代 LLM(Llama 3、千问2)使用 GQA 减少 KV Cache 显存:
MHA: Q头数 = K头数 = V头数 = 32
GQA: Q头数 = 32, K头数 = K头数 = 8(每组 4 个 Q 共享一对 KV)
MQA: Q头数 = 32, K头数 = K头数 = 1(极端情况)python
# GQA 实现示意
class GroupedQueryAttention(torch.nn.Module):
def __init__(self, d_model, num_q_heads, num_kv_heads):
super().__init__()
self.num_q_heads = num_q_heads
self.num_kv_heads = num_kv_heads
self.num_groups = num_q_heads // num_kv_heads
self.head_dim = d_model // num_q_heads
self.q_proj = torch.nn.Linear(d_model, d_model)
self.k_proj = torch.nn.Linear(d_model, num_kv_heads * self.head_dim)
self.v_proj = torch.nn.Linear(d_model, num_kv_heads * self.head_dim)
self.o_proj = torch.nn.Linear(d_model, d_model)KV Cache 原理
python
# 推理时 KV Cache 避免重复计算
# 第一次生成 token_1:
# 计算 K1, V1 并缓存
# 第二次生成 token_2:
# 只计算 K2, V2,与缓存的 K1,V1 拼接
# 不需要重新计算 K1, V1
# HuggingFace 自动管理 KV Cache
outputs = model.generate(
input_ids,
max_new_tokens=100,
use_cache=True # 默认开启
)