179 lines
7.1 KiB
Python
179 lines
7.1 KiB
Python
import concurrent.futures
|
||
import threading
|
||
from firecrawl import FirecrawlApp
|
||
from backend.core.config import settings
|
||
from backend.services.data_service import data_service
|
||
from backend.services.llm_service import llm_service
|
||
from backend.utils.text_process import text_processor
|
||
import logging
|
||
|
||
# 获取当前模块的专用 Logger
|
||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||
logger = logging.getLogger(__name__)
|
||
class CrawlerService:
|
||
def __init__(self):
|
||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||
self.max_workers = 5
|
||
|
||
# [新增] 内存状态追踪
|
||
# 结构: { task_id: { url: "status_desc" } }
|
||
self._active_workers = {}
|
||
self._lock = threading.Lock()
|
||
|
||
def _track_start(self, task_id, url):
|
||
"""开始追踪某个URL"""
|
||
with self._lock:
|
||
if task_id not in self._active_workers:
|
||
self._active_workers[task_id] = set()
|
||
self._active_workers[task_id].add(url)
|
||
|
||
def _track_end(self, task_id, url):
|
||
"""结束追踪某个URL"""
|
||
with self._lock:
|
||
if task_id in self._active_workers:
|
||
self._active_workers[task_id].discard(url)
|
||
|
||
def get_task_status(self, task_id: int):
|
||
"""
|
||
[综合监控] 获取全量状态 = 数据库统计 + 实时线程列表
|
||
"""
|
||
# 1. 获取数据库层面的统计 (宏观)
|
||
db_data = data_service.get_task_monitor_data(task_id)
|
||
if not db_data:
|
||
return None
|
||
|
||
# 2. 获取内存层面的活跃线程 (微观)
|
||
with self._lock:
|
||
active_urls = list(self._active_workers.get(task_id, []))
|
||
# 输出情况
|
||
logger.info(f"Task {task_id} active threads: {active_urls}")
|
||
logger.info(f"Task {task_id} stats: {db_data['db_stats']}") # 打印数据库统计信息
|
||
|
||
return {
|
||
"root_url": db_data["root_url"],
|
||
"stats": db_data["db_stats"], # Pending, Completed, Failed 等
|
||
"active_threads": active_urls, # 当前 CPU/网络 正在处理的 URL
|
||
"active_thread_count": len(active_urls)
|
||
}
|
||
|
||
def map_site(self, start_url: str):
|
||
"""阶段1:站点地图扫描"""
|
||
logger.info(f"Mapping: {start_url}")
|
||
try:
|
||
task_res = data_service.register_task(start_url)
|
||
urls_to_add = [start_url]
|
||
|
||
# 如果任务已存在,不再重新 Map,直接返回
|
||
if not task_res['is_new_task']:
|
||
logger.info(f"Task {task_res['task_id']} exists, skipping map.")
|
||
return {
|
||
"task_id": task_res['task_id'],
|
||
"count": 0,
|
||
"is_new": False
|
||
}
|
||
|
||
# 新任务执行 Map
|
||
try:
|
||
map_res = self.firecrawl.map(start_url)
|
||
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
|
||
for link in found_links:
|
||
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
||
urls_to_add.append(u)
|
||
logger.info(f"Map found {len(found_links)} links")
|
||
except Exception as e:
|
||
logger.warning(f"Map failed, proceeding with seed only: {e}")
|
||
|
||
if urls_to_add:
|
||
data_service.add_urls(task_res['task_id'], urls_to_add)
|
||
|
||
return {
|
||
"task_id": task_res['task_id'],
|
||
"count": len(urls_to_add),
|
||
"is_new": True
|
||
}
|
||
except Exception as e:
|
||
logger.error(f"Map failed: {e}")
|
||
raise e
|
||
|
||
def _process_single_url(self, task_id: int, url: str):
|
||
"""[Worker] 单个 URL 处理线程"""
|
||
# 1. 内存标记:开始
|
||
self._track_start(task_id, url)
|
||
logger.info(f"[THREAD START] {url}")
|
||
|
||
try:
|
||
# 2. 爬取
|
||
scrape_res = self.firecrawl.scrape(
|
||
url, formats=['markdown'], only_main_content=True
|
||
)
|
||
|
||
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
|
||
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
|
||
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
|
||
|
||
if not raw_md:
|
||
data_service.mark_url_status(task_id, url, 'failed')
|
||
return
|
||
|
||
# 3. 清洗 & 切分
|
||
clean_md = text_processor.clean_markdown(raw_md)
|
||
chunks = text_processor.split_markdown(clean_md)
|
||
|
||
chunks_data = []
|
||
for i, chunk in enumerate(chunks):
|
||
headers = chunk['metadata']
|
||
path = " > ".join(headers.values())
|
||
emb_input = f"{title}\n{path}\n{chunk['content']}"
|
||
vector = llm_service.get_embedding(emb_input)
|
||
if vector:
|
||
chunks_data.append({
|
||
"index": i, "content": chunk['content'], "embedding": vector,
|
||
"meta_info": {"header_path": path, "headers": headers}
|
||
})
|
||
|
||
# 4. 入库 (会自动标记 completed)
|
||
if chunks_data:
|
||
data_service.save_chunks(task_id, url, title, chunks_data)
|
||
else:
|
||
data_service.mark_url_status(task_id, url, 'failed')
|
||
|
||
except Exception as e:
|
||
logger.error(f"[THREAD ERROR] {url}: {e}")
|
||
data_service.mark_url_status(task_id, url, 'failed')
|
||
|
||
finally:
|
||
# 5. 内存标记:结束 (无论成功失败都要移除)
|
||
self._track_end(task_id, url)
|
||
|
||
def process_queue_concurrent(self, task_id: int, batch_size: int = 10):
|
||
"""阶段2:多线程并发处理"""
|
||
urls = data_service.get_pending_urls(task_id, limit=batch_size)
|
||
if not urls: return {"msg": "No pending urls"}
|
||
|
||
logger.info(f"Batch started: {len(urls)} urls")
|
||
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||
# 提交任务到线程池
|
||
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
|
||
# 等待完成
|
||
concurrent.futures.wait(futures)
|
||
|
||
return {"msg": "Batch completed", "count": len(urls)}
|
||
|
||
def search(self, query: str, task_id, return_num: int):
|
||
"""阶段3:搜索"""
|
||
vector = llm_service.get_embedding(query)
|
||
if not vector: return {"msg": "Embedding failed", "results": []}
|
||
|
||
coarse_limit = min(return_num * 10, 100)
|
||
coarse_limit = max(coarse_limit, 50)
|
||
|
||
coarse_res = data_service.search(query, vector, task_id, coarse_limit)
|
||
candidates = coarse_res.get('results', [])
|
||
|
||
if not candidates: return {"results": []}
|
||
|
||
final_res = llm_service.rerank(query, candidates, return_num)
|
||
return {"results": final_res}
|
||
|
||
crawler_service = CrawlerService() |