Files
wiki_crawler/backend/service.py

205 lines
8.2 KiB
Python
Raw Normal View History

2025-12-22 22:08:51 +08:00
# service.py
from sqlalchemy import select, insert, update, delete, and_
2025-12-20 17:08:54 +08:00
from .database import db_instance
from .utils import normalize_url
class CrawlerSqlService:
2025-12-20 17:08:54 +08:00
def __init__(self):
self.db = db_instance
def register_task(self, url: str):
2025-12-22 22:08:51 +08:00
"""完全使用库 API 实现的注册"""
2025-12-20 17:08:54 +08:00
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
2025-12-22 22:08:51 +08:00
# 使用 select() API
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
2025-12-20 17:08:54 +08:00
if existing:
return {"task_id": existing[0], "is_new_task": False}
2025-12-22 22:08:51 +08:00
# 使用 insert() API
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True}
2025-12-20 17:08:54 +08:00
2025-12-29 14:42:33 +08:00
def add_urls(self, task_id: int, urls: list[str]):
2025-12-22 22:08:51 +08:00
"""通用 API 实现的批量添加(含详细返回)"""
success_urls, skipped_urls, failed_urls = [], [], []
2025-12-20 17:08:54 +08:00
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
2025-12-22 22:08:51 +08:00
try:
# 检查队列中是否已存在该 URL (通用写法)
2025-12-22 22:08:51 +08:00
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if conn.execute(check_q).fetchone():
skipped_urls.append(clean_url)
continue
# 插入新 URL
conn.execute(insert(self.db.queue).values(
task_id=task_id, url=clean_url, status='pending'
))
success_urls.append(clean_url)
except Exception:
failed_urls.append(clean_url)
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
2025-12-20 17:08:54 +08:00
def get_pending_urls(self, task_id: int, limit: int):
2025-12-22 22:08:51 +08:00
"""原子锁定 API 实现"""
2025-12-20 17:08:54 +08:00
with self.db.engine.begin() as conn:
2025-12-22 22:08:51 +08:00
query = select(self.db.queue.c.url).where(
2025-12-20 17:08:54 +08:00
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
2025-12-22 22:08:51 +08:00
urls = [r[0] for r in conn.execute(query).fetchall()]
2025-12-20 17:08:54 +08:00
if urls:
2025-12-22 22:08:51 +08:00
upd = update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
conn.execute(upd)
return {"urls": urls}
2025-12-22 22:50:07 +08:00
2025-12-20 17:08:54 +08:00
def save_results(self, task_id: int, results: list):
2025-12-22 22:33:12 +08:00
"""
保存同一 URL 的多个切片
返回 URL 下切片的详细处理统计及页面更新状态
"""
if not results:
return {"msg": "No data provided"}
# 1. 基础信息提取 (假设 results 里的 source_url 都是一致的)
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
target_url = normalize_url(first_item.get('source_url'))
2025-12-22 22:08:51 +08:00
2025-12-22 22:33:12 +08:00
# 结果统计容器
inserted_chunks = []
updated_chunks = []
failed_chunks = []
is_page_update = False
2025-12-20 17:08:54 +08:00
with self.db.engine.begin() as conn:
2025-12-22 22:33:12 +08:00
# 2. 判断该 URL 是否已经有切片存在 (以此判定是否为“页面更新”)
check_page_stmt = select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
).limit(1)
if conn.execute(check_page_stmt).fetchone():
is_page_update = True
# 3. 逐个处理切片
2025-12-20 17:08:54 +08:00
for res in results:
2025-12-22 22:08:51 +08:00
data = res if isinstance(res, dict) else res.__dict__
c_idx = data.get('chunk_index')
try:
2025-12-22 22:33:12 +08:00
# 检查具体某个 index 的切片是否存在
find_chunk_stmt = select(self.db.chunks.c.id).where(
2025-12-22 22:08:51 +08:00
and_(
self.db.chunks.c.task_id == task_id,
2025-12-22 22:33:12 +08:00
self.db.chunks.c.source_url == target_url,
2025-12-22 22:08:51 +08:00
self.db.chunks.c.chunk_index == c_idx
)
2025-12-20 17:08:54 +08:00
)
2025-12-22 22:33:12 +08:00
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
2025-12-22 22:08:51 +08:00
2025-12-22 22:33:12 +08:00
if existing_chunk:
# 覆盖更新现有切片
upd_stmt = update(self.db.chunks).where(
self.db.chunks.c.id == existing_chunk[0]
).values(
2025-12-22 22:08:51 +08:00
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
2025-12-22 22:33:12 +08:00
conn.execute(upd_stmt)
updated_chunks.append(c_idx)
2025-12-22 22:08:51 +08:00
else:
2025-12-22 22:33:12 +08:00
# 插入新切片
ins_stmt = insert(self.db.chunks).values(
2025-12-22 22:08:51 +08:00
task_id=task_id,
2025-12-22 22:33:12 +08:00
source_url=target_url,
2025-12-22 22:08:51 +08:00
chunk_index=c_idx,
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
2025-12-22 22:33:12 +08:00
conn.execute(ins_stmt)
inserted_chunks.append(c_idx)
2025-12-22 22:08:51 +08:00
except Exception as e:
2025-12-22 22:33:12 +08:00
print(f"Chunk {c_idx} failed: {e}")
failed_chunks.append(c_idx)
# 4. 最终更新队列状态
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
).values(status='completed')
)
2025-12-22 22:08:51 +08:00
return {
2025-12-22 22:33:12 +08:00
"source_url": target_url,
"is_page_update": is_page_update, # 标志:此页面此前是否有过内容
"detail": {
"inserted_chunk_indexes": inserted_chunks,
"updated_chunk_indexes": updated_chunks,
"failed_chunk_indexes": failed_chunks
},
"counts": {
"inserted": len(inserted_chunks),
"updated": len(updated_chunks),
"failed": len(failed_chunks)
}
2025-12-22 22:08:51 +08:00
}
2025-12-23 00:36:49 +08:00
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
:param query_embedding: 问题的向量
:param task_id: 可选的任务ID不传则搜全表
:param limit: 返回结果数量
"""
with self.db.engine.connect() as conn:
# 1. 选择需要的字段
# 我们同时返回 task_id方便在全库搜索时知道来源哪个任务
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.chunk_index
)
# 2. 动态添加过滤条件
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 3. 按余弦距离排序1 - 余弦相似度)
# 距离越小,相似度越高
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
# 4. 执行并解析结果
rows = conn.execute(stmt).fetchall()
results = []
for r in rows:
results.append({
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"chunk_index": r[4]
})
return results
2025-12-20 17:08:54 +08:00
2025-12-22 22:50:07 +08:00
crawler_sql_service = CrawlerSqlService()