Files
wiki_crawler/backend/services/crawler_service.py
2026-01-20 01:51:39 +08:00

258 lines
9.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import concurrent.futures
import threading
import logging
from typing import Dict, Any, List, Optional, Union
from firecrawl import FirecrawlApp
from backend.core.config import settings
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
from backend.utils.text_process import text_processor
# 获取当前模块的专用 Logger
logger = logging.getLogger(__name__)
class CrawlerService:
"""
爬虫业务服务层 (Crawler Service)
职责:
1. 协调外部 API (Firecrawl) 和内部服务 (DataService, LLMService)。
2. 管理多线程爬取任务及其状态。
3. 提供统一的搜索入口 (混合检索 + Rerank)。
"""
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
self.max_workers = 5 # 线程池最大并发数
# 内存状态追踪: { task_id: set([url1, url2]) }
self._active_workers: Dict[int, set] = {}
self._lock = threading.Lock()
def _track_start(self, task_id: int, url: str):
"""[Internal] 标记某个URL开始处理"""
with self._lock:
if task_id not in self._active_workers:
self._active_workers[task_id] = set()
self._active_workers[task_id].add(url)
def _track_end(self, task_id: int, url: str):
"""[Internal] 标记某个URL处理结束"""
with self._lock:
if task_id in self._active_workers:
self._active_workers[task_id].discard(url)
def get_task_status(self, task_id: int) -> Optional[Dict[str, Any]]:
"""
获取任务的实时综合状态。
Args:
task_id (int): 任务 ID
Returns:
dict: 包含数据库统计和实时线程信息的字典。如果任务不存在返回 None。
结构示例:
{
"root_url": "https://example.com",
"stats": {"pending": 10, "processing": 2, "completed": 5, "failed": 0},
"active_threads": ["https://example.com/page1"],
"active_thread_count": 1
}
"""
# 1. 获取数据库层面的统计 (宏观)
db_data = data_service.get_task_monitor_data(task_id)
if not db_data:
return None
# 2. 获取内存层面的活跃线程 (微观)
with self._lock:
active_urls = list(self._active_workers.get(task_id, []))
# 日志输出当前状态
logger.info(f"Task {task_id} active threads: {active_urls}")
logger.info(f"Task {task_id} stats: {db_data['db_stats']}")
return {
"root_url": db_data["root_url"],
"stats": db_data["db_stats"],
"active_threads": active_urls,
"active_thread_count": len(active_urls)
}
def map_site(self, start_url: str) -> Dict[str, Any]:
"""
第一阶段:站点地图扫描 (Map)
Args:
start_url (str): 目标网站的根 URL
Returns:
dict: 包含任务 ID 和发现链接数的字典。
{
"task_id": 123,
"count": 50,
"is_new": True
}
"""
logger.info(f"Mapping: {start_url}")
try:
task_res = data_service.register_task(start_url)
urls_to_add = [start_url]
# 如果任务已存在,不再重新 Map直接返回
if not task_res['is_new_task']:
logger.info(f"Task {task_res['task_id']} exists, skipping map.")
return {
"task_id": task_res['task_id'],
"count": 0,
"is_new": False
}
# 新任务执行 Map
try:
map_res = self.firecrawl.map(start_url)
# 兼容不同版本的 SDK 返回结构
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
for link in found_links:
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
urls_to_add.append(u)
logger.info(f"Map found {len(found_links)} links")
except Exception as e:
logger.warning(f"Map failed, proceeding with seed only: {e}")
if urls_to_add:
data_service.add_urls(task_res['task_id'], urls_to_add)
return {
"task_id": task_res['task_id'],
"count": len(urls_to_add),
"is_new": True
}
except Exception as e:
logger.error(f"Map failed: {e}")
raise e
def _process_single_url(self, task_id: int, url: str):
"""[Internal Worker] 单个 URL 处理线程逻辑"""
# 1. 内存标记:开始
self._track_start(task_id, url)
logger.info(f"[THREAD START] {url}")
try:
# 2. 爬取
scrape_res = self.firecrawl.scrape(
url, formats=['markdown'], only_main_content=True
)
# 兼容性提取
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
if not raw_md:
data_service.mark_url_status(task_id, url, 'failed')
return
# 3. 清洗 & 切分
clean_md = text_processor.clean_markdown(raw_md)
chunks = text_processor.split_markdown(clean_md)
chunks_data = []
for i, chunk in enumerate(chunks):
headers = chunk['metadata']
path = " > ".join(headers.values())
emb_input = f"{title}\n{path}\n{chunk['content']}"
vector = llm_service.get_embedding(emb_input)
if vector:
chunks_data.append({
"index": i, "content": chunk['content'], "embedding": vector,
"meta_info": {"header_path": path, "headers": headers}
})
# 4. 入库 (会自动标记 completed)
if chunks_data:
data_service.save_chunks(task_id, url, title, chunks_data)
else:
data_service.mark_url_status(task_id, url, 'failed')
except Exception as e:
logger.error(f"[THREAD ERROR] {url}: {e}")
data_service.mark_url_status(task_id, url, 'failed')
finally:
# 5. 内存标记:结束 (无论成功失败都要移除)
self._track_end(task_id, url)
def process_queue_concurrent(self, task_id: int, batch_size: int = 10) -> Dict[str, Any]:
"""
第二阶段:多线程并发处理 (Process)
Args:
task_id (int): 任务 ID
batch_size (int): 本次批次处理的 URL 数量(会分配给线程池并发执行)
Returns:
dict: 处理结果概览
{
"msg": "Batch completed",
"count": 10
}
"""
urls = data_service.get_pending_urls(task_id, limit=batch_size)
if not urls: return {"msg": "No pending urls", "count": 0}
logger.info(f"Batch started: {len(urls)} urls")
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交任务到线程池
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
# 等待完成 (阻塞直到所有线程结束)
concurrent.futures.wait(futures)
return {"msg": "Batch completed", "count": len(urls)}
def search(self, query: str, task_id: Optional[int], return_num: int) -> Dict[str, Any]:
"""
第三阶段:智能搜索 (Search)
流程:用户问题 -> Embedding -> 数据库混合检索(粗排) -> Rerank模型(精排) -> 结果
Args:
query (str): 用户问题
task_id (Optional[int]): 指定搜索的任务 IDNone 为全库搜索
return_num (int): 最终返回给用户的条数 (Top K)
Returns:
dict: 搜索结果列表
{
"results": [
{"content": "...", "score": 0.98, "meta_info": {...}},
...
]
}
"""
# 1. 生成向量
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
# 2. 数据库粗排 (召回 10 倍数量或至少 50 条)
coarse_limit = min(return_num * 10, 100)
coarse_limit = max(coarse_limit, 50)
coarse_res = data_service.search(
query_text=query,
query_vector=vector,
task_id=task_id,
candidates_num=coarse_limit
)
candidates = coarse_res.get('results', [])
if not candidates: return {"results": []}
# 3. LLM 精排 (Rerank)
final_res = llm_service.rerank(query, candidates, return_num)
return {"results": final_res}
crawler_service = CrawlerService()