完成混合检索
This commit is contained in:
@@ -142,9 +142,9 @@ class CrawlerService:
|
|||||||
|
|
||||||
return {"msg": "Batch processed", "count": processed}
|
return {"msg": "Batch processed", "count": processed}
|
||||||
|
|
||||||
def search(self, query: str, task_id: int, limit: int):
|
def search(self, query: str, task_id, limit: int):
|
||||||
vector = llm_service.get_embedding(query)
|
vector = llm_service.get_embedding(query)
|
||||||
if not vector: return {"msg": "Embedding failed", "results": []}
|
if not vector: return {"msg": "Embedding failed", "results": []}
|
||||||
return data_service.search(vector, task_id, limit)
|
return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit)
|
||||||
|
|
||||||
crawler_service = CrawlerService()
|
crawler_service = CrawlerService()
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from sqlalchemy import select, insert, update, and_
|
from sqlalchemy import select, insert, update, and_, text, func, desc
|
||||||
from backend.core.database import db
|
from backend.core.database import db
|
||||||
from backend.utils.common import normalize_url
|
from backend.utils.common import normalize_url
|
||||||
|
|
||||||
@@ -91,7 +91,63 @@ class DataService:
|
|||||||
|
|
||||||
return {"msg": f"Saved {count} chunks", "count": count}
|
return {"msg": f"Saved {count} chunks", "count": count}
|
||||||
|
|
||||||
def search(self, vector: list, task_id: int = None, limit: int = 5):
|
def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5):
|
||||||
|
"""
|
||||||
|
Phase 2: 混合检索 (Hybrid Search)
|
||||||
|
综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword)
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = []
|
||||||
|
with self.db.engine.connect() as conn:
|
||||||
|
# 定义混合检索的 SQL 逻辑
|
||||||
|
|
||||||
|
# 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法)
|
||||||
|
keyword_query = func.websearch_to_tsquery('english', query_text)
|
||||||
|
|
||||||
|
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
|
||||||
|
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
|
||||||
|
|
||||||
|
# 综合打分列: 0.7 * Vector + 0.3 * Keyword
|
||||||
|
# coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0
|
||||||
|
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
|
||||||
|
|
||||||
|
stmt = select(
|
||||||
|
self.db.chunks.c.task_id,
|
||||||
|
self.db.chunks.c.source_url,
|
||||||
|
self.db.chunks.c.title,
|
||||||
|
self.db.chunks.c.content,
|
||||||
|
self.db.chunks.c.meta_info,
|
||||||
|
final_score
|
||||||
|
)
|
||||||
|
|
||||||
|
if task_id:
|
||||||
|
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||||
|
|
||||||
|
# 按综合分数倒序
|
||||||
|
stmt = stmt.order_by(desc("score")).limit(limit)
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = conn.execute(stmt).fetchall()
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"task_id": r[0],
|
||||||
|
"source_url": r[1],
|
||||||
|
"title": r[2],
|
||||||
|
"content": r[3],
|
||||||
|
"meta_info": r[4],
|
||||||
|
"score": float(r[5])
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[ERROR] Hybrid search failed: {e}")
|
||||||
|
return self._fallback_vector_search(query_vector, task_id, limit)
|
||||||
|
|
||||||
|
return {"results": results, "msg": f"Hybrid found {len(results)}"}
|
||||||
|
|
||||||
|
def _fallback_vector_search(self, vector, task_id, limit):
|
||||||
|
"""降级兜底:纯向量搜索"""
|
||||||
|
print("[WARN] Fallback to pure vector search")
|
||||||
with self.db.engine.connect() as conn:
|
with self.db.engine.connect() as conn:
|
||||||
stmt = select(
|
stmt = select(
|
||||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||||
@@ -100,12 +156,8 @@ class DataService:
|
|||||||
|
|
||||||
if task_id:
|
if task_id:
|
||||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||||
|
|
||||||
rows = conn.execute(stmt).fetchall()
|
rows = conn.execute(stmt).fetchall()
|
||||||
results = [
|
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
|
||||||
{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]}
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return {"results": results, "msg": f"Found {len(results)}"}
|
|
||||||
|
|
||||||
data_service = DataService()
|
data_service = DataService()
|
||||||
@@ -86,7 +86,7 @@ def run_e2e_test():
|
|||||||
# Step 3: 轮询搜索结果 (Polling)
|
# Step 3: 轮询搜索结果 (Polling)
|
||||||
# ---------------------------------------------------------
|
# ---------------------------------------------------------
|
||||||
log("STEP 3", "轮询搜索接口,等待数据入库...")
|
log("STEP 3", "轮询搜索接口,等待数据入库...")
|
||||||
|
task_id = 6
|
||||||
max_retries = 12
|
max_retries = 12
|
||||||
found_data = False
|
found_data = False
|
||||||
search_results = []
|
search_results = []
|
||||||
|
|||||||
Reference in New Issue
Block a user