Files
wiki_crawler/nodes/register.py
2025-12-20 17:08:54 +08:00

136 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
from urllib.parse import urlparse, urlunparse
from langchain_community.utilities import SQLDatabase
from sqlalchemy import Table, MetaData, select, insert
# --- 1. 工具函数URL 标准化 ---
def normalize_url(url: str) -> str:
"""
标准化 URL解决末尾斜杠、大小写、锚点造成的重复问题
"""
if not url:
return ""
# 去除前后空格
url = url.strip()
# 解析 URL
parsed = urlparse(url)
# 1. 协议和域名转小写
scheme = parsed.scheme.lower()
netloc = parsed.netloc.lower()
# 2. 路径处理:去除末尾斜杠
path = parsed.path
if path.endswith('/'):
path = path.rstrip('/')
# 3. 忽略 Query 参数排序或 Fragment (#)
# 这里保留 Query 参数,但丢弃 Fragment因为锚点指向同一页面
query = parsed.query
# 重新拼接
return urlunparse((scheme, netloc, path, parsed.params, query, ""))
# --- 2. 数据库连接工厂 ---
def get_db_connection(db_url: str):
"""
获取通用数据库连接,处理协议兼容性
"""
if not db_url:
raise ValueError("数据库连接字符串 (db_url) 不能为空")
# 修复常见协议头报错
if db_url.startswith("postgres://"):
db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1)
elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url:
db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1)
try:
# pool_pre_ping=True 用于在获取连接前检查有效性,防止超时断开
return SQLDatabase.from_uri(db_url, engine_args={"pool_pre_ping": True})
except Exception as e:
raise RuntimeError(f"数据库连接失败: {str(e)}")
# --- 3. 核心业务逻辑 ---
def _logic_handler(db: SQLDatabase, inputs: dict):
"""
业务逻辑:注册任务并初始化队列
"""
engine = db._engine
metadata = MetaData()
# 标准化输入 URL
raw_url = inputs.get("url", "")
if not raw_url:
raise ValueError("输入参数 'url' 缺失")
clean_url = normalize_url(raw_url)
# 反射获取表结构无需写SQL
tasks_table = Table('crawl_tasks', metadata, autoload_with=engine)
queue_table = Table('crawl_queue', metadata, autoload_with=engine)
with engine.begin() as conn:
# 1. 查询该 root_url 是否已存在
find_stmt = select(tasks_table.c.id).where(tasks_table.c.root_url == clean_url)
existing_task = conn.execute(find_stmt).fetchone()
if existing_task:
# 任务已存在,直接返回
return {
"task_id": existing_task[0],
"is_new_task": 0,
"url": clean_url
}
# 2. 任务不存在,创建新任务记录
# .returning(tasks_table.c.id) 是 PostgreSQL 获取自增 ID 的标准写法
insert_task_stmt = insert(tasks_table).values(
root_url=clean_url,
status='running'
).returning(tasks_table.c.id)
new_task_id = conn.execute(insert_task_stmt).fetchone()[0]
# 3. 初始化任务队列:将根 URL 作为第一条待爬取数据
# 确保根 URL 也经过标准化处理
insert_queue_stmt = insert(queue_table).values(
task_id=new_task_id,
url=clean_url,
status='pending'
)
conn.execute(insert_queue_stmt)
return {
"task_id": new_task_id,
"is_new_task": 1,
"url": clean_url
}
# --- 4. Dify 节点主入口 ---
def main(url: str, DB_URL: str):
"""
Dify 节点入口函数
"""
ret = {"code": 0, "msg": "unknown", "data": None}
# 从输入或环境变量获取数据库地址
db_url = DB_URL
try:
# 获取连接
db = get_db_connection(db_url)
# 处理逻辑
result_data = _logic_handler(db, url)
ret["code"] = 1
ret["msg"] = "注册成功"
ret["data"] = result_data
except Exception as e:
ret["code"] = 0
ret["msg"] = str(e)
ret["data"] = None
return {
"code": ret["code"],
"msg": ret["msg"],
"data": ret["data"]
}