添加search方法

This commit is contained in:
2025-12-23 00:36:49 +08:00
parent 1585b2c31b
commit 9b283d2f72
5 changed files with 134 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
from fastapi import FastAPI
from .service import crawler_service
from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest
from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest
from .utils import make_response
app = FastAPI(title="Wiki Crawler API")
@@ -37,7 +37,41 @@ async def save_results(req: SaveResultsRequest):
return make_response(1, "Success", data)
except Exception as e:
return make_response(0, str(e))
@app.post("/search")
async def search(req: SearchRequest):
"""
通用搜索接口:
支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。
"""
try:
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536
if not req.query_embedding or len(req.query_embedding) != 1536:
return make_response(
code=2,
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
data=None
)
# 2. 调用业务类执行搜索
data = crawler_service.search_knowledge(
query_embedding=req.query_embedding,
task_id=req.task_id,
limit=req.limit
)
# 3. 统一返回
return make_response(
code=1,
msg="搜索完成",
data=data
)
except Exception as e:
# 记录日志并返回失败信息
print(f"搜索接口异常: {str(e)}")
return make_response(code=0, msg=f"搜索失败: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)

View File

@@ -22,4 +22,10 @@ class CrawlResult(BaseModel):
class SaveResultsRequest(BaseModel):
task_id: int
results: List[CrawlResult]
results: List[CrawlResult]
class SearchRequest(BaseModel):
# 如果不传 task_id则进行全库搜索
task_id: Optional[int] = None
query_embedding: List[float]
limit: Optional[int] = 5

View File

@@ -156,6 +156,50 @@ class CrawlerService:
"failed": len(failed_chunks)
}
}
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
:param query_embedding: 问题的向量
:param task_id: 可选的任务ID不传则搜全表
:param limit: 返回结果数量
"""
with self.db.engine.connect() as conn:
# 1. 选择需要的字段
# 我们同时返回 task_id方便在全库搜索时知道来源哪个任务
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.chunk_index
)
# 2. 动态添加过滤条件
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 3. 按余弦距离排序1 - 余弦相似度)
# 距离越小,相似度越高
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
# 4. 执行并解析结果
rows = conn.execute(stmt).fetchall()
results = []
for r in rows:
results.append({
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"chunk_index": r[4]
})
return results
crawler_service = CrawlerService()