v3接口restful风格,规范化接口;添加mcp服务器;新增log模块

This commit is contained in:
2026-01-19 23:54:29 +08:00
parent 389c13a2a7
commit 7c99e67a7f
14 changed files with 780 additions and 376 deletions

View File

@@ -1,11 +1,14 @@
from sqlalchemy import select, insert, update, and_, text, func, desc
from backend.core.database import db
from backend.utils.common import normalize_url
import logging
# 获取当前模块的专用 Logger
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
logger = logging.getLogger(__name__)
class DataService:
"""
数据持久化服务层
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
"""
def __init__(self):
self.db = db
@@ -17,11 +20,11 @@ class DataService:
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
return {"task_id": existing[0], "is_new_task": False}
else:
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
return {"task_id": new_task[0], "is_new_task": True}
def add_urls(self, task_id: int, urls: list[str]):
success_urls = []
@@ -29,7 +32,7 @@ class DataService:
for url in urls:
clean_url = normalize_url(url)
try:
check_q = select(self.db.queue).where(
check_q = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if not conn.execute(check_q).fetchone():
@@ -41,115 +44,111 @@ class DataService:
def get_pending_urls(self, task_id: int, limit: int):
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
# 原子锁定:获取并标记为 processing
subquery = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
).limit(limit).with_for_update(skip_locked=True)
if urls:
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing'))
return {"urls": urls, "msg": "Fetched pending urls"}
stmt = update(self.db.queue).where(
self.db.queue.c.id.in_(subquery)
).values(status='processing').returning(self.db.queue.c.url)
result = conn.execute(stmt).fetchall()
return [r[0] for r in result]
def mark_url_status(self, task_id: int, url: str, status: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status=status))
def get_task_monitor_data(self, task_id: int):
"""[数据库层监控] 获取持久化的任务状态"""
with self.db.engine.connect() as conn:
# 1. 检查任务是否存在
task_exists = conn.execute(select(self.db.tasks.c.root_url).where(self.db.tasks.c.id == task_id)).fetchone()
if not task_exists:
return None
# 2. 统计各状态数量
stats_rows = conn.execute(select(
self.db.queue.c.status, func.count(self.db.queue.c.id)
).where(self.db.queue.c.task_id == task_id).group_by(self.db.queue.c.status)).fetchall()
stats = {"pending": 0, "processing": 0, "completed": 0, "failed": 0}
for status, count in stats_rows:
if status in stats: stats[status] = count
stats["total"] = sum(stats.values())
return {
"root_url": task_exists[0],
"db_stats": stats
}
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
"""
保存切片数据 (Phase 1.5: 支持 meta_info)
"""
clean_url = normalize_url(source_url)
count = 0
with self.db.engine.begin() as conn:
for item in chunks_data:
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
idx = item['index']
meta = item.get('meta_info', {})
# 检查是否存在
existing = conn.execute(select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == clean_url,
self.db.chunks.c.chunk_index == idx)
)).fetchone()
values = {
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
"title": title, "content": item['content'], "embedding": item['embedding'],
"meta_info": meta
}
if existing:
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
else:
conn.execute(insert(self.db.chunks).values(**values))
count += 1
# 标记队列完成
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
return {"msg": f"Saved {count} chunks", "count": count}
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 5):
"""
Phase 2: 混合检索 (Hybrid Search)
"""
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50):
# 向量格式清洗
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
if query_vector and isinstance(query_vector, list) and len(query_vector) > 0:
if isinstance(query_vector[0], list): query_vector = query_vector[0]
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
query_vector = query_vector[0]
results = []
with self.db.engine.connect() as conn:
keyword_query = func.websearch_to_tsquery('english', query_text) # 转换为 tsquery
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))# 计算向量相似度
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query) # 计算关键词相似度
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")# 计算最终分数
keyword_query = func.websearch_to_tsquery('english', query_text)
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.meta_info,
final_score
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
)
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 使用 candidates_num 控制召回数量
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
stmt = stmt.order_by(desc("score")).limit(candidates_num)
try:
rows = conn.execute(stmt).fetchall()
results = [
{
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"meta_info": r[4],
"score": float(r[5])
}
for r in rows
]
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
except Exception as e:
print(f"[ERROR] Hybrid search failed: {e}")
logger.error(f"Hybrid search failed: {e}")
return self._fallback_vector_search(query_vector, task_id, candidates_num)
return {"results": results, "msg": f"Hybrid found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit):
print("[WARN] Fallback to pure vector search")
logger.warning("Fallback to pure vector search")
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall()
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
return {"results": [{"content": r[3], "meta_info": r[4], "score": 0.0} for r in rows], "msg": "Fallback found"}
data_service = DataService()