111 lines
4.8 KiB
Python
111 lines
4.8 KiB
Python
|
|
from sqlalchemy import select, insert, update, and_
|
|||
|
|
from backend.core.database import db
|
|||
|
|
from backend.utils.common import normalize_url
|
|||
|
|
|
|||
|
|
class DataService:
|
|||
|
|
"""
|
|||
|
|
数据持久化服务层
|
|||
|
|
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
|
|||
|
|
"""
|
|||
|
|
def __init__(self):
|
|||
|
|
self.db = db
|
|||
|
|
|
|||
|
|
def register_task(self, url: str):
|
|||
|
|
clean_url = normalize_url(url)
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
|||
|
|
existing = conn.execute(query).fetchone()
|
|||
|
|
|
|||
|
|
if existing:
|
|||
|
|
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
|
|||
|
|
else:
|
|||
|
|
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
|||
|
|
new_task = conn.execute(stmt).fetchone()
|
|||
|
|
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
|
|||
|
|
|
|||
|
|
def add_urls(self, task_id: int, urls: list[str]):
|
|||
|
|
success_urls = []
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
for url in urls:
|
|||
|
|
clean_url = normalize_url(url)
|
|||
|
|
try:
|
|||
|
|
check_q = select(self.db.queue).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
|||
|
|
)
|
|||
|
|
if not conn.execute(check_q).fetchone():
|
|||
|
|
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
|
|||
|
|
success_urls.append(clean_url)
|
|||
|
|
except Exception:
|
|||
|
|
pass
|
|||
|
|
return {"msg": f"Added {len(success_urls)} new urls"}
|
|||
|
|
|
|||
|
|
def get_pending_urls(self, task_id: int, limit: int):
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
query = select(self.db.queue.c.url).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
|||
|
|
).limit(limit)
|
|||
|
|
urls = [r[0] for r in conn.execute(query).fetchall()]
|
|||
|
|
|
|||
|
|
if urls:
|
|||
|
|
conn.execute(update(self.db.queue).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
|||
|
|
).values(status='processing'))
|
|||
|
|
|
|||
|
|
return {"urls": urls, "msg": "Fetched pending urls"}
|
|||
|
|
|
|||
|
|
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
|
|||
|
|
"""
|
|||
|
|
保存切片数据 (Phase 1.5: 支持 meta_info)
|
|||
|
|
"""
|
|||
|
|
clean_url = normalize_url(source_url)
|
|||
|
|
count = 0
|
|||
|
|
with self.db.engine.begin() as conn:
|
|||
|
|
for item in chunks_data:
|
|||
|
|
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
|
|||
|
|
idx = item['index']
|
|||
|
|
meta = item.get('meta_info', {})
|
|||
|
|
|
|||
|
|
# 检查是否存在
|
|||
|
|
existing = conn.execute(select(self.db.chunks.c.id).where(
|
|||
|
|
and_(self.db.chunks.c.task_id == task_id,
|
|||
|
|
self.db.chunks.c.source_url == clean_url,
|
|||
|
|
self.db.chunks.c.chunk_index == idx)
|
|||
|
|
)).fetchone()
|
|||
|
|
|
|||
|
|
values = {
|
|||
|
|
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
|
|||
|
|
"title": title, "content": item['content'], "embedding": item['embedding'],
|
|||
|
|
"meta_info": meta
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if existing:
|
|||
|
|
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
|
|||
|
|
else:
|
|||
|
|
conn.execute(insert(self.db.chunks).values(**values))
|
|||
|
|
count += 1
|
|||
|
|
|
|||
|
|
# 标记队列完成
|
|||
|
|
conn.execute(update(self.db.queue).where(
|
|||
|
|
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
|||
|
|
).values(status='completed'))
|
|||
|
|
|
|||
|
|
return {"msg": f"Saved {count} chunks", "count": count}
|
|||
|
|
|
|||
|
|
def search(self, vector: list, task_id: int = None, limit: int = 5):
|
|||
|
|
with self.db.engine.connect() as conn:
|
|||
|
|
stmt = select(
|
|||
|
|
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
|||
|
|
self.db.chunks.c.content, self.db.chunks.c.meta_info
|
|||
|
|
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
|
|||
|
|
|
|||
|
|
if task_id:
|
|||
|
|
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
|||
|
|
|
|||
|
|
rows = conn.execute(stmt).fetchall()
|
|||
|
|
results = [
|
|||
|
|
{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]}
|
|||
|
|
for r in rows
|
|||
|
|
]
|
|||
|
|
return {"results": results, "msg": f"Found {len(results)}"}
|
|||
|
|
|
|||
|
|
data_service = DataService()
|