from sqlalchemy import select, insert, update, and_ from ..database import db_instance from ..utils import normalize_url class CrawlerSqlService: def __init__(self): self.db = db_instance def register_task(self, url: str): """完全使用库 API 实现的注册""" clean_url = normalize_url(url) result = {} with self.db.engine.begin() as conn: # 使用 select() API query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) existing = conn.execute(query).fetchone() if existing: result = { "task_id": existing[0], "is_new_task": False, "msg": "Task already exists" } else: # 使用 insert() API stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) new_task = conn.execute(stmt).fetchone() result = { "task_id": new_task[0], "is_new_task": True, "msg": "New task created successfully" } return result def add_urls(self, task_id: int, urls: list[str]): """通用 API 实现的批量添加(含详细返回)""" success_urls, skipped_urls, failed_urls = [], [], [] with self.db.engine.begin() as conn: for url in urls: clean_url = normalize_url(url) try: # 检查队列中是否已存在该 URL check_q = select(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) ) if conn.execute(check_q).fetchone(): skipped_urls.append(clean_url) continue # 插入新 URL conn.execute(insert(self.db.queue).values( task_id=task_id, url=clean_url, status='pending' )) success_urls.append(clean_url) except Exception: failed_urls.append(clean_url) # 构造返回消息 msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}" return { "success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls, "msg": msg } def get_pending_urls(self, task_id: int, limit: int): """原子锁定 API 实现""" result = {} with self.db.engine.begin() as conn: query = select(self.db.queue.c.url).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') ).limit(limit) urls = [r[0] for r in conn.execute(query).fetchall()] if urls: upd = update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls)) ).values(status='processing') conn.execute(upd) result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"} else: result = {"urls": [], "msg": "Queue is empty"} return result def save_results(self, task_id: int, results: list): """ 保存同一 URL 的多个切片。 """ if not results: return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}} # 1. 基础信息提取 first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__ target_url = normalize_url(first_item.get('source_url')) # 结果统计容器 inserted_chunks = [] updated_chunks = [] failed_chunks = [] is_page_update = False with self.db.engine.begin() as conn: # 2. 判断该 URL 是否已经有切片存在 check_page_stmt = select(self.db.chunks.c.id).where( and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url) ).limit(1) if conn.execute(check_page_stmt).fetchone(): is_page_update = True # 3. 逐个处理切片 for res in results: data = res if isinstance(res, dict) else res.__dict__ c_idx = data.get('chunk_index') try: # 检查切片是否存在 find_chunk_stmt = select(self.db.chunks.c.id).where( and_( self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url, self.db.chunks.c.chunk_index == c_idx ) ) existing_chunk = conn.execute(find_chunk_stmt).fetchone() if existing_chunk: # 覆盖更新 upd_stmt = update(self.db.chunks).where( self.db.chunks.c.id == existing_chunk[0] ).values( title=data.get('title'), content=data.get('content'), embedding=data.get('embedding') ) conn.execute(upd_stmt) updated_chunks.append(c_idx) else: # 插入新切片 ins_stmt = insert(self.db.chunks).values( task_id=task_id, source_url=target_url, chunk_index=c_idx, title=data.get('title'), content=data.get('content'), embedding=data.get('embedding') ) conn.execute(ins_stmt) inserted_chunks.append(c_idx) except Exception as e: print(f"Chunk {c_idx} failed: {e}") failed_chunks.append(c_idx) # 4. 更新队列状态 conn.execute( update(self.db.queue).where( and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url) ).values(status='completed') ) # 构造返回 msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}" return { "source_url": target_url, "is_page_update": is_page_update, "detail": { "inserted_chunk_indexes": inserted_chunks, "updated_chunk_indexes": updated_chunks, "failed_chunk_indexes": failed_chunks }, "counts": { "inserted": len(inserted_chunks), "updated": len(updated_chunks), "failed": len(failed_chunks) }, "msg": msg } def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5): """ 高性能向量搜索方法 """ results = [] msg = "" with self.db.engine.connect() as conn: stmt = select( self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, self.db.chunks.c.content, self.db.chunks.c.chunk_index ) if task_id is not None: stmt = stmt.where(self.db.chunks.c.task_id == task_id) stmt = stmt.order_by( self.db.chunks.c.embedding.cosine_distance(query_embedding) ).limit(limit) rows = conn.execute(stmt).fetchall() for r in rows: results.append({ "task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "chunk_index": r[4] }) if results: msg = f"Found {len(results)} matches" else: msg = "No matching content found" return {"results": results, "msg": msg} crawler_sql_service = CrawlerSqlService()