完成节点
This commit is contained in:
@@ -12,8 +12,10 @@ class AddUrlsRequest(BaseModel):
|
||||
task_id: int
|
||||
urls: List[str]
|
||||
|
||||
# schemas.py
|
||||
class CrawlResult(BaseModel):
|
||||
url: str
|
||||
source_url: str
|
||||
chunk_index: int # 新增字段
|
||||
title: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
embedding: Optional[List[float]] = None
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy import select, update, and_
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
# service.py
|
||||
from sqlalchemy import select, insert, update, delete, and_
|
||||
from .database import db_instance
|
||||
from .utils import normalize_url
|
||||
|
||||
@@ -8,84 +8,120 @@ class CrawlerService:
|
||||
self.db = db_instance
|
||||
|
||||
def register_task(self, url: str):
|
||||
"""注册新任务并初始化队列"""
|
||||
"""完全使用库 API 实现的注册"""
|
||||
clean_url = normalize_url(url)
|
||||
with self.db.engine.begin() as conn:
|
||||
# 1. 查重
|
||||
find_stmt = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(find_stmt).fetchone()
|
||||
# 使用 select() API
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
return {"task_id": existing[0], "is_new_task": False}
|
||||
|
||||
# 2. 插入新任务
|
||||
new_task = conn.execute(
|
||||
pg_insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
).fetchone()
|
||||
task_id = new_task[0]
|
||||
|
||||
# 3. 初始化首个 URL 到队列
|
||||
conn.execute(
|
||||
pg_insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending')
|
||||
)
|
||||
return {"task_id": task_id, "is_new_task": True}
|
||||
# 使用 insert() API
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
return {"task_id": new_task[0], "is_new_task": True}
|
||||
|
||||
def add_urls(self, task_id: int, urls: list):
|
||||
"""批量存入新发现的待处理 URL(自动去重)"""
|
||||
added_count = 0
|
||||
"""通用 API 实现的批量添加(含详细返回)"""
|
||||
success_urls, skipped_urls, failed_urls = [], [], []
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
stmt = pg_insert(self.db.queue).values(
|
||||
task_id=task_id,
|
||||
url=clean_url,
|
||||
status='pending'
|
||||
).on_conflict_do_nothing(index_elements=['task_id', 'url'])
|
||||
|
||||
res = conn.execute(stmt)
|
||||
if res.rowcount > 0:
|
||||
added_count += 1
|
||||
return {"added_count": added_count}
|
||||
try:
|
||||
# 检查是否存在 (通用写法)
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if conn.execute(check_q).fetchone():
|
||||
skipped_urls.append(clean_url)
|
||||
continue
|
||||
|
||||
# 插入新 URL
|
||||
conn.execute(insert(self.db.queue).values(
|
||||
task_id=task_id, url=clean_url, status='pending'
|
||||
))
|
||||
success_urls.append(clean_url)
|
||||
except Exception:
|
||||
failed_urls.append(clean_url)
|
||||
|
||||
return {"success_urls": success_urls, "skipped_urls": skipped_urls, "failed_urls": failed_urls}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
"""原子化获取待处理 URL 并锁定"""
|
||||
"""原子锁定 API 实现"""
|
||||
with self.db.engine.begin() as conn:
|
||||
stmt = select(self.db.queue.c.url).where(
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
|
||||
urls = [r[0] for r in conn.execute(stmt).fetchall()]
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
|
||||
if urls:
|
||||
conn.execute(
|
||||
update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing')
|
||||
)
|
||||
return {"urls": urls}
|
||||
upd = update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing')
|
||||
conn.execute(upd)
|
||||
return {"urls": urls}
|
||||
|
||||
def save_results(self, task_id: int, results: list):
|
||||
"""保存正文、向量并闭环队列状态"""
|
||||
"""通用 API 实现的 UPSERT 逻辑:区分插入、更新、失败"""
|
||||
inserted_urls, updated_urls, failed_urls = [], [], []
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
for res in results:
|
||||
clean_url = normalize_url(res.url)
|
||||
# 存入数据
|
||||
conn.execute(
|
||||
pg_insert(self.db.chunks).values(
|
||||
task_id=task_id,
|
||||
source_url=clean_url,
|
||||
title=res.title,
|
||||
content=res.content,
|
||||
embedding=res.embedding
|
||||
)
|
||||
)
|
||||
# 更新状态
|
||||
conn.execute(
|
||||
update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status='completed')
|
||||
)
|
||||
return {"inserted": len(results)}
|
||||
# 适配 Dify 传来的字典或对象
|
||||
data = res if isinstance(res, dict) else res.__dict__
|
||||
clean_url = normalize_url(data.get('source_url'))
|
||||
c_idx = data.get('chunk_index')
|
||||
|
||||
try:
|
||||
# 1. 检查是否存在该切片
|
||||
find_q = select(self.db.chunks).where(
|
||||
and_(
|
||||
self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == clean_url,
|
||||
self.db.chunks.c.chunk_index == c_idx
|
||||
)
|
||||
)
|
||||
existing = conn.execute(find_q).fetchone()
|
||||
|
||||
if existing:
|
||||
# 2. 执行更新 API
|
||||
upd = update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(upd)
|
||||
updated_urls.append(clean_url)
|
||||
else:
|
||||
# 3. 执行插入 API
|
||||
ins = insert(self.db.chunks).values(
|
||||
task_id=task_id,
|
||||
source_url=clean_url,
|
||||
chunk_index=c_idx,
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(ins)
|
||||
inserted_urls.append(clean_url)
|
||||
|
||||
# 4. 更新队列状态
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status='completed'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
failed_urls.append(clean_url)
|
||||
|
||||
return {
|
||||
"inserted_urls": list(set(inserted_urls)),
|
||||
"updated_urls": list(set(updated_urls)),
|
||||
"failed_urls": failed_urls
|
||||
}
|
||||
|
||||
# 全局单例
|
||||
crawler_service = CrawlerService()
|
||||
Reference in New Issue
Block a user