From d5ee00d40407f4c5fdb7edae83379859db429e87 Mon Sep 17 00:00:00 2001 From: QingGang Date: Tue, 13 Jan 2026 02:23:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=B7=B7=E5=90=88=E6=A3=80?= =?UTF-8?q?=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/crawler_service.py | 4 +- backend/services/data_service.py | 68 +++++++++++++++++++++++++---- scripts/test_apis.py | 2 +- 3 files changed, 63 insertions(+), 11 deletions(-) diff --git a/backend/services/crawler_service.py b/backend/services/crawler_service.py index 0293c22..3eed083 100644 --- a/backend/services/crawler_service.py +++ b/backend/services/crawler_service.py @@ -142,9 +142,9 @@ class CrawlerService: return {"msg": "Batch processed", "count": processed} - def search(self, query: str, task_id: int, limit: int): + def search(self, query: str, task_id, limit: int): vector = llm_service.get_embedding(query) if not vector: return {"msg": "Embedding failed", "results": []} - return data_service.search(vector, task_id, limit) + return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit) crawler_service = CrawlerService() \ No newline at end of file diff --git a/backend/services/data_service.py b/backend/services/data_service.py index a52d5a5..1e75700 100644 --- a/backend/services/data_service.py +++ b/backend/services/data_service.py @@ -1,4 +1,4 @@ -from sqlalchemy import select, insert, update, and_ +from sqlalchemy import select, insert, update, and_, text, func, desc from backend.core.database import db from backend.utils.common import normalize_url @@ -91,7 +91,63 @@ class DataService: return {"msg": f"Saved {count} chunks", "count": count} - def search(self, vector: list, task_id: int = None, limit: int = 5): + def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5): + """ + Phase 2: 混合检索 (Hybrid Search) + 综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword) + """ + + results = [] + with self.db.engine.connect() as conn: + # 定义混合检索的 SQL 逻辑 + + # 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法) + keyword_query = func.websearch_to_tsquery('english', query_text) + + vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector)) + keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query) + + # 综合打分列: 0.7 * Vector + 0.3 * Keyword + # coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0 + final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score") + + stmt = select( + self.db.chunks.c.task_id, + self.db.chunks.c.source_url, + self.db.chunks.c.title, + self.db.chunks.c.content, + self.db.chunks.c.meta_info, + final_score + ) + + if task_id: + stmt = stmt.where(self.db.chunks.c.task_id == task_id) + + # 按综合分数倒序 + stmt = stmt.order_by(desc("score")).limit(limit) + + try: + rows = conn.execute(stmt).fetchall() + results = [ + { + "task_id": r[0], + "source_url": r[1], + "title": r[2], + "content": r[3], + "meta_info": r[4], + "score": float(r[5]) + } + for r in rows + ] + except Exception as e: + print(f"[ERROR] Hybrid search failed: {e}") + return self._fallback_vector_search(query_vector, task_id, limit) + + return {"results": results, "msg": f"Hybrid found {len(results)}"} + + def _fallback_vector_search(self, vector, task_id, limit): + """降级兜底:纯向量搜索""" + print("[WARN] Fallback to pure vector search") with self.db.engine.connect() as conn: stmt = select( self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, @@ -100,12 +156,8 @@ class DataService: if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id) - + rows = conn.execute(stmt).fetchall() - results = [ - {"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]} - for r in rows - ] - return {"results": results, "msg": f"Found {len(results)}"} + return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"} data_service = DataService() \ No newline at end of file diff --git a/scripts/test_apis.py b/scripts/test_apis.py index 7503fda..4a1a132 100644 --- a/scripts/test_apis.py +++ b/scripts/test_apis.py @@ -86,7 +86,7 @@ def run_e2e_test(): # Step 3: 轮询搜索结果 (Polling) # --------------------------------------------------------- log("STEP 3", "轮询搜索接口,等待数据入库...") - + task_id = 6 max_retries = 12 found_data = False search_results = []