Files
wiki_crawler/nodes/register.py

136 lines
4.2 KiB
Python
Raw Normal View History

2025-12-20 17:08:54 +08:00
import json
from urllib.parse import urlparse, urlunparse
from langchain_community.utilities import SQLDatabase
from sqlalchemy import Table, MetaData, select, insert
# --- 1. 工具函数URL 标准化 ---
def normalize_url(url: str) -> str:
"""
标准化 URL解决末尾斜杠大小写锚点造成的重复问题
"""
if not url:
return ""
# 去除前后空格
url = url.strip()
# 解析 URL
parsed = urlparse(url)
# 1. 协议和域名转小写
scheme = parsed.scheme.lower()
netloc = parsed.netloc.lower()
# 2. 路径处理:去除末尾斜杠
path = parsed.path
if path.endswith('/'):
path = path.rstrip('/')
# 3. 忽略 Query 参数排序或 Fragment (#)
# 这里保留 Query 参数,但丢弃 Fragment因为锚点指向同一页面
query = parsed.query
# 重新拼接
return urlunparse((scheme, netloc, path, parsed.params, query, ""))
# --- 2. 数据库连接工厂 ---
def get_db_connection(db_url: str):
"""
获取通用数据库连接处理协议兼容性
"""
if not db_url:
raise ValueError("数据库连接字符串 (db_url) 不能为空")
# 修复常见协议头报错
if db_url.startswith("postgres://"):
db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1)
elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url:
db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1)
try:
# pool_pre_ping=True 用于在获取连接前检查有效性,防止超时断开
return SQLDatabase.from_uri(db_url, engine_args={"pool_pre_ping": True})
except Exception as e:
raise RuntimeError(f"数据库连接失败: {str(e)}")
# --- 3. 核心业务逻辑 ---
def _logic_handler(db: SQLDatabase, inputs: dict):
"""
业务逻辑注册任务并初始化队列
"""
engine = db._engine
metadata = MetaData()
# 标准化输入 URL
raw_url = inputs.get("url", "")
if not raw_url:
raise ValueError("输入参数 'url' 缺失")
clean_url = normalize_url(raw_url)
# 反射获取表结构无需写SQL
tasks_table = Table('crawl_tasks', metadata, autoload_with=engine)
queue_table = Table('crawl_queue', metadata, autoload_with=engine)
with engine.begin() as conn:
# 1. 查询该 root_url 是否已存在
find_stmt = select(tasks_table.c.id).where(tasks_table.c.root_url == clean_url)
existing_task = conn.execute(find_stmt).fetchone()
if existing_task:
# 任务已存在,直接返回
return {
"task_id": existing_task[0],
"is_new_task": 0,
"url": clean_url
}
# 2. 任务不存在,创建新任务记录
# .returning(tasks_table.c.id) 是 PostgreSQL 获取自增 ID 的标准写法
insert_task_stmt = insert(tasks_table).values(
root_url=clean_url,
status='running'
).returning(tasks_table.c.id)
new_task_id = conn.execute(insert_task_stmt).fetchone()[0]
# 3. 初始化任务队列:将根 URL 作为第一条待爬取数据
# 确保根 URL 也经过标准化处理
insert_queue_stmt = insert(queue_table).values(
task_id=new_task_id,
url=clean_url,
status='pending'
)
conn.execute(insert_queue_stmt)
return {
"task_id": new_task_id,
"is_new_task": 1,
"url": clean_url
}
# --- 4. Dify 节点主入口 ---
def main(url: str, DB_URL: str):
"""
Dify 节点入口函数
"""
ret = {"code": 0, "msg": "unknown", "data": None}
# 从输入或环境变量获取数据库地址
db_url = DB_URL
try:
# 获取连接
db = get_db_connection(db_url)
# 处理逻辑
result_data = _logic_handler(db, url)
ret["code"] = 1
ret["msg"] = "注册成功"
ret["data"] = result_data
except Exception as e:
ret["code"] = 0
ret["msg"] = str(e)
ret["data"] = None
return {
"code": ret["code"],
"msg": ret["msg"],
"data": ret["data"]
}