241 lines
11 KiB
Python
241 lines
11 KiB
Python
from sqlalchemy import select, insert, update, and_, text, func, desc
|
||
from backend.core.database import db
|
||
from backend.utils.common import normalize_url
|
||
import logging
|
||
|
||
# 获取当前模块的专用 Logger
|
||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
||
logger = logging.getLogger(__name__)
|
||
class DataService:
|
||
"""
|
||
数据持久化服务层
|
||
"""
|
||
def __init__(self):
|
||
self.db = db
|
||
|
||
def register_task(self, url: str):
|
||
clean_url = normalize_url(url)
|
||
with self.db.engine.begin() as conn:
|
||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||
existing = conn.execute(query).fetchone()
|
||
|
||
if existing:
|
||
return {"task_id": existing[0], "is_new_task": False}
|
||
else:
|
||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||
new_task = conn.execute(stmt).fetchone()
|
||
return {"task_id": new_task[0], "is_new_task": True}
|
||
|
||
def add_urls(self, task_id: int, urls: list[str]):
|
||
success_urls = []
|
||
with self.db.engine.begin() as conn:
|
||
for url in urls:
|
||
clean_url = normalize_url(url)
|
||
try:
|
||
check_q = select(self.db.queue.c.id).where(
|
||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||
)
|
||
if not conn.execute(check_q).fetchone():
|
||
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
|
||
success_urls.append(clean_url)
|
||
except Exception:
|
||
pass
|
||
return {"msg": f"Added {len(success_urls)} new urls"}
|
||
|
||
def get_pending_urls(self, task_id: int, limit: int):
|
||
with self.db.engine.begin() as conn:
|
||
# 原子锁定:获取并标记为 processing
|
||
subquery = select(self.db.queue.c.id).where(
|
||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||
).limit(limit).with_for_update(skip_locked=True)
|
||
|
||
stmt = update(self.db.queue).where(
|
||
self.db.queue.c.id.in_(subquery)
|
||
).values(status='processing').returning(self.db.queue.c.url)
|
||
|
||
result = conn.execute(stmt).fetchall()
|
||
return [r[0] for r in result]
|
||
|
||
def mark_url_status(self, task_id: int, url: str, status: str):
|
||
clean_url = normalize_url(url)
|
||
with self.db.engine.begin() as conn:
|
||
conn.execute(update(self.db.queue).where(
|
||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||
).values(status=status))
|
||
|
||
def get_all_tasks(self):
|
||
"""
|
||
[新增] 获取所有已注册的任务(知识库列表)
|
||
用于前端展示或工作流的路由选择
|
||
"""
|
||
with self.db.engine.connect() as conn:
|
||
# 查询 id, root_url, created_at (如果有的话)
|
||
# 这里假设 tasks 表里有 id 和 root_url
|
||
stmt = select(self.db.tasks.c.id, self.db.tasks.c.root_url).order_by(self.db.tasks.c.id)
|
||
rows = conn.execute(stmt).fetchall()
|
||
|
||
# 返回精简列表
|
||
return [
|
||
{"task_id": r[0], "root_url": r[1], "name": self._extract_name(r[1])}
|
||
for r in rows
|
||
]
|
||
|
||
def _extract_name(self, url: str) -> str:
|
||
"""辅助方法:从 URL 提取一个简短的名字作为 Alias"""
|
||
try:
|
||
from urllib.parse import urlparse
|
||
domain = urlparse(url).netloc
|
||
# 比如 docs.firecrawl.dev -> firecrawl
|
||
parts = domain.split('.')
|
||
if len(parts) >= 2:
|
||
return parts[-2]
|
||
return domain
|
||
except:
|
||
return url
|
||
|
||
def get_task_by_root_url(self, url: str):
|
||
"""返回已存在任务的 id(如果没有则返回 None)"""
|
||
clean_url = normalize_url(url)
|
||
with self.db.engine.connect() as conn:
|
||
row = conn.execute(select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)).fetchone()
|
||
return row[0] if row else None
|
||
|
||
def create_task_with_urls(self, url: str, urls: list[str]):
|
||
"""
|
||
原子化:在单个事务中创建任务并批量插入 URL(去重)。
|
||
如果任务已存在,则不会创建新任务,而是把新的 URL 去重后插入到该任务下。
|
||
|
||
返回:{"task_id": int, "is_new_task": bool, "added": int}
|
||
"""
|
||
clean_root = normalize_url(url)
|
||
clean_urls = [normalize_url(u) for u in urls]
|
||
added_count = 0
|
||
with self.db.engine.begin() as conn:
|
||
# 1. 尝试获取已存在任务
|
||
existing = conn.execute(select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_root)).fetchone()
|
||
if existing:
|
||
task_id = existing[0]
|
||
is_new = False
|
||
else:
|
||
# 创建新任务并返回 id
|
||
stmt = insert(self.db.tasks).values(root_url=clean_root).returning(self.db.tasks.c.id)
|
||
task_id = conn.execute(stmt).fetchone()[0]
|
||
is_new = True
|
||
|
||
# 2. 批量插入 urls,跳过已存在项
|
||
# 使用临时表或单条插入均可,这里逐条检查以保证兼容性
|
||
for u in clean_urls:
|
||
try:
|
||
exists_q = select(self.db.queue.c.id).where(
|
||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == u)
|
||
)
|
||
if not conn.execute(exists_q).fetchone():
|
||
conn.execute(insert(self.db.queue).values(task_id=task_id, url=u, status='pending'))
|
||
added_count += 1
|
||
except Exception:
|
||
# 单条插入失败时忽略,继续处理剩余 URL
|
||
continue
|
||
|
||
return {"task_id": task_id, "is_new_task": is_new, "added": added_count}
|
||
|
||
def delete_task(self, task_id: int):
|
||
"""删除任务与相关队列与分片(谨慎使用,主要用于回滚)"""
|
||
with self.db.engine.begin() as conn:
|
||
try:
|
||
conn.execute(text("DELETE FROM chunks WHERE task_id = :tid"), {"tid": task_id})
|
||
conn.execute(text("DELETE FROM queue WHERE task_id = :tid"), {"tid": task_id})
|
||
conn.execute(text("DELETE FROM tasks WHERE id = :tid"), {"tid": task_id})
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete task {task_id}: {e}")
|
||
return False
|
||
|
||
def get_task_monitor_data(self, task_id: int):
|
||
"""[数据库层监控] 获取持久化的任务状态"""
|
||
with self.db.engine.connect() as conn:
|
||
# 1. 检查任务是否存在
|
||
task_exists = conn.execute(select(self.db.tasks.c.root_url).where(self.db.tasks.c.id == task_id)).fetchone()
|
||
if not task_exists:
|
||
return None
|
||
|
||
# 2. 统计各状态数量
|
||
stats_rows = conn.execute(select(
|
||
self.db.queue.c.status, func.count(self.db.queue.c.id)
|
||
).where(self.db.queue.c.task_id == task_id).group_by(self.db.queue.c.status)).fetchall()
|
||
|
||
stats = {"pending": 0, "processing": 0, "completed": 0, "failed": 0}
|
||
for status, count in stats_rows:
|
||
if status in stats: stats[status] = count
|
||
stats["total"] = sum(stats.values())
|
||
|
||
return {
|
||
"root_url": task_exists[0],
|
||
"db_stats": stats
|
||
}
|
||
|
||
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
|
||
clean_url = normalize_url(source_url)
|
||
with self.db.engine.begin() as conn:
|
||
for item in chunks_data:
|
||
idx = item['index']
|
||
meta = item.get('meta_info', {})
|
||
existing = conn.execute(select(self.db.chunks.c.id).where(
|
||
and_(self.db.chunks.c.task_id == task_id,
|
||
self.db.chunks.c.source_url == clean_url,
|
||
self.db.chunks.c.chunk_index == idx)
|
||
)).fetchone()
|
||
values = {
|
||
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
|
||
"title": title, "content": item['content'], "embedding": item['embedding'],
|
||
"meta_info": meta
|
||
}
|
||
if existing:
|
||
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
|
||
else:
|
||
conn.execute(insert(self.db.chunks).values(**values))
|
||
|
||
conn.execute(update(self.db.queue).where(
|
||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||
).values(status='completed'))
|
||
|
||
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50):
|
||
# 向量格式清洗
|
||
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
|
||
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
|
||
query_vector = query_vector[0]
|
||
|
||
results = []
|
||
with self.db.engine.connect() as conn:
|
||
keyword_query = func.websearch_to_tsquery('english', query_text)
|
||
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
|
||
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
|
||
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
|
||
|
||
stmt = select(
|
||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
|
||
)
|
||
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||
stmt = stmt.order_by(desc("score")).limit(candidates_num)
|
||
|
||
try:
|
||
rows = conn.execute(stmt).fetchall()
|
||
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
|
||
except Exception as e:
|
||
logger.error(f"Hybrid search failed: {e}")
|
||
return self._fallback_vector_search(query_vector, task_id, candidates_num)
|
||
|
||
return {"results": results, "msg": f"Hybrid found {len(results)}"}
|
||
|
||
def _fallback_vector_search(self, vector, task_id, limit):
|
||
logger.warning("Fallback to pure vector search")
|
||
with self.db.engine.connect() as conn:
|
||
stmt = select(
|
||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||
self.db.chunks.c.content, self.db.chunks.c.meta_info
|
||
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
|
||
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||
rows = conn.execute(stmt).fetchall()
|
||
return {"results": [{"content": r[3], "meta_info": r[4], "score": 0.0} for r in rows], "msg": "Fallback found"}
|
||
|
||
data_service = DataService() |