Files
wiki_crawler/backend/services/crawler_service.py

179 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import concurrent.futures
import threading
from firecrawl import FirecrawlApp
from backend.core.config import settings
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
from backend.utils.text_process import text_processor
import logging
# 获取当前模块的专用 Logger
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
logger = logging.getLogger(__name__)
class CrawlerService:
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
self.max_workers = 5
# [新增] 内存状态追踪
# 结构: { task_id: { url: "status_desc" } }
self._active_workers = {}
self._lock = threading.Lock()
def _track_start(self, task_id, url):
"""开始追踪某个URL"""
with self._lock:
if task_id not in self._active_workers:
self._active_workers[task_id] = set()
self._active_workers[task_id].add(url)
def _track_end(self, task_id, url):
"""结束追踪某个URL"""
with self._lock:
if task_id in self._active_workers:
self._active_workers[task_id].discard(url)
def get_task_status(self, task_id: int):
"""
[综合监控] 获取全量状态 = 数据库统计 + 实时线程列表
"""
# 1. 获取数据库层面的统计 (宏观)
db_data = data_service.get_task_monitor_data(task_id)
if not db_data:
return None
# 2. 获取内存层面的活跃线程 (微观)
with self._lock:
active_urls = list(self._active_workers.get(task_id, []))
# 输出情况
logger.info(f"Task {task_id} active threads: {active_urls}")
logger.info(f"Task {task_id} stats: {db_data['db_stats']}") # 打印数据库统计信息
return {
"root_url": db_data["root_url"],
"stats": db_data["db_stats"], # Pending, Completed, Failed 等
"active_threads": active_urls, # 当前 CPU/网络 正在处理的 URL
"active_thread_count": len(active_urls)
}
def map_site(self, start_url: str):
"""阶段1站点地图扫描"""
logger.info(f"Mapping: {start_url}")
try:
task_res = data_service.register_task(start_url)
urls_to_add = [start_url]
# 如果任务已存在,不再重新 Map直接返回
if not task_res['is_new_task']:
logger.info(f"Task {task_res['task_id']} exists, skipping map.")
return {
"task_id": task_res['task_id'],
"count": 0,
"is_new": False
}
# 新任务执行 Map
try:
map_res = self.firecrawl.map(start_url)
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
for link in found_links:
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
urls_to_add.append(u)
logger.info(f"Map found {len(found_links)} links")
except Exception as e:
logger.warning(f"Map failed, proceeding with seed only: {e}")
if urls_to_add:
data_service.add_urls(task_res['task_id'], urls_to_add)
return {
"task_id": task_res['task_id'],
"count": len(urls_to_add),
"is_new": True
}
except Exception as e:
logger.error(f"Map failed: {e}")
raise e
def _process_single_url(self, task_id: int, url: str):
"""[Worker] 单个 URL 处理线程"""
# 1. 内存标记:开始
self._track_start(task_id, url)
logger.info(f"[THREAD START] {url}")
try:
# 2. 爬取
scrape_res = self.firecrawl.scrape(
url, formats=['markdown'], only_main_content=True
)
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
if not raw_md:
data_service.mark_url_status(task_id, url, 'failed')
return
# 3. 清洗 & 切分
clean_md = text_processor.clean_markdown(raw_md)
chunks = text_processor.split_markdown(clean_md)
chunks_data = []
for i, chunk in enumerate(chunks):
headers = chunk['metadata']
path = " > ".join(headers.values())
emb_input = f"{title}\n{path}\n{chunk['content']}"
vector = llm_service.get_embedding(emb_input)
if vector:
chunks_data.append({
"index": i, "content": chunk['content'], "embedding": vector,
"meta_info": {"header_path": path, "headers": headers}
})
# 4. 入库 (会自动标记 completed)
if chunks_data:
data_service.save_chunks(task_id, url, title, chunks_data)
else:
data_service.mark_url_status(task_id, url, 'failed')
except Exception as e:
logger.error(f"[THREAD ERROR] {url}: {e}")
data_service.mark_url_status(task_id, url, 'failed')
finally:
# 5. 内存标记:结束 (无论成功失败都要移除)
self._track_end(task_id, url)
def process_queue_concurrent(self, task_id: int, batch_size: int = 10):
"""阶段2多线程并发处理"""
urls = data_service.get_pending_urls(task_id, limit=batch_size)
if not urls: return {"msg": "No pending urls"}
logger.info(f"Batch started: {len(urls)} urls")
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 提交任务到线程池
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
# 等待完成
concurrent.futures.wait(futures)
return {"msg": "Batch completed", "count": len(urls)}
def search(self, query: str, task_id, return_num: int):
"""阶段3搜索"""
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
coarse_limit = min(return_num * 10, 100)
coarse_limit = max(coarse_limit, 50)
coarse_res = data_service.search(query, vector, task_id, coarse_limit)
candidates = coarse_res.get('results', [])
if not candidates: return {"results": []}
final_res = llm_service.rerank(query, candidates, return_num)
return {"results": final_res}
crawler_service = CrawlerService()