新增RAG测试脚本
This commit is contained in:
@@ -91,24 +91,20 @@ class DataService:
|
||||
|
||||
return {"msg": f"Saved {count} chunks", "count": count}
|
||||
|
||||
def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5):
|
||||
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 5):
|
||||
"""
|
||||
Phase 2: 混合检索 (Hybrid Search)
|
||||
综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword)
|
||||
"""
|
||||
# 向量格式清洗
|
||||
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
|
||||
if query_vector and isinstance(query_vector, list) and len(query_vector) > 0:
|
||||
if isinstance(query_vector[0], list): query_vector = query_vector[0]
|
||||
|
||||
results = []
|
||||
with self.db.engine.connect() as conn:
|
||||
# 定义混合检索的 SQL 逻辑
|
||||
|
||||
# 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法)
|
||||
keyword_query = func.websearch_to_tsquery('english', query_text)
|
||||
|
||||
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
|
||||
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
|
||||
|
||||
# 综合打分列: 0.7 * Vector + 0.3 * Keyword
|
||||
# coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0
|
||||
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
|
||||
|
||||
stmt = select(
|
||||
@@ -123,8 +119,8 @@ class DataService:
|
||||
if task_id:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
# 按综合分数倒序
|
||||
stmt = stmt.order_by(desc("score")).limit(limit)
|
||||
# 使用 candidates_num 控制召回数量
|
||||
stmt = stmt.order_by(desc("score")).limit(candidates_num)
|
||||
|
||||
try:
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
@@ -141,23 +137,19 @@ class DataService:
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Hybrid search failed: {e}")
|
||||
return self._fallback_vector_search(query_vector, task_id, limit)
|
||||
return self._fallback_vector_search(query_vector, task_id, candidates_num)
|
||||
|
||||
return {"results": results, "msg": f"Hybrid found {len(results)}"}
|
||||
|
||||
def _fallback_vector_search(self, vector, task_id, limit):
|
||||
"""降级兜底:纯向量搜索"""
|
||||
print("[WARN] Fallback to pure vector search")
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||
self.db.chunks.c.content, self.db.chunks.c.meta_info
|
||||
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
|
||||
|
||||
if task_id:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
|
||||
|
||||
data_service = DataService()
|
||||
Reference in New Issue
Block a user