diff --git a/README.md b/README.md index 125df1b..e4c1fa0 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,15 @@ 完成wiki网页爬取和向量化与知识库查找 +mcp调试命令 + +```bash +npx @modelcontextprotocol/inspector uv run backend/mcp_server.py +``` + +需要nodejs环境 +打开页面后在Environment Variables中添加`PYTHONIOENCODING = utf-8`来防止编码问题(视具体情况而定,如果可以正常运行,也可以不加) + ## 当前状况 1. chunk分段逻辑:根据返回的markdown进行分割,按照#、##进行标题的分类,增加JSONB格式字段meta_info,有下面两个字段,分别可以用于数据库查询和LLM上下文认知资料来源 diff --git a/backend/core/config.py b/backend/core/config.py index ef19971..310ae34 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -1,27 +1,26 @@ +import os +from typing import ClassVar # <--- 1. 导入这个 from pydantic_settings import BaseSettings, SettingsConfigDict -import logging -# 获取当前模块的专用 Logger -# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径 -logger = logging.getLogger(__name__) class Settings(BaseSettings): - """ - 系统配置类 - 自动读取环境变量或 .env 文件 - """ - CANDIDATE_NUM: int = 10 - DB_USER: str DB_PASS: str DB_HOST: str DB_PORT: str = "5432" DB_NAME: str - DASHSCOPE_API_KEY: str FIRECRAWL_API_KEY: str + + CANDIDATE_NUM: int = 50 - # 配置:忽略多余的环境变量,指定编码 - model_config = SettingsConfigDict(env_file=".env", extra="ignore", env_file_encoding='utf-8') + # ========================================================= + # 【核心修复】加上 ClassVar 类型注解 + # ========================================================= + BASE_DIR: ClassVar[str] = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ENV_PATH: ClassVar[str] = os.path.join(BASE_DIR, ".env") + + # 使用绝对路径加载 + model_config = SettingsConfigDict(env_file=ENV_PATH, extra="ignore") @property def DATABASE_URL(self) -> str: diff --git a/backend/mcp_server.py b/backend/mcp_server.py index ea92425..3a5beb5 100644 --- a/backend/mcp_server.py +++ b/backend/mcp_server.py @@ -1,64 +1,145 @@ import sys import os -import asyncio -# 路径兼容 +import logging +from typing import Optional # 确保引入 Optional +import threading +# 1. 路径兼容 (确保能找到 backend 包) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from backend.core.logger import setup_logging -# 初始化日志 (写在 FastMCP 初始化之前) -setup_logging() from mcp.server.fastmcp import FastMCP +from backend.core.logger import setup_logging from backend.services.crawler_service import crawler_service +# 2. 初始化日志 (必须走 stderr) +setup_logging() +logger = logging.getLogger("mcp_server") + +# 3. 初始化 MCP 服务 mcp = FastMCP("WikiCrawler-V3") @mcp.tool() async def kb_add_website(url: str) -> str: - """[Admin] Add a website map task.""" + """ + [Admin] Input a URL to map and register a task. + This is the first step to add a knowledge base. + + Args: + url: The root URL of the website (e.g., https://docs.firecrawl.dev). + + Returns: + Task ID and count of found links. + """ try: res = crawler_service.map_site(url) - return f"Task Registered. ID: {res['task_id']}, Links Found: {res['count']}" + return f"Task Registered. ID: {res['task_id']}, Links Found: {res['count']}, Is New: {res['is_new']}" except Exception as e: + logger.error(f"Add website failed: {e}", exc_info=True) return f"Error: {e}" @mcp.tool() async def kb_check_status(task_id: int) -> str: - """[Monitor] Check detailed progress and active threads.""" + """ + [Monitor] Check detailed progress and active threads. + Use this to see if the crawler is still running or finished. + + Args: + task_id: The ID of the task to check. + + Returns: + A formatted report including progress stats and currently crawling URLs. + """ data = crawler_service.get_task_status(task_id) if not data: return "Task not found." s = data['stats'] threads = data['active_threads'] + # 格式化输出给 LLM 阅读 report = ( + f"--- Task {task_id} Status ---\n" + f"Root URL: {data['root_url']}\n" f"Progress: {s['completed']}/{s['total']} (Pending: {s['pending']})\n" - f"Active Threads: {len(threads)}\n" + f"Active Threads (Running): {len(threads)}\n" ) + if threads: report += "Currently Crawling:\n" + "\n".join([f"- {t}" for t in threads[:5]]) + if len(threads) > 5: + report += f"\n... and {len(threads)-5} more." + return report @mcp.tool() -async def kb_run_crawler(task_id: int, batch_size: int = 5) -> str: - """[Action] Trigger crawler batch.""" - # MCP 同步调用以获得反馈 - res = crawler_service.process_queue_concurrent(task_id, batch_size) - return f"Batch Finished. Count: {res.get('count', 0)}" +async def kb_run_crawler(task_id: int, batch_size: int = 20) -> str: + """ + [Action] Trigger the crawler in BACKGROUND mode. + This returns immediately, so you can use 'kb_check_status' to monitor progress. + + Args: + task_id: The ID of the task. + batch_size: Number of URLs to process (suggest 10-20). + + Returns: + Status message confirming start. + """ + # 定义一个在后台跑的包装函数 + def background_task(): + try: + logger.info(f"Background batch started for Task {task_id}") + # 这里是阻塞操作,但它现在跑在独立线程里 + crawler_service.process_queue_concurrent(task_id, batch_size) + logger.info(f"Background batch finished for Task {task_id}") + except Exception as e: + logger.error(f"Background task failed: {e}", exc_info=True) + + # 2. 创建并启动线程 + thread = threading.Thread(target=background_task) + thread.daemon = True # 设置为守护线程,防止主程序退出时卡死 + thread.start() + + # 3. 立即返回,不等待爬取结束 + return f"🚀 Background crawler started for Task {task_id} (Batch Size: {batch_size}). You can now check status." + @mcp.tool() -async def kb_search(query: str, task_id: int = None) -> str: - """[User] Search knowledge base.""" - res = crawler_service.search(query, task_id, 5) - results = res.get('results', []) - if not results: return "No results." +async def kb_search(query: str, task_id: Optional[int] = None, limit: int = 5) -> str: + """ + [User] Search knowledge base with Hybrid Search & Rerank. - output = [] - for i, r in enumerate(results): - score_display = f"{r['score']:.4f}" + (" (Reranked)" if r.get('reranked') else "") - meta = r.get('meta_info', {}) - path = meta.get('header_path', 'Root') - output.append(f"[{i+1}] Score: {score_display}\nPath: {path}\nContent: {r['content'][:200]}...") - return "\n\n".join(output) + Args: + query: The user's question or search keywords. + task_id: (Optional) Limit search to a specific task ID. + limit: (Optional) Number of results to return (default 5). + + Returns: + Ranked content blocks with source paths. + """ + try: + res = crawler_service.search(query, task_id, limit) + results = res.get('results', []) + + if not results: return "No results found." + + output = [] + for i, r in enumerate(results): + score_display = f"{r['score']:.4f}" + (" (Reranked)" if r.get('reranked') else "") + meta = r.get('meta_info', {}) + path = meta.get('header_path', 'Root') + + # 格式化单个结果块 + block = ( + f"[{i+1}] Score: {score_display}\n" + f"Path: {path}\n" + f"Content: {r['content'][:300]}..." # 限制长度防止 Context 溢出 + ) + output.append(block) + + return "\n\n".join(output) + + except Exception as e: + logger.error(f"Search failed: {e}", exc_info=True) + return f"Search Error: {e}" if __name__ == "__main__": + # 启动 MCP 服务 mcp.run() \ No newline at end of file diff --git a/backend/services/crawler_service.py b/backend/services/crawler_service.py index 49ed0c6..2507fd0 100644 --- a/backend/services/crawler_service.py +++ b/backend/services/crawler_service.py @@ -1,41 +1,64 @@ import concurrent.futures import threading +import logging +from typing import Dict, Any, List, Optional, Union + from firecrawl import FirecrawlApp from backend.core.config import settings from backend.services.data_service import data_service from backend.services.llm_service import llm_service from backend.utils.text_process import text_processor -import logging # 获取当前模块的专用 Logger -# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径 logger = logging.getLogger(__name__) + class CrawlerService: + """ + 爬虫业务服务层 (Crawler Service) + + 职责: + 1. 协调外部 API (Firecrawl) 和内部服务 (DataService, LLMService)。 + 2. 管理多线程爬取任务及其状态。 + 3. 提供统一的搜索入口 (混合检索 + Rerank)。 + """ + def __init__(self): self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY) - self.max_workers = 5 + self.max_workers = 5 # 线程池最大并发数 - # [新增] 内存状态追踪 - # 结构: { task_id: { url: "status_desc" } } - self._active_workers = {} + # 内存状态追踪: { task_id: set([url1, url2]) } + self._active_workers: Dict[int, set] = {} self._lock = threading.Lock() - def _track_start(self, task_id, url): - """开始追踪某个URL""" + def _track_start(self, task_id: int, url: str): + """[Internal] 标记某个URL开始处理""" with self._lock: if task_id not in self._active_workers: self._active_workers[task_id] = set() self._active_workers[task_id].add(url) - def _track_end(self, task_id, url): - """结束追踪某个URL""" + def _track_end(self, task_id: int, url: str): + """[Internal] 标记某个URL处理结束""" with self._lock: if task_id in self._active_workers: self._active_workers[task_id].discard(url) - def get_task_status(self, task_id: int): + def get_task_status(self, task_id: int) -> Optional[Dict[str, Any]]: """ - [综合监控] 获取全量状态 = 数据库统计 + 实时线程列表 + 获取任务的实时综合状态。 + + Args: + task_id (int): 任务 ID + + Returns: + dict: 包含数据库统计和实时线程信息的字典。如果任务不存在返回 None。 + 结构示例: + { + "root_url": "https://example.com", + "stats": {"pending": 10, "processing": 2, "completed": 5, "failed": 0}, + "active_threads": ["https://example.com/page1"], + "active_thread_count": 1 + } """ # 1. 获取数据库层面的统计 (宏观) db_data = data_service.get_task_monitor_data(task_id) @@ -45,19 +68,33 @@ class CrawlerService: # 2. 获取内存层面的活跃线程 (微观) with self._lock: active_urls = list(self._active_workers.get(task_id, [])) - # 输出情况 + + # 日志输出当前状态 logger.info(f"Task {task_id} active threads: {active_urls}") - logger.info(f"Task {task_id} stats: {db_data['db_stats']}") # 打印数据库统计信息 + logger.info(f"Task {task_id} stats: {db_data['db_stats']}") return { "root_url": db_data["root_url"], - "stats": db_data["db_stats"], # Pending, Completed, Failed 等 - "active_threads": active_urls, # 当前 CPU/网络 正在处理的 URL + "stats": db_data["db_stats"], + "active_threads": active_urls, "active_thread_count": len(active_urls) } - def map_site(self, start_url: str): - """阶段1:站点地图扫描""" + def map_site(self, start_url: str) -> Dict[str, Any]: + """ + 第一阶段:站点地图扫描 (Map) + + Args: + start_url (str): 目标网站的根 URL + + Returns: + dict: 包含任务 ID 和发现链接数的字典。 + { + "task_id": 123, + "count": 50, + "is_new": True + } + """ logger.info(f"Mapping: {start_url}") try: task_res = data_service.register_task(start_url) @@ -75,7 +112,9 @@ class CrawlerService: # 新任务执行 Map try: map_res = self.firecrawl.map(start_url) + # 兼容不同版本的 SDK 返回结构 found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', []) + for link in found_links: u = link if isinstance(link, str) else getattr(link, 'url', str(link)) urls_to_add.append(u) @@ -96,7 +135,7 @@ class CrawlerService: raise e def _process_single_url(self, task_id: int, url: str): - """[Worker] 单个 URL 处理线程""" + """[Internal Worker] 单个 URL 处理线程逻辑""" # 1. 内存标记:开始 self._track_start(task_id, url) logger.info(f"[THREAD START] {url}") @@ -107,6 +146,7 @@ class CrawlerService: url, formats=['markdown'], only_main_content=True ) + # 兼容性提取 raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '') metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {}) title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url) @@ -124,6 +164,7 @@ class CrawlerService: headers = chunk['metadata'] path = " > ".join(headers.values()) emb_input = f"{title}\n{path}\n{chunk['content']}" + vector = llm_service.get_embedding(emb_input) if vector: chunks_data.append({ @@ -145,34 +186,72 @@ class CrawlerService: # 5. 内存标记:结束 (无论成功失败都要移除) self._track_end(task_id, url) - def process_queue_concurrent(self, task_id: int, batch_size: int = 10): - """阶段2:多线程并发处理""" + def process_queue_concurrent(self, task_id: int, batch_size: int = 10) -> Dict[str, Any]: + """ + 第二阶段:多线程并发处理 (Process) + + Args: + task_id (int): 任务 ID + batch_size (int): 本次批次处理的 URL 数量(会分配给线程池并发执行) + + Returns: + dict: 处理结果概览 + { + "msg": "Batch completed", + "count": 10 + } + """ urls = data_service.get_pending_urls(task_id, limit=batch_size) - if not urls: return {"msg": "No pending urls"} + if not urls: return {"msg": "No pending urls", "count": 0} logger.info(f"Batch started: {len(urls)} urls") with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: # 提交任务到线程池 futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls} - # 等待完成 + # 等待完成 (阻塞直到所有线程结束) concurrent.futures.wait(futures) return {"msg": "Batch completed", "count": len(urls)} - def search(self, query: str, task_id, return_num: int): - """阶段3:搜索""" + def search(self, query: str, task_id: Optional[int], return_num: int) -> Dict[str, Any]: + """ + 第三阶段:智能搜索 (Search) + 流程:用户问题 -> Embedding -> 数据库混合检索(粗排) -> Rerank模型(精排) -> 结果 + + Args: + query (str): 用户问题 + task_id (Optional[int]): 指定搜索的任务 ID,None 为全库搜索 + return_num (int): 最终返回给用户的条数 (Top K) + + Returns: + dict: 搜索结果列表 + { + "results": [ + {"content": "...", "score": 0.98, "meta_info": {...}}, + ... + ] + } + """ + # 1. 生成向量 vector = llm_service.get_embedding(query) if not vector: return {"msg": "Embedding failed", "results": []} + # 2. 数据库粗排 (召回 10 倍数量或至少 50 条) coarse_limit = min(return_num * 10, 100) coarse_limit = max(coarse_limit, 50) - coarse_res = data_service.search(query, vector, task_id, coarse_limit) + coarse_res = data_service.search( + query_text=query, + query_vector=vector, + task_id=task_id, + candidates_num=coarse_limit + ) candidates = coarse_res.get('results', []) if not candidates: return {"results": []} + # 3. LLM 精排 (Rerank) final_res = llm_service.rerank(query, candidates, return_num) return {"results": final_res} diff --git a/backend/utils/common.py b/backend/utils/common.py index 6313dd6..0b05fad 100644 --- a/backend/utils/common.py +++ b/backend/utils/common.py @@ -11,6 +11,9 @@ def normalize_url(url: str) -> str: 2. 移除 fragment (#后面的内容) 3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面) 4. 移除尾部斜杠 + 示例: + "https://www.example.com/path/" -> "https://www.example.com/path" + "https://www.example.com/path?query=1" -> "https://www.example.com/path" """ if not url: return ""