完成RAG测试

This commit is contained in:
2026-01-27 01:41:45 +08:00
parent 155974572c
commit f78efc7125
16 changed files with 1434 additions and 66 deletions

View File

@@ -143,7 +143,7 @@ class DataService:
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50):
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50, vector_weight: float = 0.7):
# 向量格式清洗
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
@@ -151,26 +151,52 @@ class DataService:
results = []
with self.db.engine.connect() as conn:
# 1. 构造 Query 对象 (这是 tsquery 类型)
keyword_query = func.websearch_to_tsquery('english', query_text)
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
# 计算分数 (逻辑不变)
vector_dist = self.db.chunks.c.embedding.cosine_distance(query_vector)
vector_score = (1 - vector_dist)
# 注意ts_rank 需要 (tsvector, tsquery)
keyword_rank = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
keyword_score = func.coalesce(keyword_rank, 0)
keyword_weight = 1.0 - vector_weight
final_score = (vector_score * vector_weight + keyword_score * keyword_weight).label("score")
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
)
# ================= 修复点开始 =================
# 只有当 vector_weight 为 0 (纯关键词模式) 时,才强制加 WHERE 过滤
if vector_weight == 0:
# 错误写法 (SQLAlchemy 自动生成 plainto_tsquery 导致报错):
# stmt = stmt.where(keyword_query.match(self.db.chunks.c.content_tsvector))
# 正确写法 (直接使用 PG 的 @@ 操作符):
# 含义: content_tsvector @@ keyword_query
stmt = stmt.where(self.db.chunks.c.content_tsvector.op('@@')(keyword_query))
# ================= 修复点结束 =================
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
stmt = stmt.order_by(desc("score")).limit(candidates_num)
try:
rows = conn.execute(stmt).fetchall()
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
except Exception as e:
logger.error(f"Hybrid search failed: {e}")
return self._fallback_vector_search(query_vector, task_id, candidates_num)
# 打印详细错误方便调试
logger.error(f"Search failed: {e}")
# 只有混合或向量模式才回退,如果是纯关键词模式报错,回退也没用
if vector_weight > 0:
return self._fallback_vector_search(query_vector, task_id, candidates_num)
return {"results": [], "msg": "Keyword search failed"}
return {"results": results, "msg": f"Hybrid found {len(results)}"}
return {"results": results, "msg": f"Found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit):
logger.warning("Fallback to pure vector search")

View File

@@ -93,5 +93,28 @@ class LLMService:
logger.error(f"Rerank Exception: {e}")
# 降级策略
return documents[:top_n]
def chat(self, prompt: str, system_prompt: str = None, model: str = "qwen-turbo") -> str:
"""
[新增] 通用对话生成接口用于RAG的最终回答或作为测试裁判(Judge)
"""
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': prompt})
try:
resp = dashscope.Generation.call(
model=model,
messages=messages,
result_format='message'
)
if resp.status_code == HTTPStatus.OK:
return resp.output.choices[0].message.content
else:
logger.error(f"Chat API Error: {resp}")
return "Error generating response."
except Exception as e:
logger.error(f"Chat Exception: {e}")
return "Error generating response."
llm_service = LLMService()