aaaa
This commit is contained in:
15
.vscode/launch.json
vendored
15
.vscode/launch.json
vendored
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: FastAPI",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "uvicorn",
|
||||
"args": [
|
||||
"backend.main:app",
|
||||
"--reload"
|
||||
],
|
||||
"jinja": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -16,7 +16,8 @@ async def register(req: RegisterRequest):
|
||||
@app.post("/add_urls")
|
||||
async def add_urls(req: AddUrlsRequest):
|
||||
try:
|
||||
data = crawler_service.add_urls(req.task_id, req.urls_obj)
|
||||
urls = req.urls_obj["urls"]
|
||||
data = crawler_service.add_urls(req.task_id, urls=urls)
|
||||
return make_response(1, "Success", data)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
@@ -46,7 +47,9 @@ async def search(req: SearchRequest):
|
||||
"""
|
||||
try:
|
||||
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536)
|
||||
if not req.query_embedding or len(req.query_embedding) != 1536:
|
||||
vector = req.query_embedding['vector']
|
||||
|
||||
if not vector or len(vector) != 1536:
|
||||
return make_response(
|
||||
code=2,
|
||||
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
|
||||
@@ -55,7 +58,7 @@ async def search(req: SearchRequest):
|
||||
|
||||
# 2. 调用业务类执行搜索
|
||||
data = crawler_service.search_knowledge(
|
||||
query_embedding=req.query_embedding,
|
||||
query_embedding=vector,
|
||||
task_id=req.task_id,
|
||||
limit=req.limit
|
||||
)
|
||||
|
||||
@@ -27,5 +27,5 @@ class SaveResultsRequest(BaseModel):
|
||||
class SearchRequest(BaseModel):
|
||||
# 如果不传 task_id,则进行全库搜索
|
||||
task_id: Optional[int] = None
|
||||
query_embedding: List[float]
|
||||
query_embedding: dict
|
||||
limit: Optional[int] = 5
|
||||
@@ -23,11 +23,9 @@ class CrawlerService:
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
return {"task_id": new_task[0], "is_new_task": True}
|
||||
|
||||
def add_urls(self, task_id: int, urls_obj: dict):
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
"""通用 API 实现的批量添加(含详细返回)"""
|
||||
success_urls, skipped_urls, failed_urls = [], [], []
|
||||
# 从 urls_obj 中提取 urls 列表
|
||||
urls = urls_obj.get("urls", [])
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
|
||||
6
main.py
6
main.py
@@ -1,6 +0,0 @@
|
||||
def main():
|
||||
print("Hello from wiki-crawler!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -41,6 +41,9 @@ def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]:
|
||||
def main(text: str, api_key: str):
|
||||
|
||||
vector = chunks_embedding([text], api_key)[0]
|
||||
|
||||
return {
|
||||
'vector': vector
|
||||
'vector': {
|
||||
'vector': vector
|
||||
}
|
||||
}
|
||||
@@ -1,19 +1,27 @@
|
||||
def check_status(status_code: float, body: str):
|
||||
import json
|
||||
|
||||
def parse_response(status_code: float, body: str):
|
||||
'''
|
||||
检查状态码和约定的返回值
|
||||
并且返回正确的body
|
||||
'''
|
||||
if status_code != 200:
|
||||
raise Exception(f"注册任务失败,状态码:{status_code}")
|
||||
if "code" not in body or body["code"] != 1:
|
||||
|
||||
data = json.loads(body)
|
||||
|
||||
if "code" not in data or data["code"] != 1:
|
||||
raise Exception(f"注册任务失败,返回值:{body}")
|
||||
|
||||
return data["data"]
|
||||
|
||||
def main(status_code: float, body: str):
|
||||
try:
|
||||
check_status(status_code, body)
|
||||
data = parse_response(status_code, body)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
urls = body["data"]["urls"]
|
||||
urls = data["urls"]
|
||||
|
||||
return {
|
||||
"urls": urls,
|
||||
|
||||
@@ -27,14 +27,3 @@ def main(status_code: float, body: str):
|
||||
"task_id": task_id,
|
||||
"is_new_task": is_new_task
|
||||
}
|
||||
|
||||
def test():
|
||||
import json
|
||||
with open("nodes\parse_register.json", "r") as f:
|
||||
data = json.load(f)
|
||||
status_code = data["status_code"]
|
||||
body = data["body"]
|
||||
res = main(status_code, body)
|
||||
print(res)
|
||||
|
||||
test()
|
||||
@@ -1,20 +1,27 @@
|
||||
def check_status(status_code: float, body: str):
|
||||
import json
|
||||
|
||||
def parse_response(status_code: float, body: str):
|
||||
'''
|
||||
检查状态码和约定的返回值
|
||||
并且返回正确的body
|
||||
'''
|
||||
if status_code != 200:
|
||||
raise Exception(f"注册任务失败,状态码:{status_code}")
|
||||
if "code" not in body or body["code"] != 1:
|
||||
|
||||
data = json.loads(body)
|
||||
|
||||
if "code" not in data or data["code"] != 1:
|
||||
raise Exception(f"注册任务失败,返回值:{body}")
|
||||
|
||||
return data["data"]
|
||||
|
||||
def main(status_code: float, body: str):
|
||||
try:
|
||||
check_status(status_code, body)
|
||||
data = parse_response(status_code, body)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
urls_result = body["data"]
|
||||
|
||||
return {
|
||||
"add_urls_result": urls_result
|
||||
"add_urls_result": data
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user