# service.py from sqlalchemy import select, insert, update, delete, and_ from .database import db_instance from .utils import normalize_url class CrawlerService: def __init__(self): self.db = db_instance def register_task(self, url: str): """完全使用库 API 实现的注册""" clean_url = normalize_url(url) with self.db.engine.begin() as conn: # 使用 select() API query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) existing = conn.execute(query).fetchone() if existing: return {"task_id": existing[0], "is_new_task": False} # 使用 insert() API stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) new_task = conn.execute(stmt).fetchone() return {"task_id": new_task[0], "is_new_task": True} def add_urls(self, task_id: int, urls: list): """通用 API 实现的批量添加(含详细返回)""" success_urls, skipped_urls, failed_urls = [], [], [] with self.db.engine.begin() as conn: for url in urls: clean_url = normalize_url(url) try: # 检查是否存在 (通用写法) check_q = select(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ) if conn.execute(check_q).fetchone(): skipped_urls.append(clean_url) continue # 插入新 URL conn.execute(insert(self.db.queue).values( task_id=task_id, url=clean_url, status='pending' )) success_urls.append(clean_url) except Exception: failed_urls.append(clean_url) return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls} def get_pending_urls(self, task_id: int, limit: int): """原子锁定 API 实现""" with self.db.engine.begin() as conn: query = select(self.db.queue.c.url).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') ).limit(limit) urls = [r[0] for r in conn.execute(query).fetchall()] if urls: upd = update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls)) ).values(status='processing') conn.execute(upd) return {"urls": urls} def save_results(self, task_id: int, results: list): """通用 API 实现的 UPSERT 逻辑:区分插入、更新、失败""" inserted_urls, updated_urls, failed_urls = [], [], [] with self.db.engine.begin() as conn: for res in results: # 适配 Dify 传来的字典或对象 data = res if isinstance(res, dict) else res.__dict__ clean_url = normalize_url(data.get('source_url')) c_idx = data.get('chunk_index') try: # 1. 检查是否存在该切片 find_q = select(self.db.chunks).where( and_( self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == clean_url, self.db.chunks.c.chunk_index == c_idx ) ) existing = conn.execute(find_q).fetchone() if existing: # 2. 执行更新 API upd = update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values( title=data.get('title'), content=data.get('content'), embedding=data.get('embedding') ) conn.execute(upd) updated_urls.append(clean_url) else: # 3. 执行插入 API ins = insert(self.db.chunks).values( task_id=task_id, source_url=clean_url, chunk_index=c_idx, title=data.get('title'), content=data.get('content'), embedding=data.get('embedding') ) conn.execute(ins) inserted_urls.append(clean_url) # 4. 更新队列状态 conn.execute(update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ).values(status='completed')) except Exception as e: print(f"Error: {e}") failed_urls.append(clean_url) return { "inserted_urls": list(set(inserted_urls)), "updated_urls": list(set(updated_urls)), "failed_urls": failed_urls } crawler_service = CrawlerService()