import concurrent.futures import threading import logging from typing import Dict, Any, List, Optional, Union from firecrawl import FirecrawlApp from backend.core.config import settings from backend.services.data_service import data_service from backend.services.llm_service import llm_service from backend.utils.text_process import text_processor # 获取当前模块的专用 Logger logger = logging.getLogger(__name__) class CrawlerService: """ 爬虫业务服务层 (Crawler Service) 职责: 1. 协调外部 API (Firecrawl) 和内部服务 (DataService, LLMService)。 2. 管理多线程爬取任务及其状态。 3. 提供统一的搜索入口 (混合检索 + Rerank)。 """ def __init__(self): # 实例化 FirecrawlApp if settings.FIRECRAWL_API_KEY_EXSIST: self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY, api_url=settings.FIRECRAWL_API_URL) else: self.firecrawl = FirecrawlApp(api_url=settings.FIRECRAWL_API_URL) self.max_workers = 5 # 线程池最大并发数 # 内存状态追踪: { task_id: set([url1, url2]) } self._active_workers: Dict[int, set] = {} self._lock = threading.Lock() def get_knowledge_base_list(self): """获取知识库列表""" return data_service.get_all_tasks() def _track_start(self, task_id: int, url: str): """[Internal] 标记某个URL开始处理""" with self._lock: if task_id not in self._active_workers: self._active_workers[task_id] = set() self._active_workers[task_id].add(url) def _track_end(self, task_id: int, url: str): """[Internal] 标记某个URL处理结束""" with self._lock: if task_id in self._active_workers: self._active_workers[task_id].discard(url) def get_task_status(self, task_id: int) -> Optional[Dict[str, Any]]: """ 获取任务的实时综合状态。 Args: task_id (int): 任务 ID Returns: dict: 包含数据库统计和实时线程信息的字典。如果任务不存在返回 None。 结构示例: { "root_url": "https://example.com", "stats": {"pending": 10, "processing": 2, "completed": 5, "failed": 0}, "active_threads": ["https://example.com/page1"], "active_thread_count": 1 } """ # 1. 获取数据库层面的统计 (宏观) db_data = data_service.get_task_monitor_data(task_id) if not db_data: return None # 2. 获取内存层面的活跃线程 (微观) with self._lock: active_urls = list(self._active_workers.get(task_id, [])) # 日志输出当前状态 logger.info(f"Task {task_id} active threads: {active_urls}") logger.info(f"Task {task_id} stats: {db_data['db_stats']}") return { "root_url": db_data["root_url"], "stats": db_data["db_stats"], "active_threads": active_urls, "active_thread_count": len(active_urls) } def map_site(self, start_url: str, persist: bool = True) -> Dict[str, Any]: """ 第一阶段:站点地图扫描 (Map) 改动要点: - 先执行外部 map,确认能成功抓取到链接后再进行数据库注册与写入,避免出现“已注册但 map 未完成”的半成品任务。 - 增加参数 persist(默认 True)。当 persist=False 时仅返回发现的链接列表,不进行任何数据库写入(用于假性/暂存流程)。 - 使用 data_service.create_task_with_urls 在单个事务中创建任务并批量插入 URL(去重),提高原子性。 Args: start_url (str): 目标网站的根 URL persist (bool): 是否将发现的 URL 持久化到数据库。用于先做假性扫描,后续统一持久化或回滚。 Returns: dict: 包含任务 ID 和发现链接数的字典。 { "task_id": 123 | None, "count": 50, "is_new": True | False | None, "urls": [ ... ], "persisted": True | False } """ logger.info(f"Mapping (persist={persist}): {start_url}") try: # 0. 先尝试执行外部 map(不进行任何数据库动作) try: map_res = self.firecrawl.map(start_url) # 兼容不同版本的 SDK 返回结构 found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', []) urls_to_add = [start_url] for link in found_links: u = link if isinstance(link, str) else getattr(link, 'url', str(link)) urls_to_add.append(u) logger.info(f"Map found {len(found_links)} links for {start_url}") except Exception as e: # map 失败时不创建任务,直接抛出异常或返回失败信息,由上层决定回滚策略 logger.error(f"Map failed for {start_url}, aborting register: {e}", exc_info=True) raise # 1. 如果仅做假性扫描(不持久化),直接返回发现的链接,供上层统一持久化或回滚 if not persist: return { "task_id": None, "count": len(urls_to_add), "is_new": None, "urls": urls_to_add, "persisted": False } # 2. map 成功且需要持久化:使用原子化接口在单个事务中创建任务并写入队列 try: create_res = data_service.create_task_with_urls(start_url, urls_to_add) return { "task_id": create_res.get('task_id'), "count": create_res.get('added', 0), "is_new": create_res.get('is_new_task', False), "urls": urls_to_add, "persisted": True } except Exception as e: logger.error(f"Atomic create_task_with_urls failed for {start_url}: {e}", exc_info=True) raise except Exception as e: logger.error(f"Map+Register failed for {start_url}: {e}") raise def _process_single_url(self, task_id: int, url: str): """[Internal Worker] 单个 URL 处理线程逻辑""" # 1. 内存标记:开始 self._track_start(task_id, url) logger.info(f"[THREAD START] {url}") try: # 2. 爬取 scrape_res = self.firecrawl.scrape( url, formats=['markdown'], only_main_content=True ) # 兼容性提取 raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '') metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {}) title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url) if not raw_md: data_service.mark_url_status(task_id, url, 'failed') return # 3. 清洗 & 切分 clean_md = text_processor.clean_markdown(raw_md) chunks = text_processor.split_markdown(clean_md) chunks_data = [] for i, chunk in enumerate(chunks): headers = chunk['metadata'] path = " > ".join(headers.values()) emb_input = f"{title}\n{path}\n{chunk['content']}" vector = llm_service.get_embedding(emb_input) if vector: chunks_data.append({ "index": i, "content": chunk['content'], "embedding": vector, "meta_info": {"header_path": path, "headers": headers} }) # 4. 入库 (会自动标记 completed) if chunks_data: data_service.save_chunks(task_id, url, title, chunks_data) else: data_service.mark_url_status(task_id, url, 'failed') except Exception as e: logger.error(f"[THREAD ERROR] {url}: {e}") data_service.mark_url_status(task_id, url, 'failed') finally: # 5. 内存标记:结束 (无论成功失败都要移除) self._track_end(task_id, url) def process_queue_concurrent(self, task_id: int, batch_size: int = 10) -> Dict[str, Any]: """ 第二阶段:多线程并发处理 (Process) Args: task_id (int): 任务 ID batch_size (int): 本次批次处理的 URL 数量(会分配给线程池并发执行) Returns: dict: 处理结果概览 { "msg": "Batch completed", "count": 10 } """ urls = data_service.get_pending_urls(task_id, limit=batch_size) if not urls: return {"msg": "No pending urls", "count": 0} logger.info(f"Batch started: {len(urls)} urls") with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: # 提交任务到线程池 futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls} # 等待完成 (阻塞直到所有线程结束) concurrent.futures.wait(futures) return {"msg": "Batch completed", "count": len(urls)} def search(self, query: str, task_id: Optional[int], return_num: int) -> Dict[str, Any]: """ 第三阶段:智能搜索 (Search) 流程:用户问题 -> Embedding -> 数据库混合检索(粗排) -> Rerank模型(精排) -> 结果 Args: query (str): 用户问题 task_id (Optional[int]): 指定搜索的任务 ID,None 为全库搜索 return_num (int): 最终返回给用户的条数 (Top K) Returns: dict: 搜索结果列表 { "results": [ {"content": "...", "score": 0.98, "meta_info": {...}}, ... ] } """ # 1. 生成向量 vector = llm_service.get_embedding(query) if not vector: return {"msg": "Embedding failed", "results": []} # 2. 数据库粗排 (召回 10 倍数量或至少 50 条) coarse_limit = min(return_num * 10, 100) coarse_limit = max(coarse_limit, 50) coarse_res = data_service.search( query_text=query, query_vector=vector, task_id=task_id, candidates_num=coarse_limit ) candidates = coarse_res.get('results', []) if not candidates: return {"results": []} # 3. LLM 精排 (Rerank) final_res = llm_service.rerank(query, candidates, return_num) return {"results": final_res} crawler_service = CrawlerService()