SQLAlchemy ORM
简介
SQLAlchemy 是 Python 最主流的 ORM 框架,支持 MySQL、PostgreSQL、SQLite 等,是 FastAPI + LLM 应用的标准数据库层。
bash
pip install sqlalchemy pymysql aiomysql模型定义
python
from sqlalchemy import Column, Integer, String, Text, DateTime, Float, Boolean, ForeignKey
from sqlalchemy.orm import DeclarativeBase, relationship
from datetime import datetime
class Base(DeclarativeBase):
pass
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True)
user_id = Column(String(64), unique=True, index=True)
name = Column(String(100))
email = Column(String(200), unique=True)
is_active = Column(Boolean, default=True)
created_at = Column(DateTime, default=datetime.utcnow)
sessions = relationship("ChatSession", back_populates="user")
class ChatSession(Base):
__tablename__ = "chat_sessions"
id = Column(Integer, primary_key=True)
session_id = Column(String(64), unique=True, index=True)
user_id = Column(String(64), ForeignKey("users.user_id"))
title = Column(String(200))
model = Column(String(50), default="qwen-turbo")
total_tokens = Column(Integer, default=0)
total_cost = Column(Float, default=0.0)
created_at = Column(DateTime, default=datetime.utcnow)
user = relationship("User", back_populates="sessions")
messages = relationship("ChatMessage", back_populates="session")
class ChatMessage(Base):
__tablename__ = "chat_messages"
id = Column(Integer, primary_key=True)
session_id = Column(String(64), ForeignKey("chat_sessions.session_id"))
role = Column(String(20))
content = Column(Text)
tokens = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.utcnow)
session = relationship("ChatSession", back_populates="messages")Repository 模式
python
from sqlalchemy.orm import Session
from sqlalchemy import select
class ChatRepository:
def __init__(self, db: Session):
self.db = db
def create_session(self, user_id: str, model: str = "qwen-turbo") -> ChatSession:
import uuid
session = ChatSession(
session_id=str(uuid.uuid4()),
user_id=user_id,
model=model
)
self.db.add(session)
self.db.commit()
self.db.refresh(session)
return session
def add_message(self, session_id: str, role: str, content: str, tokens: int = 0):
msg = ChatMessage(
session_id=session_id,
role=role,
content=content,
tokens=tokens
)
self.db.add(msg)
# 更新 session 统计
self.db.execute(
text("UPDATE chat_sessions SET total_tokens = total_tokens + :t WHERE session_id = :sid"),
{"t": tokens, "sid": session_id}
)
self.db.commit()
def get_messages(self, session_id: str, limit: int = 20) -> list[ChatMessage]:
stmt = (
select(ChatMessage)
.where(ChatMessage.session_id == session_id)
.order_by(ChatMessage.created_at.desc())
.limit(limit)
)
messages = self.db.execute(stmt).scalars().all()
return list(reversed(messages))FastAPI 集成
python
from fastapi import FastAPI, Depends
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
DATABASE_URL = "mysql+aiomysql://user:pass@localhost/finance_db"
engine = create_async_engine(DATABASE_URL, echo=False)
AsyncSessionLocal = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async def get_db():
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
app = FastAPI()
@app.on_event("startup")
async def startup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
@app.post("/sessions")
async def create_session(user_id: str, db: AsyncSession = Depends(get_db)):
repo = ChatRepository(db)
session = repo.create_session(user_id)
return {"session_id": session.session_id}