v3接口restful风格,规范化接口;添加mcp服务器;新增log模块
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
import logging
|
||||
|
||||
# 获取当前模块的专用 Logger
|
||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||||
logger = logging.getLogger(__name__)
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
系统配置类
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
from sqlalchemy import create_engine, MetaData, Table
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from .config import settings
|
||||
import logging
|
||||
|
||||
# 获取当前模块的专用 Logger
|
||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||||
logger = logging.getLogger(__name__)
|
||||
class Database:
|
||||
"""
|
||||
数据库单例类
|
||||
@@ -30,9 +34,9 @@ class Database:
|
||||
self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine)
|
||||
self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine)
|
||||
self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine)
|
||||
print("[INFO] Database tables reflected successfully.")
|
||||
logger.info("Database tables reflected successfully.")
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Failed to reflect tables: {e}")
|
||||
logger.error(f"Failed to reflect tables: {e}")
|
||||
|
||||
# 全局数据库实例
|
||||
db = Database()
|
||||
24
backend/core/logger.py
Normal file
24
backend/core/logger.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# backend/core/logger.py
|
||||
import logging
|
||||
import sys
|
||||
|
||||
def setup_logging(level=logging.INFO):
|
||||
"""
|
||||
全局日志配置
|
||||
关键点:强制将日志输出到 sys.stderr,防止污染 sys.stdout 导致 MCP 协议崩溃。
|
||||
"""
|
||||
# 定义日志格式:时间 - 模块名 - 级别 - 内容
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
# 配置根记录器
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format=log_format,
|
||||
handlers=[
|
||||
# 【绝对关键】使用 StreamHandler(sys.stderr)
|
||||
# 这样日志会走标准错误通道,不会干扰 MCP 的标准输出通信
|
||||
logging.StreamHandler(sys.stderr)
|
||||
],
|
||||
# 强制重新配置,防止被第三方库覆盖
|
||||
force=True
|
||||
)
|
||||
@@ -1,13 +1,28 @@
|
||||
from fastapi import FastAPI
|
||||
from backend.routers import v1, v2
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from backend.routers import v3
|
||||
from backend.core.logger import setup_logging
|
||||
|
||||
app = FastAPI(title="Wiki Crawler API")
|
||||
# 程序启动第一件事:初始化日志
|
||||
setup_logging()
|
||||
|
||||
# 挂载路由
|
||||
app.include_router(v1.router)
|
||||
app.include_router(v2.router)
|
||||
|
||||
app = FastAPI(
|
||||
title="Wiki Crawler System V3",
|
||||
version="3.0.0",
|
||||
description="Enterprise-grade RAG Knowledge Base API with Real-time Monitoring"
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(v3.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 提示:运行方式 uv run backend/main.py 或 uvicorn backend.main:app --reload
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
64
backend/mcp_server.py
Normal file
64
backend/mcp_server.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import sys
|
||||
import os
|
||||
import asyncio
|
||||
# 路径兼容
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from backend.core.logger import setup_logging
|
||||
|
||||
# 初始化日志 (写在 FastMCP 初始化之前)
|
||||
setup_logging()
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from backend.services.crawler_service import crawler_service
|
||||
|
||||
mcp = FastMCP("WikiCrawler-V3")
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_add_website(url: str) -> str:
|
||||
"""[Admin] Add a website map task."""
|
||||
try:
|
||||
res = crawler_service.map_site(url)
|
||||
return f"Task Registered. ID: {res['task_id']}, Links Found: {res['count']}"
|
||||
except Exception as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_check_status(task_id: int) -> str:
|
||||
"""[Monitor] Check detailed progress and active threads."""
|
||||
data = crawler_service.get_task_status(task_id)
|
||||
if not data: return "Task not found."
|
||||
|
||||
s = data['stats']
|
||||
threads = data['active_threads']
|
||||
|
||||
report = (
|
||||
f"Progress: {s['completed']}/{s['total']} (Pending: {s['pending']})\n"
|
||||
f"Active Threads: {len(threads)}\n"
|
||||
)
|
||||
if threads:
|
||||
report += "Currently Crawling:\n" + "\n".join([f"- {t}" for t in threads[:5]])
|
||||
return report
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_run_crawler(task_id: int, batch_size: int = 5) -> str:
|
||||
"""[Action] Trigger crawler batch."""
|
||||
# MCP 同步调用以获得反馈
|
||||
res = crawler_service.process_queue_concurrent(task_id, batch_size)
|
||||
return f"Batch Finished. Count: {res.get('count', 0)}"
|
||||
|
||||
@mcp.tool()
|
||||
async def kb_search(query: str, task_id: int = None) -> str:
|
||||
"""[User] Search knowledge base."""
|
||||
res = crawler_service.search(query, task_id, 5)
|
||||
results = res.get('results', [])
|
||||
if not results: return "No results."
|
||||
|
||||
output = []
|
||||
for i, r in enumerate(results):
|
||||
score_display = f"{r['score']:.4f}" + (" (Reranked)" if r.get('reranked') else "")
|
||||
meta = r.get('meta_info', {})
|
||||
path = meta.get('header_path', 'Root')
|
||||
output.append(f"[{i+1}] Score: {score_display}\nPath: {path}\nContent: {r['content'][:200]}...")
|
||||
return "\n\n".join(output)
|
||||
|
||||
if __name__ == "__main__":
|
||||
mcp.run()
|
||||
@@ -1,38 +1,57 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from fastapi import APIRouter, BackgroundTasks, status
|
||||
from backend.services.crawler_service import crawler_service
|
||||
from backend.utils.common import make_response
|
||||
from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||
from backend.services.data_service import data_service
|
||||
from backend.schemas.v3 import (
|
||||
TaskCreateRequest, TaskExecuteRequest, SearchRequest,
|
||||
ResponseBase, TaskStatusData
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/v3", tags=["V3 Service"])
|
||||
router = APIRouter(prefix="/api/v3", tags=["V3 Knowledge Base"])
|
||||
|
||||
@router.post("/add_task")
|
||||
async def add_task(req: AutoMapRequest):
|
||||
@router.post("/tasks", status_code=status.HTTP_201_CREATED, response_model=ResponseBase)
|
||||
async def create_task(req: TaskCreateRequest):
|
||||
"""创建新任务 (Map)"""
|
||||
try:
|
||||
res = crawler_service.map_site(req.url)
|
||||
return make_response(1, res.pop("msg", "Started"), res)
|
||||
return ResponseBase(code=1, msg="Task Created", data=res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
return ResponseBase(code=0, msg=f"Map Failed: {str(e)}")
|
||||
|
||||
@router.post("/process_task")
|
||||
async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks):
|
||||
try:
|
||||
bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size)
|
||||
return make_response(1, "Background processing started", {"task_id": req.task_id})
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
@router.get("/tasks/{task_id}", response_model=ResponseBase)
|
||||
async def get_task_status(task_id: int):
|
||||
"""
|
||||
实时监控:
|
||||
返回数据库持久化状态 + 内存中正在运行的线程
|
||||
"""
|
||||
# 调用 crawler_service 的聚合方法
|
||||
data = crawler_service.get_task_status(task_id)
|
||||
|
||||
if not data:
|
||||
return ResponseBase(code=0, msg="Task not found")
|
||||
|
||||
return ResponseBase(code=1, msg="Success", data=data)
|
||||
|
||||
@router.post("/task_status")
|
||||
async def get_task_status(req):
|
||||
try:
|
||||
res = crawler_service.get_task_status(req.task_id)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
pass
|
||||
@router.post("/search")
|
||||
async def search_smart(req: TextSearchRequest):
|
||||
@router.post("/tasks/{task_id}/run", status_code=status.HTTP_202_ACCEPTED, response_model=ResponseBase)
|
||||
async def run_task(task_id: int, req: TaskExecuteRequest, bg_tasks: BackgroundTasks):
|
||||
"""触发后台多线程爬取"""
|
||||
# 简单检查任务是否存在 (查一下数据库监控数据即可)
|
||||
if not data_service.get_task_monitor_data(task_id):
|
||||
return ResponseBase(code=0, msg="Task not found")
|
||||
|
||||
# 放入后台任务
|
||||
bg_tasks.add_task(crawler_service.process_queue_concurrent, task_id, req.batch_size)
|
||||
|
||||
return ResponseBase(
|
||||
code=1,
|
||||
msg="Background Execution Started",
|
||||
data={"task_id": task_id, "mode": "concurrent_thread_pool"}
|
||||
)
|
||||
|
||||
@router.post("/search", response_model=ResponseBase)
|
||||
async def search_knowledge(req: SearchRequest):
|
||||
"""混合检索 + Rerank"""
|
||||
try:
|
||||
res = crawler_service.search(req.query, req.task_id, req.return_num)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
return ResponseBase(code=1, msg="Search Completed", data=res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
return ResponseBase(code=0, msg=f"Search Failed: {str(e)}")
|
||||
38
backend/schemas/v3.py
Normal file
38
backend/schemas/v3.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
# --- 通用响应 ---
|
||||
class ResponseBase(BaseModel):
|
||||
code: int = Field(..., description="1: 成功, 0: 失败")
|
||||
msg: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
# --- [POST] 创建任务 ---
|
||||
class TaskCreateRequest(BaseModel):
|
||||
url: str = Field(..., description="目标网站根URL", example="https://docs.firecrawl.dev")
|
||||
|
||||
# --- [POST] 执行任务 ---
|
||||
class TaskExecuteRequest(BaseModel):
|
||||
batch_size: int = Field(10, ge=1, le=50, description="并发线程数/批次大小")
|
||||
|
||||
# --- [GET] 监控数据 ---
|
||||
class TaskStatusData(BaseModel):
|
||||
root_url: str
|
||||
stats: Dict[str, int] = Field(..., description="数据库统计: pending/processing/completed")
|
||||
active_threads: List[str] = Field(..., description="内存实时: 当前正在爬取的URL列表")
|
||||
active_thread_count: int
|
||||
|
||||
# --- [POST] 搜索 ---
|
||||
class SearchRequest(BaseModel):
|
||||
query: str
|
||||
task_id: Optional[int] = None
|
||||
return_num: int = Field(5, description="返回结果数量")
|
||||
|
||||
class SearchResultItem(BaseModel):
|
||||
task_id: int
|
||||
source_url: str
|
||||
title: Optional[str] = None
|
||||
content: str
|
||||
score: float
|
||||
meta_info: Dict = {}
|
||||
reranked: Optional[bool] = False
|
||||
@@ -1,182 +1,179 @@
|
||||
# backend/services/crawler_service.py
|
||||
import concurrent.futures
|
||||
import threading
|
||||
from firecrawl import FirecrawlApp
|
||||
from backend.core.config import settings
|
||||
from backend.services.data_service import data_service
|
||||
from backend.services.llm_service import llm_service
|
||||
from backend.utils.text_process import text_processor
|
||||
import logging
|
||||
|
||||
# 获取当前模块的专用 Logger
|
||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||||
logger = logging.getLogger(__name__)
|
||||
class CrawlerService:
|
||||
"""
|
||||
爬虫编排服务
|
||||
协调 Firecrawl, LLM, 和 DataService
|
||||
"""
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
self.max_workers = 5
|
||||
|
||||
# [新增] 内存状态追踪
|
||||
# 结构: { task_id: { url: "status_desc" } }
|
||||
self._active_workers = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _track_start(self, task_id, url):
|
||||
"""开始追踪某个URL"""
|
||||
with self._lock:
|
||||
if task_id not in self._active_workers:
|
||||
self._active_workers[task_id] = set()
|
||||
self._active_workers[task_id].add(url)
|
||||
|
||||
def _track_end(self, task_id, url):
|
||||
"""结束追踪某个URL"""
|
||||
with self._lock:
|
||||
if task_id in self._active_workers:
|
||||
self._active_workers[task_id].discard(url)
|
||||
|
||||
def get_task_status(self, task_id: int):
|
||||
"""
|
||||
[综合监控] 获取全量状态 = 数据库统计 + 实时线程列表
|
||||
"""
|
||||
# 1. 获取数据库层面的统计 (宏观)
|
||||
db_data = data_service.get_task_monitor_data(task_id)
|
||||
if not db_data:
|
||||
return None
|
||||
|
||||
# 2. 获取内存层面的活跃线程 (微观)
|
||||
with self._lock:
|
||||
active_urls = list(self._active_workers.get(task_id, []))
|
||||
# 输出情况
|
||||
logger.info(f"Task {task_id} active threads: {active_urls}")
|
||||
logger.info(f"Task {task_id} stats: {db_data['db_stats']}") # 打印数据库统计信息
|
||||
|
||||
return {
|
||||
"root_url": db_data["root_url"],
|
||||
"stats": db_data["db_stats"], # Pending, Completed, Failed 等
|
||||
"active_threads": active_urls, # 当前 CPU/网络 正在处理的 URL
|
||||
"active_thread_count": len(active_urls)
|
||||
}
|
||||
|
||||
def map_site(self, start_url: str):
|
||||
print(f"[INFO] Mapping: {start_url}")
|
||||
"""阶段1:站点地图扫描"""
|
||||
logger.info(f"Mapping: {start_url}")
|
||||
try:
|
||||
# 1. 注册任务
|
||||
task_res = data_service.register_task(start_url)
|
||||
urls_to_add = [start_url]
|
||||
|
||||
# 2. 无论是否新任务,都尝试把 start_url 加入队列
|
||||
urls_to_add = [start_url]
|
||||
|
||||
# 3. 调用 Firecrawl Map
|
||||
# 如果任务已存在,不再重新 Map,直接返回
|
||||
if not task_res['is_new_task']:
|
||||
logger.info(f"Task {task_res['task_id']} exists, skipping map.")
|
||||
return {
|
||||
"task_id": task_res['task_id'],
|
||||
"count": 0,
|
||||
"is_new": False
|
||||
}
|
||||
|
||||
# 新任务执行 Map
|
||||
try:
|
||||
map_res = self.firecrawl.map(start_url)
|
||||
|
||||
# 兼容不同 SDK 版本的返回 (对象 或 字典)
|
||||
found_links = []
|
||||
if isinstance(map_res, dict):
|
||||
found_links = map_res.get('links', [])
|
||||
else:
|
||||
found_links = getattr(map_res, 'links', [])
|
||||
|
||||
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
|
||||
for link in found_links:
|
||||
# 链接可能是字符串,也可能是对象
|
||||
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
||||
urls_to_add.append(u)
|
||||
|
||||
print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.")
|
||||
logger.info(f"Map found {len(found_links)} links")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}")
|
||||
logger.warning(f"Map failed, proceeding with seed only: {e}")
|
||||
|
||||
# 4. 批量入库
|
||||
if urls_to_add:
|
||||
data_service.add_urls(task_res['task_id'], urls_to_add)
|
||||
|
||||
return {
|
||||
"msg": "Mapped successfully",
|
||||
"count": len(urls_to_add),
|
||||
"task_id": task_res['task_id']
|
||||
"task_id": task_res['task_id'],
|
||||
"count": len(urls_to_add),
|
||||
"is_new": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Map failed: {e}")
|
||||
logger.error(f"Map failed: {e}")
|
||||
raise e
|
||||
|
||||
def process_queue(self, task_id: int, batch_size: int = 5):
|
||||
"""
|
||||
处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储
|
||||
"""
|
||||
pending = data_service.get_pending_urls(task_id, batch_size)
|
||||
urls = pending['urls']
|
||||
if not urls: return {"msg": "Queue empty"}
|
||||
|
||||
processed = 0
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[INFO] Scraping: {url}")
|
||||
|
||||
# 调用 Firecrawl
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url,
|
||||
formats=['markdown'],
|
||||
only_main_content=True
|
||||
)
|
||||
|
||||
# =====================================================
|
||||
# 【核心修复】兼容字典和对象两种返回格式
|
||||
# =====================================================
|
||||
raw_md = ""
|
||||
metadata = None
|
||||
|
||||
# 1. 提取 markdown 和 metadata 本体
|
||||
if isinstance(scrape_res, dict):
|
||||
raw_md = scrape_res.get('markdown', '')
|
||||
metadata = scrape_res.get('metadata', {})
|
||||
else:
|
||||
# 如果是对象 (ScrapeResponse)
|
||||
raw_md = getattr(scrape_res, 'markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {})
|
||||
|
||||
# 2. 从 metadata 中安全提取 title
|
||||
title = url # 默认用 url 当标题
|
||||
if metadata:
|
||||
if isinstance(metadata, dict):
|
||||
# 如果 metadata 是字典
|
||||
title = metadata.get('title', url)
|
||||
else:
|
||||
# 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方
|
||||
title = getattr(metadata, 'title', url)
|
||||
|
||||
# =====================================================
|
||||
|
||||
if not raw_md:
|
||||
print(f"[WARN] Empty content for {url}")
|
||||
continue
|
||||
|
||||
# 1. 清洗
|
||||
clean_md = text_processor.clean_markdown(raw_md)
|
||||
if not clean_md: continue
|
||||
|
||||
# 2. 切分
|
||||
chunks = text_processor.split_markdown(clean_md)
|
||||
|
||||
chunks_data = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
content = chunk['content']
|
||||
headers = chunk['metadata']
|
||||
header_path = " > ".join(headers.values())
|
||||
|
||||
# 3. 向量化 (Title + Header + Content)
|
||||
emb_input = f"{title}\n{header_path}\n{content}"
|
||||
vector = llm_service.get_embedding(emb_input)
|
||||
|
||||
if vector:
|
||||
chunks_data.append({
|
||||
"index": i,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
"meta_info": {"header_path": header_path, "headers": headers}
|
||||
})
|
||||
|
||||
# 4. 存储
|
||||
if chunks_data:
|
||||
data_service.save_chunks(task_id, url, title, chunks_data)
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
# 打印详细错误方便调试
|
||||
print(f"[ERROR] Process URL {url} failed: {e}")
|
||||
def _process_single_url(self, task_id: int, url: str):
|
||||
"""[Worker] 单个 URL 处理线程"""
|
||||
# 1. 内存标记:开始
|
||||
self._track_start(task_id, url)
|
||||
logger.info(f"[THREAD START] {url}")
|
||||
|
||||
return {"msg": "Batch processed", "count": processed}
|
||||
try:
|
||||
# 2. 爬取
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url, formats=['markdown'], only_main_content=True
|
||||
)
|
||||
|
||||
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
|
||||
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
|
||||
|
||||
if not raw_md:
|
||||
data_service.mark_url_status(task_id, url, 'failed')
|
||||
return
|
||||
|
||||
# 3. 清洗 & 切分
|
||||
clean_md = text_processor.clean_markdown(raw_md)
|
||||
chunks = text_processor.split_markdown(clean_md)
|
||||
|
||||
chunks_data = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
headers = chunk['metadata']
|
||||
path = " > ".join(headers.values())
|
||||
emb_input = f"{title}\n{path}\n{chunk['content']}"
|
||||
vector = llm_service.get_embedding(emb_input)
|
||||
if vector:
|
||||
chunks_data.append({
|
||||
"index": i, "content": chunk['content'], "embedding": vector,
|
||||
"meta_info": {"header_path": path, "headers": headers}
|
||||
})
|
||||
|
||||
# 4. 入库 (会自动标记 completed)
|
||||
if chunks_data:
|
||||
data_service.save_chunks(task_id, url, title, chunks_data)
|
||||
else:
|
||||
data_service.mark_url_status(task_id, url, 'failed')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[THREAD ERROR] {url}: {e}")
|
||||
data_service.mark_url_status(task_id, url, 'failed')
|
||||
|
||||
finally:
|
||||
# 5. 内存标记:结束 (无论成功失败都要移除)
|
||||
self._track_end(task_id, url)
|
||||
|
||||
def process_queue_concurrent(self, task_id: int, batch_size: int = 10):
|
||||
"""阶段2:多线程并发处理"""
|
||||
urls = data_service.get_pending_urls(task_id, limit=batch_size)
|
||||
if not urls: return {"msg": "No pending urls"}
|
||||
|
||||
logger.info(f"Batch started: {len(urls)} urls")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交任务到线程池
|
||||
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
|
||||
# 等待完成
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
return {"msg": "Batch completed", "count": len(urls)}
|
||||
|
||||
def search(self, query: str, task_id, return_num: int):
|
||||
"""
|
||||
全链路搜索:向量生成 -> 混合检索(粗排) -> 重排序(精排)
|
||||
"""
|
||||
# 1. 生成查询向量
|
||||
"""阶段3:搜索"""
|
||||
vector = llm_service.get_embedding(query)
|
||||
if not vector: return {"msg": "Embedding failed", "results": []}
|
||||
|
||||
# 2. 计算粗排召回数量
|
||||
# 逻辑:至少召回 50 个,如果用户要很多,则召回 10 倍
|
||||
coarse_limit = return_num * 10 if return_num * 10 > settings.CANDIDATE_NUM else settings.CANDIDATE_NUM
|
||||
coarse_limit = min(return_num * 10, 100)
|
||||
coarse_limit = max(coarse_limit, 50)
|
||||
|
||||
# 3. 执行混合检索 (粗排)
|
||||
coarse_results = data_service.search(
|
||||
query_text=query,
|
||||
query_vector=vector,
|
||||
task_id=task_id,
|
||||
candidates_num=coarse_limit # 使用计算出的粗排数量
|
||||
)
|
||||
|
||||
candidates = coarse_results.get('results', [])
|
||||
|
||||
if not candidates:
|
||||
return {"msg": "No documents found", "results": []}
|
||||
coarse_res = data_service.search(query, vector, task_id, coarse_limit)
|
||||
candidates = coarse_res.get('results', [])
|
||||
|
||||
# 4. 执行重排序 (精排)
|
||||
final_results = llm_service.rerank(
|
||||
query=query,
|
||||
documents=candidates,
|
||||
top_n=return_num # 最终返回用户需要的数量
|
||||
)
|
||||
if not candidates: return {"results": []}
|
||||
|
||||
final_res = llm_service.rerank(query, candidates, return_num)
|
||||
return {"results": final_res}
|
||||
|
||||
return {
|
||||
"results": final_results,
|
||||
"msg": f"Reranked {len(final_results)} from {len(candidates)} candidates"
|
||||
}
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
@@ -1,11 +1,14 @@
|
||||
from sqlalchemy import select, insert, update, and_, text, func, desc
|
||||
from backend.core.database import db
|
||||
from backend.utils.common import normalize_url
|
||||
import logging
|
||||
|
||||
# 获取当前模块的专用 Logger
|
||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||||
logger = logging.getLogger(__name__)
|
||||
class DataService:
|
||||
"""
|
||||
数据持久化服务层
|
||||
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
|
||||
"""
|
||||
def __init__(self):
|
||||
self.db = db
|
||||
@@ -17,11 +20,11 @@ class DataService:
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
|
||||
return {"task_id": existing[0], "is_new_task": False}
|
||||
else:
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
|
||||
return {"task_id": new_task[0], "is_new_task": True}
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
success_urls = []
|
||||
@@ -29,7 +32,7 @@ class DataService:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
check_q = select(self.db.queue).where(
|
||||
check_q = select(self.db.queue.c.id).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if not conn.execute(check_q).fetchone():
|
||||
@@ -41,115 +44,111 @@ class DataService:
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
# 原子锁定:获取并标记为 processing
|
||||
subquery = select(self.db.queue.c.id).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
).limit(limit).with_for_update(skip_locked=True)
|
||||
|
||||
if urls:
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing'))
|
||||
|
||||
return {"urls": urls, "msg": "Fetched pending urls"}
|
||||
|
||||
stmt = update(self.db.queue).where(
|
||||
self.db.queue.c.id.in_(subquery)
|
||||
).values(status='processing').returning(self.db.queue.c.url)
|
||||
|
||||
result = conn.execute(stmt).fetchall()
|
||||
return [r[0] for r in result]
|
||||
|
||||
def mark_url_status(self, task_id: int, url: str, status: str):
|
||||
clean_url = normalize_url(url)
|
||||
with self.db.engine.begin() as conn:
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status=status))
|
||||
|
||||
def get_task_monitor_data(self, task_id: int):
|
||||
"""[数据库层监控] 获取持久化的任务状态"""
|
||||
with self.db.engine.connect() as conn:
|
||||
# 1. 检查任务是否存在
|
||||
task_exists = conn.execute(select(self.db.tasks.c.root_url).where(self.db.tasks.c.id == task_id)).fetchone()
|
||||
if not task_exists:
|
||||
return None
|
||||
|
||||
# 2. 统计各状态数量
|
||||
stats_rows = conn.execute(select(
|
||||
self.db.queue.c.status, func.count(self.db.queue.c.id)
|
||||
).where(self.db.queue.c.task_id == task_id).group_by(self.db.queue.c.status)).fetchall()
|
||||
|
||||
stats = {"pending": 0, "processing": 0, "completed": 0, "failed": 0}
|
||||
for status, count in stats_rows:
|
||||
if status in stats: stats[status] = count
|
||||
stats["total"] = sum(stats.values())
|
||||
|
||||
return {
|
||||
"root_url": task_exists[0],
|
||||
"db_stats": stats
|
||||
}
|
||||
|
||||
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
|
||||
"""
|
||||
保存切片数据 (Phase 1.5: 支持 meta_info)
|
||||
"""
|
||||
clean_url = normalize_url(source_url)
|
||||
count = 0
|
||||
with self.db.engine.begin() as conn:
|
||||
for item in chunks_data:
|
||||
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
|
||||
idx = item['index']
|
||||
meta = item.get('meta_info', {})
|
||||
|
||||
# 检查是否存在
|
||||
existing = conn.execute(select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == clean_url,
|
||||
self.db.chunks.c.chunk_index == idx)
|
||||
)).fetchone()
|
||||
|
||||
values = {
|
||||
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
|
||||
"title": title, "content": item['content'], "embedding": item['embedding'],
|
||||
"meta_info": meta
|
||||
}
|
||||
|
||||
if existing:
|
||||
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
|
||||
else:
|
||||
conn.execute(insert(self.db.chunks).values(**values))
|
||||
count += 1
|
||||
|
||||
# 标记队列完成
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status='completed'))
|
||||
|
||||
return {"msg": f"Saved {count} chunks", "count": count}
|
||||
|
||||
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 5):
|
||||
"""
|
||||
Phase 2: 混合检索 (Hybrid Search)
|
||||
"""
|
||||
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50):
|
||||
# 向量格式清洗
|
||||
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
|
||||
if query_vector and isinstance(query_vector, list) and len(query_vector) > 0:
|
||||
if isinstance(query_vector[0], list): query_vector = query_vector[0]
|
||||
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
|
||||
query_vector = query_vector[0]
|
||||
|
||||
results = []
|
||||
with self.db.engine.connect() as conn:
|
||||
keyword_query = func.websearch_to_tsquery('english', query_text) # 转换为 tsquery
|
||||
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))# 计算向量相似度
|
||||
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query) # 计算关键词相似度
|
||||
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")# 计算最终分数
|
||||
keyword_query = func.websearch_to_tsquery('english', query_text)
|
||||
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
|
||||
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
|
||||
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
|
||||
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id,
|
||||
self.db.chunks.c.source_url,
|
||||
self.db.chunks.c.title,
|
||||
self.db.chunks.c.content,
|
||||
self.db.chunks.c.meta_info,
|
||||
final_score
|
||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
|
||||
)
|
||||
|
||||
if task_id:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
# 使用 candidates_num 控制召回数量
|
||||
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
stmt = stmt.order_by(desc("score")).limit(candidates_num)
|
||||
|
||||
try:
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
results = [
|
||||
{
|
||||
"task_id": r[0],
|
||||
"source_url": r[1],
|
||||
"title": r[2],
|
||||
"content": r[3],
|
||||
"meta_info": r[4],
|
||||
"score": float(r[5])
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Hybrid search failed: {e}")
|
||||
logger.error(f"Hybrid search failed: {e}")
|
||||
return self._fallback_vector_search(query_vector, task_id, candidates_num)
|
||||
|
||||
return {"results": results, "msg": f"Hybrid found {len(results)}"}
|
||||
|
||||
def _fallback_vector_search(self, vector, task_id, limit):
|
||||
print("[WARN] Fallback to pure vector search")
|
||||
logger.warning("Fallback to pure vector search")
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||
self.db.chunks.c.content, self.db.chunks.c.meta_info
|
||||
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
|
||||
if task_id:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
|
||||
return {"results": [{"content": r[3], "meta_info": r[4], "score": 0.0} for r in rows], "msg": "Fallback found"}
|
||||
|
||||
data_service = DataService()
|
||||
@@ -1,7 +1,11 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from backend.core.config import settings
|
||||
import logging
|
||||
|
||||
# 获取当前模块的专用 Logger
|
||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||||
logger = logging.getLogger(__name__)
|
||||
class LLMService:
|
||||
"""
|
||||
LLM 服务封装层
|
||||
@@ -21,10 +25,10 @@ class LLMService:
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"[ERROR] Embedding API Error: {resp}")
|
||||
logger.error(f"Embedding API Error: {resp}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Embedding Exception: {e}")
|
||||
logger.error(f"Embedding Exception: {e}")
|
||||
return None
|
||||
|
||||
def rerank(self, query: str, documents: list, top_n: int = 5):
|
||||
@@ -81,12 +85,12 @@ class LLMService:
|
||||
|
||||
return reranked_results
|
||||
else:
|
||||
print(f"[ERROR] Rerank API Error: {resp}")
|
||||
logger.error(f"Rerank API Error: {resp}")
|
||||
# 降级策略:如果 Rerank 挂了,直接返回粗排的前 N 个
|
||||
return documents[:top_n]
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Rerank Exception: {e}")
|
||||
logger.error(f"Rerank Exception: {e}")
|
||||
# 降级策略
|
||||
return documents[:top_n]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user