from sqlalchemy import select, insert, update, and_, text, func, desc from backend.core.database import db from backend.utils.common import normalize_url import logging # 获取当前模块的专用 Logger # __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径 logger = logging.getLogger(__name__) class DataService: """ 数据持久化服务层 """ def __init__(self): self.db = db def register_task(self, url: str): clean_url = normalize_url(url) with self.db.engine.begin() as conn: query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) existing = conn.execute(query).fetchone() if existing: return {"task_id": existing[0], "is_new_task": False} else: stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) new_task = conn.execute(stmt).fetchone() return {"task_id": new_task[0], "is_new_task": True} def add_urls(self, task_id: int, urls: list[str]): success_urls = [] with self.db.engine.begin() as conn: for url in urls: clean_url = normalize_url(url) try: check_q = select(self.db.queue.c.id).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ) if not conn.execute(check_q).fetchone(): conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending')) success_urls.append(clean_url) except Exception: pass return {"msg": f"Added {len(success_urls)} new urls"} def get_pending_urls(self, task_id: int, limit: int): with self.db.engine.begin() as conn: # 原子锁定:获取并标记为 processing subquery = select(self.db.queue.c.id).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') ).limit(limit).with_for_update(skip_locked=True) stmt = update(self.db.queue).where( self.db.queue.c.id.in_(subquery) ).values(status='processing').returning(self.db.queue.c.url) result = conn.execute(stmt).fetchall() return [r[0] for r in result] def mark_url_status(self, task_id: int, url: str, status: str): clean_url = normalize_url(url) with self.db.engine.begin() as conn: conn.execute(update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ).values(status=status)) def get_all_tasks(self): """ [新增] 获取所有已注册的任务(知识库列表) 用于前端展示或工作流的路由选择 """ with self.db.engine.connect() as conn: # 查询 id, root_url, created_at (如果有的话) # 这里假设 tasks 表里有 id 和 root_url stmt = select(self.db.tasks.c.id, self.db.tasks.c.root_url).order_by(self.db.tasks.c.id) rows = conn.execute(stmt).fetchall() # 返回精简列表 return [ {"task_id": r[0], "root_url": r[1], "name": self._extract_name(r[1])} for r in rows ] def _extract_name(self, url: str) -> str: """辅助方法:从 URL 提取一个简短的名字作为 Alias""" try: from urllib.parse import urlparse domain = urlparse(url).netloc # 比如 docs.firecrawl.dev -> firecrawl parts = domain.split('.') if len(parts) >= 2: return parts[-2] return domain except: return url # ... (保持 get_task_monitor_data, save_chunks, search 等方法不变) ... def get_task_monitor_data(self, task_id: int): """[数据库层监控] 获取持久化的任务状态""" with self.db.engine.connect() as conn: # 1. 检查任务是否存在 task_exists = conn.execute(select(self.db.tasks.c.root_url).where(self.db.tasks.c.id == task_id)).fetchone() if not task_exists: return None # 2. 统计各状态数量 stats_rows = conn.execute(select( self.db.queue.c.status, func.count(self.db.queue.c.id) ).where(self.db.queue.c.task_id == task_id).group_by(self.db.queue.c.status)).fetchall() stats = {"pending": 0, "processing": 0, "completed": 0, "failed": 0} for status, count in stats_rows: if status in stats: stats[status] = count stats["total"] = sum(stats.values()) return { "root_url": task_exists[0], "db_stats": stats } def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list): clean_url = normalize_url(source_url) with self.db.engine.begin() as conn: for item in chunks_data: idx = item['index'] meta = item.get('meta_info', {}) existing = conn.execute(select(self.db.chunks.c.id).where( and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == clean_url, self.db.chunks.c.chunk_index == idx) )).fetchone() values = { "task_id": task_id, "source_url": clean_url, "chunk_index": idx, "title": title, "content": item['content'], "embedding": item['embedding'], "meta_info": meta } if existing: conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values)) else: conn.execute(insert(self.db.chunks).values(**values)) conn.execute(update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ).values(status='completed')) def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 50, vector_weight: float = 0.7): # 向量格式清洗 if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist() if isinstance(query_vector, list) and len(query_vector) > 0 and isinstance(query_vector[0], list): query_vector = query_vector[0] results = [] with self.db.engine.connect() as conn: # 1. 构造 Query 对象 (这是 tsquery 类型) keyword_query = func.websearch_to_tsquery('english', query_text) # 计算分数 (逻辑不变) vector_dist = self.db.chunks.c.embedding.cosine_distance(query_vector) vector_score = (1 - vector_dist) # 注意:ts_rank 需要 (tsvector, tsquery) keyword_rank = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query) keyword_score = func.coalesce(keyword_rank, 0) keyword_weight = 1.0 - vector_weight final_score = (vector_score * vector_weight + keyword_score * keyword_weight).label("score") stmt = select( self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, self.db.chunks.c.content, self.db.chunks.c.meta_info, final_score ) # ================= 修复点开始 ================= # 只有当 vector_weight 为 0 (纯关键词模式) 时,才强制加 WHERE 过滤 if vector_weight == 0: # 错误写法 (SQLAlchemy 自动生成 plainto_tsquery 导致报错): # stmt = stmt.where(keyword_query.match(self.db.chunks.c.content_tsvector)) # 正确写法 (直接使用 PG 的 @@ 操作符): # 含义: content_tsvector @@ keyword_query stmt = stmt.where(self.db.chunks.c.content_tsvector.op('@@')(keyword_query)) # ================= 修复点结束 ================= if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id) stmt = stmt.order_by(desc("score")).limit(candidates_num) try: rows = conn.execute(stmt).fetchall() results = [{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4], "score": float(r[5])} for r in rows] except Exception as e: # 打印详细错误方便调试 logger.error(f"Search failed: {e}") # 只有混合或向量模式才回退,如果是纯关键词模式报错,回退也没用 if vector_weight > 0: return self._fallback_vector_search(query_vector, task_id, candidates_num) return {"results": [], "msg": "Keyword search failed"} return {"results": results, "msg": f"Found {len(results)}"} def _fallback_vector_search(self, vector, task_id, limit): logger.warning("Fallback to pure vector search") with self.db.engine.connect() as conn: stmt = select( self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, self.db.chunks.c.content, self.db.chunks.c.meta_info ).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit) if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id) rows = conn.execute(stmt).fetchall() return {"results": [{"content": r[3], "meta_info": r[4], "score": 0.0} for r in rows], "msg": "Fallback found"} data_service = DataService()