From 9b283d2f72a12a9a6b738df9d390b916107fb2a3 Mon Sep 17 00:00:00 2001 From: QingGang Date: Tue, 23 Dec 2025 00:36:49 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0search=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/main.py | 38 +++++++++++++++++++++++++++-- backend/schemas.py | 8 ++++++- backend/service.py | 44 ++++++++++++++++++++++++++++++++++ nodes/chunk_and_embedding.py | 2 +- nodes/embedding.py | 46 ++++++++++++++++++++++++++++++++++++ 5 files changed, 134 insertions(+), 4 deletions(-) create mode 100644 nodes/embedding.py diff --git a/backend/main.py b/backend/main.py index 2b405c0..7666d91 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI from .service import crawler_service -from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest +from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest from .utils import make_response app = FastAPI(title="Wiki Crawler API") @@ -37,7 +37,41 @@ async def save_results(req: SaveResultsRequest): return make_response(1, "Success", data) except Exception as e: return make_response(0, str(e)) - + +@app.post("/search") +async def search(req: SearchRequest): + """ + 通用搜索接口: + 支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。 + """ + try: + # 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536) + if not req.query_embedding or len(req.query_embedding) != 1536: + return make_response( + code=2, + msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}", + data=None + ) + + # 2. 调用业务类执行搜索 + data = crawler_service.search_knowledge( + query_embedding=req.query_embedding, + task_id=req.task_id, + limit=req.limit + ) + + # 3. 统一返回 + return make_response( + code=1, + msg="搜索完成", + data=data + ) + + except Exception as e: + # 记录日志并返回失败信息 + print(f"搜索接口异常: {str(e)}") + return make_response(code=0, msg=f"搜索失败: {str(e)}") + if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas.py index 041abd5..bc26b21 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -22,4 +22,10 @@ class CrawlResult(BaseModel): class SaveResultsRequest(BaseModel): task_id: int - results: List[CrawlResult] \ No newline at end of file + results: List[CrawlResult] + +class SearchRequest(BaseModel): + # 如果不传 task_id,则进行全库搜索 + task_id: Optional[int] = None + query_embedding: List[float] + limit: Optional[int] = 5 \ No newline at end of file diff --git a/backend/service.py b/backend/service.py index c7874a5..281b003 100644 --- a/backend/service.py +++ b/backend/service.py @@ -156,6 +156,50 @@ class CrawlerService: "failed": len(failed_chunks) } } + def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5): + """ + 高性能向量搜索方法 + :param query_embedding: 问题的向量 + :param task_id: 可选的任务ID,不传则搜全表 + :param limit: 返回结果数量 + """ + + + with self.db.engine.connect() as conn: + # 1. 选择需要的字段 + # 我们同时返回 task_id,方便在全库搜索时知道来源哪个任务 + stmt = select( + self.db.chunks.c.task_id, + self.db.chunks.c.source_url, + self.db.chunks.c.title, + self.db.chunks.c.content, + self.db.chunks.c.chunk_index + ) + + # 2. 动态添加过滤条件 + if task_id is not None: + stmt = stmt.where(self.db.chunks.c.task_id == task_id) + + # 3. 按余弦距离排序(1 - 余弦相似度) + # 距离越小,相似度越高 + stmt = stmt.order_by( + self.db.chunks.c.embedding.cosine_distance(query_embedding) + ).limit(limit) + + # 4. 执行并解析结果 + rows = conn.execute(stmt).fetchall() + + results = [] + for r in rows: + results.append({ + "task_id": r[0], + "source_url": r[1], + "title": r[2], + "content": r[3], + "chunk_index": r[4] + }) + + return results crawler_service = CrawlerService() \ No newline at end of file diff --git a/nodes/chunk_and_embedding.py b/nodes/chunk_and_embedding.py index 597db82..6f932d3 100644 --- a/nodes/chunk_and_embedding.py +++ b/nodes/chunk_and_embedding.py @@ -57,7 +57,7 @@ def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]: if "output" in result and "embeddings" in result["output"]: embeddings_list = result["output"]["embeddings"] - embeddings_list.sort(key=lambda x: x["text_index"]) + embeddings_list.sort(key=lambda x: x["text_index"]) # 按文本索引排序,确保顺序一致 # --- 核心修复:对每个浮点数保留 8 位小数,解决精度过高报错 --- final_vectors = [] diff --git a/nodes/embedding.py b/nodes/embedding.py new file mode 100644 index 0000000..cff4507 --- /dev/null +++ b/nodes/embedding.py @@ -0,0 +1,46 @@ +import requests + +def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]: + if not texts: + return [] + + MODEL_NAME = "text-embedding-v4" + url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding" + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + payload = { + "model": MODEL_NAME, + "input": {"texts": texts}, + "parameters": {"text_type": "document", "dimension": 1536} + } + + try: + response = requests.post(url, headers=headers, json=payload, timeout=60) + response.raise_for_status() + result = response.json() + + if "output" in result and "embeddings" in result["output"]: + embeddings_list = result["output"]["embeddings"] + embeddings_list.sort(key=lambda x: x["text_index"]) + + # --- 核心修复:对每个浮点数保留 8 位小数,解决精度过高报错 --- + final_vectors = [] + for item in embeddings_list: + # 将每个 float 限制在 8 位精度以内 + rounded_vector = [round(float(val), 8) for val in item["embedding"]] + final_vectors.append(rounded_vector) + return final_vectors + else: + return [None] * len(texts) + except Exception as e: + print(f"Alibaba Embedding Error: {e}") + return [None] * len(texts) + +def main(text: str, api_key: str): + + vector = chunks_embedding([text], api_key)[0] + return { + 'vector': vector + } \ No newline at end of file