修改配置和response的细节
This commit is contained in:
230
backend/services/crawler_sql_service.py
Normal file
230
backend/services/crawler_sql_service.py
Normal file
@@ -0,0 +1,230 @@
|
||||
from sqlalchemy import select, insert, update, and_
|
||||
from ..database import db_instance
|
||||
from ..utils import normalize_url
|
||||
|
||||
class CrawlerSqlService:
|
||||
def __init__(self):
|
||||
self.db = db_instance
|
||||
|
||||
def register_task(self, url: str):
|
||||
"""完全使用库 API 实现的注册"""
|
||||
clean_url = normalize_url(url)
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 使用 select() API
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
result = {
|
||||
"task_id": existing[0],
|
||||
"is_new_task": False,
|
||||
"msg": "Task already exists"
|
||||
}
|
||||
else:
|
||||
# 使用 insert() API
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
result = {
|
||||
"task_id": new_task[0],
|
||||
"is_new_task": True,
|
||||
"msg": "New task created successfully"
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
"""通用 API 实现的批量添加(含详细返回)"""
|
||||
success_urls, skipped_urls, failed_urls = [], [], []
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
# 检查队列中是否已存在该 URL
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if conn.execute(check_q).fetchone():
|
||||
skipped_urls.append(clean_url)
|
||||
continue
|
||||
|
||||
# 插入新 URL
|
||||
conn.execute(insert(self.db.queue).values(
|
||||
task_id=task_id, url=clean_url, status='pending'
|
||||
))
|
||||
success_urls.append(clean_url)
|
||||
except Exception:
|
||||
failed_urls.append(clean_url)
|
||||
|
||||
# 构造返回消息
|
||||
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
|
||||
|
||||
return {
|
||||
"success_urls": success_urls,
|
||||
"skipped_urls": skipped_urls,
|
||||
"failed_urls": failed_urls,
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
"""原子锁定 API 实现"""
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
|
||||
if urls:
|
||||
upd = update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing')
|
||||
conn.execute(upd)
|
||||
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
|
||||
else:
|
||||
result = {"urls": [], "msg": "Queue is empty"}
|
||||
|
||||
return result
|
||||
|
||||
def save_results(self, task_id: int, results: list):
|
||||
"""
|
||||
保存同一 URL 的多个切片。
|
||||
"""
|
||||
if not results:
|
||||
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
|
||||
|
||||
# 1. 基础信息提取
|
||||
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
|
||||
target_url = normalize_url(first_item.get('source_url'))
|
||||
|
||||
# 结果统计容器
|
||||
inserted_chunks = []
|
||||
updated_chunks = []
|
||||
failed_chunks = []
|
||||
is_page_update = False
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 2. 判断该 URL 是否已经有切片存在
|
||||
check_page_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
|
||||
).limit(1)
|
||||
if conn.execute(check_page_stmt).fetchone():
|
||||
is_page_update = True
|
||||
|
||||
# 3. 逐个处理切片
|
||||
for res in results:
|
||||
data = res if isinstance(res, dict) else res.__dict__
|
||||
c_idx = data.get('chunk_index')
|
||||
|
||||
try:
|
||||
# 检查切片是否存在
|
||||
find_chunk_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(
|
||||
self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == target_url,
|
||||
self.db.chunks.c.chunk_index == c_idx
|
||||
)
|
||||
)
|
||||
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
|
||||
|
||||
if existing_chunk:
|
||||
# 覆盖更新
|
||||
upd_stmt = update(self.db.chunks).where(
|
||||
self.db.chunks.c.id == existing_chunk[0]
|
||||
).values(
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(upd_stmt)
|
||||
updated_chunks.append(c_idx)
|
||||
else:
|
||||
# 插入新切片
|
||||
ins_stmt = insert(self.db.chunks).values(
|
||||
task_id=task_id,
|
||||
source_url=target_url,
|
||||
chunk_index=c_idx,
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(ins_stmt)
|
||||
inserted_chunks.append(c_idx)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Chunk {c_idx} failed: {e}")
|
||||
failed_chunks.append(c_idx)
|
||||
|
||||
# 4. 更新队列状态
|
||||
conn.execute(
|
||||
update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
|
||||
).values(status='completed')
|
||||
)
|
||||
|
||||
# 构造返回
|
||||
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
|
||||
|
||||
return {
|
||||
"source_url": target_url,
|
||||
"is_page_update": is_page_update,
|
||||
"detail": {
|
||||
"inserted_chunk_indexes": inserted_chunks,
|
||||
"updated_chunk_indexes": updated_chunks,
|
||||
"failed_chunk_indexes": failed_chunks
|
||||
},
|
||||
"counts": {
|
||||
"inserted": len(inserted_chunks),
|
||||
"updated": len(updated_chunks),
|
||||
"failed": len(failed_chunks)
|
||||
},
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
高性能向量搜索方法
|
||||
"""
|
||||
results = []
|
||||
msg = ""
|
||||
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id,
|
||||
self.db.chunks.c.source_url,
|
||||
self.db.chunks.c.title,
|
||||
self.db.chunks.c.content,
|
||||
self.db.chunks.c.chunk_index
|
||||
)
|
||||
|
||||
if task_id is not None:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
stmt = stmt.order_by(
|
||||
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||
).limit(limit)
|
||||
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
|
||||
for r in rows:
|
||||
results.append({
|
||||
"task_id": r[0],
|
||||
"source_url": r[1],
|
||||
"title": r[2],
|
||||
"content": r[3],
|
||||
"chunk_index": r[4]
|
||||
})
|
||||
|
||||
if results:
|
||||
msg = f"Found {len(results)} matches"
|
||||
else:
|
||||
msg = "No matching content found"
|
||||
|
||||
return {"results": results, "msg": msg}
|
||||
|
||||
|
||||
crawler_sql_service = CrawlerSqlService()
|
||||
Reference in New Issue
Block a user