Skip to content

Reranker — 重排序提升 RAG 精度

简介

Reranker 对初步检索的候选文档进行精细排序,显著提升 RAG 答案质量。典型流程:向量检索 Top-20 → Reranker 精排 → 取 Top-3 送入 LLM。

bash
pip install sentence-transformers

BGE Reranker

python
from sentence_transformers import CrossEncoder

# 加载 Reranker 模型
reranker = CrossEncoder("BAAI/bge-reranker-base", max_length=512)

# 候选文档(来自向量检索)
query = "银行不良贷款率如何计算"
candidates = [
    "不良贷款率 = 不良贷款余额 / 贷款总余额 × 100%",
    "银行资本充足率的计算方法",
    "不良贷款包括次级、可疑、损失三类",
    "今天股市行情分析",
]

# 重排序
pairs = [[query, doc] for doc in candidates]
scores = reranker.predict(pairs)

# 按分数排序
ranked = sorted(zip(scores, candidates), reverse=True)
for score, doc in ranked[:2]:
    print(f"分数: {score:.3f} | {doc}")

集成到 RAG 流程

python
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import HuggingFaceEmbeddings
from sentence_transformers import CrossEncoder
from langchain_openai import ChatOpenAI

embeddings = HuggingFaceEmbeddings(model_name="BAAI/bge-small-zh-v1.5")
vectorstore = Chroma(persist_directory="./chroma_db", embedding_function=embeddings)
reranker = CrossEncoder("BAAI/bge-reranker-base")
llm = ChatOpenAI(model="qwen-plus", ...)

def rag_with_rerank(query: str, initial_k: int = 10, final_k: int = 3) -> str:
    # 1. 向量检索(宽召回)
    docs = vectorstore.similarity_search(query, k=initial_k)
    
    # 2. Reranker 精排
    pairs = [[query, doc.page_content] for doc in docs]
    scores = reranker.predict(pairs)
    
    ranked_docs = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)
    top_docs = [doc for _, doc in ranked_docs[:final_k]]
    
    # 3. 构建 Prompt
    context = "\n\n".join([doc.page_content for doc in top_docs])
    prompt = f"基于以下文档回答问题:\n\n{context}\n\n问题:{query}"
    
    # 4. LLM 生成
    response = llm.invoke(prompt)
    return response.content

result = rag_with_rerank("不良贷款率的监管标准是什么")
print(result)

LangChain Reranker

python
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder

# 使用 LangChain 封装
model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base")
compressor = CrossEncoderReranker(model=model, top_n=3)

compression_retriever = ContextualCompressionRetriever(
    base_compressor=compressor,
    base_retriever=vectorstore.as_retriever(search_kwargs={"k": 10})
)

docs = compression_retriever.invoke("不良贷款率计算方法")

本站内容由 褚成志 整理编写,仅供学习参考