新增RAG测试脚本

This commit is contained in:
2026-01-13 10:37:19 +08:00
parent d5ee00d404
commit e5ac2dde03
5 changed files with 386 additions and 21 deletions

View File

@@ -142,9 +142,41 @@ class CrawlerService:
return {"msg": "Batch processed", "count": processed}
def search(self, query: str, task_id, limit: int):
def search(self, query: str, task_id, return_num: int):
"""
全链路搜索:向量生成 -> 混合检索(粗排) -> 重排序(精排)
"""
# 1. 生成查询向量
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit)
# 2. 计算粗排召回数量
# 逻辑:至少召回 50 个,如果用户要很多,则召回 10 倍
coarse_limit = return_num * 10 if return_num * 10 > 50 else 50
# 3. 执行混合检索 (粗排)
coarse_results = data_service.search(
query_text=query,
query_vector=vector,
task_id=task_id,
candidates_num=coarse_limit # 使用计算出的粗排数量
)
candidates = coarse_results.get('results', [])
if not candidates:
return {"msg": "No documents found", "results": []}
# 4. 执行重排序 (精排)
final_results = llm_service.rerank(
query=query,
documents=candidates,
top_n=return_num # 最终返回用户需要的数量
)
return {
"results": final_results,
"msg": f"Reranked {len(final_results)} from {len(candidates)} candidates"
}
crawler_service = CrawlerService()