80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
from fastapi import FastAPI
|
||
from .service import crawler_service
|
||
from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest
|
||
from .utils import make_response
|
||
|
||
app = FastAPI(title="Wiki Crawler API")
|
||
|
||
@app.post("/register")
|
||
async def register(req: RegisterRequest):
|
||
try:
|
||
data = crawler_service.register_task(req.url)
|
||
return make_response(1, "Success", data)
|
||
except Exception as e:
|
||
return make_response(0, str(e))
|
||
|
||
@app.post("/add_urls")
|
||
async def add_urls(req: AddUrlsRequest):
|
||
try:
|
||
urls = req.urls_obj["urls"]
|
||
data = crawler_service.add_urls(req.task_id, urls=urls)
|
||
return make_response(1, "Success", data)
|
||
except Exception as e:
|
||
return make_response(0, str(e))
|
||
|
||
@app.post("/pending_urls")
|
||
async def pending_urls(req: PendingRequest):
|
||
try:
|
||
data = crawler_service.get_pending_urls(req.task_id, req.limit)
|
||
msg = "Success" if data["urls"] else "Queue Empty"
|
||
return make_response(1, msg, data)
|
||
except Exception as e:
|
||
return make_response(0, str(e))
|
||
|
||
@app.post("/save_results")
|
||
async def save_results(req: SaveResultsRequest):
|
||
try:
|
||
data = crawler_service.save_results(req.task_id, req.results)
|
||
return make_response(1, "Success", data)
|
||
except Exception as e:
|
||
return make_response(0, str(e))
|
||
|
||
@app.post("/search")
|
||
async def search(req: SearchRequest):
|
||
"""
|
||
通用搜索接口:
|
||
支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。
|
||
"""
|
||
try:
|
||
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536)
|
||
vector = req.query_embedding['vector']
|
||
|
||
if not vector or len(vector) != 1536:
|
||
return make_response(
|
||
code=2,
|
||
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
|
||
data=None
|
||
)
|
||
|
||
# 2. 调用业务类执行搜索
|
||
data = crawler_service.search_knowledge(
|
||
query_embedding=vector,
|
||
task_id=req.task_id,
|
||
limit=req.limit
|
||
)
|
||
|
||
# 3. 统一返回
|
||
return make_response(
|
||
code=1,
|
||
msg="搜索完成",
|
||
data=data
|
||
)
|
||
|
||
except Exception as e:
|
||
# 记录日志并返回失败信息
|
||
print(f"搜索接口异常: {str(e)}")
|
||
return make_response(code=0, msg=f"搜索失败: {str(e)}")
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) |