修改配置和response的细节

This commit is contained in:
2025-12-30 16:57:31 +08:00
parent 8972246445
commit d191b13455
5 changed files with 308 additions and 205 deletions

View File

@@ -1,7 +1,8 @@
# backend/main.py
from fastapi import FastAPI, APIRouter, BackgroundTasks
from .service import crawler_sql_service
from .workflow import workflow
# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow
from .services.crawler_sql_service import crawler_sql_service
from .services.automated_crawler import workflow
from .schemas import (
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
AutoMapRequest, AutoProcessRequest, TextSearchRequest
@@ -10,6 +11,11 @@ from .utils import make_response
app = FastAPI(title="Wiki Crawler API")
# ==========================================
# 工具函数
# ==========================================
# ==========================================
# V1 Router: 原始的底层接口 (Manual Control)
# ==========================================
@@ -18,8 +24,10 @@ router_v1 = APIRouter()
@router_v1.post("/register")
async def register(req: RegisterRequest):
try:
data = crawler_sql_service.register_task(req.url)
return make_response(1, "Success", data)
# Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'}
res = crawler_sql_service.register_task(req.url)
# 使用 pop 将 msg 提取出来作为响应的 msg剩下的作为 data
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@@ -27,44 +35,43 @@ async def register(req: RegisterRequest):
async def add_urls(req: AddUrlsRequest):
try:
urls = req.urls_obj["urls"]
data = crawler_sql_service.add_urls(req.task_id, urls=urls)
return make_response(1, "Success", data)
res = crawler_sql_service.add_urls(req.task_id, urls=urls)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/pending_urls")
async def pending_urls(req: PendingRequest):
try:
data = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
msg = "Success" if data["urls"] else "Queue Empty"
return make_response(1, msg, data)
res = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
# 即使队列为空Service 也会返回 msg="Queue is empty"
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/save_results")
async def save_results(req: SaveResultsRequest):
try:
data = crawler_sql_service.save_results(req.task_id, req.results)
return make_response(1, "Success", data)
res = crawler_sql_service.save_results(req.task_id, req.results)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/search")
async def search_v1(req: SearchRequest):
"""V1 搜索:需要客户端自己传向量"""
"""V1 搜索:客户端手动传向量"""
try:
vector = req.query_embedding['vector']
# 注意这里需要确认你数据库的向量维度。TextEmbedding V3 可能是 1024V2 是 1536。
# 请根据你的 PGVector 设置进行匹配。
if not vector:
return make_response(2, "Vector is empty", None)
data = crawler_sql_service.search_knowledge(
# Service 现在返回 {'results': [...], 'msg': 'Found ...'}
res = crawler_sql_service.search_knowledge(
query_embedding=vector,
task_id=req.task_id,
limit=req.limit
)
return make_response(1, "Search Done", data)
return make_response(1, res.pop("msg", "Search Done"), res)
except Exception as e:
return make_response(0, str(e))
@@ -75,16 +82,14 @@ async def search_v1(req: SearchRequest):
router_v2 = APIRouter()
@router_v2.post("/auto/map")
async def auto_map(req: AutoMapRequest, background_tasks: BackgroundTasks):
async def auto_map(req: AutoMapRequest):
"""
[步] 输入首页 URL自动调用 Firecrawl Map 并入库
[步] 输入首页 URL自动调用 Firecrawl Map 并入库
"""
# 也可以放入 background_tasks但 map 通常比较快这里演示同步返回任务ID
try:
# 为了不阻塞主线程,如果 map 很慢,建议放入 background_tasks
# 这里为了能立刻看到 task_id先同步调用 (Firecrawl Map 比较快)
data = workflow.map_and_ingest(req.url)
return make_response(1, "Mapping Started", data)
# Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...}
res = workflow.map_and_ingest(req.url)
return make_response(1, res.pop("msg", "Mapping Started"), res)
except Exception as e:
return make_response(0, str(e))
@@ -93,9 +98,14 @@ async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTask
"""
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
"""
try:
# 将耗时操作放入后台任务
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
return make_response(1, "Processing started in background", {"task_id": req.task_id})
# 因为是后台任务,无法立即获取 Service 的返回值 msg只能返回通用消息
return make_response(1, "Background processing started", {"task_id": req.task_id})
except Exception as e:
return make_response(0, str(e))
@router_v2.post("/search")
async def search_v2(req: TextSearchRequest):
@@ -103,8 +113,9 @@ async def search_v2(req: TextSearchRequest):
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
"""
try:
data = workflow.search_with_embedding(req.query, req.task_id, req.limit)
return make_response(1, "Search Success", data)
# Workflow 返回 {'results': [...], 'msg': '...'}
res = workflow.search_with_embedding(req.query, req.task_id, req.limit)
return make_response(1, res.pop("msg", "Search Success"), res)
except Exception as e:
return make_response(0, f"Search Failed: {str(e)}")

View File

@@ -0,0 +1,201 @@
import dashscope
from http import HTTPStatus
from firecrawl import FirecrawlApp
from langchain_text_splitters import RecursiveCharacterTextSplitter
from ..config import settings
from .crawler_sql_service import crawler_sql_service
# 初始化配置
dashscope.api_key = settings.DASHSCOPE_API_KEY
class AutomatedCrawler:
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def _get_embedding(self, text: str):
"""内部方法:调用 Dashscope 生成向量"""
# 注意:此方法是内部辅助,出错返回 None由调用方处理状态
embedding = None
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
input=text,
dimension=1536
)
if resp.status_code == HTTPStatus.OK:
embedding = resp.output['embeddings'][0]['embedding']
else:
print(f"Embedding API Error: {resp}")
except Exception as e:
print(f"Embedding Exception: {e}")
return embedding
def map_and_ingest(self, start_url: str):
"""
V2 步骤1: 地图式扫描并入库
"""
print(f"[WorkFlow] Start mapping: {start_url}")
result = {}
try:
# 1. 在数据库注册任务
task_info = crawler_sql_service.register_task(start_url)
task_id = task_info['task_id']
is_new_task = task_info['is_new_task']
# 2. 调用 Firecrawl Map
if is_new_task:
map_result = self.firecrawl.map(start_url)
urls = []
# 兼容 firecrawl sdk 不同版本的返回结构
# 如果 map_result 是对象且有 links 属性
if hasattr(map_result, 'links'):
for link in map_result.links:
# 假设 link 是对象或字典,视具体 SDK 版本而定
# 如果 link 是字符串直接 append
if isinstance(link, str):
urls.append(link)
else:
urls.append(getattr(link, 'url', str(link)))
# 如果是字典
elif isinstance(map_result, dict):
urls = map_result.get('links', [])
print(f"[WorkFlow] Found {len(urls)} links")
# 3. 批量入库
res = {"msg": "No urls found to add"}
if urls:
res = crawler_sql_service.add_urls(task_id, urls)
result = {
"msg": "Task successfully mapped and URLs added",
"task_id": task_id,
"is_new_task": is_new_task,
"url_count": len(urls),
"map_detail": res
}
else:
result = {
"msg": "Task already exists, skipped mapping",
"task_id": task_id,
"is_new_task": False,
"url_count": 0,
"map_detail": {}
}
except Exception as e:
print(f"[WorkFlow] Map Error: {e}")
# 向上抛出异常,由 main.py 捕获并返回错误 Response
raise e
return result
def process_task_queue(self, task_id: int, limit: int = 10):
"""
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
"""
processed_count = 0
total_chunks_saved = 0
result = {}
# 1. 获取待处理 URL
pending = crawler_sql_service.get_pending_urls(task_id, limit)
urls = pending['urls']
if not urls:
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
else:
for url in urls:
try:
print(f"[WorkFlow] Processing: {url}")
# 2. 单页抓取
scrape_res = self.firecrawl.scrape(
url,
params={'formats': ['markdown'], 'onlyMainContent': True}
)
# 兼容 SDK 返回类型 (对象或字典)
content = ""
metadata = {}
if isinstance(scrape_res, dict):
content = scrape_res.get('markdown', '')
metadata = scrape_res.get('metadata', {})
else:
content = getattr(scrape_res, 'markdown', '')
metadata = getattr(scrape_res, 'metadata', {})
if not metadata and hasattr(scrape_res, 'metadata_dict'):
metadata = scrape_res.metadata_dict
title = metadata.get('title', url)
if not content:
print(f"[WorkFlow] Skip empty content: {url}")
continue
# 3. 切片
chunks = self.splitter.split_text(content)
results_to_save = []
# 4. 向量化
for idx, chunk_text in enumerate(chunks):
vector = self._get_embedding(chunk_text)
if vector:
results_to_save.append({
"source_url": url,
"chunk_index": idx,
"title": title,
"content": chunk_text,
"embedding": vector
})
# 5. 保存
if results_to_save:
save_res = crawler_sql_service.save_results(task_id, results_to_save)
processed_count += 1
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
except Exception as e:
print(f"[WorkFlow] Error processing {url}: {e}")
# 此处不抛出异常,以免打断整个批次的循环
# 实际生产建议在这里调用 service 将 url 标记为 failed
result = {
"msg": f"Batch processing complete. URLs processed: {processed_count}",
"processed_urls": processed_count,
"total_chunks_saved": total_chunks_saved
}
return result
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
"""
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
"""
result = {}
# 1. 获取向量
vector = self._get_embedding(query_text)
if not vector:
result = {
"msg": "Failed to generate embedding for query",
"results": []
}
else:
# 2. 执行搜索
# search_knowledge 现在已经返回带 msg 的字典了
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
return result
# 单例模式
workflow = AutomatedCrawler()

View File

@@ -1,7 +1,6 @@
# service.py
from sqlalchemy import select, insert, update, delete, and_
from .database import db_instance
from .utils import normalize_url
from sqlalchemy import select, insert, update, and_
from ..database import db_instance
from ..utils import normalize_url
class CrawlerSqlService:
def __init__(self):
@@ -10,18 +9,30 @@ class CrawlerSqlService:
def register_task(self, url: str):
"""完全使用库 API 实现的注册"""
clean_url = normalize_url(url)
result = {}
with self.db.engine.begin() as conn:
# 使用 select() API
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False}
result = {
"task_id": existing[0],
"is_new_task": False,
"msg": "Task already exists"
}
else:
# 使用 insert() API
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
result = {
"task_id": new_task[0],
"is_new_task": True,
"msg": "New task created successfully"
}
return result
def add_urls(self, task_id: int, urls: list[str]):
"""通用 API 实现的批量添加(含详细返回)"""
@@ -31,7 +42,7 @@ class CrawlerSqlService:
for url in urls:
clean_url = normalize_url(url)
try:
# 检查队列中是否已存在该 URL (通用写法)
# 检查队列中是否已存在该 URL
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
@@ -47,10 +58,20 @@ class CrawlerSqlService:
except Exception:
failed_urls.append(clean_url)
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
# 构造返回消息
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
return {
"success_urls": success_urls,
"skipped_urls": skipped_urls,
"failed_urls": failed_urls,
"msg": msg
}
def get_pending_urls(self, task_id: int, limit: int):
"""原子锁定 API 实现"""
result = {}
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
@@ -63,17 +84,20 @@ class CrawlerSqlService:
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
conn.execute(upd)
return {"urls": urls}
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
else:
result = {"urls": [], "msg": "Queue is empty"}
return result
def save_results(self, task_id: int, results: list):
"""
保存同一 URL 的多个切片
返回 URL 下切片的详细处理统计及页面更新状态
"""
if not results:
return {"msg": "No data provided"}
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
# 1. 基础信息提取 (假设 results 里的 source_url 都是一致的)
# 1. 基础信息提取
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
target_url = normalize_url(first_item.get('source_url'))
@@ -84,7 +108,7 @@ class CrawlerSqlService:
is_page_update = False
with self.db.engine.begin() as conn:
# 2. 判断该 URL 是否已经有切片存在 (以此判定是否为“页面更新”)
# 2. 判断该 URL 是否已经有切片存在
check_page_stmt = select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
).limit(1)
@@ -97,7 +121,7 @@ class CrawlerSqlService:
c_idx = data.get('chunk_index')
try:
# 检查具体某个 index 的切片是否存在
# 检查切片是否存在
find_chunk_stmt = select(self.db.chunks.c.id).where(
and_(
self.db.chunks.c.task_id == task_id,
@@ -108,7 +132,7 @@ class CrawlerSqlService:
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
if existing_chunk:
# 覆盖更新现有切片
# 覆盖更新
upd_stmt = update(self.db.chunks).where(
self.db.chunks.c.id == existing_chunk[0]
).values(
@@ -135,16 +159,19 @@ class CrawlerSqlService:
print(f"Chunk {c_idx} failed: {e}")
failed_chunks.append(c_idx)
# 4. 最终更新队列状态
# 4. 更新队列状态
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
).values(status='completed')
)
# 构造返回
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
return {
"source_url": target_url,
"is_page_update": is_page_update, # 标志:此页面此前是否有过内容
"is_page_update": is_page_update,
"detail": {
"inserted_chunk_indexes": inserted_chunks,
"updated_chunk_indexes": updated_chunks,
@@ -154,20 +181,18 @@ class CrawlerSqlService:
"inserted": len(inserted_chunks),
"updated": len(updated_chunks),
"failed": len(failed_chunks)
},
"msg": msg
}
}
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
:param query_embedding: 问题的向量
:param task_id: 可选的任务ID不传则搜全表
:param limit: 返回结果数量
"""
results = []
msg = ""
with self.db.engine.connect() as conn:
# 1. 选择需要的字段
# 我们同时返回 task_id方便在全库搜索时知道来源哪个任务
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
@@ -176,20 +201,15 @@ class CrawlerSqlService:
self.db.chunks.c.chunk_index
)
# 2. 动态添加过滤条件
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 3. 按余弦距离排序1 - 余弦相似度)
# 距离越小,相似度越高
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
# 4. 执行并解析结果
rows = conn.execute(stmt).fetchall()
results = []
for r in rows:
results.append({
"task_id": r[0],
@@ -199,7 +219,12 @@ class CrawlerSqlService:
"chunk_index": r[4]
})
return results
if results:
msg = f"Found {len(results)} matches"
else:
msg = "No matching content found"
return {"results": results, "msg": msg}
crawler_sql_service = CrawlerSqlService()

