Files
wiki_crawler/backend/services/data_service.py
2026-01-27 01:41:45 +08:00

212 lines
9.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy import select, insert, update, and_, text, func, desc
from backend.core.database import db
from backend.utils.common import normalize_url
import logging
# 获取当前模块的专用 Logger
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
logger = logging.getLogger(__name__)
class DataService:
"""
数据持久化服务层
"""
def __init__(self):
self.db = db
def register_task(self, url: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False}
else:
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
def add_urls(self, task_id: int, urls: list[str]):
success_urls = []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
check_q = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if not conn.execute(check_q).fetchone():
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
success_urls.append(clean_url)
except Exception:
pass
return {"msg": f"Added {len(success_urls)} new urls"}
def get_pending_urls(self, task_id: int, limit: int):
with self.db.engine.begin() as conn:
# 原子锁定:获取并标记为 processing
subquery = select(self.db.queue.c.id).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit).with_for_update(skip_locked=True)
stmt = update(self.db.queue).where(
self.db.queue.c.id.in_(subquery)
).values(status='processing').returning(self.db.queue.c.url)
result = conn.execute(stmt).fetchall()
return [r[0] for r in result]
def mark_url_status(self, task_id: int, url: str, status: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status=status))
def get_all_tasks(self):
"""
[新增] 获取所有已注册的任务(知识库列表)
用于前端展示或工作流的路由选择
"""
with self.db.engine.connect() as conn:
# 查询 id, root_url, created_at (如果有的话)
# 这里假设 tasks 表里有 id 和 root_url
stmt = select(self.db.tasks.c.id, self.db.tasks.c.root_url).order_by(self.db.tasks.c.id)
rows = conn.execute(stmt).fetchall()
# 返回精简列表
return [
{"task_id": r[0], "root_url": r[1], "name": self._extract_name(r[1])}
for r in rows
]
def _extract_name(self, url: str) -> str:
"""辅助方法:从 URL 提取一个简短的名字作为 Alias"""
try:
from urllib.parse import urlparse
domain = urlparse(url).netloc
# 比如 docs.firecrawl.dev -> firecrawl
parts = domain.split('.')
if len(parts) >= 2:
return parts[-2]
return domain
except:
return url
# ... (保持 get_task_monitor_data, save_chunks, search 等方法不变) ...
def get_task_monitor_data(self, task_id: int):
"""[数据库层监控] 获取持久化的任务状态"""
with self.db.engine.connect() as conn:
# 1. 检查任务是否存在
task_exists = conn.execute(select(self.db.tasks.c.root_url).where(self.db.tasks.c.id == task_id)).fetchone()
if not task_exists:
return None
# 2. 统计各状态数量
stats_rows = conn.execute(select(
self.db.queue.c.status, func.count(self.db.queue.c.id)
).where(self.db.queue.c.task_id == task_id).group_by(self.db.queue.c.status)).fetchall()
stats = {"pending": 0, "processing": 0, "completed": 0, "failed": 0}
for status, count in stats_rows:
if status in stats: stats[status] = count
stats["total"] = sum(stats.values())
return {
"root_url": task_exists[0],
"db_stats": stats
}
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
clean_url = normalize_url(source_url)
with self.db.engine.begin() as conn:
for item in chunks_data:
idx = item['index']
meta = item.get('meta_info', {})
existing = conn.execute(select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == clean_url,
self.db.chunks.c.chunk_index == idx)
)).fetchone()
values = {
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
"title": title, "content": item['content'], "embedding": item['embedding'],
"meta_info": meta
}
if existing:
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
else:
conn.execute(insert(self.db.chunks).values(**values))
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50, vector_weight: float = 0.7):
# 向量格式清洗
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list):
query_vector = query_vector[0]
results = []
with self.db.engine.connect() as conn:
# 1. 构造 Query 对象 (这是 tsquery 类型)
keyword_query = func.websearch_to_tsquery('english', query_text)
# 计算分数 (逻辑不变)
vector_dist = self.db.chunks.c.embedding.cosine_distance(query_vector)
vector_score = (1 - vector_dist)
# 注意ts_rank 需要 (tsvector, tsquery)
keyword_rank = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
keyword_score = func.coalesce(keyword_rank, 0)
keyword_weight = 1.0 - vector_weight
final_score = (vector_score * vector_weight + keyword_score * keyword_weight).label("score")
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score
)
# ================= 修复点开始 =================
# 只有当 vector_weight 为 0 (纯关键词模式) 时,才强制加 WHERE 过滤
if vector_weight == 0:
# 错误写法 (SQLAlchemy 自动生成 plainto_tsquery 导致报错):
# stmt = stmt.where(keyword_query.match(self.db.chunks.c.content_tsvector))
# 正确写法 (直接使用 PG 的 @@ 操作符):
# 含义: content_tsvector @@ keyword_query
stmt = stmt.where(self.db.chunks.c.content_tsvector.op('@@')(keyword_query))
# ================= 修复点结束 =================
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
stmt = stmt.order_by(desc("score")).limit(candidates_num)
try:
rows = conn.execute(stmt).fetchall()
results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows]
except Exception as e:
# 打印详细错误方便调试
logger.error(f"Search failed: {e}")
# 只有混合或向量模式才回退,如果是纯关键词模式报错,回退也没用
if vector_weight > 0:
return self._fallback_vector_search(query_vector, task_id, candidates_num)
return {"results": [], "msg": "Keyword search failed"}
return {"results": results, "msg": f"Found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit):
logger.warning("Fallback to pure vector search")
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall()
return {"results": [{"content": r[3], "meta_info": r[4], "score": 0.0} for r in rows], "msg": "Fallback found"}
data_service = DataService()