diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..48e4ceb --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__/ +.venv \ No newline at end of file diff --git a/backend/__pycache__/config.cpython-313.pyc b/backend/__pycache__/config.cpython-313.pyc deleted file mode 100644 index 6505e50..0000000 Binary files a/backend/__pycache__/config.cpython-313.pyc and /dev/null differ diff --git a/backend/__pycache__/database.cpython-313.pyc b/backend/__pycache__/database.cpython-313.pyc deleted file mode 100644 index 37a2b12..0000000 Binary files a/backend/__pycache__/database.cpython-313.pyc and /dev/null differ diff --git a/backend/__pycache__/main.cpython-313.pyc b/backend/__pycache__/main.cpython-313.pyc deleted file mode 100644 index ef786ff..0000000 Binary files a/backend/__pycache__/main.cpython-313.pyc and /dev/null differ diff --git a/backend/__pycache__/schemas.cpython-313.pyc b/backend/__pycache__/schemas.cpython-313.pyc deleted file mode 100644 index a6a884b..0000000 Binary files a/backend/__pycache__/schemas.cpython-313.pyc and /dev/null differ diff --git a/backend/__pycache__/service.cpython-313.pyc b/backend/__pycache__/service.cpython-313.pyc deleted file mode 100644 index 6d7cc5a..0000000 Binary files a/backend/__pycache__/service.cpython-313.pyc and /dev/null differ diff --git a/backend/__pycache__/utils.cpython-313.pyc b/backend/__pycache__/utils.cpython-313.pyc deleted file mode 100644 index 86cab0e..0000000 Binary files a/backend/__pycache__/utils.cpython-313.pyc and /dev/null differ diff --git a/backend/database.py b/backend/database.py index df8b970..abaa07f 100644 --- a/backend/database.py +++ b/backend/database.py @@ -1,22 +1,32 @@ -from sqlalchemy import create_engine, MetaData, Table +from sqlalchemy import create_engine, MetaData, Table, event +from pgvector.sqlalchemy import Vector # 必须导入这个 from .config import settings class Database: def __init__(self): + # 1. 创建引擎 self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True) + + # 2. 【核心修复】手动注册 vector 类型,让反射能识别它 + # 这告诉 SQLAlchemy:如果在数据库里看到名为 "vector" 的类型,请使用 pgvector 库的 Vector 类来处理 + self.engine.dialect.ischema_names['vector'] = Vector + self.metadata = MetaData() self.tasks = None self.queue = None self.chunks = None + self._reflect_tables() def _reflect_tables(self): try: # 自动从数据库加载表结构 + # 因为上面注册了 ischema_names,现在 chunks_table.c.embedding 就能被正确识别为 Vector 类型了 self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine) self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine) self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine) except Exception as e: print(f"❌ 数据库表加载失败: {e}") +# 全局单例 db_instance = Database() \ No newline at end of file diff --git a/nodes/register.py b/nodes/register.py deleted file mode 100644 index 787b6d3..0000000 --- a/nodes/register.py +++ /dev/null @@ -1,136 +0,0 @@ -import json -from urllib.parse import urlparse, urlunparse -from langchain_community.utilities import SQLDatabase -from sqlalchemy import Table, MetaData, select, insert - -# --- 1. 工具函数:URL 标准化 --- -def normalize_url(url: str) -> str: - """ - 标准化 URL,解决末尾斜杠、大小写、锚点造成的重复问题 - """ - if not url: - return "" - # 去除前后空格 - url = url.strip() - # 解析 URL - parsed = urlparse(url) - # 1. 协议和域名转小写 - scheme = parsed.scheme.lower() - netloc = parsed.netloc.lower() - # 2. 路径处理:去除末尾斜杠 - path = parsed.path - if path.endswith('/'): - path = path.rstrip('/') - # 3. 忽略 Query 参数排序或 Fragment (#) - # 这里保留 Query 参数,但丢弃 Fragment,因为锚点指向同一页面 - query = parsed.query - - # 重新拼接 - return urlunparse((scheme, netloc, path, parsed.params, query, "")) - -# --- 2. 数据库连接工厂 --- -def get_db_connection(db_url: str): - """ - 获取通用数据库连接,处理协议兼容性 - """ - if not db_url: - raise ValueError("数据库连接字符串 (db_url) 不能为空") - - # 修复常见协议头报错 - if db_url.startswith("postgres://"): - db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1) - elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url: - db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1) - - try: - # pool_pre_ping=True 用于在获取连接前检查有效性,防止超时断开 - return SQLDatabase.from_uri(db_url, engine_args={"pool_pre_ping": True}) - except Exception as e: - raise RuntimeError(f"数据库连接失败: {str(e)}") - -# --- 3. 核心业务逻辑 --- -def _logic_handler(db: SQLDatabase, inputs: dict): - """ - 业务逻辑:注册任务并初始化队列 - """ - engine = db._engine - metadata = MetaData() - - # 标准化输入 URL - raw_url = inputs.get("url", "") - if not raw_url: - raise ValueError("输入参数 'url' 缺失") - clean_url = normalize_url(raw_url) - - # 反射获取表结构(无需写SQL) - tasks_table = Table('crawl_tasks', metadata, autoload_with=engine) - queue_table = Table('crawl_queue', metadata, autoload_with=engine) - - with engine.begin() as conn: - # 1. 查询该 root_url 是否已存在 - find_stmt = select(tasks_table.c.id).where(tasks_table.c.root_url == clean_url) - existing_task = conn.execute(find_stmt).fetchone() - - if existing_task: - # 任务已存在,直接返回 - return { - "task_id": existing_task[0], - "is_new_task": 0, - "url": clean_url - } - - # 2. 任务不存在,创建新任务记录 - # .returning(tasks_table.c.id) 是 PostgreSQL 获取自增 ID 的标准写法 - insert_task_stmt = insert(tasks_table).values( - root_url=clean_url, - status='running' - ).returning(tasks_table.c.id) - - new_task_id = conn.execute(insert_task_stmt).fetchone()[0] - - # 3. 初始化任务队列:将根 URL 作为第一条待爬取数据 - # 确保根 URL 也经过标准化处理 - insert_queue_stmt = insert(queue_table).values( - task_id=new_task_id, - url=clean_url, - status='pending' - ) - conn.execute(insert_queue_stmt) - - return { - "task_id": new_task_id, - "is_new_task": 1, - "url": clean_url - } - -# --- 4. Dify 节点主入口 --- -def main(url: str, DB_URL: str): - """ - Dify 节点入口函数 - """ - ret = {"code": 0, "msg": "unknown", "data": None} - - # 从输入或环境变量获取数据库地址 - db_url = DB_URL - - try: - # 获取连接 - db = get_db_connection(db_url) - - # 处理逻辑 - result_data = _logic_handler(db, url) - - ret["code"] = 1 - ret["msg"] = "注册成功" - ret["data"] = result_data - - except Exception as e: - ret["code"] = 0 - ret["msg"] = str(e) - ret["data"] = None - - return { - "code": ret["code"], - "msg": ret["msg"], - "data": ret["data"] - } \ No newline at end of file diff --git a/nodes/template.py b/nodes/template.py deleted file mode 100644 index 79cfa3e..0000000 --- a/nodes/template.py +++ /dev/null @@ -1,106 +0,0 @@ -import json -from urllib.parse import urlparse, urlunparse -from langchain_community.utilities import SQLDatabase -from sqlalchemy import Table, MetaData, select, insert, update, delete, and_ - -# --- 工具函数:URL 标准化 --- -def normalize_url(url: str) -> str: - """ - 标准化 URL,确保末尾斜杠、大小写等不影响唯一性判定 - """ - if not url: - return url - - # 1. 解析 URL - parsed = urlparse(url.strip()) - - # 2. 转换协议和域名为小写 (Domain 是不区分大小写的) - scheme = parsed.scheme.lower() - netloc = parsed.netloc.lower() - - # 3. 处理路径:去除末尾的斜杠 - path = parsed.path - if path.endswith('/'): - path = path.rstrip('/') - - # 4. 去除 Fragment (#部分),保留 Query 参数 - # 如果需要忽略 Query 参数,可以将 query 设置为 "" - query = parsed.query - - # 5. 重新拼接 - normalized = urlunparse((scheme, netloc, path, parsed.params, query, "")) - return normalized - -# --- 数据库连接工厂 --- -def get_db_connection(db_url: str): - """ - 获取通用数据库连接,处理协议兼容性 - """ - if db_url.startswith("postgres://"): - db_url = db_url.replace("postgres://", "postgresql+psycopg2://", 1) - elif db_url.startswith("postgresql://") and "+psycopg2" not in db_url: - db_url = db_url.replace("postgresql://", "postgresql+psycopg2://", 1) - - try: - # engine_args 确保连接池在 Dify 高并发下更稳定 - return SQLDatabase.from_uri(db_url, engine_args={ - "pool_pre_ping": True, - "pool_recycle": 3600 - }) - except Exception as e: - raise RuntimeError(f"DB_CONNECT_ERROR: {str(e)}") - -# --- Dify 节点主入口 --- -def main(inputs: dict): - """ - Dify 节点主入口函数 - """ - ret = {"code": 0, "msg": "unknown", "data": None} - - # 预设数据库连接字符串 (建议在 Dify 环境变量中配置) - db_url = inputs.get("db_url") - - try: - # 1. 初始化数据库 - db = get_db_connection(db_url) - - # 2. 执行具体的业务逻辑 - result_data = _logic_handler(db, inputs) - - ret["code"] = 1 - ret["msg"] = "success" - ret["data"] = result_data - - except Exception as e: - ret["code"] = 0 - ret["msg"] = str(e) - ret["data"] = None - - return ret - -# ------------------------------------------------- -# 业务逻辑处理器:每个节点只需修改这里 -# ------------------------------------------------- -def _logic_handler(db: SQLDatabase, inputs: dict): - """ - 在这里编写具体的业务操作 - """ - engine = db._engine - metadata = MetaData() - - # 示例:获取并标准化 URL - raw_url = inputs.get("url", "") - clean_url = normalize_url(raw_url) - - # 反射获取表对象 - # tasks = Table('crawl_tasks', metadata, autoload_with=engine) - - # 使用 SQLAlchemy Core 进行操作(无需写原生SQL) - # with engine.begin() as conn: - # stmt = select(tasks).where(tasks.c.root_url == clean_url) - # result = conn.execute(stmt).fetchone() - - return { - "processed_url": clean_url, - "info": "逻辑已执行" - } \ No newline at end of file