新增获取全部知识库的接口,api版本号重归到v1
This commit is contained in:
@@ -1,16 +1,12 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from backend.routers import v3
|
||||
from backend.core.logger import setup_logging
|
||||
|
||||
# 程序启动第一件事:初始化日志
|
||||
setup_logging()
|
||||
|
||||
# 引入新路由
|
||||
from backend.routers import v1
|
||||
|
||||
app = FastAPI(
|
||||
title="Wiki Crawler System V3",
|
||||
version="3.0.0",
|
||||
description="Enterprise-grade RAG Knowledge Base API with Real-time Monitoring"
|
||||
title="Wiki Crawler API",
|
||||
version="1.0.0", # 版本号回归
|
||||
description="RAG Knowledge Base Service"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -21,7 +17,8 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(v3.router)
|
||||
# 挂载 V1
|
||||
app.include_router(v1.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -1,30 +1,67 @@
|
||||
from fastapi import APIRouter
|
||||
# backend/routers/v1.py
|
||||
from fastapi import APIRouter, BackgroundTasks, status
|
||||
from backend.services.crawler_service import crawler_service
|
||||
from backend.services.data_service import data_service
|
||||
from backend.schemas.v1 import (
|
||||
TaskCreateRequest, TaskExecuteRequest, SearchRequest,
|
||||
ResponseBase, KnowledgeBaseListResponse
|
||||
)
|
||||
from backend.utils.common import make_response
|
||||
from backend.schemas.schemas import RegisterRequest, AddUrlsRequest, PendingRequest, SearchRequest
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["V1 Manual"])
|
||||
# 【改动】前缀变更为 v1
|
||||
router = APIRouter(prefix="/api/v1", tags=["Knowledge Base API"])
|
||||
|
||||
@router.post("/register")
|
||||
async def register(req: RegisterRequest):
|
||||
# =======================================================
|
||||
# 1. 获取知识库列表 (核心新功能)
|
||||
# =======================================================
|
||||
@router.get("/knowledge-bases", response_model=ResponseBase)
|
||||
async def list_knowledge_bases():
|
||||
"""
|
||||
列出所有已存在的知识库(Task)。
|
||||
工作流可以用这个接口获取 task_id 列表,让 LLM 选择查哪个库。
|
||||
"""
|
||||
kb_list = crawler_service.get_knowledge_base_list()
|
||||
return ResponseBase(
|
||||
code=1,
|
||||
msg="Success",
|
||||
data={"total": len(kb_list), "list": kb_list}
|
||||
)
|
||||
|
||||
# =======================================================
|
||||
# 2. 任务管理
|
||||
# =======================================================
|
||||
@router.post("/tasks", status_code=status.HTTP_201_CREATED, response_model=ResponseBase)
|
||||
async def create_task(req: TaskCreateRequest):
|
||||
try:
|
||||
res = data_service.register_task(req.url)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
res = crawler_service.map_site(req.url)
|
||||
return ResponseBase(code=1, msg="Task Created", data=res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
return ResponseBase(code=0, msg=str(e))
|
||||
|
||||
@router.post("/add_urls")
|
||||
async def add_urls(req: AddUrlsRequest):
|
||||
try:
|
||||
res = data_service.add_urls(req.task_id, req.urls_obj["urls"])
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseBase)
|
||||
async def get_task_status(task_id: int):
|
||||
data = crawler_service.get_task_status(task_id)
|
||||
if not data:
|
||||
return ResponseBase(code=0, msg="Task not found")
|
||||
return ResponseBase(code=1, msg="Success", data=data)
|
||||
|
||||
@router.post("/search")
|
||||
async def search_manual(req: SearchRequest):
|
||||
@router.post("/tasks/{task_id}/run", status_code=status.HTTP_202_ACCEPTED, response_model=ResponseBase)
|
||||
async def run_task(task_id: int, req: TaskExecuteRequest, bg_tasks: BackgroundTasks):
|
||||
# 简单检查
|
||||
if not data_service.get_task_monitor_data(task_id):
|
||||
return ResponseBase(code=0, msg="Task not found")
|
||||
|
||||
bg_tasks.add_task(crawler_service.process_queue_concurrent, task_id, req.batch_size)
|
||||
return ResponseBase(code=1, msg="Execution Started", data={"task_id": task_id})
|
||||
|
||||
# =======================================================
|
||||
# 3. 搜索
|
||||
# =======================================================
|
||||
@router.post("/search", response_model=ResponseBase)
|
||||
async def search_knowledge(req: SearchRequest):
|
||||
try:
|
||||
res = data_service.search(req.query_embedding['vector'], req.task_id, req.limit)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
# req.limit 映射到 return_num
|
||||
res = crawler_service.search(req.query, req.task_id, req.limit)
|
||||
return ResponseBase(code=1, msg="Search Completed", data=res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
return ResponseBase(code=0, msg=str(e))
|
||||
@@ -1,30 +0,0 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from backend.services.crawler_service import crawler_service
|
||||
from backend.utils.common import make_response
|
||||
from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||
|
||||
router = APIRouter(prefix="/api/v2", tags=["V2 Automated"])
|
||||
|
||||
@router.post("/crawler/map")
|
||||
async def auto_map(req: AutoMapRequest):
|
||||
try:
|
||||
res = crawler_service.map_site(req.url)
|
||||
return make_response(1, res.pop("msg", "Started"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router.post("/crawler/process")
|
||||
async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks):
|
||||
try:
|
||||
bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size)
|
||||
return make_response(1, "Background processing started", {"task_id": req.task_id})
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router.post("/search")
|
||||
async def search_smart(req: TextSearchRequest):
|
||||
try:
|
||||
res = crawler_service.search(req.query, req.task_id, req.return_num)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
@@ -1,57 +0,0 @@
|
||||
from fastapi import APIRouter, BackgroundTasks, status
|
||||
from backend.services.crawler_service import crawler_service
|
||||
from backend.services.data_service import data_service
|
||||
from backend.schemas.v3 import (
|
||||
TaskCreateRequest, TaskExecuteRequest, SearchRequest,
|
||||
ResponseBase, TaskStatusData
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v3", tags=["V3 Knowledge Base"])
|
||||
|
||||
@router.post("/tasks", status_code=status.HTTP_201_CREATED, response_model=ResponseBase)
|
||||
async def create_task(req: TaskCreateRequest):
|
||||
"""创建新任务 (Map)"""
|
||||
try:
|
||||
res = crawler_service.map_site(req.url)
|
||||
return ResponseBase(code=1, msg="Task Created", data=res)
|
||||
except Exception as e:
|
||||
return ResponseBase(code=0, msg=f"Map Failed: {str(e)}")
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseBase)
|
||||
async def get_task_status(task_id: int):
|
||||
"""
|
||||
实时监控:
|
||||
返回数据库持久化状态 + 内存中正在运行的线程
|
||||
"""
|
||||
# 调用 crawler_service 的聚合方法
|
||||
data = crawler_service.get_task_status(task_id)
|
||||
|
||||
if not data:
|
||||
return ResponseBase(code=0, msg="Task not found")
|
||||
|
||||
return ResponseBase(code=1, msg="Success", data=data)
|
||||
|
||||
@router.post("/tasks/{task_id}/run", status_code=status.HTTP_202_ACCEPTED, response_model=ResponseBase)
|
||||
async def run_task(task_id: int, req: TaskExecuteRequest, bg_tasks: BackgroundTasks):
|
||||
"""触发后台多线程爬取"""
|
||||
# 简单检查任务是否存在 (查一下数据库监控数据即可)
|
||||
if not data_service.get_task_monitor_data(task_id):
|
||||
return ResponseBase(code=0, msg="Task not found")
|
||||
|
||||
# 放入后台任务
|
||||
bg_tasks.add_task(crawler_service.process_queue_concurrent, task_id, req.batch_size)
|
||||
|
||||
return ResponseBase(
|
||||
code=1,
|
||||
msg="Background Execution Started",
|
||||
data={"task_id": task_id, "mode": "concurrent_thread_pool"}
|
||||
)
|
||||
|
||||
@router.post("/search", response_model=ResponseBase)
|
||||
async def search_knowledge(req: SearchRequest):
|
||||
"""混合检索 + Rerank"""
|
||||
try:
|
||||
res = crawler_service.search(req.query, req.task_id, req.return_num)
|
||||
return ResponseBase(code=1, msg="Search Completed", data=res)
|
||||
except Exception as e:
|
||||
return ResponseBase(code=0, msg=f"Search Failed: {str(e)}")
|
||||
@@ -1,44 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Any
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
class PendingRequest(BaseModel):
|
||||
task_id: int
|
||||
limit: Optional[int] = 10
|
||||
|
||||
class AddUrlsRequest(BaseModel):
|
||||
task_id: int
|
||||
urls_obj: dict
|
||||
|
||||
class CrawlResult(BaseModel):
|
||||
source_url: str
|
||||
chunk_index: int # 新增字段
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
embedding: Optional[List[float]] = None
|
||||
|
||||
class SaveResultsRequest(BaseModel):
|
||||
task_id: int
|
||||
results: List[CrawlResult]
|
||||
|
||||
class SearchRequest(BaseModel):
|
||||
# 如果不传 task_id,则进行全库搜索
|
||||
task_id: Optional[int] = None
|
||||
query_embedding: dict
|
||||
limit: Optional[int] = 5
|
||||
|
||||
# === V2 New Schemas ===
|
||||
class AutoMapRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
class AutoProcessRequest(BaseModel):
|
||||
task_id: int
|
||||
batch_size: Optional[int] = 5
|
||||
|
||||
class TextSearchRequest(BaseModel):
|
||||
query: str # 用户直接传文字,不需要传向量了
|
||||
task_id: Optional[int] = None
|
||||
return_num: Optional[int] = 5
|
||||
@@ -7,27 +7,32 @@ class ResponseBase(BaseModel):
|
||||
msg: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
# --- [GET] 知识库列表 (新功能) ---
|
||||
class KnowledgeBaseItem(BaseModel):
|
||||
task_id: int
|
||||
root_url: str
|
||||
name: str # 提取出的简短名称,方便 LLM 识别
|
||||
|
||||
class KnowledgeBaseListResponse(BaseModel):
|
||||
total: int
|
||||
list: List[KnowledgeBaseItem]
|
||||
|
||||
# --- [POST] 创建任务 ---
|
||||
class TaskCreateRequest(BaseModel):
|
||||
url: str = Field(..., description="目标网站根URL", example="https://docs.firecrawl.dev")
|
||||
url: str = Field(..., description="目标网站根URL")
|
||||
|
||||
# --- [POST] 执行任务 ---
|
||||
class TaskExecuteRequest(BaseModel):
|
||||
batch_size: int = Field(10, ge=1, le=50, description="并发线程数/批次大小")
|
||||
|
||||
# --- [GET] 监控数据 ---
|
||||
class TaskStatusData(BaseModel):
|
||||
root_url: str
|
||||
stats: Dict[str, int] = Field(..., description="数据库统计: pending/processing/completed")
|
||||
active_threads: List[str] = Field(..., description="内存实时: 当前正在爬取的URL列表")
|
||||
active_thread_count: int
|
||||
batch_size: int = Field(10, le=50)
|
||||
|
||||
# --- [POST] 搜索 ---
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
task_id: Optional[int] = None
|
||||
return_num: int = Field(5, description="返回结果数量")
|
||||
# 明确支持 None 为全局搜索
|
||||
task_id: Optional[int] = Field(None, description="任务ID,不传则搜全库")
|
||||
limit: int = Field(5, description="返回数量")
|
||||
|
||||
# ... (SearchResultItem 等保持不变) ...
|
||||
class SearchResultItem(BaseModel):
|
||||
task_id: int
|
||||
source_url: str
|
||||
@@ -30,6 +30,11 @@ class CrawlerService:
|
||||
self._active_workers: Dict[int, set] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def get_knowledge_base_list(self):
|
||||
"""获取知识库列表"""
|
||||
return data_service.get_all_tasks()
|
||||
|
||||
|
||||
def _track_start(self, task_id: int, url: str):
|
||||
"""[Internal] 标记某个URL开始处理"""
|
||||
with self._lock:
|
||||
|
||||
@@ -63,6 +63,38 @@ class DataService:
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status=status))
|
||||
|
||||
def get_all_tasks(self):
|
||||
"""
|
||||
[新增] 获取所有已注册的任务(知识库列表)
|
||||
用于前端展示或工作流的路由选择
|
||||
"""
|
||||
with self.db.engine.connect() as conn:
|
||||
# 查询 id, root_url, created_at (如果有的话)
|
||||
# 这里假设 tasks 表里有 id 和 root_url
|
||||
stmt = select(self.db.tasks.c.id, self.db.tasks.c.root_url).order_by(self.db.tasks.c.id)
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
|
||||
# 返回精简列表
|
||||
return [
|
||||
{"task_id": r[0], "root_url": r[1], "name": self._extract_name(r[1])}
|
||||
for r in rows
|
||||
]
|
||||
|
||||
def _extract_name(self, url: str) -> str:
|
||||
"""辅助方法:从 URL 提取一个简短的名字作为 Alias"""
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
domain = urlparse(url).netloc
|
||||
# 比如 docs.firecrawl.dev -> firecrawl
|
||||
parts = domain.split('.')
|
||||
if len(parts) >= 2:
|
||||
return parts[-2]
|
||||
return domain
|
||||
except:
|
||||
return url
|
||||
|
||||
# ... (保持 get_task_monitor_data, save_chunks, search 等方法不变) ...
|
||||
|
||||
def get_task_monitor_data(self, task_id: int):
|
||||
"""[数据库层监控] 获取持久化的任务状态"""
|
||||
with self.db.engine.connect() as conn:
|
||||
|
||||
Reference in New Issue
Block a user