This commit is contained in:
2025-12-29 14:42:33 +08:00
parent 9f636d1c31
commit 8c4491b383
9 changed files with 37 additions and 50 deletions

15
.vscode/launch.json vendored
View File

@@ -1,15 +0,0 @@
{
"configurations": [
{
"name": "Python Debugger: FastAPI",
"type": "debugpy",
"request": "launch",
"module": "uvicorn",
"args": [
"backend.main:app",
"--reload"
],
"jinja": true
}
]
}

View File

@@ -16,7 +16,8 @@ async def register(req: RegisterRequest):
@app.post("/add_urls")
async def add_urls(req: AddUrlsRequest):
try:
data = crawler_service.add_urls(req.task_id, req.urls_obj)
urls = req.urls_obj["urls"]
data = crawler_service.add_urls(req.task_id, urls=urls)
return make_response(1, "Success", data)
except Exception as e:
return make_response(0, str(e))
@@ -46,7 +47,9 @@ async def search(req: SearchRequest):
"""
try:
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536
if not req.query_embedding or len(req.query_embedding) != 1536:
vector = req.query_embedding['vector']
if not vector or len(vector) != 1536:
return make_response(
code=2,
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
@@ -55,7 +58,7 @@ async def search(req: SearchRequest):
# 2. 调用业务类执行搜索
data = crawler_service.search_knowledge(
query_embedding=req.query_embedding,
query_embedding=vector,
task_id=req.task_id,
limit=req.limit
)

View File

@@ -27,5 +27,5 @@ class SaveResultsRequest(BaseModel):
class SearchRequest(BaseModel):
# 如果不传 task_id则进行全库搜索
task_id: Optional[int] = None
query_embedding: List[float]
query_embedding: dict
limit: Optional[int] = 5

View File

@@ -23,11 +23,9 @@ class CrawlerService:
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
def add_urls(self, task_id: int, urls_obj: dict):
def add_urls(self, task_id: int, urls: list[str]):
"""通用 API 实现的批量添加(含详细返回)"""
success_urls, skipped_urls, failed_urls = [], [], []
# 从 urls_obj 中提取 urls 列表
urls = urls_obj.get("urls", [])
with self.db.engine.begin() as conn:
for url in urls:

View File

@@ -1,6 +0,0 @@
def main():
print("Hello from wiki-crawler!")
if __name__ == "__main__":
main()

View File

@@ -41,6 +41,9 @@ def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]:
def main(text: str, api_key: str):
vector = chunks_embedding([text], api_key)[0]
return {
'vector': vector
'vector': {
'vector': vector
}
}

View File

@@ -1,19 +1,27 @@
def check_status(status_code: float, body: str):
import json
def parse_response(status_code: float, body: str):
'''
检查状态码和约定的返回值
并且返回正确的body
'''
if status_code != 200:
raise Exception(f"注册任务失败,状态码:{status_code}")
if "code" not in body or body["code"] != 1:
data = json.loads(body)
if "code" not in data or data["code"] != 1:
raise Exception(f"注册任务失败,返回值:{body}")
return data["data"]
def main(status_code: float, body: str):
try:
check_status(status_code, body)
data = parse_response(status_code, body)
except Exception as e:
raise e
urls = body["data"]["urls"]
urls = data["urls"]
return {
"urls": urls,

View File

@@ -26,15 +26,4 @@ def main(status_code: float, body: str):
return {
"task_id": task_id,
"is_new_task": is_new_task
}
def test():
import json
with open("nodes\parse_register.json", "r") as f:
data = json.load(f)
status_code = data["status_code"]
body = data["body"]
res = main(status_code, body)
print(res)
test()
}

View File

@@ -1,20 +1,27 @@
def check_status(status_code: float, body: str):
import json
def parse_response(status_code: float, body: str):
'''
检查状态码和约定的返回值
并且返回正确的body
'''
if status_code != 200:
raise Exception(f"注册任务失败,状态码:{status_code}")
if "code" not in body or body["code"] != 1:
data = json.loads(body)
if "code" not in data or data["code"] != 1:
raise Exception(f"注册任务失败,返回值:{body}")
return data["data"]
def main(status_code: float, body: str):
try:
check_status(status_code, body)
data = parse_response(status_code, body)
except Exception as e:
raise e
urls_result = body["data"]
return {
"add_urls_result": urls_result
"add_urls_result": data
}