Files
wiki_crawler/backend/service.py
2025-12-22 22:08:51 +08:00

127 lines
5.3 KiB
Python

# service.py
from sqlalchemy import select, insert, update, delete, and_
from .database import db_instance
from .utils import normalize_url
class CrawlerService:
def __init__(self):
self.db = db_instance
def register_task(self, url: str):
"""完全使用库 API 实现的注册"""
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
# 使用 select() API
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False}
# 使用 insert() API
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
def add_urls(self, task_id: int, urls: list):
"""通用 API 实现的批量添加(含详细返回)"""
success_urls, skipped_urls, failed_urls = [], [], []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
# 检查是否存在 (通用写法)
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if conn.execute(check_q).fetchone():
skipped_urls.append(clean_url)
continue
# 插入新 URL
conn.execute(insert(self.db.queue).values(
task_id=task_id, url=clean_url, status='pending'
))
success_urls.append(clean_url)
except Exception:
failed_urls.append(clean_url)
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
def get_pending_urls(self, task_id: int, limit: int):
"""原子锁定 API 实现"""
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
if urls:
upd = update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
conn.execute(upd)
return {"urls": urls}
def save_results(self, task_id: int, results: list):
"""通用 API 实现的 UPSERT 逻辑:区分插入、更新、失败"""
inserted_urls, updated_urls, failed_urls = [], [], []
with self.db.engine.begin() as conn:
for res in results:
# 适配 Dify 传来的字典或对象
data = res if isinstance(res, dict) else res.__dict__
clean_url = normalize_url(data.get('source_url'))
c_idx = data.get('chunk_index')
try:
# 1. 检查是否存在该切片
find_q = select(self.db.chunks).where(
and_(
self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == clean_url,
self.db.chunks.c.chunk_index == c_idx
)
)
existing = conn.execute(find_q).fetchone()
if existing:
# 2. 执行更新 API
upd = update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(upd)
updated_urls.append(clean_url)
else:
# 3. 执行插入 API
ins = insert(self.db.chunks).values(
task_id=task_id,
source_url=clean_url,
chunk_index=c_idx,
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(ins)
inserted_urls.append(clean_url)
# 4. 更新队列状态
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
except Exception as e:
print(f"Error: {e}")
failed_urls.append(clean_url)
return {
"inserted_urls": list(set(inserted_urls)),
"updated_urls": list(set(updated_urls)),
"failed_urls": failed_urls
}
crawler_service = CrawlerService()