diff --git a/backend/core/config.py b/backend/core/config.py index 310ae34..1bcb648 100644 --- a/backend/core/config.py +++ b/backend/core/config.py @@ -9,9 +9,11 @@ class Settings(BaseSettings): DB_PORT: str = "5432" DB_NAME: str DASHSCOPE_API_KEY: str + FIRECRAWL_API_KEY_EXSIST: bool = True # 是否存在API密钥 FIRECRAWL_API_KEY: str - - CANDIDATE_NUM: int = 50 + FIRECRAWL_API_URL: str = "https://api.firecrawl.dev/" # 默认官方API + + CANDIDATE_NUM: int = 50 # ========================================================= # 【核心修复】加上 ClassVar 类型注解 diff --git a/backend/services/crawler_service.py b/backend/services/crawler_service.py index e8948a7..8f3593d 100644 --- a/backend/services/crawler_service.py +++ b/backend/services/crawler_service.py @@ -23,7 +23,11 @@ class CrawlerService: """ def __init__(self): - self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY) + # 实例化 FirecrawlApp + if settings.FIRECRAWL_API_KEY_EXSIST: + self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY, api_url=settings.FIRECRAWL_API_URL) + else: + self.firecrawl = FirecrawlApp(api_url=settings.FIRECRAWL_API_URL) self.max_workers = 5 # 线程池最大并发数 # 内存状态追踪: { task_id: set([url1, url2]) } @@ -85,59 +89,74 @@ class CrawlerService: "active_thread_count": len(active_urls) } - def map_site(self, start_url: str) -> Dict[str, Any]: + def map_site(self, start_url: str, persist: bool = True) -> Dict[str, Any]: """ 第一阶段:站点地图扫描 (Map) - + + 改动要点: + - 先执行外部 map,确认能成功抓取到链接后再进行数据库注册与写入,避免出现“已注册但 map 未完成”的半成品任务。 + - 增加参数 persist(默认 True)。当 persist=False 时仅返回发现的链接列表,不进行任何数据库写入(用于假性/暂存流程)。 + - 使用 data_service.create_task_with_urls 在单个事务中创建任务并批量插入 URL(去重),提高原子性。 + Args: start_url (str): 目标网站的根 URL - + persist (bool): 是否将发现的 URL 持久化到数据库。用于先做假性扫描,后续统一持久化或回滚。 + Returns: dict: 包含任务 ID 和发现链接数的字典。 { - "task_id": 123, + "task_id": 123 | None, "count": 50, - "is_new": True + "is_new": True | False | None, + "urls": [ ... ], + "persisted": True | False } """ - logger.info(f"Mapping: {start_url}") + logger.info(f"Mapping (persist={persist}): {start_url}") try: - task_res = data_service.register_task(start_url) - urls_to_add = [start_url] - - # 如果任务已存在,不再重新 Map,直接返回 - if not task_res['is_new_task']: - logger.info(f"Task {task_res['task_id']} exists, skipping map.") - return { - "task_id": task_res['task_id'], - "count": 0, - "is_new": False - } - - # 新任务执行 Map + # 0. 先尝试执行外部 map(不进行任何数据库动作) try: map_res = self.firecrawl.map(start_url) # 兼容不同版本的 SDK 返回结构 found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', []) - + + urls_to_add = [start_url] for link in found_links: u = link if isinstance(link, str) else getattr(link, 'url', str(link)) urls_to_add.append(u) - logger.info(f"Map found {len(found_links)} links") + logger.info(f"Map found {len(found_links)} links for {start_url}") except Exception as e: - logger.warning(f"Map failed, proceeding with seed only: {e}") + # map 失败时不创建任务,直接抛出异常或返回失败信息,由上层决定回滚策略 + logger.error(f"Map failed for {start_url}, aborting register: {e}", exc_info=True) + raise + + # 1. 如果仅做假性扫描(不持久化),直接返回发现的链接,供上层统一持久化或回滚 + if not persist: + return { + "task_id": None, + "count": len(urls_to_add), + "is_new": None, + "urls": urls_to_add, + "persisted": False + } + + # 2. map 成功且需要持久化:使用原子化接口在单个事务中创建任务并写入队列 + try: + create_res = data_service.create_task_with_urls(start_url, urls_to_add) + return { + "task_id": create_res.get('task_id'), + "count": create_res.get('added', 0), + "is_new": create_res.get('is_new_task', False), + "urls": urls_to_add, + "persisted": True + } + except Exception as e: + logger.error(f"Atomic create_task_with_urls failed for {start_url}: {e}", exc_info=True) + raise - if urls_to_add: - data_service.add_urls(task_res['task_id'], urls_to_add) - - return { - "task_id": task_res['task_id'], - "count": len(urls_to_add), - "is_new": True - } except Exception as e: - logger.error(f"Map failed: {e}") - raise e + logger.error(f"Map+Register failed for {start_url}: {e}") + raise def _process_single_url(self, task_id: int, url: str): """[Internal Worker] 单个 URL 处理线程逻辑""" diff --git a/backend/services/data_service.py b/backend/services/data_service.py index 10fd24b..d61d5d1 100644 --- a/backend/services/data_service.py +++ b/backend/services/data_service.py @@ -93,7 +93,62 @@ class DataService: except: return url - # ... (保持 get_task_monitor_data, save_chunks, search 等方法不变) ... + def get_task_by_root_url(self, url: str): + """返回已存在任务的 id(如果没有则返回 None)""" + clean_url = normalize_url(url) + with self.db.engine.connect() as conn: + row = conn.execute(select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)).fetchone() + return row[0] if row else None + + def create_task_with_urls(self, url: str, urls: list[str]): + """ + 原子化:在单个事务中创建任务并批量插入 URL(去重)。 + 如果任务已存在,则不会创建新任务,而是把新的 URL 去重后插入到该任务下。 + + 返回:{"task_id": int, "is_new_task": bool, "added": int} + """ + clean_root = normalize_url(url) + clean_urls = [normalize_url(u) for u in urls] + added_count = 0 + with self.db.engine.begin() as conn: + # 1. 尝试获取已存在任务 + existing = conn.execute(select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_root)).fetchone() + if existing: + task_id = existing[0] + is_new = False + else: + # 创建新任务并返回 id + stmt = insert(self.db.tasks).values(root_url=clean_root).returning(self.db.tasks.c.id) + task_id = conn.execute(stmt).fetchone()[0] + is_new = True + + # 2. 批量插入 urls,跳过已存在项 + # 使用临时表或单条插入均可,这里逐条检查以保证兼容性 + for u in clean_urls: + try: + exists_q = select(self.db.queue.c.id).where( + and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == u) + ) + if not conn.execute(exists_q).fetchone(): + conn.execute(insert(self.db.queue).values(task_id=task_id, url=u, status='pending')) + added_count += 1 + except Exception: + # 单条插入失败时忽略,继续处理剩余 URL + continue + + return {"task_id": task_id, "is_new_task": is_new, "added": added_count} + + def delete_task(self, task_id: int): + """删除任务与相关队列与分片(谨慎使用,主要用于回滚)""" + with self.db.engine.begin() as conn: + try: + conn.execute(text("DELETE FROM chunks WHERE task_id = :tid"), {"tid": task_id}) + conn.execute(text("DELETE FROM queue WHERE task_id = :tid"), {"tid": task_id}) + conn.execute(text("DELETE FROM tasks WHERE id = :tid"), {"tid": task_id}) + return True + except Exception as e: + logger.error(f"Failed to delete task {task_id}: {e}") + return False def get_task_monitor_data(self, task_id: int): """[数据库层监控] 获取持久化的任务状态"""