View File

@@ -29,5 +29,11 @@ def normalize_url(url: str) -> str:
if not path: path = ""
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
def make_response(code: int, msg: str, data: any = None):
def make_response(code: int, msg: str = "Success", data: any = None):
"""
统一响应格式
:param code: 1 成功, 0 失败, 其他自定义
:param msg: 提示信息
:param data: 返回数据
"""
return {"code": code, "msg": msg, "data": data}

View File

@@ -1,140 +0,0 @@
# backend/workflow.py
import dashscope
from http import HTTPStatus
from firecrawl import FirecrawlApp
from langchain_text_splitters import RecursiveCharacterTextSplitter
from .config import settings
from .service import crawler_sql_service
# 初始化配置
dashscope.api_key = settings.DASHSCOPE_API_KEY
class AutomatedCrawler:
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
# 文本切分器每块500字符重叠100字符保证语义连贯
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def _get_embedding(self, text: str):
"""调用 Dashscope 生成向量"""
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4,
input=text,
dimension=1536
)
if resp.status_code == HTTPStatus.OK:
# 阿里 text-embedding-v3 默认维度是 1024v2 是 1536
# 如果你的数据库是 1536 维,请使用 text_embedding_v2 或调整参数
return resp.output['embeddings'][0]['embedding']
else:
print(f"Embedding API Error: {resp}")
return None
except Exception as e:
print(f"Embedding Exception: {e}")
return None
def map_and_ingest(self, start_url: str):
"""
V2 步骤1: 地图式扫描并入库
"""
print(f"[WorkFlow] Start mapping: {start_url}")
# 1. 在数据库注册任务
task_info = crawler_sql_service.register_task(start_url)
task_id = task_info['task_id']
try:
# 2. 调用 Firecrawl Map
map_result = self.firecrawl.map(start_url)
urls=[]
for link in map_result.links:
urls.append(link.url)
print(f"[WorkFlow] Found {len(urls)} links")
# 3. 批量入库 (状态为 pending)
if urls:
res = crawler_sql_service.add_urls(task_id, urls)
return {
"task_id": task_id,
"map_result": res
}
except Exception as e:
print(f"[WorkFlow] Map Error: {e}")
raise e
def process_task_queue(self, task_id: int, limit: int = 10):
"""
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
"""
# 1. 获取待处理 URL
pending = crawler_sql_service.get_pending_urls(task_id, limit)
urls = pending['urls']
if not urls:
return {"msg": "No pending urls", "count": 0}
processed_count = 0
for url in urls:
try:
print(f"[WorkFlow] Processing: {url}")
# 2. 单页抓取 (Markdown)
scrape_res = self.firecrawl.scrape(url, formats=['markdown'], only_main_content=True)
content = scrape_res.markdown
metadata = scrape_res.metadata_dict
title = metadata.get('title', url)
if not content:
print(f"[WorkFlow] Skip empty content: {url}")
continue
# 3. 切片
chunks = self.splitter.split_text(content)
results_to_save = []
# 4. 逐个切片向量化 (可以优化为批量调用 embedding API)
for idx, chunk_text in enumerate(chunks):
vector = self._get_embedding(chunk_text)
if vector:
results_to_save.append({
"source_url": url,
"chunk_index": idx,
"title": title,
"content": chunk_text,
"embedding": vector
})
# 5. 保存结果
if results_to_save:
crawler_sql_service.save_results(task_id, results_to_save)
processed_count += 1
except Exception as e:
print(f"[WorkFlow] Error processing {url}: {e}")
# 实际生产中这里应该把 URL 状态改为 'failed'
return {"processed_urls": processed_count, "total_chunks": "dynamic"}
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
"""
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
"""
vector = self._get_embedding(query_text)
if not vector:
raise Exception("Failed to generate embedding for query")
return crawler_sql_service.search_knowledge(vector, task_id, limit)
# 单例模式
workflow = AutomatedCrawler()