import json from urllib.parse import urlparse, urlunparse from langchain_community.utilities import SQLDatabase from sqlalchemy import Table, MetaData, select, insert # --- 1. 工具函数:URL 标准化 --- def normalize_url(url: str) -> str: """ 标准化 URL,解决末尾斜杠、大小写、锚点造成的重复问题 """ if not url: return "" # 去除前后空格 url = url.strip() # 解析 URL parsed = urlparse(url) # 1. 协议和域名转小写 scheme = parsed.scheme.lower() netloc = parsed.netloc.lower() # 2. 路径处理:去除末尾斜杠 path = parsed.path if path.endswith('/'): path = path.rstrip('/') # 3. 忽略 Query 参数排序或 Fragment (#) # 这里保留 Query 参数,但丢弃 Fragment,因为锚点指向同一页面 query = parsed.query # 重新拼接 return urlunparse((scheme, netloc, path, parsed.params, query, "")) # --- 2. 数据库连接工厂 --- def get_db_connection(db_url: str): """ 获取通用数据库连接,处理协议兼容性 """ if not db_url: raise ValueError("数据库连接字符串 (db_url) 不能为空") # 修复常见协议头报错 if db_url.startswith("postgres://"): db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1) elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url: db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1) try: # pool_pre_ping=True 用于在获取连接前检查有效性,防止超时断开 return SQLDatabase.from_uri(db_url, engine_args={"pool_pre_ping": True}) except Exception as e: raise RuntimeError(f"数据库连接失败: {str(e)}") # --- 3. 核心业务逻辑 --- def _logic_handler(db: SQLDatabase, inputs: dict): """ 业务逻辑:注册任务并初始化队列 """ engine = db._engine metadata = MetaData() # 标准化输入 URL raw_url = inputs.get("url", "") if not raw_url: raise ValueError("输入参数 'url' 缺失") clean_url = normalize_url(raw_url) # 反射获取表结构(无需写SQL) tasks_table = Table('crawl_tasks', metadata, autoload_with=engine) queue_table = Table('crawl_queue', metadata, autoload_with=engine) with engine.begin() as conn: # 1. 查询该 root_url 是否已存在 find_stmt = select(tasks_table.c.id).where(tasks_table.c.root_url == clean_url) existing_task = conn.execute(find_stmt).fetchone() if existing_task: # 任务已存在,直接返回 return { "task_id": existing_task[0], "is_new_task": 0, "url": clean_url } # 2. 任务不存在,创建新任务记录 # .returning(tasks_table.c.id) 是 PostgreSQL 获取自增 ID 的标准写法 insert_task_stmt = insert(tasks_table).values( root_url=clean_url, status='running' ).returning(tasks_table.c.id) new_task_id = conn.execute(insert_task_stmt).fetchone()[0] # 3. 初始化任务队列:将根 URL 作为第一条待爬取数据 # 确保根 URL 也经过标准化处理 insert_queue_stmt = insert(queue_table).values( task_id=new_task_id, url=clean_url, status='pending' ) conn.execute(insert_queue_stmt) return { "task_id": new_task_id, "is_new_task": 1, "url": clean_url } # --- 4. Dify 节点主入口 --- def main(url: str, DB_URL: str): """ Dify 节点入口函数 """ ret = {"code": 0, "msg": "unknown", "data": None} # 从输入或环境变量获取数据库地址 db_url = DB_URL try: # 获取连接 db = get_db_connection(db_url) # 处理逻辑 result_data = _logic_handler(db, url) ret["code"] = 1 ret["msg"] = "注册成功" ret["data"] = result_data except Exception as e: ret["code"] = 0 ret["msg"] = str(e) ret["data"] = None return { "code": ret["code"], "msg": ret["msg"], "data": ret["data"] }