修改配置和response的细节
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
# backend/main.py
|
# backend/main.py
|
||||||
from fastapi import FastAPI, APIRouter, BackgroundTasks
|
from fastapi import FastAPI, APIRouter, BackgroundTasks
|
||||||
from .service import crawler_sql_service
|
# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow
|
||||||
from .workflow import workflow
|
from .services.crawler_sql_service import crawler_sql_service
|
||||||
|
from .services.automated_crawler import workflow
|
||||||
from .schemas import (
|
from .schemas import (
|
||||||
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
|
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
|
||||||
AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||||
@@ -10,6 +11,11 @@ from .utils import make_response
|
|||||||
|
|
||||||
app = FastAPI(title="Wiki Crawler API")
|
app = FastAPI(title="Wiki Crawler API")
|
||||||
|
|
||||||
|
# ==========================================
|
||||||
|
# 工具函数
|
||||||
|
# ==========================================
|
||||||
|
|
||||||
|
|
||||||
# ==========================================
|
# ==========================================
|
||||||
# V1 Router: 原始的底层接口 (Manual Control)
|
# V1 Router: 原始的底层接口 (Manual Control)
|
||||||
# ==========================================
|
# ==========================================
|
||||||
@@ -18,8 +24,10 @@ router_v1 = APIRouter()
|
|||||||
@router_v1.post("/register")
|
@router_v1.post("/register")
|
||||||
async def register(req: RegisterRequest):
|
async def register(req: RegisterRequest):
|
||||||
try:
|
try:
|
||||||
data = crawler_sql_service.register_task(req.url)
|
# Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'}
|
||||||
return make_response(1, "Success", data)
|
res = crawler_sql_service.register_task(req.url)
|
||||||
|
# 使用 pop 将 msg 提取出来作为响应的 msg,剩下的作为 data
|
||||||
|
return make_response(1, res.pop("msg", "Success"), res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@@ -27,44 +35,43 @@ async def register(req: RegisterRequest):
|
|||||||
async def add_urls(req: AddUrlsRequest):
|
async def add_urls(req: AddUrlsRequest):
|
||||||
try:
|
try:
|
||||||
urls = req.urls_obj["urls"]
|
urls = req.urls_obj["urls"]
|
||||||
data = crawler_sql_service.add_urls(req.task_id, urls=urls)
|
res = crawler_sql_service.add_urls(req.task_id, urls=urls)
|
||||||
return make_response(1, "Success", data)
|
return make_response(1, res.pop("msg", "Success"), res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@router_v1.post("/pending_urls")
|
@router_v1.post("/pending_urls")
|
||||||
async def pending_urls(req: PendingRequest):
|
async def pending_urls(req: PendingRequest):
|
||||||
try:
|
try:
|
||||||
data = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
|
res = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
|
||||||
msg = "Success" if data["urls"] else "Queue Empty"
|
# 即使队列为空,Service 也会返回 msg="Queue is empty"
|
||||||
return make_response(1, msg, data)
|
return make_response(1, res.pop("msg", "Success"), res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@router_v1.post("/save_results")
|
@router_v1.post("/save_results")
|
||||||
async def save_results(req: SaveResultsRequest):
|
async def save_results(req: SaveResultsRequest):
|
||||||
try:
|
try:
|
||||||
data = crawler_sql_service.save_results(req.task_id, req.results)
|
res = crawler_sql_service.save_results(req.task_id, req.results)
|
||||||
return make_response(1, "Success", data)
|
return make_response(1, res.pop("msg", "Success"), res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@router_v1.post("/search")
|
@router_v1.post("/search")
|
||||||
async def search_v1(req: SearchRequest):
|
async def search_v1(req: SearchRequest):
|
||||||
"""V1 搜索:需要客户端自己传向量"""
|
"""V1 搜索:客户端手动传向量"""
|
||||||
try:
|
try:
|
||||||
vector = req.query_embedding['vector']
|
vector = req.query_embedding['vector']
|
||||||
# 注意:这里需要确认你数据库的向量维度。TextEmbedding V3 可能是 1024,V2 是 1536。
|
|
||||||
# 请根据你的 PGVector 设置进行匹配。
|
|
||||||
if not vector:
|
if not vector:
|
||||||
return make_response(2, "Vector is empty", None)
|
return make_response(2, "Vector is empty", None)
|
||||||
|
|
||||||
data = crawler_sql_service.search_knowledge(
|
# Service 现在返回 {'results': [...], 'msg': 'Found ...'}
|
||||||
|
res = crawler_sql_service.search_knowledge(
|
||||||
query_embedding=vector,
|
query_embedding=vector,
|
||||||
task_id=req.task_id,
|
task_id=req.task_id,
|
||||||
limit=req.limit
|
limit=req.limit
|
||||||
)
|
)
|
||||||
return make_response(1, "Search Done", data)
|
return make_response(1, res.pop("msg", "Search Done"), res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@@ -75,16 +82,14 @@ async def search_v1(req: SearchRequest):
|
|||||||
router_v2 = APIRouter()
|
router_v2 = APIRouter()
|
||||||
|
|
||||||
@router_v2.post("/auto/map")
|
@router_v2.post("/auto/map")
|
||||||
async def auto_map(req: AutoMapRequest, background_tasks: BackgroundTasks):
|
async def auto_map(req: AutoMapRequest):
|
||||||
"""
|
"""
|
||||||
[异步] 输入首页 URL,自动调用 Firecrawl Map 并入库
|
[同步] 输入首页 URL,自动调用 Firecrawl Map 并入库
|
||||||
"""
|
"""
|
||||||
# 也可以放入 background_tasks,但 map 通常比较快,这里演示同步返回任务ID
|
|
||||||
try:
|
try:
|
||||||
# 为了不阻塞主线程,如果 map 很慢,建议放入 background_tasks
|
# Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...}
|
||||||
# 这里为了能立刻看到 task_id,先同步调用 (Firecrawl Map 比较快)
|
res = workflow.map_and_ingest(req.url)
|
||||||
data = workflow.map_and_ingest(req.url)
|
return make_response(1, res.pop("msg", "Mapping Started"), res)
|
||||||
return make_response(1, "Mapping Started", data)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@@ -93,9 +98,14 @@ async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTask
|
|||||||
"""
|
"""
|
||||||
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
|
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
|
||||||
"""
|
"""
|
||||||
# 将耗时操作放入后台任务
|
try:
|
||||||
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
|
# 将耗时操作放入后台任务
|
||||||
return make_response(1, "Processing started in background", {"task_id": req.task_id})
|
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
|
||||||
|
|
||||||
|
# 因为是后台任务,无法立即获取 Service 的返回值 msg,只能返回通用消息
|
||||||
|
return make_response(1, "Background processing started", {"task_id": req.task_id})
|
||||||
|
except Exception as e:
|
||||||
|
return make_response(0, str(e))
|
||||||
|
|
||||||
@router_v2.post("/search")
|
@router_v2.post("/search")
|
||||||
async def search_v2(req: TextSearchRequest):
|
async def search_v2(req: TextSearchRequest):
|
||||||
@@ -103,8 +113,9 @@ async def search_v2(req: TextSearchRequest):
|
|||||||
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
|
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
data = workflow.search_with_embedding(req.query, req.task_id, req.limit)
|
# Workflow 返回 {'results': [...], 'msg': '...'}
|
||||||
return make_response(1, "Search Success", data)
|
res = workflow.search_with_embedding(req.query, req.task_id, req.limit)
|
||||||
|
return make_response(1, res.pop("msg", "Search Success"), res)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, f"Search Failed: {str(e)}")
|
return make_response(0, f"Search Failed: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
201
backend/services/automated_crawler.py
Normal file
201
backend/services/automated_crawler.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import dashscope
|
||||||
|
from http import HTTPStatus
|
||||||
|
from firecrawl import FirecrawlApp
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
from ..config import settings
|
||||||
|
from .crawler_sql_service import crawler_sql_service
|
||||||
|
|
||||||
|
# 初始化配置
|
||||||
|
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||||
|
|
||||||
|
class AutomatedCrawler:
|
||||||
|
def __init__(self):
|
||||||
|
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||||
|
self.splitter = RecursiveCharacterTextSplitter(
|
||||||
|
chunk_size=500,
|
||||||
|
chunk_overlap=100,
|
||||||
|
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_embedding(self, text: str):
|
||||||
|
"""内部方法:调用 Dashscope 生成向量"""
|
||||||
|
# 注意:此方法是内部辅助,出错返回 None,由调用方处理状态
|
||||||
|
embedding = None
|
||||||
|
try:
|
||||||
|
resp = dashscope.TextEmbedding.call(
|
||||||
|
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
|
||||||
|
input=text,
|
||||||
|
dimension=1536
|
||||||
|
)
|
||||||
|
if resp.status_code == HTTPStatus.OK:
|
||||||
|
embedding = resp.output['embeddings'][0]['embedding']
|
||||||
|
else:
|
||||||
|
print(f"Embedding API Error: {resp}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Embedding Exception: {e}")
|
||||||
|
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
def map_and_ingest(self, start_url: str):
|
||||||
|
"""
|
||||||
|
V2 步骤1: 地图式扫描并入库
|
||||||
|
"""
|
||||||
|
print(f"[WorkFlow] Start mapping: {start_url}")
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 1. 在数据库注册任务
|
||||||
|
task_info = crawler_sql_service.register_task(start_url)
|
||||||
|
task_id = task_info['task_id']
|
||||||
|
is_new_task = task_info['is_new_task']
|
||||||
|
|
||||||
|
# 2. 调用 Firecrawl Map
|
||||||
|
if is_new_task:
|
||||||
|
map_result = self.firecrawl.map(start_url)
|
||||||
|
|
||||||
|
urls = []
|
||||||
|
# 兼容 firecrawl sdk 不同版本的返回结构
|
||||||
|
# 如果 map_result 是对象且有 links 属性
|
||||||
|
if hasattr(map_result, 'links'):
|
||||||
|
for link in map_result.links:
|
||||||
|
# 假设 link 是对象或字典,视具体 SDK 版本而定
|
||||||
|
# 如果 link 是字符串直接 append
|
||||||
|
if isinstance(link, str):
|
||||||
|
urls.append(link)
|
||||||
|
else:
|
||||||
|
urls.append(getattr(link, 'url', str(link)))
|
||||||
|
# 如果是字典
|
||||||
|
elif isinstance(map_result, dict):
|
||||||
|
urls = map_result.get('links', [])
|
||||||
|
|
||||||
|
print(f"[WorkFlow] Found {len(urls)} links")
|
||||||
|
|
||||||
|
# 3. 批量入库
|
||||||
|
res = {"msg": "No urls found to add"}
|
||||||
|
if urls:
|
||||||
|
res = crawler_sql_service.add_urls(task_id, urls)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"msg": "Task successfully mapped and URLs added",
|
||||||
|
"task_id": task_id,
|
||||||
|
"is_new_task": is_new_task,
|
||||||
|
"url_count": len(urls),
|
||||||
|
"map_detail": res
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
result = {
|
||||||
|
"msg": "Task already exists, skipped mapping",
|
||||||
|
"task_id": task_id,
|
||||||
|
"is_new_task": False,
|
||||||
|
"url_count": 0,
|
||||||
|
"map_detail": {}
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WorkFlow] Map Error: {e}")
|
||||||
|
# 向上抛出异常,由 main.py 捕获并返回错误 Response
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def process_task_queue(self, task_id: int, limit: int = 10):
|
||||||
|
"""
|
||||||
|
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
||||||
|
"""
|
||||||
|
processed_count = 0
|
||||||
|
total_chunks_saved = 0
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 1. 获取待处理 URL
|
||||||
|
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
||||||
|
urls = pending['urls']
|
||||||
|
|
||||||
|
if not urls:
|
||||||
|
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
|
||||||
|
else:
|
||||||
|
for url in urls:
|
||||||
|
try:
|
||||||
|
print(f"[WorkFlow] Processing: {url}")
|
||||||
|
# 2. 单页抓取
|
||||||
|
scrape_res = self.firecrawl.scrape(
|
||||||
|
url,
|
||||||
|
params={'formats': ['markdown'], 'onlyMainContent': True}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 兼容 SDK 返回类型 (对象或字典)
|
||||||
|
content = ""
|
||||||
|
metadata = {}
|
||||||
|
|
||||||
|
if isinstance(scrape_res, dict):
|
||||||
|
content = scrape_res.get('markdown', '')
|
||||||
|
metadata = scrape_res.get('metadata', {})
|
||||||
|
else:
|
||||||
|
content = getattr(scrape_res, 'markdown', '')
|
||||||
|
metadata = getattr(scrape_res, 'metadata', {})
|
||||||
|
if not metadata and hasattr(scrape_res, 'metadata_dict'):
|
||||||
|
metadata = scrape_res.metadata_dict
|
||||||
|
|
||||||
|
title = metadata.get('title', url)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
print(f"[WorkFlow] Skip empty content: {url}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 3. 切片
|
||||||
|
chunks = self.splitter.split_text(content)
|
||||||
|
results_to_save = []
|
||||||
|
|
||||||
|
# 4. 向量化
|
||||||
|
for idx, chunk_text in enumerate(chunks):
|
||||||
|
vector = self._get_embedding(chunk_text)
|
||||||
|
if vector:
|
||||||
|
results_to_save.append({
|
||||||
|
"source_url": url,
|
||||||
|
"chunk_index": idx,
|
||||||
|
"title": title,
|
||||||
|
"content": chunk_text,
|
||||||
|
"embedding": vector
|
||||||
|
})
|
||||||
|
|
||||||
|
# 5. 保存
|
||||||
|
if results_to_save:
|
||||||
|
save_res = crawler_sql_service.save_results(task_id, results_to_save)
|
||||||
|
processed_count += 1
|
||||||
|
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[WorkFlow] Error processing {url}: {e}")
|
||||||
|
# 此处不抛出异常,以免打断整个批次的循环
|
||||||
|
# 实际生产建议在这里调用 service 将 url 标记为 failed
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"msg": f"Batch processing complete. URLs processed: {processed_count}",
|
||||||
|
"processed_urls": processed_count,
|
||||||
|
"total_chunks_saved": total_chunks_saved
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
||||||
|
"""
|
||||||
|
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
# 1. 获取向量
|
||||||
|
vector = self._get_embedding(query_text)
|
||||||
|
|
||||||
|
if not vector:
|
||||||
|
result = {
|
||||||
|
"msg": "Failed to generate embedding for query",
|
||||||
|
"results": []
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 2. 执行搜索
|
||||||
|
# search_knowledge 现在已经返回带 msg 的字典了
|
||||||
|
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# 单例模式
|
||||||
|
workflow = AutomatedCrawler()
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
# service.py
|
from sqlalchemy import select, insert, update, and_
|
||||||
from sqlalchemy import select, insert, update, delete, and_
|
from ..database import db_instance
|
||||||
from .database import db_instance
|
from ..utils import normalize_url
|
||||||
from .utils import normalize_url
|
|
||||||
|
|
||||||
class CrawlerSqlService:
|
class CrawlerSqlService:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -10,18 +9,30 @@ class CrawlerSqlService:
|
|||||||
def register_task(self, url: str):
|
def register_task(self, url: str):
|
||||||
"""完全使用库 API 实现的注册"""
|
"""完全使用库 API 实现的注册"""
|
||||||
clean_url = normalize_url(url)
|
clean_url = normalize_url(url)
|
||||||
|
result = {}
|
||||||
|
|
||||||
with self.db.engine.begin() as conn:
|
with self.db.engine.begin() as conn:
|
||||||
# 使用 select() API
|
# 使用 select() API
|
||||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||||
existing = conn.execute(query).fetchone()
|
existing = conn.execute(query).fetchone()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
return {"task_id": existing[0], "is_new_task": False}
|
result = {
|
||||||
|
"task_id": existing[0],
|
||||||
# 使用 insert() API
|
"is_new_task": False,
|
||||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
"msg": "Task already exists"
|
||||||
new_task = conn.execute(stmt).fetchone()
|
}
|
||||||
return {"task_id": new_task[0], "is_new_task": True}
|
else:
|
||||||
|
# 使用 insert() API
|
||||||
|
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||||
|
new_task = conn.execute(stmt).fetchone()
|
||||||
|
result = {
|
||||||
|
"task_id": new_task[0],
|
||||||
|
"is_new_task": True,
|
||||||
|
"msg": "New task created successfully"
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def add_urls(self, task_id: int, urls: list[str]):
|
def add_urls(self, task_id: int, urls: list[str]):
|
||||||
"""通用 API 实现的批量添加(含详细返回)"""
|
"""通用 API 实现的批量添加(含详细返回)"""
|
||||||
@@ -31,7 +42,7 @@ class CrawlerSqlService:
|
|||||||
for url in urls:
|
for url in urls:
|
||||||
clean_url = normalize_url(url)
|
clean_url = normalize_url(url)
|
||||||
try:
|
try:
|
||||||
# 检查队列中是否已存在该 URL (通用写法)
|
# 检查队列中是否已存在该 URL
|
||||||
check_q = select(self.db.queue).where(
|
check_q = select(self.db.queue).where(
|
||||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||||
)
|
)
|
||||||
@@ -46,11 +57,21 @@ class CrawlerSqlService:
|
|||||||
success_urls.append(clean_url)
|
success_urls.append(clean_url)
|
||||||
except Exception:
|
except Exception:
|
||||||
failed_urls.append(clean_url)
|
failed_urls.append(clean_url)
|
||||||
|
|
||||||
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
|
# 构造返回消息
|
||||||
|
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success_urls": success_urls,
|
||||||
|
"skipped_urls": skipped_urls,
|
||||||
|
"failed_urls": failed_urls,
|
||||||
|
"msg": msg
|
||||||
|
}
|
||||||
|
|
||||||
def get_pending_urls(self, task_id: int, limit: int):
|
def get_pending_urls(self, task_id: int, limit: int):
|
||||||
"""原子锁定 API 实现"""
|
"""原子锁定 API 实现"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
with self.db.engine.begin() as conn:
|
with self.db.engine.begin() as conn:
|
||||||
query = select(self.db.queue.c.url).where(
|
query = select(self.db.queue.c.url).where(
|
||||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||||
@@ -63,17 +84,20 @@ class CrawlerSqlService:
|
|||||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||||
).values(status='processing')
|
).values(status='processing')
|
||||||
conn.execute(upd)
|
conn.execute(upd)
|
||||||
return {"urls": urls}
|
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
|
||||||
|
else:
|
||||||
|
result = {"urls": [], "msg": "Queue is empty"}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def save_results(self, task_id: int, results: list):
|
def save_results(self, task_id: int, results: list):
|
||||||
"""
|
"""
|
||||||
保存同一 URL 的多个切片。
|
保存同一 URL 的多个切片。
|
||||||
返回:该 URL 下切片的详细处理统计及页面更新状态。
|
|
||||||
"""
|
"""
|
||||||
if not results:
|
if not results:
|
||||||
return {"msg": "No data provided"}
|
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
|
||||||
|
|
||||||
# 1. 基础信息提取 (假设 results 里的 source_url 都是一致的)
|
# 1. 基础信息提取
|
||||||
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
|
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
|
||||||
target_url = normalize_url(first_item.get('source_url'))
|
target_url = normalize_url(first_item.get('source_url'))
|
||||||
|
|
||||||
@@ -84,7 +108,7 @@ class CrawlerSqlService:
|
|||||||
is_page_update = False
|
is_page_update = False
|
||||||
|
|
||||||
with self.db.engine.begin() as conn:
|
with self.db.engine.begin() as conn:
|
||||||
# 2. 判断该 URL 是否已经有切片存在 (以此判定是否为“页面更新”)
|
# 2. 判断该 URL 是否已经有切片存在
|
||||||
check_page_stmt = select(self.db.chunks.c.id).where(
|
check_page_stmt = select(self.db.chunks.c.id).where(
|
||||||
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
|
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
|
||||||
).limit(1)
|
).limit(1)
|
||||||
@@ -97,7 +121,7 @@ class CrawlerSqlService:
|
|||||||
c_idx = data.get('chunk_index')
|
c_idx = data.get('chunk_index')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 检查具体某个 index 的切片是否存在
|
# 检查切片是否存在
|
||||||
find_chunk_stmt = select(self.db.chunks.c.id).where(
|
find_chunk_stmt = select(self.db.chunks.c.id).where(
|
||||||
and_(
|
and_(
|
||||||
self.db.chunks.c.task_id == task_id,
|
self.db.chunks.c.task_id == task_id,
|
||||||
@@ -108,7 +132,7 @@ class CrawlerSqlService:
|
|||||||
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
|
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
|
||||||
|
|
||||||
if existing_chunk:
|
if existing_chunk:
|
||||||
# 覆盖更新现有切片
|
# 覆盖更新
|
||||||
upd_stmt = update(self.db.chunks).where(
|
upd_stmt = update(self.db.chunks).where(
|
||||||
self.db.chunks.c.id == existing_chunk[0]
|
self.db.chunks.c.id == existing_chunk[0]
|
||||||
).values(
|
).values(
|
||||||
@@ -135,16 +159,19 @@ class CrawlerSqlService:
|
|||||||
print(f"Chunk {c_idx} failed: {e}")
|
print(f"Chunk {c_idx} failed: {e}")
|
||||||
failed_chunks.append(c_idx)
|
failed_chunks.append(c_idx)
|
||||||
|
|
||||||
# 4. 最终更新队列状态
|
# 4. 更新队列状态
|
||||||
conn.execute(
|
conn.execute(
|
||||||
update(self.db.queue).where(
|
update(self.db.queue).where(
|
||||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
|
||||||
).values(status='completed')
|
).values(status='completed')
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 构造返回
|
||||||
|
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"source_url": target_url,
|
"source_url": target_url,
|
||||||
"is_page_update": is_page_update, # 标志:此页面此前是否有过内容
|
"is_page_update": is_page_update,
|
||||||
"detail": {
|
"detail": {
|
||||||
"inserted_chunk_indexes": inserted_chunks,
|
"inserted_chunk_indexes": inserted_chunks,
|
||||||
"updated_chunk_indexes": updated_chunks,
|
"updated_chunk_indexes": updated_chunks,
|
||||||
@@ -154,20 +181,18 @@ class CrawlerSqlService:
|
|||||||
"inserted": len(inserted_chunks),
|
"inserted": len(inserted_chunks),
|
||||||
"updated": len(updated_chunks),
|
"updated": len(updated_chunks),
|
||||||
"failed": len(failed_chunks)
|
"failed": len(failed_chunks)
|
||||||
}
|
},
|
||||||
|
"msg": msg
|
||||||
}
|
}
|
||||||
|
|
||||||
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||||
"""
|
"""
|
||||||
高性能向量搜索方法
|
高性能向量搜索方法
|
||||||
:param query_embedding: 问题的向量
|
|
||||||
:param task_id: 可选的任务ID,不传则搜全表
|
|
||||||
:param limit: 返回结果数量
|
|
||||||
"""
|
"""
|
||||||
|
results = []
|
||||||
|
msg = ""
|
||||||
|
|
||||||
|
|
||||||
with self.db.engine.connect() as conn:
|
with self.db.engine.connect() as conn:
|
||||||
# 1. 选择需要的字段
|
|
||||||
# 我们同时返回 task_id,方便在全库搜索时知道来源哪个任务
|
|
||||||
stmt = select(
|
stmt = select(
|
||||||
self.db.chunks.c.task_id,
|
self.db.chunks.c.task_id,
|
||||||
self.db.chunks.c.source_url,
|
self.db.chunks.c.source_url,
|
||||||
@@ -176,20 +201,15 @@ class CrawlerSqlService:
|
|||||||
self.db.chunks.c.chunk_index
|
self.db.chunks.c.chunk_index
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 动态添加过滤条件
|
|
||||||
if task_id is not None:
|
if task_id is not None:
|
||||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||||
|
|
||||||
# 3. 按余弦距离排序(1 - 余弦相似度)
|
|
||||||
# 距离越小,相似度越高
|
|
||||||
stmt = stmt.order_by(
|
stmt = stmt.order_by(
|
||||||
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||||
).limit(limit)
|
).limit(limit)
|
||||||
|
|
||||||
# 4. 执行并解析结果
|
|
||||||
rows = conn.execute(stmt).fetchall()
|
rows = conn.execute(stmt).fetchall()
|
||||||
|
|
||||||
results = []
|
|
||||||
for r in rows:
|
for r in rows:
|
||||||
results.append({
|
results.append({
|
||||||
"task_id": r[0],
|
"task_id": r[0],
|
||||||
@@ -198,8 +218,13 @@ class CrawlerSqlService:
|
|||||||
"content": r[3],
|
"content": r[3],
|
||||||
"chunk_index": r[4]
|
"chunk_index": r[4]
|
||||||
})
|
})
|
||||||
|
|
||||||
return results
|
if results:
|
||||||
|
msg = f"Found {len(results)} matches"
|
||||||
|
else:
|
||||||
|
msg = "No matching content found"
|
||||||
|
|
||||||
|
return {"results": results, "msg": msg}
|
||||||
|
|
||||||
|
|
||||||
crawler_sql_service = CrawlerSqlService()
|
crawler_sql_service = CrawlerSqlService()
|
||||||
@@ -29,5 +29,11 @@ def normalize_url(url: str) -> str:
|
|||||||
if not path: path = ""
|
if not path: path = ""
|
||||||
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
|
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
|
||||||
|
|
||||||
def make_response(code: int, msg: str, data: any = None):
|
def make_response(code: int, msg: str = "Success", data: any = None):
|
||||||
|
"""
|
||||||
|
统一响应格式
|
||||||
|
:param code: 1 成功, 0 失败, 其他自定义
|
||||||
|
:param msg: 提示信息
|
||||||
|
:param data: 返回数据
|
||||||
|
"""
|
||||||
return {"code": code, "msg": msg, "data": data}
|
return {"code": code, "msg": msg, "data": data}
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
# backend/workflow.py
|
|
||||||
import dashscope
|
|
||||||
from http import HTTPStatus
|
|
||||||
from firecrawl import FirecrawlApp
|
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
||||||
from .config import settings
|
|
||||||
from .service import crawler_sql_service
|
|
||||||
|
|
||||||
# 初始化配置
|
|
||||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
|
||||||
|
|
||||||
class AutomatedCrawler:
|
|
||||||
def __init__(self):
|
|
||||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
|
||||||
# 文本切分器:每块500字符,重叠100字符,保证语义连贯
|
|
||||||
self.splitter = RecursiveCharacterTextSplitter(
|
|
||||||
chunk_size=500,
|
|
||||||
chunk_overlap=100,
|
|
||||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_embedding(self, text: str):
|
|
||||||
"""调用 Dashscope 生成向量"""
|
|
||||||
try:
|
|
||||||
resp = dashscope.TextEmbedding.call(
|
|
||||||
model=dashscope.TextEmbedding.Models.text_embedding_v4,
|
|
||||||
input=text,
|
|
||||||
dimension=1536
|
|
||||||
)
|
|
||||||
if resp.status_code == HTTPStatus.OK:
|
|
||||||
# 阿里 text-embedding-v3 默认维度是 1024,v2 是 1536
|
|
||||||
# 如果你的数据库是 1536 维,请使用 text_embedding_v2 或调整参数
|
|
||||||
return resp.output['embeddings'][0]['embedding']
|
|
||||||
else:
|
|
||||||
print(f"Embedding API Error: {resp}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Embedding Exception: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def map_and_ingest(self, start_url: str):
|
|
||||||
"""
|
|
||||||
V2 步骤1: 地图式扫描并入库
|
|
||||||
"""
|
|
||||||
print(f"[WorkFlow] Start mapping: {start_url}")
|
|
||||||
|
|
||||||
# 1. 在数据库注册任务
|
|
||||||
task_info = crawler_sql_service.register_task(start_url)
|
|
||||||
task_id = task_info['task_id']
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 2. 调用 Firecrawl Map
|
|
||||||
map_result = self.firecrawl.map(start_url)
|
|
||||||
|
|
||||||
urls=[]
|
|
||||||
for link in map_result.links:
|
|
||||||
urls.append(link.url)
|
|
||||||
|
|
||||||
print(f"[WorkFlow] Found {len(urls)} links")
|
|
||||||
|
|
||||||
# 3. 批量入库 (状态为 pending)
|
|
||||||
if urls:
|
|
||||||
res = crawler_sql_service.add_urls(task_id, urls)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": task_id,
|
|
||||||
"map_result": res
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WorkFlow] Map Error: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def process_task_queue(self, task_id: int, limit: int = 10):
|
|
||||||
"""
|
|
||||||
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
|
||||||
"""
|
|
||||||
# 1. 获取待处理 URL
|
|
||||||
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
|
||||||
urls = pending['urls']
|
|
||||||
|
|
||||||
if not urls:
|
|
||||||
return {"msg": "No pending urls", "count": 0}
|
|
||||||
|
|
||||||
processed_count = 0
|
|
||||||
|
|
||||||
for url in urls:
|
|
||||||
try:
|
|
||||||
print(f"[WorkFlow] Processing: {url}")
|
|
||||||
# 2. 单页抓取 (Markdown)
|
|
||||||
scrape_res = self.firecrawl.scrape(url, formats=['markdown'], only_main_content=True)
|
|
||||||
|
|
||||||
content = scrape_res.markdown
|
|
||||||
|
|
||||||
metadata = scrape_res.metadata_dict
|
|
||||||
title = metadata.get('title', url)
|
|
||||||
|
|
||||||
if not content:
|
|
||||||
print(f"[WorkFlow] Skip empty content: {url}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 3. 切片
|
|
||||||
chunks = self.splitter.split_text(content)
|
|
||||||
|
|
||||||
results_to_save = []
|
|
||||||
|
|
||||||
# 4. 逐个切片向量化 (可以优化为批量调用 embedding API)
|
|
||||||
for idx, chunk_text in enumerate(chunks):
|
|
||||||
vector = self._get_embedding(chunk_text)
|
|
||||||
if vector:
|
|
||||||
results_to_save.append({
|
|
||||||
"source_url": url,
|
|
||||||
"chunk_index": idx,
|
|
||||||
"title": title,
|
|
||||||
"content": chunk_text,
|
|
||||||
"embedding": vector
|
|
||||||
})
|
|
||||||
|
|
||||||
# 5. 保存结果
|
|
||||||
if results_to_save:
|
|
||||||
crawler_sql_service.save_results(task_id, results_to_save)
|
|
||||||
processed_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[WorkFlow] Error processing {url}: {e}")
|
|
||||||
# 实际生产中这里应该把 URL 状态改为 'failed'
|
|
||||||
|
|
||||||
return {"processed_urls": processed_count, "total_chunks": "dynamic"}
|
|
||||||
|
|
||||||
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
|
||||||
"""
|
|
||||||
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
|
||||||
"""
|
|
||||||
vector = self._get_embedding(query_text)
|
|
||||||
if not vector:
|
|
||||||
raise Exception("Failed to generate embedding for query")
|
|
||||||
|
|
||||||
return crawler_sql_service.search_knowledge(vector, task_id, limit)
|
|
||||||
|
|
||||||
# 单例模式
|
|
||||||
workflow = AutomatedCrawler()
|
|
||||||
Reference in New Issue
Block a user