添加search方法
This commit is contained in:
@@ -156,6 +156,50 @@ class CrawlerService:
|
||||
"failed": len(failed_chunks)
|
||||
}
|
||||
}
|
||||
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
高性能向量搜索方法
|
||||
:param query_embedding: 问题的向量
|
||||
:param task_id: 可选的任务ID,不传则搜全表
|
||||
:param limit: 返回结果数量
|
||||
"""
|
||||
|
||||
|
||||
with self.db.engine.connect() as conn:
|
||||
# 1. 选择需要的字段
|
||||
# 我们同时返回 task_id,方便在全库搜索时知道来源哪个任务
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id,
|
||||
self.db.chunks.c.source_url,
|
||||
self.db.chunks.c.title,
|
||||
self.db.chunks.c.content,
|
||||
self.db.chunks.c.chunk_index
|
||||
)
|
||||
|
||||
# 2. 动态添加过滤条件
|
||||
if task_id is not None:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
# 3. 按余弦距离排序(1 - 余弦相似度)
|
||||
# 距离越小,相似度越高
|
||||
stmt = stmt.order_by(
|
||||
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||
).limit(limit)
|
||||
|
||||
# 4. 执行并解析结果
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
|
||||
results = []
|
||||
for r in rows:
|
||||
results.append({
|
||||
"task_id": r[0],
|
||||
"source_url": r[1],
|
||||
"title": r[2],
|
||||
"content": r[3],
|
||||
"chunk_index": r[4]
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
Reference in New Issue
Block a user