Files
wiki_crawler/backend/services/data_service.py

241 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy import select, insert, update, and_, text, func, desc
from backend.core.database import db
from backend.utils.common import normalize_url
import logging
# 获取当前模块的专用 Logger
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
logger = logging.getLogger(__name__)
class DataService:
"""
数据持久化服务层
"""
def __init__(self):
self.db = db
def register_task(self, url: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False}
else:
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
def add_urls(self, task_id: int, urls: list[str]):
success_urls = []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
check_q = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if not conn.execute(check_q).fetchone():
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
success_urls.append(clean_url)
except Exception:
pass
return {"msg": f"Added {len(success_urls)} new urls"}
def get_pending_urls(self, task_id: int, limit: int):
with self.db.engine.begin() as conn:
# 原子锁定:获取并标记为 processing
subquery = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit).with_for_update(skip_locked=True)
stmt = update(self.db.queue).where(
self.db.queue.c.id.in_(subquery)
).values(status='processing').returning(self.db.queue.c.url)
result = conn.execute(stmt).fetchall()
return [r[0] for r in result]
def mark_url_status(self, task_id: int, url: str, status: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status=status))
def get_all_tasks(self):
"""
[新增] 获取所有已注册的任务(知识库列表)
用于前端展示或工作流的路由选择
"""
with self.db.engine.connect() as conn:
# 查询 id, root_url, created_at (如果有的话)
# 这里假设 tasks 表里有 id 和 root_url
stmt = select(self.db.tasks.c.id, self.db.tasks.c.root_url).order_by(self.db.tasks.c.id)
rows = conn.execute(stmt).fetchall()
# 返回精简列表
return [
{"task_id": r[0], "root_url": r[1], "name": self._extract_name(r[1])}
for r in rows
]
def _extract_name(self, url: str) -> str:
"""辅助方法:从 URL 提取一个简短的名字作为 Alias"""
try:
from urllib.parse import urlparse
domain = urlparse(url).netloc
# 比如 docs.firecrawl.dev -> firecrawl
parts = domain.split('.')
if len(parts) >= 2:
return parts[-2]
return domain
except:
return url
def get_task_by_root_url(self, url: str):
"""返回已存在任务的 id如果没有则返回 None"""
clean_url = normalize_url(url)
with self.db.engine.connect() as conn:
row = conn.execute(select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)).fetchone()
return row[0] if row else None
def create_task_with_urls(self, url: str, urls: list[str]):
"""
原子化:在单个事务中创建任务并批量插入 URL去重
如果任务已存在,则不会创建新任务,而是把新的 URL 去重后插入到该任务下。
返回:{"task_id": int, "is_new_task": bool, "added": int}
"""
clean_root = normalize_url(url)
clean_urls = [normalize_url(u) for u in urls]
added_count = 0
with self.db.engine.begin() as conn:
# 1. 尝试获取已存在任务
existing = conn.execute(select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_root)).fetchone()
if existing:
task_id = existing[0]
is_new = False
else:
# 创建新任务并返回 id
stmt = insert(self.db.tasks).values(root_url=clean_root).returning(self.db.tasks.c.id)
task_id = conn.execute(stmt).fetchone()[0]
is_new = True
# 2. 批量插入 urls跳过已存在项
# 使用临时表或单条插入均可,这里逐条检查以保证兼容性
for u in clean_urls:
try:
exists_q = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == u)
)
if not conn.execute(exists_q).fetchone():
conn.execute(insert(self.db.queue).values(task_id=task_id, url=u, status='pending'))
added_count += 1
except Exception:
# 单条插入失败时忽略,继续处理剩余 URL
continue
return {"task_id": task_id, "is_new_task": is_new, "added": added_count}
def delete_task(self, task_id: int):
"""删除任务与相关队列与分片(谨慎使用,主要用于回滚)"""
with self.db.engine.begin() as conn:
try:
conn.execute(text("DELETE FROM chunks WHERE task_id = :tid"), {"tid": task_id})
conn.execute(text("DELETE FROM queue WHERE task_id = :tid"), {"tid": task_id})
conn.execute(text("DELETE FROM tasks WHERE id = :tid"), {"tid": task_id})
return True
except Exception as e:
logger.error(f"Failed to delete task {task_id}: {e}")
return False
def get_task_monitor_data(self, task_id: int):
"""[数据库层监控] 获取持久化的任务状态"""
with self.db.engine.connect() as conn:
# 1. 检查任务是否存在
task_exists = conn.execute(select(self.db.tasks.c.root_url).where(self.db.tasks.c.id == task_id)).fetchone()
if not task_exists:
return None
# 2. 统计各状态数量
stats_rows = conn.execute(select(
self.db.queue.c.status, func.count(self.db.queue.c.id)
).where(self.db.queue.c.task_id == task_id).group_by(self.db.queue.c.status)).fetchall()
stats = {"pending": 0, "processing": 0, "completed": 0, "failed": 0}
for status, count in stats_rows:
if status in stats: stats[status] = count
stats["total"] = sum(stats.values())
return {
"root_url": task_exists[0],
"db_stats": stats
}
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
clean_url = normalize_url(source_url)
with self.db.engine.begin() as conn:
for item in chunks_data:
idx = item['index']
meta = item.get('meta_info', {})
existing = conn.execute(select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == clean_url,
self.db.chunks.c.chunk_index == idx)
)).fetchone()
values = {
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
"title": title, "content": item['content'], "embedding": item['embedding'],
"meta_info": meta
}
if existing:
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
else:
conn.execute(insert(self.db.chunks).values(**values))
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50):
# 向量格式清洗
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
query_vector = query_vector[0]
results = []
with self.db.engine.connect() as conn:
keyword_query = func.websearch_to_tsquery('english', query_text)
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
)
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
stmt = stmt.order_by(desc("score")).limit(candidates_num)
try:
rows = conn.execute(stmt).fetchall()
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
except Exception as e:
logger.error(f"Hybrid search failed: {e}")
return self._fallback_vector_search(query_vector, task_id, candidates_num)
return {"results": results, "msg": f"Hybrid found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit):
logger.warning("Fallback to pure vector search")
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall()
return {"results": [{"content": r[3], "meta_info": r[4], "score": 0.0} for r in rows], "msg": "Fallback found"}
data_service = DataService()