Files
wiki_crawler/backend/service.py
2025-12-20 17:08:54 +08:00

91 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from sqlalchemy import select, update, and_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from .database import db_instance
from .utils import normalize_url
class CrawlerService:
def __init__(self):
self.db = db_instance
def register_task(self, url: str):
"""注册新任务并初始化队列"""
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
# 1. 查重
find_stmt = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(find_stmt).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False}
# 2. 插入新任务
new_task = conn.execute(
pg_insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
).fetchone()
task_id = new_task[0]
# 3. 初始化首个 URL 到队列
conn.execute(
pg_insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending')
)
return {"task_id": task_id, "is_new_task": True}
def add_urls(self, task_id: int, urls: list):
"""批量存入新发现的待处理 URL自动去重"""
added_count = 0
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
stmt = pg_insert(self.db.queue).values(
task_id=task_id,
url=clean_url,
status='pending'
).on_conflict_do_nothing(index_elements=['task_id', 'url'])
res = conn.execute(stmt)
if res.rowcount > 0:
added_count += 1
return {"added_count": added_count}
def get_pending_urls(self, task_id: int, limit: int):
"""原子化获取待处理 URL 并锁定"""
with self.db.engine.begin() as conn:
stmt = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(stmt).fetchall()]
if urls:
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
)
return {"urls": urls}
def save_results(self, task_id: int, results: list):
"""保存正文、向量并闭环队列状态"""
with self.db.engine.begin() as conn:
for res in results:
clean_url = normalize_url(res.url)
# 存入数据
conn.execute(
pg_insert(self.db.chunks).values(
task_id=task_id,
source_url=clean_url,
title=res.title,
content=res.content,
embedding=res.embedding
)
)
# 更新状态
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed')
)
return {"inserted": len(results)}
# 全局单例
crawler_service = CrawlerService()