新增获取全部知识库的接口,api版本号重归到v1

This commit is contained in:
2026-01-20 02:47:03 +08:00
parent 860ada3334
commit 155974572c
10 changed files with 130 additions and 184 deletions

View File

@@ -1,16 +1,12 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from backend.routers import v3
from backend.core.logger import setup_logging
# 程序启动第一件事:初始化日志
setup_logging()
# 引入新路由
from backend.routers import v1
app = FastAPI(
title="Wiki Crawler System V3",
version="3.0.0",
description="Enterprise-grade RAG Knowledge Base API with Real-time Monitoring"
title="Wiki Crawler API",
version="1.0.0", # 版本号回归
description="RAG Knowledge Base Service"
)
app.add_middleware(
@@ -21,7 +17,8 @@ app.add_middleware(
allow_headers=["*"],
)
app.include_router(v3.router)
# 挂载 V1
app.include_router(v1.router)
if __name__ == "__main__":
import uvicorn

View File

@@ -1,30 +1,67 @@
from fastapi import APIRouter
# backend/routers/v1.py
from fastapi import APIRouter, BackgroundTasks, status
from backend.services.crawler_service import crawler_service
from backend.services.data_service import data_service
from backend.schemas.v1 import (
TaskCreateRequest, TaskExecuteRequest, SearchRequest,
ResponseBase, KnowledgeBaseListResponse
)
from backend.utils.common import make_response
from backend.schemas.schemas import RegisterRequest, AddUrlsRequest, PendingRequest, SearchRequest
router = APIRouter(prefix="/api/v1", tags=["V1 Manual"])
# 【改动】前缀变更为 v1
router = APIRouter(prefix="/api/v1", tags=["Knowledge Base API"])
@router.post("/register")
async def register(req: RegisterRequest):
# =======================================================
# 1. 获取知识库列表 (核心新功能)
# =======================================================
@router.get("/knowledge-bases", response_model=ResponseBase)
async def list_knowledge_bases():
"""
列出所有已存在的知识库(Task)。
工作流可以用这个接口获取 task_id 列表,让 LLM 选择查哪个库。
"""
kb_list = crawler_service.get_knowledge_base_list()
return ResponseBase(
code=1,
msg="Success",
data={"total": len(kb_list), "list": kb_list}
)
# =======================================================
# 2. 任务管理
# =======================================================
@router.post("/tasks", status_code=status.HTTP_201_CREATED, response_model=ResponseBase)
async def create_task(req: TaskCreateRequest):
try:
res = data_service.register_task(req.url)
return make_response(1, res.pop("msg", "Success"), res)
res = crawler_service.map_site(req.url)
return ResponseBase(code=1, msg="Task Created", data=res)
except Exception as e:
return make_response(0, str(e))
return ResponseBase(code=0, msg=str(e))
@router.post("/add_urls")
async def add_urls(req: AddUrlsRequest):
try:
res = data_service.add_urls(req.task_id, req.urls_obj["urls"])
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router.get("/tasks/{task_id}", response_model=ResponseBase)
async def get_task_status(task_id: int):
data = crawler_service.get_task_status(task_id)
if not data:
return ResponseBase(code=0, msg="Task not found")
return ResponseBase(code=1, msg="Success", data=data)
@router.post("/search")
async def search_manual(req: SearchRequest):
@router.post("/tasks/{task_id}/run", status_code=status.HTTP_202_ACCEPTED, response_model=ResponseBase)
async def run_task(task_id: int, req: TaskExecuteRequest, bg_tasks: BackgroundTasks):
# 简单检查
if not data_service.get_task_monitor_data(task_id):
return ResponseBase(code=0, msg="Task not found")
bg_tasks.add_task(crawler_service.process_queue_concurrent, task_id, req.batch_size)
return ResponseBase(code=1, msg="Execution Started", data={"task_id": task_id})
# =======================================================
# 3. 搜索
# =======================================================
@router.post("/search", response_model=ResponseBase)
async def search_knowledge(req: SearchRequest):
try:
res = data_service.search(req.query_embedding['vector'], req.task_id, req.limit)
return make_response(1, res.pop("msg", "Success"), res)
# req.limit 映射到 return_num
res = crawler_service.search(req.query, req.task_id, req.limit)
return ResponseBase(code=1, msg="Search Completed", data=res)
except Exception as e:
return make_response(0, str(e))
return ResponseBase(code=0, msg=str(e))

View File

@@ -1,30 +0,0 @@
from fastapi import APIRouter, BackgroundTasks
from backend.services.crawler_service import crawler_service
from backend.utils.common import make_response
from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest
router = APIRouter(prefix="/api/v2", tags=["V2 Automated"])
@router.post("/crawler/map")
async def auto_map(req: AutoMapRequest):
try:
res = crawler_service.map_site(req.url)
return make_response(1, res.pop("msg", "Started"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/crawler/process")
async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks):
try:
bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size)
return make_response(1, "Background processing started", {"task_id": req.task_id})
except Exception as e:
return make_response(0, str(e))
@router.post("/search")
async def search_smart(req: TextSearchRequest):
try:
res = crawler_service.search(req.query, req.task_id, req.return_num)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))

