删掉杂项
This commit is contained in:
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
__pycache__/
|
||||||
|
.venv
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,22 +1,32 @@
|
|||||||
from sqlalchemy import create_engine, MetaData, Table
|
from sqlalchemy import create_engine, MetaData, Table, event
|
||||||
|
from pgvector.sqlalchemy import Vector # 必须导入这个
|
||||||
from .config import settings
|
from .config import settings
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
# 1. 创建引擎
|
||||||
self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
|
self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
|
||||||
|
|
||||||
|
# 2. 【核心修复】手动注册 vector 类型,让反射能识别它
|
||||||
|
# 这告诉 SQLAlchemy:如果在数据库里看到名为 "vector" 的类型,请使用 pgvector 库的 Vector 类来处理
|
||||||
|
self.engine.dialect.ischema_names['vector'] = Vector
|
||||||
|
|
||||||
self.metadata = MetaData()
|
self.metadata = MetaData()
|
||||||
self.tasks = None
|
self.tasks = None
|
||||||
self.queue = None
|
self.queue = None
|
||||||
self.chunks = None
|
self.chunks = None
|
||||||
|
|
||||||
self._reflect_tables()
|
self._reflect_tables()
|
||||||
|
|
||||||
def _reflect_tables(self):
|
def _reflect_tables(self):
|
||||||
try:
|
try:
|
||||||
# 自动从数据库加载表结构
|
# 自动从数据库加载表结构
|
||||||
|
# 因为上面注册了 ischema_names,现在 chunks_table.c.embedding 就能被正确识别为 Vector 类型了
|
||||||
self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine)
|
self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine)
|
||||||
self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine)
|
self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine)
|
||||||
self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine)
|
self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"❌ 数据库表加载失败: {e}")
|
print(f"❌ 数据库表加载失败: {e}")
|
||||||
|
|
||||||
|
# 全局单例
|
||||||
db_instance = Database()
|
db_instance = Database()
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
import json
|
|
||||||
from urllib.parse import urlparse, urlunparse
|
|
||||||
from langchain_community.utilities import SQLDatabase
|
|
||||||
from sqlalchemy import Table, MetaData, select, insert
|
|
||||||
|
|
||||||
# --- 1. 工具函数:URL 标准化 ---
|
|
||||||
def normalize_url(url: str) -> str:
|
|
||||||
"""
|
|
||||||
标准化 URL,解决末尾斜杠、大小写、锚点造成的重复问题
|
|
||||||
"""
|
|
||||||
if not url:
|
|
||||||
return ""
|
|
||||||
# 去除前后空格
|
|
||||||
url = url.strip()
|
|
||||||
# 解析 URL
|
|
||||||
parsed = urlparse(url)
|
|
||||||
# 1. 协议和域名转小写
|
|
||||||
scheme = parsed.scheme.lower()
|
|
||||||
netloc = parsed.netloc.lower()
|
|
||||||
# 2. 路径处理:去除末尾斜杠
|
|
||||||
path = parsed.path
|
|
||||||
if path.endswith('/'):
|
|
||||||
path = path.rstrip('/')
|
|
||||||
# 3. 忽略 Query 参数排序或 Fragment (#)
|
|
||||||
# 这里保留 Query 参数,但丢弃 Fragment,因为锚点指向同一页面
|
|
||||||
query = parsed.query
|
|
||||||
|
|
||||||
# 重新拼接
|
|
||||||
return urlunparse((scheme, netloc, path, parsed.params, query, ""))
|
|
||||||
|
|
||||||
# --- 2. 数据库连接工厂 ---
|
|
||||||
def get_db_connection(db_url: str):
|
|
||||||
"""
|
|
||||||
获取通用数据库连接,处理协议兼容性
|
|
||||||
"""
|
|
||||||
if not db_url:
|
|
||||||
raise ValueError("数据库连接字符串 (db_url) 不能为空")
|
|
||||||
|
|
||||||
# 修复常见协议头报错
|
|
||||||
if db_url.startswith("postgres://"):
|
|
||||||
db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1)
|
|
||||||
elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url:
|
|
||||||
db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# pool_pre_ping=True 用于在获取连接前检查有效性,防止超时断开
|
|
||||||
return SQLDatabase.from_uri(db_url, engine_args={"pool_pre_ping": True})
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"数据库连接失败: {str(e)}")
|
|
||||||
|
|
||||||
# --- 3. 核心业务逻辑 ---
|
|
||||||
def _logic_handler(db: SQLDatabase, inputs: dict):
|
|
||||||
"""
|
|
||||||
业务逻辑:注册任务并初始化队列
|
|
||||||
"""
|
|
||||||
engine = db._engine
|
|
||||||
metadata = MetaData()
|
|
||||||
|
|
||||||
# 标准化输入 URL
|
|
||||||
raw_url = inputs.get("url", "")
|
|
||||||
if not raw_url:
|
|
||||||
raise ValueError("输入参数 'url' 缺失")
|
|
||||||
clean_url = normalize_url(raw_url)
|
|
||||||
|
|
||||||
# 反射获取表结构(无需写SQL)
|
|
||||||
tasks_table = Table('crawl_tasks', metadata, autoload_with=engine)
|
|
||||||
queue_table = Table('crawl_queue', metadata, autoload_with=engine)
|
|
||||||
|
|
||||||
with engine.begin() as conn:
|
|
||||||
# 1. 查询该 root_url 是否已存在
|
|
||||||
find_stmt = select(tasks_table.c.id).where(tasks_table.c.root_url == clean_url)
|
|
||||||
existing_task = conn.execute(find_stmt).fetchone()
|
|
||||||
|
|
||||||
if existing_task:
|
|
||||||
# 任务已存在,直接返回
|
|
||||||
return {
|
|
||||||
"task_id": existing_task[0],
|
|
||||||
"is_new_task": 0,
|
|
||||||
"url": clean_url
|
|
||||||
}
|
|
||||||
|
|
||||||
# 2. 任务不存在,创建新任务记录
|
|
||||||
# .returning(tasks_table.c.id) 是 PostgreSQL 获取自增 ID 的标准写法
|
|
||||||
insert_task_stmt = insert(tasks_table).values(
|
|
||||||
root_url=clean_url,
|
|
||||||
status='running'
|
|
||||||
).returning(tasks_table.c.id)
|
|
||||||
|
|
||||||
new_task_id = conn.execute(insert_task_stmt).fetchone()[0]
|
|
||||||
|
|
||||||
# 3. 初始化任务队列:将根 URL 作为第一条待爬取数据
|
|
||||||
# 确保根 URL 也经过标准化处理
|
|
||||||
insert_queue_stmt = insert(queue_table).values(
|
|
||||||
task_id=new_task_id,
|
|
||||||
url=clean_url,
|
|
||||||
status='pending'
|
|
||||||
)
|
|
||||||
conn.execute(insert_queue_stmt)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task_id": new_task_id,
|
|
||||||
"is_new_task": 1,
|
|
||||||
"url": clean_url
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- 4. Dify 节点主入口 ---
|
|
||||||
def main(url: str, DB_URL: str):
|
|
||||||
"""
|
|
||||||
Dify 节点入口函数
|
|
||||||
"""
|
|
||||||
ret = {"code": 0, "msg": "unknown", "data": None}
|
|
||||||
|
|
||||||
# 从输入或环境变量获取数据库地址
|
|
||||||
db_url = DB_URL
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取连接
|
|
||||||
db = get_db_connection(db_url)
|
|
||||||
|
|
||||||
# 处理逻辑
|
|
||||||
result_data = _logic_handler(db, url)
|
|
||||||
|
|
||||||
ret["code"] = 1
|
|
||||||
ret["msg"] = "注册成功"
|
|
||||||
ret["data"] = result_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ret["code"] = 0
|
|
||||||
ret["msg"] = str(e)
|
|
||||||
ret["data"] = None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"code": ret["code"],
|
|
||||||
"msg": ret["msg"],
|
|
||||||
"data": ret["data"]
|
|
||||||
}
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
import json
|
|
||||||
from urllib.parse import urlparse, urlunparse
|
|
||||||
from langchain_community.utilities import SQLDatabase
|
|
||||||
from sqlalchemy import Table, MetaData, select, insert, update, delete, and_
|
|
||||||
|
|
||||||
# --- 工具函数:URL 标准化 ---
|
|
||||||
def normalize_url(url: str) -> str:
|
|
||||||
"""
|
|
||||||
标准化 URL,确保末尾斜杠、大小写等不影响唯一性判定
|
|
||||||
"""
|
|
||||||
if not url:
|
|
||||||
return url
|
|
||||||
|
|
||||||
# 1. 解析 URL
|
|
||||||
parsed = urlparse(url.strip())
|
|
||||||
|
|
||||||
# 2. 转换协议和域名为小写 (Domain 是不区分大小写的)
|
|
||||||
scheme = parsed.scheme.lower()
|
|
||||||
netloc = parsed.netloc.lower()
|
|
||||||
|
|
||||||
# 3. 处理路径:去除末尾的斜杠
|
|
||||||
path = parsed.path
|
|
||||||
if path.endswith('/'):
|
|
||||||
path = path.rstrip('/')
|
|
||||||
|
|
||||||
# 4. 去除 Fragment (#部分),保留 Query 参数
|
|
||||||
# 如果需要忽略 Query 参数,可以将 query 设置为 ""
|
|
||||||
query = parsed.query
|
|
||||||
|
|
||||||
# 5. 重新拼接
|
|
||||||
normalized = urlunparse((scheme, netloc, path, parsed.params, query, ""))
|
|
||||||
return normalized
|
|
||||||
|
|
||||||
# --- 数据库连接工厂 ---
|
|
||||||
def get_db_connection(db_url: str):
|
|
||||||
"""
|
|
||||||
获取通用数据库连接,处理协议兼容性
|
|
||||||
"""
|
|
||||||
if db_url.startswith("postgres://"):
|
|
||||||
db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1)
|
|
||||||
elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url:
|
|
||||||
db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# engine_args 确保连接池在 Dify 高并发下更稳定
|
|
||||||
return SQLDatabase.from_uri(db_url, engine_args={
|
|
||||||
"pool_pre_ping": True,
|
|
||||||
"pool_recycle": 3600
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"DB_CONNECT_ERROR: {str(e)}")
|
|
||||||
|
|
||||||
# --- Dify 节点主入口 ---
|
|
||||||
def main(inputs: dict):
|
|
||||||
"""
|
|
||||||
Dify 节点主入口函数
|
|
||||||
"""
|
|
||||||
ret = {"code": 0, "msg": "unknown", "data": None}
|
|
||||||
|
|
||||||
# 预设数据库连接字符串 (建议在 Dify 环境变量中配置)
|
|
||||||
db_url = inputs.get("db_url")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 1. 初始化数据库
|
|
||||||
db = get_db_connection(db_url)
|
|
||||||
|
|
||||||
# 2. 执行具体的业务逻辑
|
|
||||||
result_data = _logic_handler(db, inputs)
|
|
||||||
|
|
||||||
ret["code"] = 1
|
|
||||||
ret["msg"] = "success"
|
|
||||||
ret["data"] = result_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
ret["code"] = 0
|
|
||||||
ret["msg"] = str(e)
|
|
||||||
ret["data"] = None
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
# -------------------------------------------------
|
|
||||||
# 业务逻辑处理器:每个节点只需修改这里
|
|
||||||
# -------------------------------------------------
|
|
||||||
def _logic_handler(db: SQLDatabase, inputs: dict):
|
|
||||||
"""
|
|
||||||
在这里编写具体的业务操作
|
|
||||||
"""
|
|
||||||
engine = db._engine
|
|
||||||
metadata = MetaData()
|
|
||||||
|
|
||||||
# 示例:获取并标准化 URL
|
|
||||||
raw_url = inputs.get("url", "")
|
|
||||||
clean_url = normalize_url(raw_url)
|
|
||||||
|
|
||||||
# 反射获取表对象
|
|
||||||
# tasks = Table('crawl_tasks', metadata, autoload_with=engine)
|
|
||||||
|
|
||||||
# 使用 SQLAlchemy Core 进行操作(无需写原生SQL)
|
|
||||||
# with engine.begin() as conn:
|
|
||||||
# stmt = select(tasks).where(tasks.c.root_url == clean_url)
|
|
||||||
# result = conn.execute(stmt).fetchone()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"processed_url": clean_url,
|
|
||||||
"info": "逻辑已执行"
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user