Files
wiki_crawler/backend/workflow.py

140 lines
5.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# backend/workflow.py
import dashscope
from http import HTTPStatus
from firecrawl import FirecrawlApp
from langchain_text_splitters import RecursiveCharacterTextSplitter
from .config import settings
from .service import crawler_sql_service
# 初始化配置
dashscope.api_key = settings.DASHSCOPE_API_KEY
class AutomatedCrawler:
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
# 文本切分器每块500字符重叠100字符保证语义连贯
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def _get_embedding(self, text: str):
"""调用 Dashscope 生成向量"""
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4,
input=text,
dimension=1536
)
if resp.status_code == HTTPStatus.OK:
# 阿里 text-embedding-v3 默认维度是 1024v2 是 1536
# 如果你的数据库是 1536 维,请使用 text_embedding_v2 或调整参数
return resp.output['embeddings'][0]['embedding']
else:
print(f"Embedding API Error: {resp}")
return None
except Exception as e:
print(f"Embedding Exception: {e}")
return None
def map_and_ingest(self, start_url: str):
"""
V2 步骤1: 地图式扫描并入库
"""
print(f"[WorkFlow] Start mapping: {start_url}")
# 1. 在数据库注册任务
task_info = crawler_sql_service.register_task(start_url)
task_id = task_info['task_id']
try:
# 2. 调用 Firecrawl Map
map_result = self.firecrawl.map(start_url)
urls=[]
for link in map_result.links:
urls.append(link.url)
print(f"[WorkFlow] Found {len(urls)} links")
# 3. 批量入库 (状态为 pending)
if urls:
res = crawler_sql_service.add_urls(task_id, urls)
return {
"task_id": task_id,
"map_result": res
}
except Exception as e:
print(f"[WorkFlow] Map Error: {e}")
raise e
def process_task_queue(self, task_id: int, limit: int = 10):
"""
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
"""
# 1. 获取待处理 URL
pending = crawler_sql_service.get_pending_urls(task_id, limit)
urls = pending['urls']
if not urls:
return {"msg": "No pending urls", "count": 0}
processed_count = 0
for url in urls:
try:
print(f"[WorkFlow] Processing: {url}")
# 2. 单页抓取 (Markdown)
scrape_res = self.firecrawl.scrape(url, formats=['markdown'], only_main_content=True)
content = scrape_res.markdown
metadata = scrape_res.metadata_dict
title = metadata.get('title', url)
if not content:
print(f"[WorkFlow] Skip empty content: {url}")
continue
# 3. 切片
chunks = self.splitter.split_text(content)
results_to_save = []
# 4. 逐个切片向量化 (可以优化为批量调用 embedding API)
for idx, chunk_text in enumerate(chunks):
vector = self._get_embedding(chunk_text)
if vector:
results_to_save.append({
"source_url": url,
"chunk_index": idx,
"title": title,
"content": chunk_text,
"embedding": vector
})
# 5. 保存结果
if results_to_save:
crawler_sql_service.save_results(task_id, results_to_save)
processed_count += 1
except Exception as e:
print(f"[WorkFlow] Error processing {url}: {e}")
# 实际生产中这里应该把 URL 状态改为 'failed'
return {"processed_urls": processed_count, "total_chunks": "dynamic"}
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
"""
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
"""
vector = self._get_embedding(query_text)
if not vector:
raise Exception("Failed to generate embedding for query")
return crawler_sql_service.search_knowledge(vector, task_id, limit)
# 单例模式
workflow = AutomatedCrawler()