View File

@@ -1,57 +0,0 @@
from fastapi import APIRouter, BackgroundTasks, status
from backend.services.crawler_service import crawler_service
from backend.services.data_service import data_service
from backend.schemas.v3 import (
TaskCreateRequest, TaskExecuteRequest, SearchRequest,
ResponseBase, TaskStatusData
)
router = APIRouter(prefix="/api/v3", tags=["V3 Knowledge Base"])
@router.post("/tasks", status_code=status.HTTP_201_CREATED, response_model=ResponseBase)
async def create_task(req: TaskCreateRequest):
"""创建新任务 (Map)"""
try:
res = crawler_service.map_site(req.url)
return ResponseBase(code=1, msg="Task Created", data=res)
except Exception as e:
return ResponseBase(code=0, msg=f"Map Failed: {str(e)}")
@router.get("/tasks/{task_id}", response_model=ResponseBase)
async def get_task_status(task_id: int):
"""
实时监控:
返回数据库持久化状态 + 内存中正在运行的线程
"""
# 调用 crawler_service 的聚合方法
data = crawler_service.get_task_status(task_id)
if not data:
return ResponseBase(code=0, msg="Task not found")
return ResponseBase(code=1, msg="Success", data=data)
@router.post("/tasks/{task_id}/run", status_code=status.HTTP_202_ACCEPTED, response_model=ResponseBase)
async def run_task(task_id: int, req: TaskExecuteRequest, bg_tasks: BackgroundTasks):
"""触发后台多线程爬取"""
# 简单检查任务是否存在 (查一下数据库监控数据即可)
if not data_service.get_task_monitor_data(task_id):
return ResponseBase(code=0, msg="Task not found")
# 放入后台任务
bg_tasks.add_task(crawler_service.process_queue_concurrent, task_id, req.batch_size)
return ResponseBase(
code=1,
msg="Background Execution Started",
data={"task_id": task_id, "mode": "concurrent_thread_pool"}
)
@router.post("/search", response_model=ResponseBase)
async def search_knowledge(req: SearchRequest):
"""混合检索 + Rerank"""
try:
res = crawler_service.search(req.query, req.task_id, req.return_num)
return ResponseBase(code=1, msg="Search Completed", data=res)
except Exception as e:
return ResponseBase(code=0, msg=f"Search Failed: {str(e)}")

View File

