完成RAG测试
This commit is contained in:
@@ -143,7 +143,7 @@ class DataService:
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status='completed'))
|
||||
|
||||
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50):
|
||||
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50, vector_weight: float = 0.7):
|
||||
# 向量格式清洗
|
||||
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
|
||||
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
|
||||
@@ -151,26 +151,52 @@ class DataService:
|
||||
|
||||
results = []
|
||||
with self.db.engine.connect() as conn:
|
||||
# 1. 构造 Query 对象 (这是 tsquery 类型)
|
||||
keyword_query = func.websearch_to_tsquery('english', query_text)
|
||||
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
|
||||
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
|
||||
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
|
||||
|
||||
# 计算分数 (逻辑不变)
|
||||
vector_dist = self.db.chunks.c.embedding.cosine_distance(query_vector)
|
||||
vector_score = (1 - vector_dist)
|
||||
|
||||
# 注意:ts_rank 需要 (tsvector, tsquery)
|
||||
keyword_rank = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
|
||||
keyword_score = func.coalesce(keyword_rank, 0)
|
||||
|
||||
keyword_weight = 1.0 - vector_weight
|
||||
final_score = (vector_score * vector_weight + keyword_score * keyword_weight).label("score")
|
||||
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
|
||||
)
|
||||
|
||||
# ================= 修复点开始 =================
|
||||
# 只有当 vector_weight 为 0 (纯关键词模式) 时,才强制加 WHERE 过滤
|
||||
if vector_weight == 0:
|
||||
# 错误写法 (SQLAlchemy 自动生成 plainto_tsquery 导致报错):
|
||||
# stmt = stmt.where(keyword_query.match(self.db.chunks.c.content_tsvector))
|
||||
|
||||
# 正确写法 (直接使用 PG 的 @@ 操作符):
|
||||
# 含义: content_tsvector @@ keyword_query
|
||||
stmt = stmt.where(self.db.chunks.c.content_tsvector.op('@@')(keyword_query))
|
||||
# ================= 修复点结束 =================
|
||||
|
||||
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
stmt = stmt.order_by(desc("score")).limit(candidates_num)
|
||||
|
||||
try:
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
|
||||
except Exception as e:
|
||||
logger.error(f"Hybrid search failed: {e}")
|
||||
return self._fallback_vector_search(query_vector, task_id, candidates_num)
|
||||
# 打印详细错误方便调试
|
||||
logger.error(f"Search failed: {e}")
|
||||
# 只有混合或向量模式才回退,如果是纯关键词模式报错,回退也没用
|
||||
if vector_weight > 0:
|
||||
return self._fallback_vector_search(query_vector, task_id, candidates_num)
|
||||
return {"results": [], "msg": "Keyword search failed"}
|
||||
|
||||
return {"results": results, "msg": f"Hybrid found {len(results)}"}
|
||||
return {"results": results, "msg": f"Found {len(results)}"}
|
||||
|
||||
def _fallback_vector_search(self, vector, task_id, limit):
|
||||
logger.warning("Fallback to pure vector search")
|
||||
|
||||
@@ -93,5 +93,28 @@ class LLMService:
|
||||
logger.error(f"Rerank Exception: {e}")
|
||||
# 降级策略
|
||||
return documents[:top_n]
|
||||
def chat(self, prompt: str, system_prompt: str = None, model: str = "qwen-turbo") -> str:
|
||||
"""
|
||||
[新增] 通用对话生成接口,用于RAG的最终回答或作为测试裁判(Judge)
|
||||
"""
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({'role': 'system', 'content': system_prompt})
|
||||
messages.append({'role': 'user', 'content': prompt})
|
||||
|
||||
try:
|
||||
resp = dashscope.Generation.call(
|
||||
model=model,
|
||||
messages=messages,
|
||||
result_format='message'
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return resp.output.choices[0].message.content
|
||||
else:
|
||||
logger.error(f"Chat API Error: {resp}")
|
||||
return "Error generating response."
|
||||
except Exception as e:
|
||||
logger.error(f"Chat Exception: {e}")
|
||||
return "Error generating response."
|
||||
|
||||
llm_service = LLMService()
|
||||
Reference in New Issue
Block a user