91 lines
3.5 KiB
Python
91 lines
3.5 KiB
Python
|
|
from sqlalchemy import select, update, and_
|
|||
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|||
|
|
from .database import db_instance
|
|||
|
|
from .utils import normalize_url
|
|||
|
|
|
|||
|
|
class CrawlerService:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.db = db_instance
|
|||
|
|
|
|||
|
|
def register_task(self, url: str):
|
|||
|
|
"""注册新任务并初始化队列"""
|
|||
|
|
clean_url = normalize_url(url)
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
# 1. 查重
|
|||
|
|
find_stmt = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
|||
|
|
existing = conn.execute(find_stmt).fetchone()
|
|||
|
|
|
|||
|
|
if existing:
|
|||
|
|
return {"task_id": existing[0], "is_new_task": False}
|
|||
|
|
|
|||
|
|
# 2. 插入新任务
|
|||
|
|
new_task = conn.execute(
|
|||
|
|
pg_insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
|||
|
|
).fetchone()
|
|||
|
|
task_id = new_task[0]
|
|||
|
|
|
|||
|
|
# 3. 初始化首个 URL 到队列
|
|||
|
|
conn.execute(
|
|||
|
|
pg_insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending')
|
|||
|
|
)
|
|||
|
|
return {"task_id": task_id, "is_new_task": True}
|
|||
|
|
|
|||
|
|
def add_urls(self, task_id: int, urls: list):
|
|||
|
|
"""批量存入新发现的待处理 URL(自动去重)"""
|
|||
|
|
added_count = 0
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
for url in urls:
|
|||
|
|
clean_url = normalize_url(url)
|
|||
|
|
stmt = pg_insert(self.db.queue).values(
|
|||
|
|
task_id=task_id,
|
|||
|
|
url=clean_url,
|
|||
|
|
status='pending'
|
|||
|
|
).on_conflict_do_nothing(index_elements=['task_id', 'url'])
|
|||
|
|
|
|||
|
|
res = conn.execute(stmt)
|
|||
|
|
if res.rowcount > 0:
|
|||
|
|
added_count += 1
|
|||
|
|
return {"added_count": added_count}
|
|||
|
|
|
|||
|
|
def get_pending_urls(self, task_id: int, limit: int):
|
|||
|
|
"""原子化获取待处理 URL 并锁定"""
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
stmt = select(self.db.queue.c.url).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
|||
|
|
).limit(limit)
|
|||
|
|
|
|||
|
|
urls = [r[0] for r in conn.execute(stmt).fetchall()]
|
|||
|
|
|
|||
|
|
if urls:
|
|||
|
|
conn.execute(
|
|||
|
|
update(self.db.queue).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
|||
|
|
).values(status='processing')
|
|||
|
|
)
|
|||
|
|
return {"urls": urls}
|
|||
|
|
|
|||
|
|
def save_results(self, task_id: int, results: list):
|
|||
|
|
"""保存正文、向量并闭环队列状态"""
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
for res in results:
|
|||
|
|
clean_url = normalize_url(res.url)
|
|||
|
|
# 存入数据
|
|||
|
|
conn.execute(
|
|||
|
|
pg_insert(self.db.chunks).values(
|
|||
|
|
task_id=task_id,
|
|||
|
|
source_url=clean_url,
|
|||
|
|
title=res.title,
|
|||
|
|
content=res.content,
|
|||
|
|
embedding=res.embedding
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
# 更新状态
|
|||
|
|
conn.execute(
|
|||
|
|
update(self.db.queue).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
|||
|
|
).values(status='completed')
|
|||
|
|
)
|
|||
|
|
return {"inserted": len(results)}
|
|||
|
|
|
|||
|
|
# 全局单例
|
|||
|
|
crawler_service = CrawlerService()
|