diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 1b22d63..0000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "configurations": [ - { - "name": "Python Debugger: FastAPI", - "type": "debugpy", - "request": "launch", - "module": "uvicorn", - "args": [ - "backend.main:app", - "--reload" - ], - "jinja": true - } - ] -} \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index b8cec64..349ce7c 100644 --- a/backend/main.py +++ b/backend/main.py @@ -16,7 +16,8 @@ async def register(req: RegisterRequest): @app.post("/add_urls") async def add_urls(req: AddUrlsRequest): try: - data = crawler_service.add_urls(req.task_id, req.urls_obj) + urls = req.urls_obj["urls"] + data = crawler_service.add_urls(req.task_id, urls=urls) return make_response(1, "Success", data) except Exception as e: return make_response(0, str(e)) @@ -46,7 +47,9 @@ async def search(req: SearchRequest): """ try: # 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536) - if not req.query_embedding or len(req.query_embedding) != 1536: + vector = req.query_embedding['vector'] + + if not vector or len(vector) != 1536: return make_response( code=2, msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}", @@ -55,7 +58,7 @@ async def search(req: SearchRequest): # 2. 调用业务类执行搜索 data = crawler_service.search_knowledge( - query_embedding=req.query_embedding, + query_embedding=vector, task_id=req.task_id, limit=req.limit ) diff --git a/backend/schemas.py b/backend/schemas.py index c1c4d92..4a5cbad 100644 --- a/backend/schemas.py +++ b/backend/schemas.py @@ -27,5 +27,5 @@ class SaveResultsRequest(BaseModel): class SearchRequest(BaseModel): # 如果不传 task_id,则进行全库搜索 task_id: Optional[int] = None - query_embedding: List[float] + query_embedding: dict limit: Optional[int] = 5 \ No newline at end of file diff --git a/backend/service.py b/backend/service.py index 4c535fa..1aacbaa 100644 --- a/backend/service.py +++ b/backend/service.py @@ -23,11 +23,9 @@ class CrawlerService: new_task = conn.execute(stmt).fetchone() return {"task_id": new_task[0], "is_new_task": True} - def add_urls(self, task_id: int, urls_obj: dict): + def add_urls(self, task_id: int, urls: list[str]): """通用 API 实现的批量添加(含详细返回)""" success_urls, skipped_urls, failed_urls = [], [], [] - # 从 urls_obj 中提取 urls 列表 - urls = urls_obj.get("urls", []) with self.db.engine.begin() as conn: for url in urls: diff --git a/main.py b/main.py deleted file mode 100644 index a580541..0000000 --- a/main.py +++ /dev/null @@ -1,6 +0,0 @@ -def main(): - print("Hello from wiki-crawler!") - - -if __name__ == "__main__": - main() diff --git a/nodes/embedding.py b/nodes/embedding.py index cff4507..0bf41b5 100644 --- a/nodes/embedding.py +++ b/nodes/embedding.py @@ -41,6 +41,9 @@ def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]: def main(text: str, api_key: str): vector = chunks_embedding([text], api_key)[0] + return { - 'vector': vector + 'vector': { + 'vector': vector + } } \ No newline at end of file diff --git a/nodes/parse_pending_urls.py b/nodes/parse_pending_urls.py index 75613bf..c198791 100644 --- a/nodes/parse_pending_urls.py +++ b/nodes/parse_pending_urls.py @@ -1,19 +1,27 @@ -def check_status(status_code: float, body: str): +import json + +def parse_response(status_code: float, body: str): ''' 检查状态码和约定的返回值 + 并且返回正确的body ''' if status_code != 200: raise Exception(f"注册任务失败,状态码:{status_code}") - if "code" not in body or body["code"] != 1: + + data = json.loads(body) + + if "code" not in data or data["code"] != 1: raise Exception(f"注册任务失败,返回值:{body}") + return data["data"] + def main(status_code: float, body: str): try: - check_status(status_code, body) + data = parse_response(status_code, body) except Exception as e: raise e - urls = body["data"]["urls"] + urls = data["urls"] return { "urls": urls, diff --git a/nodes/parse_register.py b/nodes/parse_register.py index 7c401b4..407bfe4 100644 --- a/nodes/parse_register.py +++ b/nodes/parse_register.py @@ -26,15 +26,4 @@ def main(status_code: float, body: str): return { "task_id": task_id, "is_new_task": is_new_task - } - -def test(): - import json - with open("nodes\parse_register.json", "r") as f: - data = json.load(f) - status_code = data["status_code"] - body = data["body"] - res = main(status_code, body) - print(res) - -test() \ No newline at end of file + } \ No newline at end of file diff --git a/nodes/parse_save_urls.py b/nodes/parse_save_urls.py index 00968ab..b72de8d 100644 --- a/nodes/parse_save_urls.py +++ b/nodes/parse_save_urls.py @@ -1,20 +1,27 @@ -def check_status(status_code: float, body: str): +import json + +def parse_response(status_code: float, body: str): ''' 检查状态码和约定的返回值 + 并且返回正确的body ''' if status_code != 200: raise Exception(f"注册任务失败,状态码:{status_code}") - if "code" not in body or body["code"] != 1: + + data = json.loads(body) + + if "code" not in data or data["code"] != 1: raise Exception(f"注册任务失败,返回值:{body}") + + return data["data"] def main(status_code: float, body: str): try: - check_status(status_code, body) + data = parse_response(status_code, body) except Exception as e: raise e - urls_result = body["data"] return { - "add_urls_result": urls_result + "add_urls_result": data }