from sqlalchemy import select, insert, update, and_ from backend.core.database import db from backend.utils.common import normalize_url class DataService: """ 数据持久化服务层 只负责数据库 CRUD 操作,不包含外部 API 调用逻辑 """ def __init__(self): self.db = db def register_task(self, url: str): clean_url = normalize_url(url) with self.db.engine.begin() as conn: query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) existing = conn.execute(query).fetchone() if existing: return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"} else: stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) new_task = conn.execute(stmt).fetchone() return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"} def add_urls(self, task_id: int, urls: list[str]): success_urls = [] with self.db.engine.begin() as conn: for url in urls: clean_url = normalize_url(url) try: check_q = select(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ) if not conn.execute(check_q).fetchone(): conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending')) success_urls.append(clean_url) except Exception: pass return {"msg": f"Added {len(success_urls)} new urls"} def get_pending_urls(self, task_id: int, limit: int): with self.db.engine.begin() as conn: query = select(self.db.queue.c.url).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') ).limit(limit) urls = [r[0] for r in conn.execute(query).fetchall()] if urls: conn.execute(update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls)) ).values(status='processing')) return {"urls": urls, "msg": "Fetched pending urls"} def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list): """ 保存切片数据 (Phase 1.5: 支持 meta_info) """ clean_url = normalize_url(source_url) count = 0 with self.db.engine.begin() as conn: for item in chunks_data: # item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict} idx = item['index'] meta = item.get('meta_info', {}) # 检查是否存在 existing = conn.execute(select(self.db.chunks.c.id).where( and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == clean_url, self.db.chunks.c.chunk_index == idx) )).fetchone() values = { "task_id": task_id, "source_url": clean_url, "chunk_index": idx, "title": title, "content": item['content'], "embedding": item['embedding'], "meta_info": meta } if existing: conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values)) else: conn.execute(insert(self.db.chunks).values(**values)) count += 1 # 标记队列完成 conn.execute(update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ).values(status='completed')) return {"msg": f"Saved {count} chunks", "count": count} def search(self, vector: list, task_id: int = None, limit: int = 5): with self.db.engine.connect() as conn: stmt = select( self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, self.db.chunks.c.content, self.db.chunks.c.meta_info ).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit) if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id) rows = conn.execute(stmt).fetchall() results = [ {"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]} for r in rows ] return {"results": results, "msg": f"Found {len(results)}"} data_service = DataService()