Skip to content

SQLAlchemy ORM

简介

SQLAlchemy 是 Python 最主流的 ORM 框架,支持 MySQL、PostgreSQL、SQLite 等,是 FastAPI + LLM 应用的标准数据库层。

bash
pip install sqlalchemy pymysql aiomysql

模型定义

python
from sqlalchemy import Column, Integer, String, Text, DateTime, Float, Boolean, ForeignKey
from sqlalchemy.orm import DeclarativeBase, relationship
from datetime import datetime

class Base(DeclarativeBase):
    pass

class User(Base):
    __tablename__ = "users"
    
    id = Column(Integer, primary_key=True)
    user_id = Column(String(64), unique=True, index=True)
    name = Column(String(100))
    email = Column(String(200), unique=True)
    is_active = Column(Boolean, default=True)
    created_at = Column(DateTime, default=datetime.utcnow)
    
    sessions = relationship("ChatSession", back_populates="user")

class ChatSession(Base):
    __tablename__ = "chat_sessions"
    
    id = Column(Integer, primary_key=True)
    session_id = Column(String(64), unique=True, index=True)
    user_id = Column(String(64), ForeignKey("users.user_id"))
    title = Column(String(200))
    model = Column(String(50), default="qwen-turbo")
    total_tokens = Column(Integer, default=0)
    total_cost = Column(Float, default=0.0)
    created_at = Column(DateTime, default=datetime.utcnow)
    
    user = relationship("User", back_populates="sessions")
    messages = relationship("ChatMessage", back_populates="session")

class ChatMessage(Base):
    __tablename__ = "chat_messages"
    
    id = Column(Integer, primary_key=True)
    session_id = Column(String(64), ForeignKey("chat_sessions.session_id"))
    role = Column(String(20))
    content = Column(Text)
    tokens = Column(Integer, default=0)
    created_at = Column(DateTime, default=datetime.utcnow)
    
    session = relationship("ChatSession", back_populates="messages")

Repository 模式

python
from sqlalchemy.orm import Session
from sqlalchemy import select

class ChatRepository:
    def __init__(self, db: Session):
        self.db = db
    
    def create_session(self, user_id: str, model: str = "qwen-turbo") -> ChatSession:
        import uuid
        session = ChatSession(
            session_id=str(uuid.uuid4()),
            user_id=user_id,
            model=model
        )
        self.db.add(session)
        self.db.commit()
        self.db.refresh(session)
        return session
    
    def add_message(self, session_id: str, role: str, content: str, tokens: int = 0):
        msg = ChatMessage(
            session_id=session_id,
            role=role,
            content=content,
            tokens=tokens
        )
        self.db.add(msg)
        
        # 更新 session 统计
        self.db.execute(
            text("UPDATE chat_sessions SET total_tokens = total_tokens + :t WHERE session_id = :sid"),
            {"t": tokens, "sid": session_id}
        )
        self.db.commit()
    
    def get_messages(self, session_id: str, limit: int = 20) -> list[ChatMessage]:
        stmt = (
            select(ChatMessage)
            .where(ChatMessage.session_id == session_id)
            .order_by(ChatMessage.created_at.desc())
            .limit(limit)
        )
        messages = self.db.execute(stmt).scalars().all()
        return list(reversed(messages))

FastAPI 集成

python
from fastapi import FastAPI, Depends
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker

DATABASE_URL = "mysql+aiomysql://user:pass@localhost/finance_db"
engine = create_async_engine(DATABASE_URL, echo=False)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)

async def get_db():
    async with AsyncSessionLocal() as session:
        try:
            yield session
        finally:
            await session.close()

app = FastAPI()

@app.on_event("startup")
async def startup():
    async with engine.begin() as conn:
        await conn.run_sync(Base.metadata.create_all)

@app.post("/sessions")
async def create_session(user_id: str, db: AsyncSession = Depends(get_db)):
    repo = ChatRepository(db)
    session = repo.create_session(user_id)
    return {"session_id": session.session_id}

本站内容由 褚成志 整理编写,仅供学习参考