140 lines
5.0 KiB
Python
140 lines
5.0 KiB
Python
|
|
# backend/workflow.py
|
|||
|
|
import dashscope
|
|||
|
|
from http import HTTPStatus
|
|||
|
|
from firecrawl import FirecrawlApp
|
|||
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
|
|
from .config import settings
|
|||
|
|
from .service import crawler_sql_service
|
|||
|
|
|
|||
|
|
# 初始化配置
|
|||
|
|
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
|||
|
|
|
|||
|
|
class AutomatedCrawler:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
|||
|
|
# 文本切分器:每块500字符,重叠100字符,保证语义连贯
|
|||
|
|
self.splitter = RecursiveCharacterTextSplitter(
|
|||
|
|
chunk_size=500,
|
|||
|
|
chunk_overlap=100,
|
|||
|
|
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _get_embedding(self, text: str):
|
|||
|
|
"""调用 Dashscope 生成向量"""
|
|||
|
|
try:
|
|||
|
|
resp = dashscope.TextEmbedding.call(
|
|||
|
|
model=dashscope.TextEmbedding.Models.text_embedding_v4,
|
|||
|
|
input=text,
|
|||
|
|
dimension=1536
|
|||
|
|
)
|
|||
|
|
if resp.status_code == HTTPStatus.OK:
|
|||
|
|
# 阿里 text-embedding-v3 默认维度是 1024,v2 是 1536
|
|||
|
|
# 如果你的数据库是 1536 维,请使用 text_embedding_v2 或调整参数
|
|||
|
|
return resp.output['embeddings'][0]['embedding']
|
|||
|
|
else:
|
|||
|
|
print(f"Embedding API Error: {resp}")
|
|||
|
|
return None
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Embedding Exception: {e}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def map_and_ingest(self, start_url: str):
|
|||
|
|
"""
|
|||
|
|
V2 步骤1: 地图式扫描并入库
|
|||
|
|
"""
|
|||
|
|
print(f"[WorkFlow] Start mapping: {start_url}")
|
|||
|
|
|
|||
|
|
# 1. 在数据库注册任务
|
|||
|
|
task_info = crawler_sql_service.register_task(start_url)
|
|||
|
|
task_id = task_info['task_id']
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 2. 调用 Firecrawl Map
|
|||
|
|
map_result = self.firecrawl.map(start_url)
|
|||
|
|
|
|||
|
|
urls=[]
|
|||
|
|
for link in map_result.links:
|
|||
|
|
urls.append(link.url)
|
|||
|
|
|
|||
|
|
print(f"[WorkFlow] Found {len(urls)} links")
|
|||
|
|
|
|||
|
|
# 3. 批量入库 (状态为 pending)
|
|||
|
|
if urls:
|
|||
|
|
res = crawler_sql_service.add_urls(task_id, urls)
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"map_result": res
|
|||
|
|
}
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WorkFlow] Map Error: {e}")
|
|||
|
|
raise e
|
|||
|
|
|
|||
|
|
def process_task_queue(self, task_id: int, limit: int = 10):
|
|||
|
|
"""
|
|||
|
|
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
|||
|
|
"""
|
|||
|
|
# 1. 获取待处理 URL
|
|||
|
|
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
|||
|
|
urls = pending['urls']
|
|||
|
|
|
|||
|
|
if not urls:
|
|||
|
|
return {"msg": "No pending urls", "count": 0}
|
|||
|
|
|
|||
|
|
processed_count = 0
|
|||
|
|
|
|||
|
|
for url in urls:
|
|||
|
|
try:
|
|||
|
|
print(f"[WorkFlow] Processing: {url}")
|
|||
|
|
# 2. 单页抓取 (Markdown)
|
|||
|
|
scrape_res = self.firecrawl.scrape(url, formats=['markdown'], only_main_content=True)
|
|||
|
|
|
|||
|
|
content = scrape_res.markdown
|
|||
|
|
|
|||
|
|
metadata = scrape_res.metadata_dict
|
|||
|
|
title = metadata.get('title', url)
|
|||
|
|
|
|||
|
|
if not content:
|
|||
|
|
print(f"[WorkFlow] Skip empty content: {url}")
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
# 3. 切片
|
|||
|
|
chunks = self.splitter.split_text(content)
|
|||
|
|
|
|||
|
|
results_to_save = []
|
|||
|
|
|
|||
|
|
# 4. 逐个切片向量化 (可以优化为批量调用 embedding API)
|
|||
|
|
for idx, chunk_text in enumerate(chunks):
|
|||
|
|
vector = self._get_embedding(chunk_text)
|
|||
|
|
if vector:
|
|||
|
|
results_to_save.append({
|
|||
|
|
"source_url": url,
|
|||
|
|
"chunk_index": idx,
|
|||
|
|
"title": title,
|
|||
|
|
"content": chunk_text,
|
|||
|
|
"embedding": vector
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 5. 保存结果
|
|||
|
|
if results_to_save:
|
|||
|
|
crawler_sql_service.save_results(task_id, results_to_save)
|
|||
|
|
processed_count += 1
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[WorkFlow] Error processing {url}: {e}")
|
|||
|
|
# 实际生产中这里应该把 URL 状态改为 'failed'
|
|||
|
|
|
|||
|
|
return {"processed_urls": processed_count, "total_chunks": "dynamic"}
|
|||
|
|
|
|||
|
|
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
|||
|
|
"""
|
|||
|
|
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
|||
|
|
"""
|
|||
|
|
vector = self._get_embedding(query_text)
|
|||
|
|
if not vector:
|
|||
|
|
raise Exception("Failed to generate embedding for query")
|
|||
|
|
|
|||
|
|
return crawler_sql_service.search_knowledge(vector, task_id, limit)
|
|||
|
|
|
|||
|
|
# 单例模式
|
|||
|
|
workflow = AutomatedCrawler()
|