Skip to content

PyTorch 核心原理

简介

PyTorch 是 LLM 训练和推理的主流框架,HuggingFace 生态全面基于 PyTorch。理解 PyTorch 核心机制是进行模型微调的基础。

bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

张量操作

python
import torch

# 创建张量
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
zeros = torch.zeros(3, 4)
rand = torch.randn(2, 3)  # 标准正态分布

# GPU 操作
device = "cuda" if torch.cuda.is_available() else "cpu"
x = x.to(device)

# 基本运算
a = torch.randn(3, 4)
b = torch.randn(4, 5)
c = torch.matmul(a, b)  # 矩阵乘法 (3, 5)

# 广播
x = torch.randn(3, 1)
y = torch.randn(1, 4)
z = x + y  # 广播为 (3, 4)

# 维度操作
x = torch.randn(2, 3, 4)
print(x.shape)           # torch.Size([2, 3, 4])
print(x.view(2, 12).shape)    # reshape
print(x.permute(0, 2, 1).shape)  # 转置维度
print(x.unsqueeze(0).shape)   # 增加维度 (1, 2, 3, 4)

自动微分(Autograd)

python
# 计算图与梯度
x = torch.tensor(2.0, requires_grad=True)
y = x ** 2 + 3 * x + 1  # y = x² + 3x + 1

y.backward()  # 反向传播
print(x.grad)  # dy/dx = 2x + 3 = 7.0

# 多变量
x = torch.randn(3, requires_grad=True)
y = (x ** 2).sum()
y.backward()
print(x.grad)  # 2x

# 禁用梯度(推理时节省内存)
with torch.no_grad():
    output = model(input_data)

# 或使用装饰器
@torch.no_grad()
def inference(model, data):
    return model(data)

构建神经网络

python
import torch.nn as nn
import torch.nn.functional as F

class FinancialClassifier(nn.Module):
    """金融风险分类器"""
    
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)

# 初始化
model = FinancialClassifier(input_dim=50, hidden_dim=256, num_classes=3)
model = model.to(device)

# 查看参数量
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数: {total_params:,}, 可训练: {trainable_params:,}")

训练循环

python
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

# 准备数据
X = torch.randn(1000, 50)
y = torch.randint(0, 3, (1000,))
dataset = TensorDataset(X, y)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

# 优化器与损失函数
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# 训练
def train_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    
    for batch_x, batch_y in loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        
        # 梯度裁剪(防止梯度爆炸)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

# 评估
def evaluate(model, loader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            outputs = model(batch_x)
            _, predicted = outputs.max(1)
            correct += predicted.eq(batch_y).sum().item()
            total += batch_y.size(0)
    
    return correct / total

# 训练循环
for epoch in range(20):
    loss = train_epoch(model, loader, optimizer, criterion)
    scheduler.step()
    if epoch % 5 == 0:
        print(f"Epoch {epoch}: loss={loss:.4f}")

模型保存与加载

python
# 保存
torch.save(model.state_dict(), "model.pth")

# 加载
model = FinancialClassifier(50, 256, 3)
model.load_state_dict(torch.load("model.pth", map_location=device))
model.eval()

# 保存完整检查点
checkpoint = {
    "epoch": 20,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "loss": 0.123
}
torch.save(checkpoint, "checkpoint.pth")

下一步

本站内容由 褚成志 整理编写,仅供学习参考