from sqlalchemy import select, update, and_ from sqlalchemy.dialects.postgresql import insert as pg_insert from .database import db_instance from .utils import normalize_url class CrawlerService: def __init__(self): self.db = db_instance def register_task(self, url: str): """注册新任务并初始化队列""" clean_url = normalize_url(url) with self.db.engine.begin() as conn: # 1. 查重 find_stmt = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) existing = conn.execute(find_stmt).fetchone() if existing: return {"task_id": existing[0], "is_new_task": False} # 2. 插入新任务 new_task = conn.execute( pg_insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) ).fetchone() task_id = new_task[0] # 3. 初始化首个 URL 到队列 conn.execute( pg_insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending') ) return {"task_id": task_id, "is_new_task": True} def add_urls(self, task_id: int, urls: list): """批量存入新发现的待处理 URL(自动去重)""" added_count = 0 with self.db.engine.begin() as conn: for url in urls: clean_url = normalize_url(url) stmt = pg_insert(self.db.queue).values( task_id=task_id, url=clean_url, status='pending' ).on_conflict_do_nothing(index_elements=['task_id', 'url']) res = conn.execute(stmt) if res.rowcount > 0: added_count += 1 return {"added_count": added_count} def get_pending_urls(self, task_id: int, limit: int): """原子化获取待处理 URL 并锁定""" with self.db.engine.begin() as conn: stmt = select(self.db.queue.c.url).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') ).limit(limit) urls = [r[0] for r in conn.execute(stmt).fetchall()] if urls: conn.execute( update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls)) ).values(status='processing') ) return {"urls": urls} def save_results(self, task_id: int, results: list): """保存正文、向量并闭环队列状态""" with self.db.engine.begin() as conn: for res in results: clean_url = normalize_url(res.url) # 存入数据 conn.execute( pg_insert(self.db.chunks).values( task_id=task_id, source_url=clean_url, title=res.title, content=res.content, embedding=res.embedding ) ) # 更新状态 conn.execute( update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ).values(status='completed') ) return {"inserted": len(results)} # 全局单例 crawler_service = CrawlerService()