新增RAG测试脚本

This commit is contained in:
2026-01-13 10:37:19 +08:00
parent d5ee00d404
commit e5ac2dde03
5 changed files with 386 additions and 21 deletions

View File

@@ -142,9 +142,41 @@ class CrawlerService:
return {"msg": "Batch processed", "count": processed}
def search(self, query: str, task_id, limit: int):
def search(self, query: str, task_id, return_num: int):
"""
全链路搜索:向量生成 -> 混合检索(粗排) -> 重排序(精排)
"""
# 1. 生成查询向量
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit)
# 2. 计算粗排召回数量
# 逻辑:至少召回 50 个,如果用户要很多,则召回 10 倍
coarse_limit = return_num * 10 if return_num * 10 > 50 else 50
# 3. 执行混合检索 (粗排)
coarse_results = data_service.search(
query_text=query,
query_vector=vector,
task_id=task_id,
candidates_num=coarse_limit # 使用计算出的粗排数量
)
candidates = coarse_results.get('results', [])
if not candidates:
return {"msg": "No documents found", "results": []}
# 4. 执行重排序 (精排)
final_results = llm_service.rerank(
query=query,
documents=candidates,
top_n=return_num # 最终返回用户需要的数量
)
return {
"results": final_results,
"msg": f"Reranked {len(final_results)} from {len(candidates)} candidates"
}
crawler_service = CrawlerService()

View File

@@ -91,24 +91,20 @@ class DataService:
return {"msg": f"Saved {count} chunks", "count": count}
def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5):
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 5):
"""
Phase 2: 混合检索 (Hybrid Search)
综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword)
"""
# 向量格式清洗
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
if query_vector and isinstance(query_vector, list) and len(query_vector) > 0:
if isinstance(query_vector[0], list): query_vector = query_vector[0]
results = []
with self.db.engine.connect() as conn:
# 定义混合检索的 SQL 逻辑
# 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法)
keyword_query = func.websearch_to_tsquery('english', query_text)
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
# 综合打分列: 0.7 * Vector + 0.3 * Keyword
# coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
stmt = select(
@@ -123,8 +119,8 @@ class DataService:
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 按综合分数倒序
stmt = stmt.order_by(desc("score")).limit(limit)
# 使用 candidates_num 控制召回数量
stmt = stmt.order_by(desc("score")).limit(candidates_num)
try:
rows = conn.execute(stmt).fetchall()
@@ -141,23 +137,19 @@ class DataService:
]
except Exception as e:
print(f"[ERROR] Hybrid search failed: {e}")
return self._fallback_vector_search(query_vector, task_id, limit)
return self._fallback_vector_search(query_vector, task_id, candidates_num)
return {"results": results, "msg": f"Hybrid found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit):
"""降级兜底:纯向量搜索"""
print("[WARN] Fallback to pure vector search")
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall()
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
data_service = DataService()

View File

@@ -5,16 +5,16 @@ from backend.core.config import settings
class LLMService:
"""
LLM 服务封装层
负责与 DashScope 或其他模型供应商交互
负责与 DashScope (通义千问/GTE) 交互,包括 Embedding 和 Rerank
"""
def __init__(self):
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embedding(self, text: str, dimension: int = 1536):
"""生成文本向量"""
"""生成文本向量 (Bi-Encoder)"""
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4,
model=dashscope.TextEmbedding.Models.text_embedding_v4, # 或 v4视你的数据库维度而定
input=text,
dimension=dimension
)
@@ -27,4 +27,67 @@ class LLMService:
print(f"[ERROR] Embedding Exception: {e}")
return None
def rerank(self, query: str, documents: list, top_n: int = 5):
"""
执行重排序 (Cross-Encoder)
Args:
query: 用户问题
documents: 粗排召回的切片列表 (List[dict]),必须包含 'content' 字段
top_n: 最终返回多少个结果
Returns:
List[dict]: 排序后并截取 Top N 的文档列表,包含新的 'score'
"""
if not documents:
return []
# 1. 准备输入数据
# Rerank API 需要纯文本列表,但我们需要保留 documents 里的 meta_info 和 id
# 所以我们提取 content 给 API拿到 index 后再映射回去
doc_contents = [doc.get('content', '') for doc in documents]
# 如果文档太多(比如超过 100 个),建议先截断,避免 API 超时或报错
if len(doc_contents) > 50:
doc_contents = doc_contents[:50]
documents = documents[:50]
try:
# 2. 调用 DashScope GTE-Rerank
resp = dashscope.TextReRank.call(
model='gte-rerank',
query=query,
documents=doc_contents,
top_n=top_n,
return_documents=False # 我们只需要索引和分数,不需要它把文本再传回来
)
if resp.status_code == HTTPStatus.OK:
# 3. 结果重组
# API 返回结构示例: output.results = [{'index': 2, 'relevance_score': 0.98}, {'index': 0, ...}]
reranked_results = []
for item in resp.output.results:
# 根据 API 返回的 index 找到原始文档对象
original_doc = documents[item.index]
# 更新分数为 Rerank 的精准分数 (通常是 0~1 之间的置信度)
original_doc['score'] = item.relevance_score
# 标记来源,方便调试知道这是 Rerank 过的
original_doc['reranked'] = True
reranked_results.append(original_doc)
return reranked_results
else:
print(f"[ERROR] Rerank API Error: {resp}")
# 降级策略:如果 Rerank 挂了,直接返回粗排的前 N 个
return documents[:top_n]
except Exception as e:
print(f"[ERROR] Rerank Exception: {e}")
# 降级策略
return documents[:top_n]
llm_service = LLMService()