150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
# backend/services/crawler_service.py
|
|
from firecrawl import FirecrawlApp
|
|
from backend.core.config import settings
|
|
from backend.services.data_service import data_service
|
|
from backend.services.llm_service import llm_service
|
|
from backend.utils.text_process import text_processor
|
|
|
|
class CrawlerService:
|
|
"""
|
|
爬虫编排服务
|
|
协调 Firecrawl, LLM, 和 DataService
|
|
"""
|
|
def __init__(self):
|
|
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
|
|
|
def map_site(self, start_url: str):
|
|
print(f"[INFO] Mapping: {start_url}")
|
|
try:
|
|
# 1. 注册任务
|
|
task_res = data_service.register_task(start_url)
|
|
|
|
# 2. 无论是否新任务,都尝试把 start_url 加入队列
|
|
urls_to_add = [start_url]
|
|
|
|
# 3. 调用 Firecrawl Map
|
|
try:
|
|
map_res = self.firecrawl.map(start_url)
|
|
|
|
# 兼容不同 SDK 版本的返回 (对象 或 字典)
|
|
found_links = []
|
|
if isinstance(map_res, dict):
|
|
found_links = map_res.get('links', [])
|
|
else:
|
|
found_links = getattr(map_res, 'links', [])
|
|
|
|
for link in found_links:
|
|
# 链接可能是字符串,也可能是对象
|
|
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
|
urls_to_add.append(u)
|
|
|
|
print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.")
|
|
except Exception as e:
|
|
print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}")
|
|
|
|
# 4. 批量入库
|
|
if urls_to_add:
|
|
data_service.add_urls(task_res['task_id'], urls_to_add)
|
|
|
|
return {
|
|
"msg": "Mapped successfully",
|
|
"count": len(urls_to_add),
|
|
"task_id": task_res['task_id']
|
|
}
|
|
|
|
except Exception as e:
|
|
print(f"[ERROR] Map failed: {e}")
|
|
raise e
|
|
|
|
def process_queue(self, task_id: int, batch_size: int = 5):
|
|
"""
|
|
处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储
|
|
"""
|
|
pending = data_service.get_pending_urls(task_id, batch_size)
|
|
urls = pending['urls']
|
|
if not urls: return {"msg": "Queue empty"}
|
|
|
|
processed = 0
|
|
for url in urls:
|
|
try:
|
|
print(f"[INFO] Scraping: {url}")
|
|
|
|
# 调用 Firecrawl
|
|
scrape_res = self.firecrawl.scrape(
|
|
url,
|
|
formats=['markdown'],
|
|
only_main_content=True
|
|
)
|
|
|
|
# =====================================================
|
|
# 【核心修复】兼容字典和对象两种返回格式
|
|
# =====================================================
|
|
raw_md = ""
|
|
metadata = None
|
|
|
|
# 1. 提取 markdown 和 metadata 本体
|
|
if isinstance(scrape_res, dict):
|
|
raw_md = scrape_res.get('markdown', '')
|
|
metadata = scrape_res.get('metadata', {})
|
|
else:
|
|
# 如果是对象 (ScrapeResponse)
|
|
raw_md = getattr(scrape_res, 'markdown', '')
|
|
metadata = getattr(scrape_res, 'metadata', {})
|
|
|
|
# 2. 从 metadata 中安全提取 title
|
|
title = url # 默认用 url 当标题
|
|
if metadata:
|
|
if isinstance(metadata, dict):
|
|
# 如果 metadata 是字典
|
|
title = metadata.get('title', url)
|
|
else:
|
|
# 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方
|
|
title = getattr(metadata, 'title', url)
|
|
|
|
# =====================================================
|
|
|
|
if not raw_md:
|
|
print(f"[WARN] Empty content for {url}")
|
|
continue
|
|
|
|
# 1. 清洗
|
|
clean_md = text_processor.clean_markdown(raw_md)
|
|
if not clean_md: continue
|
|
|
|
# 2. 切分
|
|
chunks = text_processor.split_markdown(clean_md)
|
|
|
|
chunks_data = []
|
|
for i, chunk in enumerate(chunks):
|
|
content = chunk['content']
|
|
headers = chunk['metadata']
|
|
header_path = " > ".join(headers.values())
|
|
|
|
# 3. 向量化 (Title + Header + Content)
|
|
emb_input = f"{title}\n{header_path}\n{content}"
|
|
vector = llm_service.get_embedding(emb_input)
|
|
|
|
if vector:
|
|
chunks_data.append({
|
|
"index": i,
|
|
"content": content,
|
|
"embedding": vector,
|
|
"meta_info": {"header_path": header_path, "headers": headers}
|
|
})
|
|
|
|
# 4. 存储
|
|
if chunks_data:
|
|
data_service.save_chunks(task_id, url, title, chunks_data)
|
|
processed += 1
|
|
except Exception as e:
|
|
# 打印详细错误方便调试
|
|
print(f"[ERROR] Process URL {url} failed: {e}")
|
|
|
|
return {"msg": "Batch processed", "count": processed}
|
|
|
|
def search(self, query: str, task_id, limit: int):
|
|
vector = llm_service.get_embedding(query)
|
|
if not vector: return {"msg": "Embedding failed", "results": []}
|
|
return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit)
|
|
|
|
crawler_service = CrawlerService() |