From 9190fee16f65834fbb2f80d07c022a44b70a7ceb Mon Sep 17 00:00:00 2001 From: QingGang Date: Tue, 13 Jan 2026 01:37:26 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=98=E6=9B=B4=E9=A1=B9=E7=9B=AE=E6=9E=B6?= =?UTF-8?q?=E6=9E=84=EF=BC=8C=E6=8F=90=E9=AB=98=E6=89=A9=E5=B1=95=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 7 + .gitignore | 1 + backend/config.py | 26 --- backend/core/config.py | 24 +++ backend/{ => core}/database.py | 24 ++- backend/main.py | 128 +------------ backend/routers/v1.py | 30 +++ backend/routers/v2.py | 30 +++ backend/{ => schemas}/schemas.py | 0 backend/services/automated_crawler.py | 201 -------------------- backend/services/crawler_service.py | 150 +++++++++++++++ backend/services/crawler_sql_service.py | 230 ----------------------- backend/services/data_service.py | 111 +++++++++++ backend/services/llm_service.py | 30 +++ backend/utils.py | 39 ---- backend/utils/common.py | 26 +++ backend/utils/text_process.py | 61 ++++++ docs/docker.md | 23 ++- pyproject.toml | 1 + scripts/test_apis.py | 237 +++++++++++++++--------- scripts/update_sql.py | 82 ++++++++ uv.lock | 2 + 22 files changed, 740 insertions(+), 723 deletions(-) create mode 100644 .env.example delete mode 100644 backend/config.py create mode 100644 backend/core/config.py rename backend/{ => core}/database.py (51%) create mode 100644 backend/routers/v1.py create mode 100644 backend/routers/v2.py rename backend/{ => schemas}/schemas.py (100%) delete mode 100644 backend/services/automated_crawler.py create mode 100644 backend/services/crawler_service.py delete mode 100644 backend/services/crawler_sql_service.py create mode 100644 backend/services/data_service.py create mode 100644 backend/services/llm_service.py delete mode 100644 backend/utils.py create mode 100644 backend/utils/common.py create mode 100644 backend/utils/text_process.py create mode 100644 scripts/update_sql.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..98d8f26 --- /dev/null +++ b/.env.example @@ -0,0 +1,7 @@ +DB_USER=postgres +DB_PASS= +DB_HOST=localhost +DB_PORT=5432 +DB_NAME=wiki_crawler +DASHSCOPE_API_KEY= +FIRECRAWL_API_KEY= \ No newline at end of file diff --git a/.gitignore b/.gitignore index b0577cf..2b54524 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/ .venv wiki_backend.tar +.env \ No newline at end of file diff --git a/backend/config.py b/backend/config.py deleted file mode 100644 index 473dbc5..0000000 --- a/backend/config.py +++ /dev/null @@ -1,26 +0,0 @@ -import os - -class Settings: - # 数据库配置 - DB_USER: str = "postgres" - DB_PASS: str = "DXC_welcome001" - DB_HOST: str = "8.155.144.6" - DB_PORT: str = "25432" - DB_NAME: str = "wiki_crawler" - - DASHSCOPE_API_KEY: str = "sk-8b091493de594c5e9eb42f12f1cc5805" - FIRECRAWL_API_KEY: str = "fc-8a2af3fb6a014a27a57dfbc728cb7365" - @property # property 方法,意义:将方法转换为属性,调用时不需要加括号 - def DATABASE_URL(self) -> str: - url = f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" - return url - - def API_KEY(self, type: str) -> str: - if type == "dashscope": - return self.DASHSCOPE_API_KEY - elif type == "firecrawl": - return self.FIRECRAWL_API_KEY - else: - raise ValueError(f"Unknown API type: {type}") - -settings = Settings() \ No newline at end of file diff --git a/backend/core/config.py b/backend/core/config.py new file mode 100644 index 0000000..30704a8 --- /dev/null +++ b/backend/core/config.py @@ -0,0 +1,24 @@ +from pydantic_settings import BaseSettings, SettingsConfigDict + +class Settings(BaseSettings): + """ + 系统配置类 + 自动读取环境变量或 .env 文件 + """ + DB_USER: str + DB_PASS: str + DB_HOST: str + DB_PORT: str = "5432" + DB_NAME: str + + DASHSCOPE_API_KEY: str + FIRECRAWL_API_KEY: str + + # 配置:忽略多余的环境变量,指定编码 + model_config = SettingsConfigDict(env_file=".env", extra="ignore", env_file_encoding='utf-8') + + @property + def DATABASE_URL(self) -> str: + return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" + +settings = Settings() \ No newline at end of file diff --git a/backend/database.py b/backend/core/database.py similarity index 51% rename from backend/database.py rename to backend/core/database.py index abaa07f..6e9b2e7 100644 --- a/backend/database.py +++ b/backend/core/database.py @@ -1,14 +1,19 @@ -from sqlalchemy import create_engine, MetaData, Table, event -from pgvector.sqlalchemy import Vector # 必须导入这个 +from sqlalchemy import create_engine, MetaData, Table +from pgvector.sqlalchemy import Vector from .config import settings class Database: + """ + 数据库单例类 + 负责初始化连接池并反射加载现有的表结构 + """ def __init__(self): # 1. 创建引擎 + # pool_pre_ping=True 用于解决数据库连接长时间空闲后断开的问题 self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True) - # 2. 【核心修复】手动注册 vector 类型,让反射能识别它 - # 这告诉 SQLAlchemy:如果在数据库里看到名为 "vector" 的类型,请使用 pgvector 库的 Vector 类来处理 + # 2. 注册 pgvector 类型 + # 这是为了让 SQLAlchemy 反射机制能识别数据库中的 'vector' 类型 self.engine.dialect.ischema_names['vector'] = Vector self.metadata = MetaData() @@ -19,14 +24,15 @@ class Database: self._reflect_tables() def _reflect_tables(self): + """自动从数据库加载表定义""" try: - # 自动从数据库加载表结构 - # 因为上面注册了 ischema_names,现在 chunks_table.c.embedding 就能被正确识别为 Vector 类型了 + # autoload_with 会查询数据库元数据,自动填充 Column 信息 self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine) self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine) self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine) + print("[INFO] Database tables reflected successfully.") except Exception as e: - print(f"❌ 数据库表加载失败: {e}") + print(f"[ERROR] Failed to reflect tables: {e}") -# 全局单例 -db_instance = Database() \ No newline at end of file +# 全局数据库实例 +db = Database() \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index ada7b84..6b935c3 100644 --- a/backend/main.py +++ b/backend/main.py @@ -1,131 +1,13 @@ -# backend/main.py -from fastapi import FastAPI, APIRouter, BackgroundTasks -# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow -from .services.crawler_sql_service import crawler_sql_service -from .services.automated_crawler import workflow -from .schemas import ( - RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest, - AutoMapRequest, AutoProcessRequest, TextSearchRequest -) -from .utils import make_response +from fastapi import FastAPI +from backend.routers import v1, v2 app = FastAPI(title="Wiki Crawler API") -# ========================================== -# 工具函数 -# ========================================== - - -# ========================================== -# V1 Router: 原始的底层接口 (Manual Control) -# ========================================== -router_v1 = APIRouter() - -@router_v1.post("/register") -async def register(req: RegisterRequest): - try: - # Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'} - res = crawler_sql_service.register_task(req.url) - # 使用 pop 将 msg 提取出来作为响应的 msg,剩下的作为 data - return make_response(1, res.pop("msg", "Success"), res) - except Exception as e: - return make_response(0, str(e)) - -@router_v1.post("/add_urls") -async def add_urls(req: AddUrlsRequest): - try: - urls = req.urls_obj["urls"] - res = crawler_sql_service.add_urls(req.task_id, urls=urls) - return make_response(1, res.pop("msg", "Success"), res) - except Exception as e: - return make_response(0, str(e)) - -@router_v1.post("/pending_urls") -async def pending_urls(req: PendingRequest): - try: - res = crawler_sql_service.get_pending_urls(req.task_id, req.limit) - # 即使队列为空,Service 也会返回 msg="Queue is empty" - return make_response(1, res.pop("msg", "Success"), res) - except Exception as e: - return make_response(0, str(e)) - -@router_v1.post("/save_results") -async def save_results(req: SaveResultsRequest): - try: - res = crawler_sql_service.save_results(req.task_id, req.results) - return make_response(1, res.pop("msg", "Success"), res) - except Exception as e: - return make_response(0, str(e)) - -@router_v1.post("/search") -async def search_v1(req: SearchRequest): - """V1 搜索:客户端手动传向量""" - try: - vector = req.query_embedding['vector'] - if not vector: - return make_response(2, "Vector is empty", None) - - # Service 现在返回 {'results': [...], 'msg': 'Found ...'} - res = crawler_sql_service.search_knowledge( - query_embedding=vector, - task_id=req.task_id, - limit=req.limit - ) - return make_response(1, res.pop("msg", "Search Done"), res) - except Exception as e: - return make_response(0, str(e)) - - -# ========================================== -# V2 Router: 自动化工作流 (Automated Workflow) -# ========================================== -router_v2 = APIRouter() - -@router_v2.post("/crawler/map") -async def auto_map(req: AutoMapRequest): - """ - [同步] 输入首页 URL,自动调用 Firecrawl Map 并入库 - """ - try: - # Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...} - res = workflow.map_and_ingest(req.url) - return make_response(1, res.pop("msg", "Mapping Started"), res) - except Exception as e: - return make_response(0, str(e)) - -@router_v2.post("/crawler/process") -async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTasks): - """ - [异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库 - """ - try: - # 将耗时操作放入后台任务 - background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size) - - # 因为是后台任务,无法立即获取 Service 的返回值 msg,只能返回通用消息 - return make_response(1, "Background processing started", {"task_id": req.task_id}) - except Exception as e: - return make_response(0, str(e)) - -@router_v2.post("/search") -async def search_v2(req: TextSearchRequest): - """ - [智能] 输入自然语言文本 -> 后端转向量 -> 搜索 - """ - try: - # Workflow 返回 {'results': [...], 'msg': '...'} - res = workflow.search_with_embedding(req.query, req.task_id, req.limit) - return make_response(1, res.pop("msg", "Search Success"), res) - except Exception as e: - return make_response(0, f"Search Failed: {str(e)}") - - -# ========================================== # 挂载路由 -# ========================================== -app.include_router(router_v1, prefix="/api/v1", tags=["V1 Manual API"]) -app.include_router(router_v2, prefix="/api/v2", tags=["V2 Automated Workflow"]) +app.include_router(v1.router) +app.include_router(v2.router) if __name__ == "__main__": import uvicorn + # 提示:运行方式 uv run backend/main.py 或 uvicorn backend.main:app --reload uvicorn.run(app, host="0.0.0.0", port=8000, reload=True) \ No newline at end of file diff --git a/backend/routers/v1.py b/backend/routers/v1.py new file mode 100644 index 0000000..e587896 --- /dev/null +++ b/backend/routers/v1.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter +from backend.services.data_service import data_service +from backend.utils.common import make_response +from backend.schemas.schemas import RegisterRequest, AddUrlsRequest, PendingRequest, SearchRequest + +router = APIRouter(prefix="/api/v1", tags=["V1 Manual"]) + +@router.post("/register") +async def register(req: RegisterRequest): + try: + res = data_service.register_task(req.url) + return make_response(1, res.pop("msg", "Success"), res) + except Exception as e: + return make_response(0, str(e)) + +@router.post("/add_urls") +async def add_urls(req: AddUrlsRequest): + try: + res = data_service.add_urls(req.task_id, req.urls_obj["urls"]) + return make_response(1, res.pop("msg", "Success"), res) + except Exception as e: + return make_response(0, str(e)) + +@router.post("/search") +async def search_manual(req: SearchRequest): + try: + res = data_service.search(req.query_embedding['vector'], req.task_id, req.limit) + return make_response(1, res.pop("msg", "Success"), res) + except Exception as e: + return make_response(0, str(e)) \ No newline at end of file diff --git a/backend/routers/v2.py b/backend/routers/v2.py new file mode 100644 index 0000000..981b7a0 --- /dev/null +++ b/backend/routers/v2.py @@ -0,0 +1,30 @@ +from fastapi import APIRouter, BackgroundTasks +from backend.services.crawler_service import crawler_service +from backend.utils.common import make_response +from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest + +router = APIRouter(prefix="/api/v2", tags=["V2 Automated"]) + +@router.post("/crawler/map") +async def auto_map(req: AutoMapRequest): + try: + res = crawler_service.map_site(req.url) + return make_response(1, res.pop("msg", "Started"), res) + except Exception as e: + return make_response(0, str(e)) + +@router.post("/crawler/process") +async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks): + try: + bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size) + return make_response(1, "Background processing started", {"task_id": req.task_id}) + except Exception as e: + return make_response(0, str(e)) + +@router.post("/search") +async def search_smart(req: TextSearchRequest): + try: + res = crawler_service.search(req.query, req.task_id, req.limit) + return make_response(1, res.pop("msg", "Success"), res) + except Exception as e: + return make_response(0, str(e)) \ No newline at end of file diff --git a/backend/schemas.py b/backend/schemas/schemas.py similarity index 100% rename from backend/schemas.py rename to backend/schemas/schemas.py diff --git a/backend/services/automated_crawler.py b/backend/services/automated_crawler.py deleted file mode 100644 index 3073b8c..0000000 --- a/backend/services/automated_crawler.py +++ /dev/null @@ -1,201 +0,0 @@ -import dashscope -from http import HTTPStatus -from firecrawl import FirecrawlApp -from langchain_text_splitters import RecursiveCharacterTextSplitter -from ..config import settings -from .crawler_sql_service import crawler_sql_service - -# 初始化配置 -dashscope.api_key = settings.DASHSCOPE_API_KEY - -class AutomatedCrawler: - def __init__(self): - self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY) - self.splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=100, - separators=["\n\n", "\n", "。", "!", "?", " ", ""] - ) - - def _get_embedding(self, text: str): - """内部方法:调用 Dashscope 生成向量""" - # 注意:此方法是内部辅助,出错返回 None,由调用方处理状态 - embedding = None - try: - resp = dashscope.TextEmbedding.call( - model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本 - input=text, - dimension=1536 - ) - if resp.status_code == HTTPStatus.OK: - embedding = resp.output['embeddings'][0]['embedding'] - else: - print(f"Embedding API Error: {resp}") - except Exception as e: - print(f"Embedding Exception: {e}") - - return embedding - - def map_and_ingest(self, start_url: str): - """ - V2 步骤1: 地图式扫描并入库 - """ - print(f"[WorkFlow] Start mapping: {start_url}") - result = {} - - try: - # 1. 在数据库注册任务 - task_info = crawler_sql_service.register_task(start_url) - task_id = task_info['task_id'] - is_new_task = task_info['is_new_task'] - - # 2. 调用 Firecrawl Map - if is_new_task: - map_result = self.firecrawl.map(start_url) - - urls = [] - # 兼容 firecrawl sdk 不同版本的返回结构 - # 如果 map_result 是对象且有 links 属性 - if hasattr(map_result, 'links'): - for link in map_result.links: - # 假设 link 是对象或字典,视具体 SDK 版本而定 - # 如果 link 是字符串直接 append - if isinstance(link, str): - urls.append(link) - else: - urls.append(getattr(link, 'url', str(link))) - # 如果是字典 - elif isinstance(map_result, dict): - urls = map_result.get('links', []) - - print(f"[WorkFlow] Found {len(urls)} links") - - # 3. 批量入库 - res = {"msg": "No urls found to add"} - if urls: - res = crawler_sql_service.add_urls(task_id, urls) - - result = { - "msg": "Task successfully mapped and URLs added", - "task_id": task_id, - "is_new_task": is_new_task, - "url_count": len(urls), - "map_detail": res - } - else: - result = { - "msg": "Task already exists, skipped mapping", - "task_id": task_id, - "is_new_task": False, - "url_count": 0, - "map_detail": {} - } - - except Exception as e: - print(f"[WorkFlow] Map Error: {e}") - # 向上抛出异常,由 main.py 捕获并返回错误 Response - raise e - - return result - - def process_task_queue(self, task_id: int, limit: int = 10): - """ - V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储 - """ - processed_count = 0 - total_chunks_saved = 0 - result = {} - - # 1. 获取待处理 URL - pending = crawler_sql_service.get_pending_urls(task_id, limit) - urls = pending['urls'] - - if not urls: - result = {"msg": "Queue is empty, no processing needed", "processed_count": 0} - else: - for url in urls: - try: - print(f"[WorkFlow] Processing: {url}") - # 2. 单页抓取 - scrape_res = self.firecrawl.scrape( - url, - params={'formats': ['markdown'], 'onlyMainContent': True} - ) - - # 兼容 SDK 返回类型 (对象或字典) - content = "" - metadata = {} - - if isinstance(scrape_res, dict): - content = scrape_res.get('markdown', '') - metadata = scrape_res.get('metadata', {}) - else: - content = getattr(scrape_res, 'markdown', '') - metadata = getattr(scrape_res, 'metadata', {}) - if not metadata and hasattr(scrape_res, 'metadata_dict'): - metadata = scrape_res.metadata_dict - - title = metadata.get('title', url) - - if not content: - print(f"[WorkFlow] Skip empty content: {url}") - continue - - # 3. 切片 - chunks = self.splitter.split_text(content) - results_to_save = [] - - # 4. 向量化 - for idx, chunk_text in enumerate(chunks): - vector = self._get_embedding(chunk_text) - if vector: - results_to_save.append({ - "source_url": url, - "chunk_index": idx, - "title": title, - "content": chunk_text, - "embedding": vector - }) - - # 5. 保存 - if results_to_save: - save_res = crawler_sql_service.save_results(task_id, results_to_save) - processed_count += 1 - total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated'] - - except Exception as e: - print(f"[WorkFlow] Error processing {url}: {e}") - # 此处不抛出异常,以免打断整个批次的循环 - # 实际生产建议在这里调用 service 将 url 标记为 failed - - result = { - "msg": f"Batch processing complete. URLs processed: {processed_count}", - "processed_urls": processed_count, - "total_chunks_saved": total_chunks_saved - } - - return result - - def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5): - """ - V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库 - """ - result = {} - - # 1. 获取向量 - vector = self._get_embedding(query_text) - - if not vector: - result = { - "msg": "Failed to generate embedding for query", - "results": [] - } - else: - # 2. 执行搜索 - # search_knowledge 现在已经返回带 msg 的字典了 - result = crawler_sql_service.search_knowledge(vector, task_id, limit) - - return result - -# 单例模式 -workflow = AutomatedCrawler() \ No newline at end of file diff --git a/backend/services/crawler_service.py b/backend/services/crawler_service.py new file mode 100644 index 0000000..0293c22 --- /dev/null +++ b/backend/services/crawler_service.py @@ -0,0 +1,150 @@ +# backend/services/crawler_service.py +from firecrawl import FirecrawlApp +from backend.core.config import settings +from backend.services.data_service import data_service +from backend.services.llm_service import llm_service +from backend.utils.text_process import text_processor + +class CrawlerService: + """ + 爬虫编排服务 + 协调 Firecrawl, LLM, 和 DataService + """ + def __init__(self): + self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY) + + def map_site(self, start_url: str): + print(f"[INFO] Mapping: {start_url}") + try: + # 1. 注册任务 + task_res = data_service.register_task(start_url) + + # 2. 无论是否新任务,都尝试把 start_url 加入队列 + urls_to_add = [start_url] + + # 3. 调用 Firecrawl Map + try: + map_res = self.firecrawl.map(start_url) + + # 兼容不同 SDK 版本的返回 (对象 或 字典) + found_links = [] + if isinstance(map_res, dict): + found_links = map_res.get('links', []) + else: + found_links = getattr(map_res, 'links', []) + + for link in found_links: + # 链接可能是字符串,也可能是对象 + u = link if isinstance(link, str) else getattr(link, 'url', str(link)) + urls_to_add.append(u) + + print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.") + except Exception as e: + print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}") + + # 4. 批量入库 + if urls_to_add: + data_service.add_urls(task_res['task_id'], urls_to_add) + + return { + "msg": "Mapped successfully", + "count": len(urls_to_add), + "task_id": task_res['task_id'] + } + + except Exception as e: + print(f"[ERROR] Map failed: {e}") + raise e + + def process_queue(self, task_id: int, batch_size: int = 5): + """ + 处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储 + """ + pending = data_service.get_pending_urls(task_id, batch_size) + urls = pending['urls'] + if not urls: return {"msg": "Queue empty"} + + processed = 0 + for url in urls: + try: + print(f"[INFO] Scraping: {url}") + + # 调用 Firecrawl + scrape_res = self.firecrawl.scrape( + url, + formats=['markdown'], + only_main_content=True + ) + + # ===================================================== + # 【核心修复】兼容字典和对象两种返回格式 + # ===================================================== + raw_md = "" + metadata = None + + # 1. 提取 markdown 和 metadata 本体 + if isinstance(scrape_res, dict): + raw_md = scrape_res.get('markdown', '') + metadata = scrape_res.get('metadata', {}) + else: + # 如果是对象 (ScrapeResponse) + raw_md = getattr(scrape_res, 'markdown', '') + metadata = getattr(scrape_res, 'metadata', {}) + + # 2. 从 metadata 中安全提取 title + title = url # 默认用 url 当标题 + if metadata: + if isinstance(metadata, dict): + # 如果 metadata 是字典 + title = metadata.get('title', url) + else: + # 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方 + title = getattr(metadata, 'title', url) + + # ===================================================== + + if not raw_md: + print(f"[WARN] Empty content for {url}") + continue + + # 1. 清洗 + clean_md = text_processor.clean_markdown(raw_md) + if not clean_md: continue + + # 2. 切分 + chunks = text_processor.split_markdown(clean_md) + + chunks_data = [] + for i, chunk in enumerate(chunks): + content = chunk['content'] + headers = chunk['metadata'] + header_path = " > ".join(headers.values()) + + # 3. 向量化 (Title + Header + Content) + emb_input = f"{title}\n{header_path}\n{content}" + vector = llm_service.get_embedding(emb_input) + + if vector: + chunks_data.append({ + "index": i, + "content": content, + "embedding": vector, + "meta_info": {"header_path": header_path, "headers": headers} + }) + + # 4. 存储 + if chunks_data: + data_service.save_chunks(task_id, url, title, chunks_data) + processed += 1 + except Exception as e: + # 打印详细错误方便调试 + print(f"[ERROR] Process URL {url} failed: {e}") + + return {"msg": "Batch processed", "count": processed} + + def search(self, query: str, task_id: int, limit: int): + vector = llm_service.get_embedding(query) + if not vector: return {"msg": "Embedding failed", "results": []} + return data_service.search(vector, task_id, limit) + +crawler_service = CrawlerService() \ No newline at end of file diff --git a/backend/services/crawler_sql_service.py b/backend/services/crawler_sql_service.py deleted file mode 100644 index 5ae6f83..0000000 --- a/backend/services/crawler_sql_service.py +++ /dev/null @@ -1,230 +0,0 @@ -from sqlalchemy import select, insert, update, and_ -from ..database import db_instance -from ..utils import normalize_url - -class CrawlerSqlService: - def __init__(self): - self.db = db_instance - - def register_task(self, url: str): - """完全使用库 API 实现的注册""" - clean_url = normalize_url(url) - result = {} - - with self.db.engine.begin() as conn: - # 使用 select() API - query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) - existing = conn.execute(query).fetchone() - - if existing: - result = { - "task_id": existing[0], - "is_new_task": False, - "msg": "Task already exists" - } - else: - # 使用 insert() API - stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) - new_task = conn.execute(stmt).fetchone() - result = { - "task_id": new_task[0], - "is_new_task": True, - "msg": "New task created successfully" - } - - return result - - def add_urls(self, task_id: int, urls: list[str]): - """通用 API 实现的批量添加(含详细返回)""" - success_urls, skipped_urls, failed_urls = [], [], [] - - with self.db.engine.begin() as conn: - for url in urls: - clean_url = normalize_url(url) - try: - # 检查队列中是否已存在该 URL - check_q = select(self.db.queue).where( - and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) - ) - if conn.execute(check_q).fetchone(): - skipped_urls.append(clean_url) - continue - - # 插入新 URL - conn.execute(insert(self.db.queue).values( - task_id=task_id, url=clean_url, status='pending' - )) - success_urls.append(clean_url) - except Exception: - failed_urls.append(clean_url) - - # 构造返回消息 - msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}" - - return { - "success_urls": success_urls, - "skipped_urls": skipped_urls, - "failed_urls": failed_urls, - "msg": msg - } - - def get_pending_urls(self, task_id: int, limit: int): - """原子锁定 API 实现""" - result = {} - - with self.db.engine.begin() as conn: - query = select(self.db.queue.c.url).where( - and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') - ).limit(limit) - - urls = [r[0] for r in conn.execute(query).fetchall()] - - if urls: - upd = update(self.db.queue).where( - and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls)) - ).values(status='processing') - conn.execute(upd) - result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"} - else: - result = {"urls": [], "msg": "Queue is empty"} - - return result - - def save_results(self, task_id: int, results: list): - """ - 保存同一 URL 的多个切片。 - """ - if not results: - return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}} - - # 1. 基础信息提取 - first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__ - target_url = normalize_url(first_item.get('source_url')) - - # 结果统计容器 - inserted_chunks = [] - updated_chunks = [] - failed_chunks = [] - is_page_update = False - - with self.db.engine.begin() as conn: - # 2. 判断该 URL 是否已经有切片存在 - check_page_stmt = select(self.db.chunks.c.id).where( - and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url) - ).limit(1) - if conn.execute(check_page_stmt).fetchone(): - is_page_update = True - - # 3. 逐个处理切片 - for res in results: - data = res if isinstance(res, dict) else res.__dict__ - c_idx = data.get('chunk_index') - - try: - # 检查切片是否存在 - find_chunk_stmt = select(self.db.chunks.c.id).where( - and_( - self.db.chunks.c.task_id == task_id, - self.db.chunks.c.source_url == target_url, - self.db.chunks.c.chunk_index == c_idx - ) - ) - existing_chunk = conn.execute(find_chunk_stmt).fetchone() - - if existing_chunk: - # 覆盖更新 - upd_stmt = update(self.db.chunks).where( - self.db.chunks.c.id == existing_chunk[0] - ).values( - title=data.get('title'), - content=data.get('content'), - embedding=data.get('embedding') - ) - conn.execute(upd_stmt) - updated_chunks.append(c_idx) - else: - # 插入新切片 - ins_stmt = insert(self.db.chunks).values( - task_id=task_id, - source_url=target_url, - chunk_index=c_idx, - title=data.get('title'), - content=data.get('content'), - embedding=data.get('embedding') - ) - conn.execute(ins_stmt) - inserted_chunks.append(c_idx) - - except Exception as e: - print(f"Chunk {c_idx} failed: {e}") - failed_chunks.append(c_idx) - - # 4. 更新队列状态 - conn.execute( - update(self.db.queue).where( - and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url) - ).values(status='completed') - ) - - # 构造返回 - msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}" - - return { - "source_url": target_url, - "is_page_update": is_page_update, - "detail": { - "inserted_chunk_indexes": inserted_chunks, - "updated_chunk_indexes": updated_chunks, - "failed_chunk_indexes": failed_chunks - }, - "counts": { - "inserted": len(inserted_chunks), - "updated": len(updated_chunks), - "failed": len(failed_chunks) - }, - "msg": msg - } - - def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5): - """ - 高性能向量搜索方法 - """ - results = [] - msg = "" - - with self.db.engine.connect() as conn: - stmt = select( - self.db.chunks.c.task_id, - self.db.chunks.c.source_url, - self.db.chunks.c.title, - self.db.chunks.c.content, - self.db.chunks.c.chunk_index - ) - - if task_id is not None: - stmt = stmt.where(self.db.chunks.c.task_id == task_id) - - stmt = stmt.order_by( - self.db.chunks.c.embedding.cosine_distance(query_embedding) - ).limit(limit) - - rows = conn.execute(stmt).fetchall() - - for r in rows: - results.append({ - "task_id": r[0], - "source_url": r[1], - "title": r[2], - "content": r[3], - "chunk_index": r[4] - }) - - if results: - msg = f"Found {len(results)} matches" - else: - msg = "No matching content found" - - return {"results": results, "msg": msg} - - -crawler_sql_service = CrawlerSqlService() \ No newline at end of file diff --git a/backend/services/data_service.py b/backend/services/data_service.py new file mode 100644 index 0000000..a52d5a5 --- /dev/null +++ b/backend/services/data_service.py @@ -0,0 +1,111 @@ +from sqlalchemy import select, insert, update, and_ +from backend.core.database import db +from backend.utils.common import normalize_url + +class DataService: + """ + 数据持久化服务层 + 只负责数据库 CRUD 操作,不包含外部 API 调用逻辑 + """ + def __init__(self): + self.db = db + + def register_task(self, url: str): + clean_url = normalize_url(url) + with self.db.engine.begin() as conn: + query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url) + existing = conn.execute(query).fetchone() + + if existing: + return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"} + else: + stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id) + new_task = conn.execute(stmt).fetchone() + return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"} + + def add_urls(self, task_id: int, urls: list[str]): + success_urls = [] + with self.db.engine.begin() as conn: + for url in urls: + clean_url = normalize_url(url) + try: + check_q = select(self.db.queue).where( + and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) + ) + if not conn.execute(check_q).fetchone(): + conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending')) + success_urls.append(clean_url) + except Exception: + pass + return {"msg": f"Added {len(success_urls)} new urls"} + + def get_pending_urls(self, task_id: int, limit: int): + with self.db.engine.begin() as conn: + query = select(self.db.queue.c.url).where( + and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending') + ).limit(limit) + urls = [r[0] for r in conn.execute(query).fetchall()] + + if urls: + conn.execute(update(self.db.queue).where( + and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls)) + ).values(status='processing')) + + return {"urls": urls, "msg": "Fetched pending urls"} + + def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list): + """ + 保存切片数据 (Phase 1.5: 支持 meta_info) + """ + clean_url = normalize_url(source_url) + count = 0 + with self.db.engine.begin() as conn: + for item in chunks_data: + # item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict} + idx = item['index'] + meta = item.get('meta_info', {}) + + # 检查是否存在 + existing = conn.execute(select(self.db.chunks.c.id).where( + and_(self.db.chunks.c.task_id == task_id, + self.db.chunks.c.source_url == clean_url, + self.db.chunks.c.chunk_index == idx) + )).fetchone() + + values = { + "task_id": task_id, "source_url": clean_url, "chunk_index": idx, + "title": title, "content": item['content'], "embedding": item['embedding'], + "meta_info": meta + } + + if existing: + conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values)) + else: + conn.execute(insert(self.db.chunks).values(**values)) + count += 1 + + # 标记队列完成 + conn.execute(update(self.db.queue).where( + and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url) + ).values(status='completed')) + + return {"msg": f"Saved {count} chunks", "count": count} + + def search(self, vector: list, task_id: int = None, limit: int = 5): + with self.db.engine.connect() as conn: + stmt = select( + self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, + self.db.chunks.c.content, self.db.chunks.c.meta_info + ).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit) + + if task_id: + stmt = stmt.where(self.db.chunks.c.task_id == task_id) + + rows = conn.execute(stmt).fetchall() + results = [ + {"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]} + for r in rows + ] + return {"results": results, "msg": f"Found {len(results)}"} + +data_service = DataService() \ No newline at end of file diff --git a/backend/services/llm_service.py b/backend/services/llm_service.py new file mode 100644 index 0000000..f309e32 --- /dev/null +++ b/backend/services/llm_service.py @@ -0,0 +1,30 @@ +import dashscope +from http import HTTPStatus +from backend.core.config import settings + +class LLMService: + """ + LLM 服务封装层 + 负责与 DashScope 或其他模型供应商交互 + """ + def __init__(self): + dashscope.api_key = settings.DASHSCOPE_API_KEY + + def get_embedding(self, text: str, dimension: int = 1536): + """生成文本向量""" + try: + resp = dashscope.TextEmbedding.call( + model=dashscope.TextEmbedding.Models.text_embedding_v4, + input=text, + dimension=dimension + ) + if resp.status_code == HTTPStatus.OK: + return resp.output['embeddings'][0]['embedding'] + else: + print(f"[ERROR] Embedding API Error: {resp}") + return None + except Exception as e: + print(f"[ERROR] Embedding Exception: {e}") + return None + +llm_service = LLMService() \ No newline at end of file diff --git a/backend/utils.py b/backend/utils.py deleted file mode 100644 index 2dfb654..0000000 --- a/backend/utils.py +++ /dev/null @@ -1,39 +0,0 @@ -from urllib.parse import urlparse, urlunparse -from sqlalchemy import create_engine, MetaData, Table, select, update, and_ -# backend/llm_service.py -import dashscope -from http import HTTPStatus -from .config import settings - -# 初始化 Dashscope -dashscope.api_key = settings.DASHSCOPE_API_KEY - -def get_embeddings(texts: list[str]): - """调用通义千问 embedding 模型""" - resp = dashscope.TextEmbedding.call( - model=dashscope.TextEmbedding.Models.text_embedding_v3, # 或其他模型 - input=texts - ) - if resp.status_code == HTTPStatus.OK: - return [item['embedding'] for item in resp.output['embeddings']] - else: - print(f"Embedding Error: {resp}") - return [] -def normalize_url(url: str) -> str: - if not url: return "" - url = url.strip() - parsed = urlparse(url) - scheme = parsed.scheme.lower() - netloc = parsed.netloc.lower() - path = parsed.path.rstrip('/') - if not path: path = "" - return urlunparse((scheme, netloc, path, parsed.params, parsed.query, "")) - -def make_response(code: int, msg: str = "Success", data: any = None): - """ - 统一响应格式 - :param code: 1 成功, 0 失败, 其他自定义 - :param msg: 提示信息 - :param data: 返回数据 - """ - return {"code": code, "msg": msg, "data": data} \ No newline at end of file diff --git a/backend/utils/common.py b/backend/utils/common.py new file mode 100644 index 0000000..6313dd6 --- /dev/null +++ b/backend/utils/common.py @@ -0,0 +1,26 @@ +from urllib.parse import urlparse, urlunparse + +def make_response(code: int, msg: str = "Success", data: any = None): + """统一 API 响应格式封装""" + return {"code": code, "msg": msg, "data": data} + +def normalize_url(url: str) -> str: + """ + URL 标准化处理 + 1. 去除首尾空格 + 2. 移除 fragment (#后面的内容) + 3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面) + 4. 移除尾部斜杠 + """ + if not url: + return "" + + parsed = urlparse(url.strip()) + # 重新组合:scheme, netloc, path, params, query, fragment + # 这里我们只保留 scheme, netloc, path + clean_path = parsed.path.rstrip('/') + + # 构造新的 parsed 对象 (param, query, fragment 置空) + new_parsed = parsed._replace(path=clean_path, params='', query='', fragment='') + + return urlunparse(new_parsed) \ No newline at end of file diff --git a/backend/utils/text_process.py b/backend/utils/text_process.py new file mode 100644 index 0000000..665d330 --- /dev/null +++ b/backend/utils/text_process.py @@ -0,0 +1,61 @@ +import re +from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter + +class TextProcessor: + """文本处理工具类:负责 Markdown 清洗和切分""" + + def __init__(self): + # 基于 Markdown 标题的语义切分器 + self.md_splitter = MarkdownHeaderTextSplitter( + headers_to_split_on=[ + ("#", "h1"), + ("##", "h2"), + ("###", "h3"), + ], + strip_headers=False + ) + + # 备用的字符切分器 + self.char_splitter = RecursiveCharacterTextSplitter( + chunk_size=800, + chunk_overlap=100, + separators=["\n\n", "\n", "。", "!", "?", " ", ""] + ) + + def clean_markdown(self, text: str) -> str: + """清洗 Markdown 中的网页噪音""" + if not text: return "" + + # 去除 'Skip to main content' + text = re.sub(r'\[Skip to main content\].*?\n', '', text, flags=re.IGNORECASE) + # 去除页脚导航 (Previous / Next) + text = re.sub(r'\[Previous\].*?\[Next\].*', '', text, flags=re.DOTALL | re.IGNORECASE) + + return text.strip() + + def split_markdown(self, text: str): + """执行切分策略:先按标题切,过长则按字符切""" + md_chunks = self.md_splitter.split_text(text) + final_chunks = [] + + for chunk in md_chunks: + # chunk.page_content 是文本 + # chunk.metadata 是标题层级 + + if len(chunk.page_content) > 1000: + sub_texts = self.char_splitter.split_text(chunk.page_content) + for sub in sub_texts: + final_chunks.append({ + "content": sub, + "metadata": chunk.metadata + }) + else: + final_chunks.append({ + "content": chunk.page_content, + "metadata": chunk.metadata + }) + + return final_chunks + +# 单例工具 +text_processor = TextProcessor() \ No newline at end of file diff --git a/docs/docker.md b/docs/docker.md index 5eae05c..894999f 100644 --- a/docs/docker.md +++ b/docs/docker.md @@ -1,4 +1,3 @@ - # Wiki Crawler Backend 部署操作手册 ## 核心配置信息 (每次只需修改这里) @@ -7,7 +6,7 @@ | **字段** | **当前值 (示例)** | **说明** | **每次要改吗?** | | -------------------- | ---------------------------------- | ------------------------------ | ---------------------- | -| **Version** | **v1.0.3** | **镜像的版本标签 (Tag)** | **是 (必须改)** | +| **Version** | **v1.0.4** | **镜像的版本标签 (Tag)** | **是 (必须改)** | | **Image Name** | **wiki-crawl-backend** | **镜像/容器的名字** | **否 (固定)** | | **Namespace** | **qg-demo** | **阿里云命名空间** | **否 (固定)** | | **Registry** | **crpi-1rwd6fvain6t49g2...** | **阿里云仓库地址** | **否 (固定)** | @@ -20,18 +19,18 @@ ### 1. 构建镜像 (Build) -**修改命令最后的版本号** **v1.0.3** +**修改命令最后的版本号** **v1.0.4** ```powershell -docker build -t crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3 . +docker build -t crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4 . ``` ### 2. 推送镜像 (Push) -**修改命令最后的版本号** **v1.0.3** +**修改命令最后的版本号** **v1.0.4** ```powershell -docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3 +docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4 ``` > **成功标准:** **看到进度条走完,且最后显示** **Pushed**。 @@ -44,10 +43,10 @@ docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/w ### 1. 拉取新镜像 (Pull) -**修改命令最后的版本号** **v1.0.3** +**修改命令最后的版本号** **v1.0.4** ```bash -docker pull crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3 +docker pull crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4 ``` ### 2. 停止并删除旧容器 @@ -61,7 +60,7 @@ docker rm wiki-crawl-backend ### 3. 启动新容器 (Run) - 关键步骤 -**修改命令最后的版本号** **v1.0.3** +**修改命令最后的版本号** **v1.0.4** **code**Bash @@ -69,7 +68,7 @@ docker rm wiki-crawl-backend docker run -d --name wiki-crawl-backend \ -e PYTHONUNBUFFERED=1 \ -p 80:8000 \ - crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3 + crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4 ``` ### 4. 验证与日志查看 @@ -132,9 +131,9 @@ docker image prune -a -f ### 5. 那个超长的 URL -**crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3** +**crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4** * **Registry (仓库地址)**: **crpi-1rwd...aliyuncs.com** **-> 你的专属阿里云仓库服务器。** * **Namespace (命名空间)**: **qg-demo** **-> 你在仓库里划出的个人地盘。** * **Image Name (镜像名)**: **wiki-crawl-backend** **-> 这个项目的名字。** -* **Tag (标签)**: **v1.0.3** **-> 相当于软件的版本号。如果不写 Tag,默认就是** **latest**。**生产环境强烈建议写明确的版本号**,方便回滚(比如 1.0.3 挂了,你可以立马用 1.0.2 启动)。 +* **Tag (标签)**: **v1.0.4** **-> 相当于软件的版本号。如果不写 Tag,默认就是** **latest**。**生产环境强烈建议写明确的版本号**,方便回滚(比如 1.0.3 挂了,你可以立马用 1.0.2 启动)。 diff --git a/pyproject.toml b/pyproject.toml index 1c7287a..8644877 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "pinecone>=8.0.0", "pip>=25.3", "psycopg2-binary>=2.9.11", + "pydantic-settings>=2.12.0", "pymilvus>=2.6.5", "qdrant-client==1.10.1", "redis>=7.1.0", diff --git a/scripts/test_apis.py b/scripts/test_apis.py index 54bd56c..7503fda 100644 --- a/scripts/test_apis.py +++ b/scripts/test_apis.py @@ -1,96 +1,167 @@ import requests +import time import json -import random +import sys -# 配置后端地址 -BASE_URL = "http://47.122.127.178" +# ================= 配置区域 ================= +BASE_URL = "http://127.0.0.1:8000" -def log_res(name, response): - print(f"\n=== 测试接口: {name} ===") - if response.status_code == 200: - res_json = response.json() - print(f"状态: 成功 (HTTP 200)") - print(f"返回数据: {json.dumps(res_json, indent=2, ensure_ascii=False)}") - return res_json - else: - print(f"状态: 失败 (HTTP {response.status_code})") - print(f"错误信息: {response.text}") - return None +# 使用 Dify 文档作为测试对象 (结构清晰,适合验证 Markdown 切分) +TEST_URL = "https://docs.dify.ai/en/use-dify/knowledge/create-knowledge/import-text-data/readme" -def run_tests(): - # 测试数据准备 - test_root_url = f"https://example.com/wiki_{random.randint(1000, 9999)}" +# 测试查询词 (确保能命中上面的页面) +TEST_QUERY = "upload size limit" +# =========================================== + +class Colors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + +def log(step: str, msg: str, color=Colors.OKBLUE): + print(f"{color}[{step}] {msg}{Colors.ENDC}") + +def run_e2e_test(): + print(f"{Colors.HEADER}=== 开始 Wiki Crawler E2E 完整测试 ==={Colors.ENDC}") + + # 0. 后端健康检查 + try: + requests.get(f"{BASE_URL}/docs", timeout=3) + except Exception: + log("FATAL", "无法连接后端,请确保 main.py 正在运行 (http://127.0.0.1:8000)", Colors.FAIL) + sys.exit(1) + + # --------------------------------------------------------- + # Step 1: 地图式扫描 (Map) + # --------------------------------------------------------- + # log("STEP 1", f"注册任务并扫描链接: {TEST_URL}") - # 1. 测试 /register - print("步骤 1: 注册新任务...") - res = requests.post(f"{BASE_URL}/register", json={"url": test_root_url}) - data = log_res("注册任务", res) - if not data or data['code'] != 1: return - task_id = data['data']['task_id'] + # task_id = None + # try: + # res = requests.post(f"{BASE_URL}/api/v2/crawler/map", json={"url": TEST_URL}) + # res_json = res.json() + + # # 验证响应状态 + # if res_json.get('code') != 1: + # log("FAIL", f"Map 接口返回错误: {res_json}", Colors.FAIL) + # sys.exit(1) + + # data = res_json['data'] + # task_id = data['task_id'] + # count = data.get('count', 0) + + # log("SUCCESS", f"任务注册成功。Task ID: {task_id}, 待爬取链接数: {count}", Colors.OKGREEN) + + # except Exception as e: + # log("FAIL", f"请求异常: {e}", Colors.FAIL) + # sys.exit(1) - # 2. 测试 /add_urls - print("\n步骤 2: 模拟爬虫发现了新链接,存入队列...") - sub_urls = [ - f"{test_root_url}/page1", - f"{test_root_url}/page2", - f"{test_root_url}/page1" # 故意重复一个,测试后端去重 - ] - res = requests.post(f"{BASE_URL}/add_urls", json={ - "task_id": task_id, - "urls": sub_urls - }) - log_res("存入新链接", res) - - # 3. 测试 /pending_urls - print("\n步骤 3: 模拟爬虫节点获取待处理任务...") - res = requests.post(f"{BASE_URL}/pending_urls", json={ - "task_id": task_id, - "limit": 2 - }) - data = log_res("获取待处理URL", res) - if not data or not data['data']['urls']: - print("没有获取到待处理URL,停止后续测试") - return + # --------------------------------------------------------- + # Step 2: 触发后台处理 (Process) + # --------------------------------------------------------- + # task_id = 6 + # log("STEP 2", f"触发后台处理 -> Task ID: {task_id}") - target_url = data['data']['urls'][0] + # try: + # res = requests.post( + # f"{BASE_URL}/api/v2/crawler/process", + # json={"task_id": task_id, "batch_size": 5} + # ) + # res_json = res.json() + + # if res_json.get('code') == 1: + # log("SUCCESS", "后台处理任务已启动...", Colors.OKGREEN) + # else: + # log("FAIL", f"启动失败: {res_json}", Colors.FAIL) + # sys.exit(1) + + # except Exception as e: + # log("FAIL", f"请求异常: {e}", Colors.FAIL) + # sys.exit(1) - # 4. 测试 /save_results - print("\n步骤 4: 模拟爬虫抓取完成,存入知识片段和向量...") - # 模拟一个 1536 维的向量(已处理精度) - mock_embedding = [round(random.uniform(-1, 1), 8) for _ in range(1536)] + # --------------------------------------------------------- + # Step 3: 轮询搜索结果 (Polling) + # --------------------------------------------------------- + log("STEP 3", "轮询搜索接口,等待数据入库...") - payload = { - "task_id": task_id, - "results": [ - { - "source_url": target_url, - "chunk_index": 0, - "title": "测试页面标题 - 切片1", - "content": "这是模拟抓取到的第一段网页内容...", - "embedding": mock_embedding - }, - { - "source_url": target_url, - "chunk_index": 1, - "title": "测试页面标题 - 切片2", - "content": "这是模拟抓取到的第二段网页内容...", - "embedding": mock_embedding - } - ] + max_retries = 12 + found_data = False + search_results = [] + + for i in range(max_retries): + print(f" ⏳ 第 {i+1}/{max_retries} 次尝试搜索...", end="\r") + time.sleep(5) # 每次等待 5 秒,给爬虫和 Embedding 一点时间 + + try: + # 调用 V2 智能搜索接口 + search_res = requests.post( + f"{BASE_URL}/api/v2/search", + json={ + "query": TEST_QUERY, + "task_id": task_id, + "limit": 3 + } + ) + resp_json = search_res.json() + + # 解析响应结构: {code: 1, msg: "...", data: {results: [...]}} + if resp_json['code'] == 1: + data_body = resp_json['data'] + # 兼容性检查:确保 results 存在且不为空 + if data_body and 'results' in data_body and len(data_body['results']) > 0: + search_results = data_body['results'] + found_data = True + print("") # 换行 + log("SUCCESS", f"✅ 成功搜索到 {len(search_results)} 条相关切片!", Colors.OKGREEN) + break + except Exception as e: + # 忽略网络抖动,继续重试 + pass + + if not found_data: + print("") + log("FAIL", "❌ 超时:未能在规定时间内搜索到数据。请检查后端日志是否有报错。", Colors.FAIL) + sys.exit(1) + + # --------------------------------------------------------- + # Step 4: 验证 Phase 1.5 成果 (Meta Info) + # --------------------------------------------------------- + log("STEP 4", "验证结构化数据 (Phase 1.5 Check)") + + first_result = search_results[0] + + # 打印第一条结果用于人工确认 + print(f"\n{Colors.WARNING}--- 检索结果样本 ---{Colors.ENDC}") + print(f"Title: {first_result.get('title')}") + print(f"URL: {first_result.get('source_url')}") + print(f"Meta: {json.dumps(first_result.get('meta_info', {}), ensure_ascii=False)}") + print(f"Content Preview: {first_result.get('content')[:50]}...") + print(f"{Colors.WARNING}----------------------{Colors.ENDC}\n") + + # 自动化断言 + checks = { + "Has Content": bool(first_result.get('content')), + "Has Meta Info": 'meta_info' in first_result, + "Has Header Path": 'header_path' in first_result.get('meta_info', {}), + "Headers Dict Exists": 'headers' in first_result.get('meta_info', {}) } - res = requests.post(f"{BASE_URL}/save_results", json=payload) - log_res("保存结果", res) - # 5. 测试 /search - print("\n步骤 5: 测试基于向量的搜索...") - query = [round(random.uniform(-1, 1), 8) for _ in range(1536)] - res = requests.post(f"{BASE_URL}/search", json={ - "task_id": None, - "query_embedding": query, - "limit": 5 - }) - log_res("基于向量的搜索", res) - - print("\n✅ 所有 API 流程测试完成!") + + all_pass = True + for name, passed in checks.items(): + status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if passed else f"{Colors.FAIL}FAIL{Colors.ENDC}" + print(f"检查项 [{name}]: {status}") + if not passed: + all_pass = False + + if all_pass: + meta = first_result['meta_info'] + print(f"\n{Colors.OKBLUE}🎉 测试通过!系统已具备 Phase 1.5 (结构化 RAG) 能力。{Colors.ENDC}") + print(f"提取到的上下文路径: {Colors.HEADER}{meta.get('header_path', 'N/A')}{Colors.ENDC}") + else: + print(f"\n{Colors.FAIL}❌ 测试未完全通过:缺少必要的元数据字段。请检查 crawler_service.py 或 update_db.py。{Colors.ENDC}") if __name__ == "__main__": - run_tests() \ No newline at end of file + run_e2e_test() \ No newline at end of file diff --git a/scripts/update_sql.py b/scripts/update_sql.py new file mode 100644 index 0000000..faf4126 --- /dev/null +++ b/scripts/update_sql.py @@ -0,0 +1,82 @@ +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from sqlalchemy import create_engine, text +from backend.core.config import settings + +def update_database_schema(): + """ + 数据库无损升级脚本 + """ + print(f"🔌 连接数据库: {settings.DB_NAME}...") + engine = create_engine(settings.DATABASE_URL) + + commands = [ + # 1. 安全添加 meta_info 列 (旧数据会自动填充为 {}) + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='knowledge_chunks' AND column_name='meta_info') THEN + ALTER TABLE knowledge_chunks ADD COLUMN meta_info JSONB DEFAULT '{}'; + RAISE NOTICE '已添加 meta_info 列'; + END IF; + END $$; + """, + + # 2. 安全添加 content_tsvector 列 + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='knowledge_chunks' AND column_name='content_tsvector') THEN + ALTER TABLE knowledge_chunks ADD COLUMN content_tsvector TSVECTOR; + RAISE NOTICE '已添加 content_tsvector 列'; + END IF; + END $$; + """, + + # 3. 创建索引 (不影响现有数据) + "CREATE INDEX IF NOT EXISTS idx_chunks_meta ON knowledge_chunks USING GIN (meta_info);", + "CREATE INDEX IF NOT EXISTS idx_chunks_tsvector ON knowledge_chunks USING GIN (content_tsvector);", + + # 4. 创建触发器函数 (用于新插入的数据) + """ + CREATE OR REPLACE FUNCTION chunks_tsvector_trigger() RETURNS trigger AS $$ + BEGIN + new.content_tsvector := to_tsvector('english', coalesce(new.title, '') || ' ' || new.content); + return new; + END + $$ LANGUAGE plpgsql; + """, + + # 5. 绑定触发器 + """ + DO $$ + BEGIN + IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'tsvectorupdate') THEN + CREATE TRIGGER tsvectorupdate BEFORE INSERT OR UPDATE + ON knowledge_chunks FOR EACH ROW EXECUTE PROCEDURE chunks_tsvector_trigger(); + END IF; + END $$; + """, + + # 6. 【新增】回填旧数据 + # 让以前存的 task_id=6 的数据也能生成关键词索引 + """ + UPDATE knowledge_chunks + SET content_tsvector = to_tsvector('english', coalesce(title, '') || ' ' || content) + WHERE content_tsvector IS NULL; + """ + ] + + with engine.begin() as conn: + for cmd in commands: + try: + conn.execute(text(cmd)) + except Exception as e: + print(f"⚠️ 执行警告 (通常可忽略): {e}") + + print("✅ 数据库结构升级完成!旧数据已保留并兼容。") + +if __name__ == "__main__": + update_database_schema() \ No newline at end of file diff --git a/uv.lock b/uv.lock index 9e61ac4..dadc55c 100644 --- a/uv.lock +++ b/uv.lock @@ -2305,6 +2305,7 @@ dependencies = [ { name = "pinecone" }, { name = "pip" }, { name = "psycopg2-binary" }, + { name = "pydantic-settings" }, { name = "pymilvus" }, { name = "qdrant-client" }, { name = "redis" }, @@ -2326,6 +2327,7 @@ requires-dist = [ { name = "pinecone", specifier = ">=8.0.0" }, { name = "pip", specifier = ">=25.3" }, { name = "psycopg2-binary", specifier = ">=2.9.11" }, + { name = "pydantic-settings", specifier = ">=2.12.0" }, { name = "pymilvus", specifier = ">=2.6.5" }, { name = "qdrant-client", specifier = "==1.10.1" }, { name = "redis", specifier = ">=7.1.0" },