Files
wiki_crawler/backend/service.py

207 lines
8.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# service.py
from sqlalchemy import select, insert, update, delete, and_
from .database import db_instance
from .utils import normalize_url
class CrawlerService:
def __init__(self):
self.db = db_instance
def register_task(self, url: str):
"""完全使用库 API 实现的注册"""
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
# 使用 select() API
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False}
# 使用 insert() API
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
def add_urls(self, task_id: int, urls_obj: dict):
"""通用 API 实现的批量添加(含详细返回)"""
success_urls, skipped_urls, failed_urls = [], [], []
# 从 urls_obj 中提取 urls 列表
urls = urls_obj.get("urls", [])
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
# 检查队列中是否已存在该 URL (通用写法)
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if conn.execute(check_q).fetchone():
skipped_urls.append(clean_url)
continue
# 插入新 URL
conn.execute(insert(self.db.queue).values(
task_id=task_id, url=clean_url, status='pending'
))
success_urls.append(clean_url)
except Exception:
failed_urls.append(clean_url)
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
def get_pending_urls(self, task_id: int, limit: int):
"""原子锁定 API 实现"""
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
if urls:
upd = update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
conn.execute(upd)
return {"urls": urls}
def save_results(self, task_id: int, results: list):
"""
保存同一 URL 的多个切片。
返回:该 URL 下切片的详细处理统计及页面更新状态。
"""
if not results:
return {"msg": "No data provided"}
# 1. 基础信息提取 (假设 results 里的 source_url 都是一致的)
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
target_url = normalize_url(first_item.get('source_url'))
# 结果统计容器
inserted_chunks = []
updated_chunks = []
failed_chunks = []
is_page_update = False
with self.db.engine.begin() as conn:
# 2. 判断该 URL 是否已经有切片存在 (以此判定是否为“页面更新”)
check_page_stmt = select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
).limit(1)
if conn.execute(check_page_stmt).fetchone():
is_page_update = True
# 3. 逐个处理切片
for res in results:
data = res if isinstance(res, dict) else res.__dict__
c_idx = data.get('chunk_index')
try:
# 检查具体某个 index 的切片是否存在
find_chunk_stmt = select(self.db.chunks.c.id).where(
and_(
self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == target_url,
self.db.chunks.c.chunk_index == c_idx
)
)
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
if existing_chunk:
# 覆盖更新现有切片
upd_stmt = update(self.db.chunks).where(
self.db.chunks.c.id == existing_chunk[0]
).values(
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(upd_stmt)
updated_chunks.append(c_idx)
else:
# 插入新切片
ins_stmt = insert(self.db.chunks).values(
task_id=task_id,
source_url=target_url,
chunk_index=c_idx,
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(ins_stmt)
inserted_chunks.append(c_idx)
except Exception as e:
print(f"Chunk {c_idx} failed: {e}")
failed_chunks.append(c_idx)
# 4. 最终更新队列状态
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
).values(status='completed')
)
return {
"source_url": target_url,
"is_page_update": is_page_update, # 标志:此页面此前是否有过内容
"detail": {
"inserted_chunk_indexes": inserted_chunks,
"updated_chunk_indexes": updated_chunks,
"failed_chunk_indexes": failed_chunks
},
"counts": {
"inserted": len(inserted_chunks),
"updated": len(updated_chunks),
"failed": len(failed_chunks)
}
}
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
:param query_embedding: 问题的向量
:param task_id: 可选的任务ID不传则搜全表
:param limit: 返回结果数量
"""
with self.db.engine.connect() as conn:
# 1. 选择需要的字段
# 我们同时返回 task_id方便在全库搜索时知道来源哪个任务
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.chunk_index
)
# 2. 动态添加过滤条件
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 3. 按余弦距离排序1 - 余弦相似度)
# 距离越小,相似度越高
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
# 4. 执行并解析结果
rows = conn.execute(stmt).fetchall()
results = []
for r in rows:
results.append({
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"chunk_index": r[4]
})
return results
crawler_service = CrawlerService()