@@ -1,44 +0,0 @@
from pydantic import BaseModel
from typing import Optional, List, Any
class RegisterRequest(BaseModel):
url: str
class PendingRequest(BaseModel):
task_id: int
limit: Optional[int] = 10
class AddUrlsRequest(BaseModel):
task_id: int
urls_obj: dict
class CrawlResult(BaseModel):
source_url: str
chunk_index: int # 新增字段
title: Optional[str] = None
content: Optional[str] = None
embedding: Optional[List[float]] = None
class SaveResultsRequest(BaseModel):
task_id: int
results: List[CrawlResult]
class SearchRequest(BaseModel):
# 如果不传 task_id则进行全库搜索
task_id: Optional[int] = None
query_embedding: dict
limit: Optional[int] = 5
# === V2 New Schemas ===
class AutoMapRequest(BaseModel):
url: str
class AutoProcessRequest(BaseModel):
task_id: int
batch_size: Optional[int] = 5
class TextSearchRequest(BaseModel):
query: str # 用户直接传文字,不需要传向量了
task_id: Optional[int] = None
return_num: Optional[int] = 5

View File

@@ -7,27 +7,32 @@ class ResponseBase(BaseModel):
msg: str
data: Optional[Any] = None
# --- [GET] 知识库列表 (新功能) ---
class KnowledgeBaseItem(BaseModel):
task_id: int
root_url: str
name: str # 提取出的简短名称,方便 LLM 识别
class KnowledgeBaseListResponse(BaseModel):
total: int
list: List[KnowledgeBaseItem]
# --- [POST] 创建任务 ---
class TaskCreateRequest(BaseModel):
url: str = Field(..., description="目标网站根URL", example="https://docs.firecrawl.dev")
url: str = Field(..., description="目标网站根URL")
# --- [POST] 执行任务 ---
class TaskExecuteRequest(BaseModel):
batch_size: int = Field(10, ge=1, le=50, description="并发线程数/批次大小")
# --- [GET] 监控数据 ---
class TaskStatusData(BaseModel):
root_url: str
stats: Dict[str, int] = Field(..., description="数据库统计: pending/processing/completed")
active_threads: List[str] = Field(..., description="内存实时: 当前正在爬取的URL列表")
active_thread_count: int
batch_size: int = Field(10, le=50)
# --- [POST] 搜索 ---
class SearchRequest(BaseModel):
query: str
task_id: Optional[int] = None
return_num: int = Field(5, description="返回结果数量")
# 明确支持 None 为全局搜索
task_id: Optional[int] = Field(None, description="任务ID不传则搜全库")
limit: int = Field(5, description="返回数量")
# ... (SearchResultItem 等保持不变) ...
class SearchResultItem(BaseModel):
task_id: int
source_url: str

View File

@@ -30,6 +30,11 @@ class CrawlerService:
self._active_workers: Dict[int, set] = {}
self._lock = threading.Lock()
def get_knowledge_base_list(self):
"""获取知识库列表"""
return data_service.get_all_tasks()
def _track_start(self, task_id: int, url: str):
"""[Internal] 标记某个URL开始处理"""
with self._lock:

View File

@@ -63,6 +63,38 @@ class DataService:
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status=status))
def get_all_tasks(self):
"""
[新增] 获取所有已注册的任务(知识库列表)
用于前端展示或工作流的路由选择
"""
with self.db.engine.connect() as conn:
# 查询 id, root_url, created_at (如果有的话)
# 这里假设 tasks 表里有 id 和 root_url
stmt = select(self.db.tasks.c.id, self.db.tasks.c.root_url).order_by(self.db.tasks.c.id)
rows = conn.execute(stmt).fetchall()
# 返回精简列表
return [
{"task_id": r[0], "root_url": r[1], "name": self._extract_name(r[1])}
for r in rows
]
def _extract_name(self, url: str) -> str:
"""辅助方法:从 URL 提取一个简短的名字作为 Alias"""
try:
from urllib.parse import urlparse
domain = urlparse(url).netloc
# 比如 docs.firecrawl.dev -> firecrawl
parts = domain.split('.')
if len(parts) >= 2:
return parts[-2]
return domain
except:
return url
# ... (保持 get_task_monitor_data, save_chunks, search 等方法不变) ...
def get_task_monitor_data(self, task_id: int):
"""[数据库层监控] 获取持久化的任务状态"""
with self.db.engine.connect() as conn: