修改配置和response的细节
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
# backend/main.py
|
||||
from fastapi import FastAPI, APIRouter, BackgroundTasks
|
||||
from .service import crawler_sql_service
|
||||
from .workflow import workflow
|
||||
# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow
|
||||
from .services.crawler_sql_service import crawler_sql_service
|
||||
from .services.automated_crawler import workflow
|
||||
from .schemas import (
|
||||
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
|
||||
AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||
@@ -10,6 +11,11 @@ from .utils import make_response
|
||||
|
||||
app = FastAPI(title="Wiki Crawler API")
|
||||
|
||||
# ==========================================
|
||||
# 工具函数
|
||||
# ==========================================
|
||||
|
||||
|
||||
# ==========================================
|
||||
# V1 Router: 原始的底层接口 (Manual Control)
|
||||
# ==========================================
|
||||
@@ -18,8 +24,10 @@ router_v1 = APIRouter()
|
||||
@router_v1.post("/register")
|
||||
async def register(req: RegisterRequest):
|
||||
try:
|
||||
data = crawler_sql_service.register_task(req.url)
|
||||
return make_response(1, "Success", data)
|
||||
# Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'}
|
||||
res = crawler_sql_service.register_task(req.url)
|
||||
# 使用 pop 将 msg 提取出来作为响应的 msg,剩下的作为 data
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@@ -27,44 +35,43 @@ async def register(req: RegisterRequest):
|
||||
async def add_urls(req: AddUrlsRequest):
|
||||
try:
|
||||
urls = req.urls_obj["urls"]
|
||||
data = crawler_sql_service.add_urls(req.task_id, urls=urls)
|
||||
return make_response(1, "Success", data)
|
||||
res = crawler_sql_service.add_urls(req.task_id, urls=urls)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/pending_urls")
|
||||
async def pending_urls(req: PendingRequest):
|
||||
try:
|
||||
data = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
|
||||
msg = "Success" if data["urls"] else "Queue Empty"
|
||||
return make_response(1, msg, data)
|
||||
res = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
|
||||
# 即使队列为空,Service 也会返回 msg="Queue is empty"
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/save_results")
|
||||
async def save_results(req: SaveResultsRequest):
|
||||
try:
|
||||
data = crawler_sql_service.save_results(req.task_id, req.results)
|
||||
return make_response(1, "Success", data)
|
||||
res = crawler_sql_service.save_results(req.task_id, req.results)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/search")
|
||||
async def search_v1(req: SearchRequest):
|
||||
"""V1 搜索:需要客户端自己传向量"""
|
||||
"""V1 搜索:客户端手动传向量"""
|
||||
try:
|
||||
vector = req.query_embedding['vector']
|
||||
# 注意:这里需要确认你数据库的向量维度。TextEmbedding V3 可能是 1024,V2 是 1536。
|
||||
# 请根据你的 PGVector 设置进行匹配。
|
||||
if not vector:
|
||||
return make_response(2, "Vector is empty", None)
|
||||
|
||||
data = crawler_sql_service.search_knowledge(
|
||||
# Service 现在返回 {'results': [...], 'msg': 'Found ...'}
|
||||
res = crawler_sql_service.search_knowledge(
|
||||
query_embedding=vector,
|
||||
task_id=req.task_id,
|
||||
limit=req.limit
|
||||
)
|
||||
return make_response(1, "Search Done", data)
|
||||
return make_response(1, res.pop("msg", "Search Done"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@@ -75,16 +82,14 @@ async def search_v1(req: SearchRequest):
|
||||
router_v2 = APIRouter()
|
||||
|
||||
@router_v2.post("/auto/map")
|
||||
async def auto_map(req: AutoMapRequest, background_tasks: BackgroundTasks):
|
||||
async def auto_map(req: AutoMapRequest):
|
||||
"""
|
||||
[异步] 输入首页 URL,自动调用 Firecrawl Map 并入库
|
||||
[同步] 输入首页 URL,自动调用 Firecrawl Map 并入库
|
||||
"""
|
||||
# 也可以放入 background_tasks,但 map 通常比较快,这里演示同步返回任务ID
|
||||
try:
|
||||
# 为了不阻塞主线程,如果 map 很慢,建议放入 background_tasks
|
||||
# 这里为了能立刻看到 task_id,先同步调用 (Firecrawl Map 比较快)
|
||||
data = workflow.map_and_ingest(req.url)
|
||||
return make_response(1, "Mapping Started", data)
|
||||
# Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...}
|
||||
res = workflow.map_and_ingest(req.url)
|
||||
return make_response(1, res.pop("msg", "Mapping Started"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@@ -93,9 +98,14 @@ async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTask
|
||||
"""
|
||||
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
|
||||
"""
|
||||
try:
|
||||
# 将耗时操作放入后台任务
|
||||
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
|
||||
return make_response(1, "Processing started in background", {"task_id": req.task_id})
|
||||
|
||||
# 因为是后台任务,无法立即获取 Service 的返回值 msg,只能返回通用消息
|
||||
return make_response(1, "Background processing started", {"task_id": req.task_id})
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v2.post("/search")
|
||||
async def search_v2(req: TextSearchRequest):
|
||||
@@ -103,8 +113,9 @@ async def search_v2(req: TextSearchRequest):
|
||||
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
|
||||
"""
|
||||
try:
|
||||
data = workflow.search_with_embedding(req.query, req.task_id, req.limit)
|
||||
return make_response(1, "Search Success", data)
|
||||
# Workflow 返回 {'results': [...], 'msg': '...'}
|
||||
res = workflow.search_with_embedding(req.query, req.task_id, req.limit)
|
||||
return make_response(1, res.pop("msg", "Search Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, f"Search Failed: {str(e)}")
|
||||
|
||||
|
||||
201
backend/services/automated_crawler.py
Normal file
201
backend/services/automated_crawler.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from ..config import settings
|
||||
from .crawler_sql_service import crawler_sql_service
|
||||
|
||||
# 初始化配置
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
class AutomatedCrawler:
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||
)
|
||||
|
||||
def _get_embedding(self, text: str):
|
||||
"""内部方法:调用 Dashscope 生成向量"""
|
||||
# 注意:此方法是内部辅助,出错返回 None,由调用方处理状态
|
||||
embedding = None
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
|
||||
input=text,
|
||||
dimension=1536
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
embedding = resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"Embedding API Error: {resp}")
|
||||
except Exception as e:
|
||||
print(f"Embedding Exception: {e}")
|
||||
|
||||
return embedding
|
||||
|
||||
def map_and_ingest(self, start_url: str):
|
||||
"""
|
||||
V2 步骤1: 地图式扫描并入库
|
||||
"""
|
||||
print(f"[WorkFlow] Start mapping: {start_url}")
|
||||
result = {}
|
||||
|
||||
try:
|
||||
# 1. 在数据库注册任务
|
||||
task_info = crawler_sql_service.register_task(start_url)
|
||||
task_id = task_info['task_id']
|
||||
is_new_task = task_info['is_new_task']
|
||||
|
||||
# 2. 调用 Firecrawl Map
|
||||
if is_new_task:
|
||||
map_result = self.firecrawl.map(start_url)
|
||||
|
||||
urls = []
|
||||
# 兼容 firecrawl sdk 不同版本的返回结构
|
||||
# 如果 map_result 是对象且有 links 属性
|
||||
if hasattr(map_result, 'links'):
|
||||
for link in map_result.links:
|
||||
# 假设 link 是对象或字典,视具体 SDK 版本而定
|
||||
# 如果 link 是字符串直接 append
|
||||
if isinstance(link, str):
|
||||
urls.append(link)
|
||||
else:
|
||||
urls.append(getattr(link, 'url', str(link)))
|
||||
# 如果是字典
|
||||
elif isinstance(map_result, dict):
|
||||
urls = map_result.get('links', [])
|
||||
|
||||
print(f"[WorkFlow] Found {len(urls)} links")
|
||||
|
||||
# 3. 批量入库
|
||||
res = {"msg": "No urls found to add"}
|
||||
if urls:
|
||||
res = crawler_sql_service.add_urls(task_id, urls)
|
||||
|
||||
result = {
|
||||
"msg": "Task successfully mapped and URLs added",
|
||||
"task_id": task_id,
|
||||
"is_new_task": is_new_task,
|
||||
"url_count": len(urls),
|
||||
"map_detail": res
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"msg": "Task already exists, skipped mapping",
|
||||
"task_id": task_id,
|
||||
"is_new_task": False,
|
||||
"url_count": 0,
|
||||
"map_detail": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Map Error: {e}")
|
||||
# 向上抛出异常,由 main.py 捕获并返回错误 Response
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
def process_task_queue(self, task_id: int, limit: int = 10):
|
||||
"""
|
||||
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
||||
"""
|
||||
processed_count = 0
|
||||
total_chunks_saved = 0
|
||||
result = {}
|
||||
|
||||
# 1. 获取待处理 URL
|
||||
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
||||
urls = pending['urls']
|
||||
|
||||
if not urls:
|
||||
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
|
||||
else:
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[WorkFlow] Processing: {url}")
|
||||
# 2. 单页抓取
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url,
|
||||
params={'formats': ['markdown'], 'onlyMainContent': True}
|
||||
)
|
||||
|
||||
# 兼容 SDK 返回类型 (对象或字典)
|
||||
content = ""
|
||||
metadata = {}
|
||||
|
||||
if isinstance(scrape_res, dict):
|
||||
content = scrape_res.get('markdown', '')
|
||||
metadata = scrape_res.get('metadata', {})
|
||||
else:
|
||||
content = getattr(scrape_res, 'markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {})
|
||||
if not metadata and hasattr(scrape_res, 'metadata_dict'):
|
||||
metadata = scrape_res.metadata_dict
|
||||
|
||||
title = metadata.get('title', url)
|
||||
|
||||
if not content:
|
||||
print(f"[WorkFlow] Skip empty content: {url}")
|
||||
continue
|
||||
|
||||
# 3. 切片
|
||||
chunks = self.splitter.split_text(content)
|
||||
results_to_save = []
|
||||
|
||||
# 4. 向量化
|
||||
for idx, chunk_text in enumerate(chunks):
|
||||
vector = self._get_embedding(chunk_text)
|
||||
if vector:
|
||||
results_to_save.append({
|
||||
"source_url": url,
|
||||
"chunk_index": idx,
|
||||
"title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": vector
|
||||
})
|
||||
|
||||
# 5. 保存
|
||||
if results_to_save:
|
||||
save_res = crawler_sql_service.save_results(task_id, results_to_save)
|
||||
processed_count += 1
|
||||
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Error processing {url}: {e}")
|
||||
# 此处不抛出异常,以免打断整个批次的循环
|
||||
# 实际生产建议在这里调用 service 将 url 标记为 failed
|
||||
|
||||
result = {
|
||||
"msg": f"Batch processing complete. URLs processed: {processed_count}",
|
||||
"processed_urls": processed_count,
|
||||
"total_chunks_saved": total_chunks_saved
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. 获取向量
|
||||
vector = self._get_embedding(query_text)
|
||||
|
||||
if not vector:
|
||||
result = {
|
||||
"msg": "Failed to generate embedding for query",
|
||||
"results": []
|
||||
}
|
||||
else:
|
||||
# 2. 执行搜索
|
||||
# search_knowledge 现在已经返回带 msg 的字典了
|
||||
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
|
||||
|
||||
return result
|
||||
|
||||
# 单例模式
|
||||
workflow = AutomatedCrawler()
|
||||
@@ -1,7 +1,6 @@
|
||||
# service.py
|
||||
from sqlalchemy import select, insert, update, delete, and_
|
||||
from .database import db_instance
|
||||
from .utils import normalize_url
|
||||
from sqlalchemy import select, insert, update, and_
|
||||
from ..database import db_instance
|
||||
from ..utils import normalize_url
|
||||
|
||||
class CrawlerSqlService:
|
||||
def __init__(self):
|
||||
@@ -10,18 +9,30 @@ class CrawlerSqlService:
|
||||
def register_task(self, url: str):
|
||||
"""完全使用库 API 实现的注册"""
|
||||
clean_url = normalize_url(url)
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 使用 select() API
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
return {"task_id": existing[0], "is_new_task": False}
|
||||
|
||||
result = {
|
||||
"task_id": existing[0],
|
||||
"is_new_task": False,
|
||||
"msg": "Task already exists"
|
||||
}
|
||||
else:
|
||||
# 使用 insert() API
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
return {"task_id": new_task[0], "is_new_task": True}
|
||||
result = {
|
||||
"task_id": new_task[0],
|
||||
"is_new_task": True,
|
||||
"msg": "New task created successfully"
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
"""通用 API 实现的批量添加(含详细返回)"""
|
||||
@@ -31,7 +42,7 @@ class CrawlerSqlService:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
# 检查队列中是否已存在该 URL (通用写法)
|
||||
# 检查队列中是否已存在该 URL
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
@@ -47,10 +58,20 @@ class CrawlerSqlService:
|
||||
except Exception:
|
||||
failed_urls.append(clean_url)
|
||||
|
||||
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
|
||||
# 构造返回消息
|
||||
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
|
||||
|
||||
return {
|
||||
"success_urls": success_urls,
|
||||
"skipped_urls": skipped_urls,
|
||||
"failed_urls": failed_urls,
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
"""原子锁定 API 实现"""
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
@@ -63,17 +84,20 @@ class CrawlerSqlService:
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing')
|
||||
conn.execute(upd)
|
||||
return {"urls": urls}
|
||||
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
|
||||
else:
|
||||
result = {"urls": [], "msg": "Queue is empty"}
|
||||
|
||||
return result
|
||||
|
||||
def save_results(self, task_id: int, results: list):
|
||||
"""
|
||||
保存同一 URL 的多个切片。
|
||||
返回:该 URL 下切片的详细处理统计及页面更新状态。
|
||||
"""
|
||||
if not results:
|
||||
return {"msg": "No data provided"}
|
||||
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
|
||||
|
||||
# 1. 基础信息提取 (假设 results 里的 source_url 都是一致的)
|
||||
# 1. 基础信息提取
|
||||
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
|
||||
target_url = normalize_url(first_item.get('source_url'))
|
||||
|
||||
@@ -84,7 +108,7 @@ class CrawlerSqlService:
|
||||
is_page_update = False
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 2. 判断该 URL 是否已经有切片存在 (以此判定是否为“页面更新”)
|
||||
# 2. 判断该 URL 是否已经有切片存在
|
||||
check_page_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
|
||||
).limit(1)
|
||||
@@ -97,7 +121,7 @@ class CrawlerSqlService:
|
||||
c_idx = data.get('chunk_index')
|
||||
|
||||
try:
|
||||
# 检查具体某个 index 的切片是否存在
|
||||
# 检查切片是否存在
|
||||
find_chunk_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(
|
||||
self.db.chunks.c.task_id == task_id,
|
||||
@@ -108,7 +132,7 @@ class CrawlerSqlService:
|
||||
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
|
||||
|
||||
if existing_chunk:
|
||||
# 覆盖更新现有切片
|
||||
# 覆盖更新
|
||||
upd_stmt = update(self.db.chunks).where(
|
||||
self.db.chunks.c.id == existing_chunk[0]
|
||||
).values(
|
||||
@@ -135,16 +159,19 @@ class CrawlerSqlService:
|
||||
print(f"Chunk {c_idx} failed: {e}")
|
||||
failed_chunks.append(c_idx)
|
||||
|
||||
# 4. 最终更新队列状态
|
||||
# 4. 更新队列状态
|
||||
conn.execute(
|
||||
update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
|
||||
).values(status='completed')
|
||||
)
|
||||
|
||||
# 构造返回
|
||||
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
|
||||
|
||||
return {
|
||||
"source_url": target_url,
|
||||
"is_page_update": is_page_update, # 标志:此页面此前是否有过内容
|
||||
"is_page_update": is_page_update,
|
||||
"detail": {
|
||||
"inserted_chunk_indexes": inserted_chunks,
|
||||
"updated_chunk_indexes": updated_chunks,
|
||||
@@ -154,20 +181,18 @@ class CrawlerSqlService:
|
||||
"inserted": len(inserted_chunks),
|
||||
"updated": len(updated_chunks),
|
||||
"failed": len(failed_chunks)
|
||||
},
|
||||
"msg": msg
|
||||
}
|
||||
}
|
||||
|
||||
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
高性能向量搜索方法
|
||||
:param query_embedding: 问题的向量
|
||||
:param task_id: 可选的任务ID,不传则搜全表
|
||||
:param limit: 返回结果数量
|
||||
"""
|
||||
|
||||
results = []
|
||||
msg = ""
|
||||
|
||||
with self.db.engine.connect() as conn:
|
||||
# 1. 选择需要的字段
|
||||
# 我们同时返回 task_id,方便在全库搜索时知道来源哪个任务
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id,
|
||||
self.db.chunks.c.source_url,
|
||||
@@ -176,20 +201,15 @@ class CrawlerSqlService:
|
||||
self.db.chunks.c.chunk_index
|
||||
)
|
||||
|
||||
# 2. 动态添加过滤条件
|
||||
if task_id is not None:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
# 3. 按余弦距离排序(1 - 余弦相似度)
|
||||
# 距离越小,相似度越高
|
||||
stmt = stmt.order_by(
|
||||
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||
).limit(limit)
|
||||
|
||||
# 4. 执行并解析结果
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
|
||||
results = []
|
||||
for r in rows:
|
||||
results.append({
|
||||
"task_id": r[0],
|
||||
@@ -199,7 +219,12 @@ class CrawlerSqlService:
|
||||
"chunk_index": r[4]
|
||||
})
|
||||
|
||||
return results
|
||||
if results:
|
||||
msg = f"Found {len(results)} matches"
|
||||
else:
|
||||
msg = "No matching content found"
|
||||
|
||||
return {"results": results, "msg": msg}
|
||||
|
||||
|
||||
crawler_sql_service = CrawlerSqlService()
|
||||
@@ -29,5 +29,11 @@ def normalize_url(url: str) -> str:
|
||||
if not path: path = ""
|
||||
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
|
||||
|
||||
def make_response(code: int, msg: str, data: any = None):
|
||||
def make_response(code: int, msg: str = "Success", data: any = None):
|
||||
"""
|
||||
统一响应格式
|
||||
:param code: 1 成功, 0 失败, 其他自定义
|
||||
:param msg: 提示信息
|
||||
:param data: 返回数据
|
||||
"""
|
||||
return {"code": code, "msg": msg, "data": data}
|
||||
@@ -1,140 +0,0 @@
|
||||
# backend/workflow.py
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from .config import settings
|
||||
from .service import crawler_sql_service
|
||||
|
||||
# 初始化配置
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
class AutomatedCrawler:
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
# 文本切分器:每块500字符,重叠100字符,保证语义连贯
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||
)
|
||||
|
||||
def _get_embedding(self, text: str):
|
||||
"""调用 Dashscope 生成向量"""
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v4,
|
||||
input=text,
|
||||
dimension=1536
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
# 阿里 text-embedding-v3 默认维度是 1024,v2 是 1536
|
||||
# 如果你的数据库是 1536 维,请使用 text_embedding_v2 或调整参数
|
||||
return resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"Embedding API Error: {resp}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Embedding Exception: {e}")
|
||||
return None
|
||||
|
||||
def map_and_ingest(self, start_url: str):
|
||||
"""
|
||||
V2 步骤1: 地图式扫描并入库
|
||||
"""
|
||||
print(f"[WorkFlow] Start mapping: {start_url}")
|
||||
|
||||
# 1. 在数据库注册任务
|
||||
task_info = crawler_sql_service.register_task(start_url)
|
||||
task_id = task_info['task_id']
|
||||
|
||||
try:
|
||||
# 2. 调用 Firecrawl Map
|
||||
map_result = self.firecrawl.map(start_url)
|
||||
|
||||
urls=[]
|
||||
for link in map_result.links:
|
||||
urls.append(link.url)
|
||||
|
||||
print(f"[WorkFlow] Found {len(urls)} links")
|
||||
|
||||
# 3. 批量入库 (状态为 pending)
|
||||
if urls:
|
||||
res = crawler_sql_service.add_urls(task_id, urls)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"map_result": res
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Map Error: {e}")
|
||||
raise e
|
||||
|
||||
def process_task_queue(self, task_id: int, limit: int = 10):
|
||||
"""
|
||||
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
||||
"""
|
||||
# 1. 获取待处理 URL
|
||||
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
||||
urls = pending['urls']
|
||||
|
||||
if not urls:
|
||||
return {"msg": "No pending urls", "count": 0}
|
||||
|
||||
processed_count = 0
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[WorkFlow] Processing: {url}")
|
||||
# 2. 单页抓取 (Markdown)
|
||||
scrape_res = self.firecrawl.scrape(url, formats=['markdown'], only_main_content=True)
|
||||
|
||||
content = scrape_res.markdown
|
||||
|
||||
metadata = scrape_res.metadata_dict
|
||||
title = metadata.get('title', url)
|
||||
|
||||
if not content:
|
||||
print(f"[WorkFlow] Skip empty content: {url}")
|
||||
continue
|
||||
|
||||
# 3. 切片
|
||||
chunks = self.splitter.split_text(content)
|
||||
|
||||
results_to_save = []
|
||||
|
||||
# 4. 逐个切片向量化 (可以优化为批量调用 embedding API)
|
||||
for idx, chunk_text in enumerate(chunks):
|
||||
vector = self._get_embedding(chunk_text)
|
||||
if vector:
|
||||
results_to_save.append({
|
||||
"source_url": url,
|
||||
"chunk_index": idx,
|
||||
"title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": vector
|
||||
})
|
||||
|
||||
# 5. 保存结果
|
||||
if results_to_save:
|
||||
crawler_sql_service.save_results(task_id, results_to_save)
|
||||
processed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Error processing {url}: {e}")
|
||||
# 实际生产中这里应该把 URL 状态改为 'failed'
|
||||
|
||||
return {"processed_urls": processed_count, "total_chunks": "dynamic"}
|
||||
|
||||
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
||||
"""
|
||||
vector = self._get_embedding(query_text)
|
||||
if not vector:
|
||||
raise Exception("Failed to generate embedding for query")
|
||||
|
||||
return crawler_sql_service.search_knowledge(vector, task_id, limit)
|
||||
|
||||
# 单例模式
|
||||
workflow = AutomatedCrawler()
|
||||
Reference in New Issue
Block a user