Files
wiki_crawler/backend/services/crawler_service.py

258 lines
9.6 KiB
Python
Raw Normal View History

import concurrent.futures
import threading
2026-01-20 01:51:39 +08:00
import logging
from typing import Dict, Any, List, Optional, Union
2026-01-13 01:37:26 +08:00
from firecrawl import FirecrawlApp
from backend.core.config import settings
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
from backend.utils.text_process import text_processor
# 获取当前模块的专用 Logger
logger = logging.getLogger(__name__)
2026-01-20 01:51:39 +08:00
2026-01-13 01:37:26 +08:00
class CrawlerService:
2026-01-20 01:51:39 +08:00
"""
爬虫业务服务层 (Crawler Service)
职责
1. 协调外部 API (Firecrawl) 和内部服务 (DataService, LLMService)
2. 管理多线程爬取任务及其状态
3. 提供统一的搜索入口 (混合检索 + Rerank)
"""
2026-01-13 01:37:26 +08:00
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
2026-01-20 01:51:39 +08:00
self.max_workers = 5 # 线程池最大并发数
2026-01-20 01:51:39 +08:00
# 内存状态追踪: { task_id: set([url1, url2]) }
self._active_workers: Dict[int, set] = {}
self._lock = threading.Lock()
2026-01-20 01:51:39 +08:00
def _track_start(self, task_id: int, url: str):
"""[Internal] 标记某个URL开始处理"""
with self._lock:
if task_id not in self._active_workers:
self._active_workers[task_id] = set()
self._active_workers[task_id].add(url)
2026-01-20 01:51:39 +08:00
def _track_end(self, task_id: int, url: str):
"""[Internal] 标记某个URL处理结束"""
with self._lock:
if task_id in self._active_workers:
self._active_workers[task_id].discard(url)
2026-01-20 01:51:39 +08:00
def get_task_status(self, task_id: int) -> Optional[Dict[str, Any]]:
"""
2026-01-20 01:51:39 +08:00
获取任务的实时综合状态
Args:
task_id (int): 任务 ID
Returns:
dict: 包含数据库统计和实时线程信息的字典如果任务不存在返回 None
结构示例:
{
"root_url": "https://example.com",
"stats": {"pending": 10, "processing": 2, "completed": 5, "failed": 0},
"active_threads": ["https://example.com/page1"],
"active_thread_count": 1
}
"""
# 1. 获取数据库层面的统计 (宏观)
db_data = data_service.get_task_monitor_data(task_id)
if not db_data:
return None
# 2. 获取内存层面的活跃线程 (微观)
with self._lock:
active_urls = list(self._active_workers.get(task_id, []))
2026-01-20 01:51:39 +08:00
# 日志输出当前状态
logger.info(f"Task {task_id} active threads: {active_urls}")
2026-01-20 01:51:39 +08:00
logger.info(f"Task {task_id} stats: {db_data['db_stats']}")
return {
"root_url": db_data["root_url"],
2026-01-20 01:51:39 +08:00
"stats": db_data["db_stats"],
"active_threads": active_urls,
"active_thread_count": len(active_urls)
}
2026-01-13 01:37:26 +08:00
2026-01-20 01:51:39 +08:00
def map_site(self, start_url: str) -> Dict[str, Any]:
"""
第一阶段站点地图扫描 (Map)
Args:
start_url (str): 目标网站的根 URL
Returns:
dict: 包含任务 ID 和发现链接数的字典
{
"task_id": 123,
"count": 50,
"is_new": True
}
"""
logger.info(f"Mapping: {start_url}")
2026-01-13 01:37:26 +08:00
try:
task_res = data_service.register_task(start_url)
urls_to_add = [start_url]
2026-01-13 01:37:26 +08:00
# 如果任务已存在,不再重新 Map直接返回
if not task_res['is_new_task']:
logger.info(f"Task {task_res['task_id']} exists, skipping map.")
return {
"task_id": task_res['task_id'],
"count": 0,
"is_new": False
}
# 新任务执行 Map
2026-01-13 01:37:26 +08:00
try:
map_res = self.firecrawl.map(start_url)
2026-01-20 01:51:39 +08:00
# 兼容不同版本的 SDK 返回结构
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
2026-01-20 01:51:39 +08:00
2026-01-13 01:37:26 +08:00
for link in found_links:
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
urls_to_add.append(u)
logger.info(f"Map found {len(found_links)} links")
2026-01-13 01:37:26 +08:00
except Exception as e:
logger.warning(f"Map failed, proceeding with seed only: {e}")
2026-01-13 01:37:26 +08:00
if urls_to_add:
data_service.add_urls(task_res['task_id'], urls_to_add)
return {
"task_id": task_res['task_id'],
"count": len(urls_to_add),
"is_new": True
2026-01-13 01:37:26 +08:00
}
except Exception as e:
logger.error(f"Map failed: {e}")
2026-01-13 01:37:26 +08:00
raise e
def _process_single_url(self, task_id: int, url: str):
2026-01-20 01:51:39 +08:00
"""[Internal Worker] 单个 URL 处理线程逻辑"""
# 1. 内存标记:开始
self._track_start(task_id, url)
logger.info(f"[THREAD START] {url}")
try:
# 2. 爬取
scrape_res = self.firecrawl.scrape(
url, formats=['markdown'], only_main_content=True
)
2026-01-20 01:51:39 +08:00
# 兼容性提取
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
2026-01-13 01:37:26 +08:00
if not raw_md:
data_service.mark_url_status(task_id, url, 'failed')
return
# 3. 清洗 & 切分
clean_md = text_processor.clean_markdown(raw_md)
chunks = text_processor.split_markdown(clean_md)
chunks_data = []
for i, chunk in enumerate(chunks):
headers = chunk['metadata']
path = " > ".join(headers.values())
emb_input = f"{title}\n{path}\n{chunk['content']}"
2026-01-20 01:51:39 +08:00
vector = llm_service.get_embedding(emb_input)
if vector:
chunks_data.append({
"index": i, "content": chunk['content'], "embedding": vector,
"meta_info": {"header_path": path, "headers": headers}
})
# 4. 入库 (会自动标记 completed)
if chunks_data:
data_service.save_chunks(task_id, url, title, chunks_data)
else:
data_service.mark_url_status(task_id, url, 'failed')
except Exception as e:
logger.error(f"[THREAD ERROR] {url}: {e}")
data_service.mark_url_status(task_id, url, 'failed')
finally:
# 5. 内存标记:结束 (无论成功失败都要移除)
self._track_end(task_id, url)
2026-01-20 01:51:39 +08:00
def process_queue_concurrent(self, task_id: int, batch_size: int = 10) -> Dict[str, Any]:
"""
第二阶段多线程并发处理 (Process)
Args:
task_id (int): 任务 ID
batch_size (int): 本次批次处理的 URL 数量会分配给线程池并发执行
Returns:
dict: 处理结果概览
{
"msg": "Batch completed",
"count": 10
}
"""
urls = data_service.get_pending_urls(task_id, limit=batch_size)
2026-01-20 01:51:39 +08:00
if not urls: return {"msg": "No pending urls", "count": 0}
logger.info(f"Batch started: {len(urls)} urls")
2026-01-13 01:37:26 +08:00
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交任务到线程池
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
2026-01-20 01:51:39 +08:00
# 等待完成 (阻塞直到所有线程结束)
concurrent.futures.wait(futures)
return {"msg": "Batch completed", "count": len(urls)}
2026-01-13 01:37:26 +08:00
2026-01-20 01:51:39 +08:00
def search(self, query: str, task_id: Optional[int], return_num: int) -> Dict[str, Any]:
"""
第三阶段智能搜索 (Search)
流程用户问题 -> Embedding -> 数据库混合检索(粗排) -> Rerank模型(精排) -> 结果
Args:
query (str): 用户问题
task_id (Optional[int]): 指定搜索的任务 IDNone 为全库搜索
return_num (int): 最终返回给用户的条数 (Top K)
Returns:
dict: 搜索结果列表
{
"results": [
{"content": "...", "score": 0.98, "meta_info": {...}},
...
]
}
"""
# 1. 生成向量
2026-01-13 01:37:26 +08:00
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
2026-01-20 01:51:39 +08:00
# 2. 数据库粗排 (召回 10 倍数量或至少 50 条)
coarse_limit = min(return_num * 10, 100)
coarse_limit = max(coarse_limit, 50)
2026-01-13 10:37:19 +08:00
2026-01-20 01:51:39 +08:00
coarse_res = data_service.search(
query_text=query,
query_vector=vector,
task_id=task_id,
candidates_num=coarse_limit
)
candidates = coarse_res.get('results', [])
2026-01-13 10:37:19 +08:00
if not candidates: return {"results": []}
2026-01-20 01:51:39 +08:00
# 3. LLM 精排 (Rerank)
final_res = llm_service.rerank(query, candidates, return_num)
return {"results": final_res}
2026-01-13 10:37:19 +08:00
2026-01-13 01:37:26 +08:00
crawler_service = CrawlerService()