Redis — LLM 应用中的缓存与会话存储
简介
Redis 在 LLM 应用中承担三个核心角色:对话历史存储、LLM 响应缓存、限流计数器。
bash
pip install redis hiredis连接配置
python
import redis
from redis import ConnectionPool
# 连接池(生产推荐)
pool = ConnectionPool(
host="localhost",
port=6379,
db=0,
password="your-password",
decode_responses=True,
max_connections=20
)
r = redis.Redis(connection_pool=pool)
# 异步客户端(FastAPI 场景)
import redis.asyncio as aioredis
async_r = aioredis.from_url(
"redis://localhost:6379",
encoding="utf-8",
decode_responses=True
)对话历史存储
python
import json
from datetime import timedelta
class ChatHistoryStore:
def __init__(self, redis_client, ttl_hours: int = 24):
self.r = redis_client
self.ttl = int(timedelta(hours=ttl_hours).total_seconds())
def _key(self, session_id: str) -> str:
return f"chat:history:{session_id}"
def add_message(self, session_id: str, role: str, content: str):
key = self._key(session_id)
message = json.dumps({"role": role, "content": content}, ensure_ascii=False)
self.r.rpush(key, message)
self.r.expire(key, self.ttl)
def get_history(self, session_id: str, max_turns: int = 10) -> list:
key = self._key(session_id)
messages = self.r.lrange(key, -max_turns * 2, -1)
return [json.loads(m) for m in messages]
def clear(self, session_id: str):
self.r.delete(self._key(session_id))
# 使用
store = ChatHistoryStore(r)
store.add_message("user_123", "user", "我想申请贷款")
store.add_message("user_123", "assistant", "请问您需要什么类型的贷款?")
history = store.get_history("user_123")
# 直接传给 LLM
response = client.chat.completions.create(
model="qwen-turbo",
messages=[{"role": "system", "content": "你是金融助手"}] + history
)LLM 响应缓存
python
import hashlib
class LLMCache:
"""缓存相同问题的 LLM 响应,节省 API 费用"""
def __init__(self, redis_client, ttl_hours: int = 1):
self.r = redis_client
self.ttl = int(timedelta(hours=ttl_hours).total_seconds())
def _cache_key(self, prompt: str, model: str) -> str:
content = f"{model}:{prompt}"
return f"llm:cache:{hashlib.md5(content.encode()).hexdigest()}"
def get(self, prompt: str, model: str) -> str | None:
return self.r.get(self._cache_key(prompt, model))
def set(self, prompt: str, model: str, response: str):
self.r.setex(self._cache_key(prompt, model), self.ttl, response)
# 带缓存的 LLM 调用
cache = LLMCache(r)
def cached_llm_call(prompt: str, model: str = "qwen-turbo") -> str:
cached = cache.get(prompt, model)
if cached:
print("命中缓存")
return cached
response = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
cache.set(prompt, model, result)
return resultAPI 限流
python
import time
class RateLimiter:
"""基于 Redis 的滑动窗口限流"""
def __init__(self, redis_client):
self.r = redis_client
def is_allowed(self, user_id: str, max_requests: int = 10, window_seconds: int = 60) -> bool:
key = f"rate_limit:{user_id}"
now = time.time()
window_start = now - window_seconds
pipe = self.r.pipeline()
pipe.zremrangebyscore(key, 0, window_start) # 清理过期记录
pipe.zadd(key, {str(now): now}) # 添加当前请求
pipe.zcard(key) # 统计请求数
pipe.expire(key, window_seconds)
results = pipe.execute()
request_count = results[2]
return request_count <= max_requests
# FastAPI 中间件
from fastapi import Request, HTTPException
limiter = RateLimiter(r)
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
user_id = request.headers.get("X-User-ID", request.client.host)
if not limiter.is_allowed(user_id, max_requests=20, window_seconds=60):
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后再试")
return await call_next(request)向量缓存(Redis Stack)
python
# Redis Stack 支持向量相似度搜索
from redis.commands.search.field import VectorField, TextField
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
import numpy as np
# 创建向量索引
schema = (
TextField("content"),
VectorField("embedding",
"FLAT",
{"TYPE": "FLOAT32", "DIM": 512, "DISTANCE_METRIC": "COSINE"}
)
)
r.ft("doc_index").create_index(
schema,
definition=IndexDefinition(prefix=["doc:"], index_type=IndexType.HASH)
)
# 存储文档向量
def store_document(doc_id: str, content: str, embedding: np.ndarray):
r.hset(f"doc:{doc_id}", mapping={
"content": content,
"embedding": embedding.astype(np.float32).tobytes()
})