变更项目架构,提高扩展性
This commit is contained in:
7
.env.example
Normal file
7
.env.example
Normal file
@@ -0,0 +1,7 @@
|
||||
DB_USER=postgres
|
||||
DB_PASS=
|
||||
DB_HOST=localhost
|
||||
DB_PORT=5432
|
||||
DB_NAME=wiki_crawler
|
||||
DASHSCOPE_API_KEY=
|
||||
FIRECRAWL_API_KEY=
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
__pycache__/
|
||||
.venv
|
||||
wiki_backend.tar
|
||||
.env
|
||||
@@ -1,26 +0,0 @@
|
||||
import os
|
||||
|
||||
class Settings:
|
||||
# 数据库配置
|
||||
DB_USER: str = "postgres"
|
||||
DB_PASS: str = "DXC_welcome001"
|
||||
DB_HOST: str = "8.155.144.6"
|
||||
DB_PORT: str = "25432"
|
||||
DB_NAME: str = "wiki_crawler"
|
||||
|
||||
DASHSCOPE_API_KEY: str = "sk-8b091493de594c5e9eb42f12f1cc5805"
|
||||
FIRECRAWL_API_KEY: str = "fc-8a2af3fb6a014a27a57dfbc728cb7365"
|
||||
@property # property 方法,意义:将方法转换为属性,调用时不需要加括号
|
||||
def DATABASE_URL(self) -> str:
|
||||
url = f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
return url
|
||||
|
||||
def API_KEY(self, type: str) -> str:
|
||||
if type == "dashscope":
|
||||
return self.DASHSCOPE_API_KEY
|
||||
elif type == "firecrawl":
|
||||
return self.FIRECRAWL_API_KEY
|
||||
else:
|
||||
raise ValueError(f"Unknown API type: {type}")
|
||||
|
||||
settings = Settings()
|
||||
24
backend/core/config.py
Normal file
24
backend/core/config.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""
|
||||
系统配置类
|
||||
自动读取环境变量或 .env 文件
|
||||
"""
|
||||
DB_USER: str
|
||||
DB_PASS: str
|
||||
DB_HOST: str
|
||||
DB_PORT: str = "5432"
|
||||
DB_NAME: str
|
||||
|
||||
DASHSCOPE_API_KEY: str
|
||||
FIRECRAWL_API_KEY: str
|
||||
|
||||
# 配置:忽略多余的环境变量,指定编码
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore", env_file_encoding='utf-8')
|
||||
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,14 +1,19 @@
|
||||
from sqlalchemy import create_engine, MetaData, Table, event
|
||||
from pgvector.sqlalchemy import Vector # 必须导入这个
|
||||
from sqlalchemy import create_engine, MetaData, Table
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from .config import settings
|
||||
|
||||
class Database:
|
||||
"""
|
||||
数据库单例类
|
||||
负责初始化连接池并反射加载现有的表结构
|
||||
"""
|
||||
def __init__(self):
|
||||
# 1. 创建引擎
|
||||
# pool_pre_ping=True 用于解决数据库连接长时间空闲后断开的问题
|
||||
self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
|
||||
|
||||
# 2. 【核心修复】手动注册 vector 类型,让反射能识别它
|
||||
# 这告诉 SQLAlchemy:如果在数据库里看到名为 "vector" 的类型,请使用 pgvector 库的 Vector 类来处理
|
||||
# 2. 注册 pgvector 类型
|
||||
# 这是为了让 SQLAlchemy 反射机制能识别数据库中的 'vector' 类型
|
||||
self.engine.dialect.ischema_names['vector'] = Vector
|
||||
|
||||
self.metadata = MetaData()
|
||||
@@ -19,14 +24,15 @@ class Database:
|
||||
self._reflect_tables()
|
||||
|
||||
def _reflect_tables(self):
|
||||
"""自动从数据库加载表定义"""
|
||||
try:
|
||||
# 自动从数据库加载表结构
|
||||
# 因为上面注册了 ischema_names,现在 chunks_table.c.embedding 就能被正确识别为 Vector 类型了
|
||||
# autoload_with 会查询数据库元数据,自动填充 Column 信息
|
||||
self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine)
|
||||
self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine)
|
||||
self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine)
|
||||
print("[INFO] Database tables reflected successfully.")
|
||||
except Exception as e:
|
||||
print(f"❌ 数据库表加载失败: {e}")
|
||||
print(f"[ERROR] Failed to reflect tables: {e}")
|
||||
|
||||
# 全局单例
|
||||
db_instance = Database()
|
||||
# 全局数据库实例
|
||||
db = Database()
|
||||
128
backend/main.py
128
backend/main.py
@@ -1,131 +1,13 @@
|
||||
# backend/main.py
|
||||
from fastapi import FastAPI, APIRouter, BackgroundTasks
|
||||
# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow
|
||||
from .services.crawler_sql_service import crawler_sql_service
|
||||
from .services.automated_crawler import workflow
|
||||
from .schemas import (
|
||||
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
|
||||
AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||
)
|
||||
from .utils import make_response
|
||||
from fastapi import FastAPI
|
||||
from backend.routers import v1, v2
|
||||
|
||||
app = FastAPI(title="Wiki Crawler API")
|
||||
|
||||
# ==========================================
|
||||
# 工具函数
|
||||
# ==========================================
|
||||
|
||||
|
||||
# ==========================================
|
||||
# V1 Router: 原始的底层接口 (Manual Control)
|
||||
# ==========================================
|
||||
router_v1 = APIRouter()
|
||||
|
||||
@router_v1.post("/register")
|
||||
async def register(req: RegisterRequest):
|
||||
try:
|
||||
# Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'}
|
||||
res = crawler_sql_service.register_task(req.url)
|
||||
# 使用 pop 将 msg 提取出来作为响应的 msg,剩下的作为 data
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/add_urls")
|
||||
async def add_urls(req: AddUrlsRequest):
|
||||
try:
|
||||
urls = req.urls_obj["urls"]
|
||||
res = crawler_sql_service.add_urls(req.task_id, urls=urls)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/pending_urls")
|
||||
async def pending_urls(req: PendingRequest):
|
||||
try:
|
||||
res = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
|
||||
# 即使队列为空,Service 也会返回 msg="Queue is empty"
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/save_results")
|
||||
async def save_results(req: SaveResultsRequest):
|
||||
try:
|
||||
res = crawler_sql_service.save_results(req.task_id, req.results)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v1.post("/search")
|
||||
async def search_v1(req: SearchRequest):
|
||||
"""V1 搜索:客户端手动传向量"""
|
||||
try:
|
||||
vector = req.query_embedding['vector']
|
||||
if not vector:
|
||||
return make_response(2, "Vector is empty", None)
|
||||
|
||||
# Service 现在返回 {'results': [...], 'msg': 'Found ...'}
|
||||
res = crawler_sql_service.search_knowledge(
|
||||
query_embedding=vector,
|
||||
task_id=req.task_id,
|
||||
limit=req.limit
|
||||
)
|
||||
return make_response(1, res.pop("msg", "Search Done"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
|
||||
# ==========================================
|
||||
# V2 Router: 自动化工作流 (Automated Workflow)
|
||||
# ==========================================
|
||||
router_v2 = APIRouter()
|
||||
|
||||
@router_v2.post("/crawler/map")
|
||||
async def auto_map(req: AutoMapRequest):
|
||||
"""
|
||||
[同步] 输入首页 URL,自动调用 Firecrawl Map 并入库
|
||||
"""
|
||||
try:
|
||||
# Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...}
|
||||
res = workflow.map_and_ingest(req.url)
|
||||
return make_response(1, res.pop("msg", "Mapping Started"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v2.post("/crawler/process")
|
||||
async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
|
||||
"""
|
||||
try:
|
||||
# 将耗时操作放入后台任务
|
||||
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
|
||||
|
||||
# 因为是后台任务,无法立即获取 Service 的返回值 msg,只能返回通用消息
|
||||
return make_response(1, "Background processing started", {"task_id": req.task_id})
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v2.post("/search")
|
||||
async def search_v2(req: TextSearchRequest):
|
||||
"""
|
||||
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
|
||||
"""
|
||||
try:
|
||||
# Workflow 返回 {'results': [...], 'msg': '...'}
|
||||
res = workflow.search_with_embedding(req.query, req.task_id, req.limit)
|
||||
return make_response(1, res.pop("msg", "Search Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, f"Search Failed: {str(e)}")
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 挂载路由
|
||||
# ==========================================
|
||||
app.include_router(router_v1, prefix="/api/v1", tags=["V1 Manual API"])
|
||||
app.include_router(router_v2, prefix="/api/v2", tags=["V2 Automated Workflow"])
|
||||
app.include_router(v1.router)
|
||||
app.include_router(v2.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
# 提示:运行方式 uv run backend/main.py 或 uvicorn backend.main:app --reload
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
|
||||
30
backend/routers/v1.py
Normal file
30
backend/routers/v1.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter
|
||||
from backend.services.data_service import data_service
|
||||
from backend.utils.common import make_response
|
||||
from backend.schemas.schemas import RegisterRequest, AddUrlsRequest, PendingRequest, SearchRequest
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["V1 Manual"])
|
||||
|
||||
@router.post("/register")
|
||||
async def register(req: RegisterRequest):
|
||||
try:
|
||||
res = data_service.register_task(req.url)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router.post("/add_urls")
|
||||
async def add_urls(req: AddUrlsRequest):
|
||||
try:
|
||||
res = data_service.add_urls(req.task_id, req.urls_obj["urls"])
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router.post("/search")
|
||||
async def search_manual(req: SearchRequest):
|
||||
try:
|
||||
res = data_service.search(req.query_embedding['vector'], req.task_id, req.limit)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
30
backend/routers/v2.py
Normal file
30
backend/routers/v2.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from fastapi import APIRouter, BackgroundTasks
|
||||
from backend.services.crawler_service import crawler_service
|
||||
from backend.utils.common import make_response
|
||||
from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||
|
||||
router = APIRouter(prefix="/api/v2", tags=["V2 Automated"])
|
||||
|
||||
@router.post("/crawler/map")
|
||||
async def auto_map(req: AutoMapRequest):
|
||||
try:
|
||||
res = crawler_service.map_site(req.url)
|
||||
return make_response(1, res.pop("msg", "Started"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router.post("/crawler/process")
|
||||
async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks):
|
||||
try:
|
||||
bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size)
|
||||
return make_response(1, "Background processing started", {"task_id": req.task_id})
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router.post("/search")
|
||||
async def search_smart(req: TextSearchRequest):
|
||||
try:
|
||||
res = crawler_service.search(req.query, req.task_id, req.limit)
|
||||
return make_response(1, res.pop("msg", "Success"), res)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
@@ -1,201 +0,0 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from ..config import settings
|
||||
from .crawler_sql_service import crawler_sql_service
|
||||
|
||||
# 初始化配置
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
class AutomatedCrawler:
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||
)
|
||||
|
||||
def _get_embedding(self, text: str):
|
||||
"""内部方法:调用 Dashscope 生成向量"""
|
||||
# 注意:此方法是内部辅助,出错返回 None,由调用方处理状态
|
||||
embedding = None
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
|
||||
input=text,
|
||||
dimension=1536
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
embedding = resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"Embedding API Error: {resp}")
|
||||
except Exception as e:
|
||||
print(f"Embedding Exception: {e}")
|
||||
|
||||
return embedding
|
||||
|
||||
def map_and_ingest(self, start_url: str):
|
||||
"""
|
||||
V2 步骤1: 地图式扫描并入库
|
||||
"""
|
||||
print(f"[WorkFlow] Start mapping: {start_url}")
|
||||
result = {}
|
||||
|
||||
try:
|
||||
# 1. 在数据库注册任务
|
||||
task_info = crawler_sql_service.register_task(start_url)
|
||||
task_id = task_info['task_id']
|
||||
is_new_task = task_info['is_new_task']
|
||||
|
||||
# 2. 调用 Firecrawl Map
|
||||
if is_new_task:
|
||||
map_result = self.firecrawl.map(start_url)
|
||||
|
||||
urls = []
|
||||
# 兼容 firecrawl sdk 不同版本的返回结构
|
||||
# 如果 map_result 是对象且有 links 属性
|
||||
if hasattr(map_result, 'links'):
|
||||
for link in map_result.links:
|
||||
# 假设 link 是对象或字典,视具体 SDK 版本而定
|
||||
# 如果 link 是字符串直接 append
|
||||
if isinstance(link, str):
|
||||
urls.append(link)
|
||||
else:
|
||||
urls.append(getattr(link, 'url', str(link)))
|
||||
# 如果是字典
|
||||
elif isinstance(map_result, dict):
|
||||
urls = map_result.get('links', [])
|
||||
|
||||
print(f"[WorkFlow] Found {len(urls)} links")
|
||||
|
||||
# 3. 批量入库
|
||||
res = {"msg": "No urls found to add"}
|
||||
if urls:
|
||||
res = crawler_sql_service.add_urls(task_id, urls)
|
||||
|
||||
result = {
|
||||
"msg": "Task successfully mapped and URLs added",
|
||||
"task_id": task_id,
|
||||
"is_new_task": is_new_task,
|
||||
"url_count": len(urls),
|
||||
"map_detail": res
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"msg": "Task already exists, skipped mapping",
|
||||
"task_id": task_id,
|
||||
"is_new_task": False,
|
||||
"url_count": 0,
|
||||
"map_detail": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Map Error: {e}")
|
||||
# 向上抛出异常,由 main.py 捕获并返回错误 Response
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
def process_task_queue(self, task_id: int, limit: int = 10):
|
||||
"""
|
||||
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
||||
"""
|
||||
processed_count = 0
|
||||
total_chunks_saved = 0
|
||||
result = {}
|
||||
|
||||
# 1. 获取待处理 URL
|
||||
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
||||
urls = pending['urls']
|
||||
|
||||
if not urls:
|
||||
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
|
||||
else:
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[WorkFlow] Processing: {url}")
|
||||
# 2. 单页抓取
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url,
|
||||
params={'formats': ['markdown'], 'onlyMainContent': True}
|
||||
)
|
||||
|
||||
# 兼容 SDK 返回类型 (对象或字典)
|
||||
content = ""
|
||||
metadata = {}
|
||||
|
||||
if isinstance(scrape_res, dict):
|
||||
content = scrape_res.get('markdown', '')
|
||||
metadata = scrape_res.get('metadata', {})
|
||||
else:
|
||||
content = getattr(scrape_res, 'markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {})
|
||||
if not metadata and hasattr(scrape_res, 'metadata_dict'):
|
||||
metadata = scrape_res.metadata_dict
|
||||
|
||||
title = metadata.get('title', url)
|
||||
|
||||
if not content:
|
||||
print(f"[WorkFlow] Skip empty content: {url}")
|
||||
continue
|
||||
|
||||
# 3. 切片
|
||||
chunks = self.splitter.split_text(content)
|
||||
results_to_save = []
|
||||
|
||||
# 4. 向量化
|
||||
for idx, chunk_text in enumerate(chunks):
|
||||
vector = self._get_embedding(chunk_text)
|
||||
if vector:
|
||||
results_to_save.append({
|
||||
"source_url": url,
|
||||
"chunk_index": idx,
|
||||
"title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": vector
|
||||
})
|
||||
|
||||
# 5. 保存
|
||||
if results_to_save:
|
||||
save_res = crawler_sql_service.save_results(task_id, results_to_save)
|
||||
processed_count += 1
|
||||
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Error processing {url}: {e}")
|
||||
# 此处不抛出异常,以免打断整个批次的循环
|
||||
# 实际生产建议在这里调用 service 将 url 标记为 failed
|
||||
|
||||
result = {
|
||||
"msg": f"Batch processing complete. URLs processed: {processed_count}",
|
||||
"processed_urls": processed_count,
|
||||
"total_chunks_saved": total_chunks_saved
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. 获取向量
|
||||
vector = self._get_embedding(query_text)
|
||||
|
||||
if not vector:
|
||||
result = {
|
||||
"msg": "Failed to generate embedding for query",
|
||||
"results": []
|
||||
}
|
||||
else:
|
||||
# 2. 执行搜索
|
||||
# search_knowledge 现在已经返回带 msg 的字典了
|
||||
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
|
||||
|
||||
return result
|
||||
|
||||
# 单例模式
|
||||
workflow = AutomatedCrawler()
|
||||
150
backend/services/crawler_service.py
Normal file
150
backend/services/crawler_service.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# backend/services/crawler_service.py
|
||||
from firecrawl import FirecrawlApp
|
||||
from backend.core.config import settings
|
||||
from backend.services.data_service import data_service
|
||||
from backend.services.llm_service import llm_service
|
||||
from backend.utils.text_process import text_processor
|
||||
|
||||
class CrawlerService:
|
||||
"""
|
||||
爬虫编排服务
|
||||
协调 Firecrawl, LLM, 和 DataService
|
||||
"""
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
|
||||
def map_site(self, start_url: str):
|
||||
print(f"[INFO] Mapping: {start_url}")
|
||||
try:
|
||||
# 1. 注册任务
|
||||
task_res = data_service.register_task(start_url)
|
||||
|
||||
# 2. 无论是否新任务,都尝试把 start_url 加入队列
|
||||
urls_to_add = [start_url]
|
||||
|
||||
# 3. 调用 Firecrawl Map
|
||||
try:
|
||||
map_res = self.firecrawl.map(start_url)
|
||||
|
||||
# 兼容不同 SDK 版本的返回 (对象 或 字典)
|
||||
found_links = []
|
||||
if isinstance(map_res, dict):
|
||||
found_links = map_res.get('links', [])
|
||||
else:
|
||||
found_links = getattr(map_res, 'links', [])
|
||||
|
||||
for link in found_links:
|
||||
# 链接可能是字符串,也可能是对象
|
||||
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
||||
urls_to_add.append(u)
|
||||
|
||||
print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}")
|
||||
|
||||
# 4. 批量入库
|
||||
if urls_to_add:
|
||||
data_service.add_urls(task_res['task_id'], urls_to_add)
|
||||
|
||||
return {
|
||||
"msg": "Mapped successfully",
|
||||
"count": len(urls_to_add),
|
||||
"task_id": task_res['task_id']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Map failed: {e}")
|
||||
raise e
|
||||
|
||||
def process_queue(self, task_id: int, batch_size: int = 5):
|
||||
"""
|
||||
处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储
|
||||
"""
|
||||
pending = data_service.get_pending_urls(task_id, batch_size)
|
||||
urls = pending['urls']
|
||||
if not urls: return {"msg": "Queue empty"}
|
||||
|
||||
processed = 0
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[INFO] Scraping: {url}")
|
||||
|
||||
# 调用 Firecrawl
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url,
|
||||
formats=['markdown'],
|
||||
only_main_content=True
|
||||
)
|
||||
|
||||
# =====================================================
|
||||
# 【核心修复】兼容字典和对象两种返回格式
|
||||
# =====================================================
|
||||
raw_md = ""
|
||||
metadata = None
|
||||
|
||||
# 1. 提取 markdown 和 metadata 本体
|
||||
if isinstance(scrape_res, dict):
|
||||
raw_md = scrape_res.get('markdown', '')
|
||||
metadata = scrape_res.get('metadata', {})
|
||||
else:
|
||||
# 如果是对象 (ScrapeResponse)
|
||||
raw_md = getattr(scrape_res, 'markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {})
|
||||
|
||||
# 2. 从 metadata 中安全提取 title
|
||||
title = url # 默认用 url 当标题
|
||||
if metadata:
|
||||
if isinstance(metadata, dict):
|
||||
# 如果 metadata 是字典
|
||||
title = metadata.get('title', url)
|
||||
else:
|
||||
# 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方
|
||||
title = getattr(metadata, 'title', url)
|
||||
|
||||
# =====================================================
|
||||
|
||||
if not raw_md:
|
||||
print(f"[WARN] Empty content for {url}")
|
||||
continue
|
||||
|
||||
# 1. 清洗
|
||||
clean_md = text_processor.clean_markdown(raw_md)
|
||||
if not clean_md: continue
|
||||
|
||||
# 2. 切分
|
||||
chunks = text_processor.split_markdown(clean_md)
|
||||
|
||||
chunks_data = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
content = chunk['content']
|
||||
headers = chunk['metadata']
|
||||
header_path = " > ".join(headers.values())
|
||||
|
||||
# 3. 向量化 (Title + Header + Content)
|
||||
emb_input = f"{title}\n{header_path}\n{content}"
|
||||
vector = llm_service.get_embedding(emb_input)
|
||||
|
||||
if vector:
|
||||
chunks_data.append({
|
||||
"index": i,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
"meta_info": {"header_path": header_path, "headers": headers}
|
||||
})
|
||||
|
||||
# 4. 存储
|
||||
if chunks_data:
|
||||
data_service.save_chunks(task_id, url, title, chunks_data)
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
# 打印详细错误方便调试
|
||||
print(f"[ERROR] Process URL {url} failed: {e}")
|
||||
|
||||
return {"msg": "Batch processed", "count": processed}
|
||||
|
||||
def search(self, query: str, task_id: int, limit: int):
|
||||
vector = llm_service.get_embedding(query)
|
||||
if not vector: return {"msg": "Embedding failed", "results": []}
|
||||
return data_service.search(vector, task_id, limit)
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
@@ -1,230 +0,0 @@
|
||||
from sqlalchemy import select, insert, update, and_
|
||||
from ..database import db_instance
|
||||
from ..utils import normalize_url
|
||||
|
||||
class CrawlerSqlService:
|
||||
def __init__(self):
|
||||
self.db = db_instance
|
||||
|
||||
def register_task(self, url: str):
|
||||
"""完全使用库 API 实现的注册"""
|
||||
clean_url = normalize_url(url)
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 使用 select() API
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
result = {
|
||||
"task_id": existing[0],
|
||||
"is_new_task": False,
|
||||
"msg": "Task already exists"
|
||||
}
|
||||
else:
|
||||
# 使用 insert() API
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
result = {
|
||||
"task_id": new_task[0],
|
||||
"is_new_task": True,
|
||||
"msg": "New task created successfully"
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
"""通用 API 实现的批量添加(含详细返回)"""
|
||||
success_urls, skipped_urls, failed_urls = [], [], []
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
# 检查队列中是否已存在该 URL
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if conn.execute(check_q).fetchone():
|
||||
skipped_urls.append(clean_url)
|
||||
continue
|
||||
|
||||
# 插入新 URL
|
||||
conn.execute(insert(self.db.queue).values(
|
||||
task_id=task_id, url=clean_url, status='pending'
|
||||
))
|
||||
success_urls.append(clean_url)
|
||||
except Exception:
|
||||
failed_urls.append(clean_url)
|
||||
|
||||
# 构造返回消息
|
||||
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
|
||||
|
||||
return {
|
||||
"success_urls": success_urls,
|
||||
"skipped_urls": skipped_urls,
|
||||
"failed_urls": failed_urls,
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
"""原子锁定 API 实现"""
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
|
||||
if urls:
|
||||
upd = update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing')
|
||||
conn.execute(upd)
|
||||
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
|
||||
else:
|
||||
result = {"urls": [], "msg": "Queue is empty"}
|
||||
|
||||
return result
|
||||
|
||||
def save_results(self, task_id: int, results: list):
|
||||
"""
|
||||
保存同一 URL 的多个切片。
|
||||
"""
|
||||
if not results:
|
||||
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
|
||||
|
||||
# 1. 基础信息提取
|
||||
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
|
||||
target_url = normalize_url(first_item.get('source_url'))
|
||||
|
||||
# 结果统计容器
|
||||
inserted_chunks = []
|
||||
updated_chunks = []
|
||||
failed_chunks = []
|
||||
is_page_update = False
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 2. 判断该 URL 是否已经有切片存在
|
||||
check_page_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
|
||||
).limit(1)
|
||||
if conn.execute(check_page_stmt).fetchone():
|
||||
is_page_update = True
|
||||
|
||||
# 3. 逐个处理切片
|
||||
for res in results:
|
||||
data = res if isinstance(res, dict) else res.__dict__
|
||||
c_idx = data.get('chunk_index')
|
||||
|
||||
try:
|
||||
# 检查切片是否存在
|
||||
find_chunk_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(
|
||||
self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == target_url,
|
||||
self.db.chunks.c.chunk_index == c_idx
|
||||
)
|
||||
)
|
||||
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
|
||||
|
||||
if existing_chunk:
|
||||
# 覆盖更新
|
||||
upd_stmt = update(self.db.chunks).where(
|
||||
self.db.chunks.c.id == existing_chunk[0]
|
||||
).values(
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(upd_stmt)
|
||||
updated_chunks.append(c_idx)
|
||||
else:
|
||||
# 插入新切片
|
||||
ins_stmt = insert(self.db.chunks).values(
|
||||
task_id=task_id,
|
||||
source_url=target_url,
|
||||
chunk_index=c_idx,
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(ins_stmt)
|
||||
inserted_chunks.append(c_idx)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Chunk {c_idx} failed: {e}")
|
||||
failed_chunks.append(c_idx)
|
||||
|
||||
# 4. 更新队列状态
|
||||
conn.execute(
|
||||
update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
|
||||
).values(status='completed')
|
||||
)
|
||||
|
||||
# 构造返回
|
||||
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
|
||||
|
||||
return {
|
||||
"source_url": target_url,
|
||||
"is_page_update": is_page_update,
|
||||
"detail": {
|
||||
"inserted_chunk_indexes": inserted_chunks,
|
||||
"updated_chunk_indexes": updated_chunks,
|
||||
"failed_chunk_indexes": failed_chunks
|
||||
},
|
||||
"counts": {
|
||||
"inserted": len(inserted_chunks),
|
||||
"updated": len(updated_chunks),
|
||||
"failed": len(failed_chunks)
|
||||
},
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
高性能向量搜索方法
|
||||
"""
|
||||
results = []
|
||||
msg = ""
|
||||
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id,
|
||||
self.db.chunks.c.source_url,
|
||||
self.db.chunks.c.title,
|
||||
self.db.chunks.c.content,
|
||||
self.db.chunks.c.chunk_index
|
||||
)
|
||||
|
||||
if task_id is not None:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
stmt = stmt.order_by(
|
||||
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||
).limit(limit)
|
||||
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
|
||||
for r in rows:
|
||||
results.append({
|
||||
"task_id": r[0],
|
||||
"source_url": r[1],
|
||||
"title": r[2],
|
||||
"content": r[3],
|
||||
"chunk_index": r[4]
|
||||
})
|
||||
|
||||
if results:
|
||||
msg = f"Found {len(results)} matches"
|
||||
else:
|
||||
msg = "No matching content found"
|
||||
|
||||
return {"results": results, "msg": msg}
|
||||
|
||||
|
||||
crawler_sql_service = CrawlerSqlService()
|
||||
111
backend/services/data_service.py
Normal file
111
backend/services/data_service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from sqlalchemy import select, insert, update, and_
|
||||
from backend.core.database import db
|
||||
from backend.utils.common import normalize_url
|
||||
|
||||
class DataService:
|
||||
"""
|
||||
数据持久化服务层
|
||||
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
|
||||
"""
|
||||
def __init__(self):
|
||||
self.db = db
|
||||
|
||||
def register_task(self, url: str):
|
||||
clean_url = normalize_url(url)
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
|
||||
else:
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
success_urls = []
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if not conn.execute(check_q).fetchone():
|
||||
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
|
||||
success_urls.append(clean_url)
|
||||
except Exception:
|
||||
pass
|
||||
return {"msg": f"Added {len(success_urls)} new urls"}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
|
||||
if urls:
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing'))
|
||||
|
||||
return {"urls": urls, "msg": "Fetched pending urls"}
|
||||
|
||||
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
|
||||
"""
|
||||
保存切片数据 (Phase 1.5: 支持 meta_info)
|
||||
"""
|
||||
clean_url = normalize_url(source_url)
|
||||
count = 0
|
||||
with self.db.engine.begin() as conn:
|
||||
for item in chunks_data:
|
||||
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
|
||||
idx = item['index']
|
||||
meta = item.get('meta_info', {})
|
||||
|
||||
# 检查是否存在
|
||||
existing = conn.execute(select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == clean_url,
|
||||
self.db.chunks.c.chunk_index == idx)
|
||||
)).fetchone()
|
||||
|
||||
values = {
|
||||
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
|
||||
"title": title, "content": item['content'], "embedding": item['embedding'],
|
||||
"meta_info": meta
|
||||
}
|
||||
|
||||
if existing:
|
||||
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
|
||||
else:
|
||||
conn.execute(insert(self.db.chunks).values(**values))
|
||||
count += 1
|
||||
|
||||
# 标记队列完成
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status='completed'))
|
||||
|
||||
return {"msg": f"Saved {count} chunks", "count": count}
|
||||
|
||||
def search(self, vector: list, task_id: int = None, limit: int = 5):
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||
self.db.chunks.c.content, self.db.chunks.c.meta_info
|
||||
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
|
||||
|
||||
if task_id:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
results = [
|
||||
{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]}
|
||||
for r in rows
|
||||
]
|
||||
return {"results": results, "msg": f"Found {len(results)}"}
|
||||
|
||||
data_service = DataService()
|
||||
30
backend/services/llm_service.py
Normal file
30
backend/services/llm_service.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from backend.core.config import settings
|
||||
|
||||
class LLMService:
|
||||
"""
|
||||
LLM 服务封装层
|
||||
负责与 DashScope 或其他模型供应商交互
|
||||
"""
|
||||
def __init__(self):
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
def get_embedding(self, text: str, dimension: int = 1536):
|
||||
"""生成文本向量"""
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v4,
|
||||
input=text,
|
||||
dimension=dimension
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"[ERROR] Embedding API Error: {resp}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Embedding Exception: {e}")
|
||||
return None
|
||||
|
||||
llm_service = LLMService()
|
||||
@@ -1,39 +0,0 @@
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from sqlalchemy import create_engine, MetaData, Table, select, update, and_
|
||||
# backend/llm_service.py
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from .config import settings
|
||||
|
||||
# 初始化 Dashscope
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
def get_embeddings(texts: list[str]):
|
||||
"""调用通义千问 embedding 模型"""
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 或其他模型
|
||||
input=texts
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return [item['embedding'] for item in resp.output['embeddings']]
|
||||
else:
|
||||
print(f"Embedding Error: {resp}")
|
||||
return []
|
||||
def normalize_url(url: str) -> str:
|
||||
if not url: return ""
|
||||
url = url.strip()
|
||||
parsed = urlparse(url)
|
||||
scheme = parsed.scheme.lower()
|
||||
netloc = parsed.netloc.lower()
|
||||
path = parsed.path.rstrip('/')
|
||||
if not path: path = ""
|
||||
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
|
||||
|
||||
def make_response(code: int, msg: str = "Success", data: any = None):
|
||||
"""
|
||||
统一响应格式
|
||||
:param code: 1 成功, 0 失败, 其他自定义
|
||||
:param msg: 提示信息
|
||||
:param data: 返回数据
|
||||
"""
|
||||
return {"code": code, "msg": msg, "data": data}
|
||||
26
backend/utils/common.py
Normal file
26
backend/utils/common.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
def make_response(code: int, msg: str = "Success", data: any = None):
|
||||
"""统一 API 响应格式封装"""
|
||||
return {"code": code, "msg": msg, "data": data}
|
||||
|
||||
def normalize_url(url: str) -> str:
|
||||
"""
|
||||
URL 标准化处理
|
||||
1. 去除首尾空格
|
||||
2. 移除 fragment (#后面的内容)
|
||||
3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面)
|
||||
4. 移除尾部斜杠
|
||||
"""
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
parsed = urlparse(url.strip())
|
||||
# 重新组合:scheme, netloc, path, params, query, fragment
|
||||
# 这里我们只保留 scheme, netloc, path
|
||||
clean_path = parsed.path.rstrip('/')
|
||||
|
||||
# 构造新的 parsed 对象 (param, query, fragment 置空)
|
||||
new_parsed = parsed._replace(path=clean_path, params='', query='', fragment='')
|
||||
|
||||
return urlunparse(new_parsed)
|
||||
61
backend/utils/text_process.py
Normal file
61
backend/utils/text_process.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import re
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
|
||||
|
||||
class TextProcessor:
|
||||
"""文本处理工具类:负责 Markdown 清洗和切分"""
|
||||
|
||||
def __init__(self):
|
||||
# 基于 Markdown 标题的语义切分器
|
||||
self.md_splitter = MarkdownHeaderTextSplitter(
|
||||
headers_to_split_on=[
|
||||
("#", "h1"),
|
||||
("##", "h2"),
|
||||
("###", "h3"),
|
||||
],
|
||||
strip_headers=False
|
||||
)
|
||||
|
||||
# 备用的字符切分器
|
||||
self.char_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=800,
|
||||
chunk_overlap=100,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||
)
|
||||
|
||||
def clean_markdown(self, text: str) -> str:
|
||||
"""清洗 Markdown 中的网页噪音"""
|
||||
if not text: return ""
|
||||
|
||||
# 去除 'Skip to main content'
|
||||
text = re.sub(r'\[Skip to main content\].*?\n', '', text, flags=re.IGNORECASE)
|
||||
# 去除页脚导航 (Previous / Next)
|
||||
text = re.sub(r'\[Previous\].*?\[Next\].*', '', text, flags=re.DOTALL | re.IGNORECASE)
|
||||
|
||||
return text.strip()
|
||||
|
||||
def split_markdown(self, text: str):
|
||||
"""执行切分策略:先按标题切,过长则按字符切"""
|
||||
md_chunks = self.md_splitter.split_text(text)
|
||||
final_chunks = []
|
||||
|
||||
for chunk in md_chunks:
|
||||
# chunk.page_content 是文本
|
||||
# chunk.metadata 是标题层级
|
||||
|
||||
if len(chunk.page_content) > 1000:
|
||||
sub_texts = self.char_splitter.split_text(chunk.page_content)
|
||||
for sub in sub_texts:
|
||||
final_chunks.append({
|
||||
"content": sub,
|
||||
"metadata": chunk.metadata
|
||||
})
|
||||
else:
|
||||
final_chunks.append({
|
||||
"content": chunk.page_content,
|
||||
"metadata": chunk.metadata
|
||||
})
|
||||
|
||||
return final_chunks
|
||||
|
||||
# 单例工具
|
||||
text_processor = TextProcessor()
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
# Wiki Crawler Backend 部署操作手册
|
||||
|
||||
## 核心配置信息 (每次只需修改这里)
|
||||
@@ -7,7 +6,7 @@
|
||||
|
||||
| **字段** | **当前值 (示例)** | **说明** | **每次要改吗?** |
|
||||
| -------------------- | ---------------------------------- | ------------------------------ | ---------------------- |
|
||||
| **Version** | **v1.0.3** | **镜像的版本标签 (Tag)** | **是 (必须改)** |
|
||||
| **Version** | **v1.0.4** | **镜像的版本标签 (Tag)** | **是 (必须改)** |
|
||||
| **Image Name** | **wiki-crawl-backend** | **镜像/容器的名字** | **否 (固定)** |
|
||||
| **Namespace** | **qg-demo** | **阿里云命名空间** | **否 (固定)** |
|
||||
| **Registry** | **crpi-1rwd6fvain6t49g2...** | **阿里云仓库地址** | **否 (固定)** |
|
||||
@@ -20,18 +19,18 @@
|
||||
|
||||
### 1. 构建镜像 (Build)
|
||||
|
||||
**修改命令最后的版本号** **v1.0.3**
|
||||
**修改命令最后的版本号** **v1.0.4**
|
||||
|
||||
```powershell
|
||||
docker build -t crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3 .
|
||||
docker build -t crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4 .
|
||||
```
|
||||
|
||||
### 2. 推送镜像 (Push)
|
||||
|
||||
**修改命令最后的版本号** **v1.0.3**
|
||||
**修改命令最后的版本号** **v1.0.4**
|
||||
|
||||
```powershell
|
||||
docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3
|
||||
docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4
|
||||
```
|
||||
|
||||
> **成功标准:** **看到进度条走完,且最后显示** **Pushed**。
|
||||
@@ -44,10 +43,10 @@ docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/w
|
||||
|
||||
### 1. 拉取新镜像 (Pull)
|
||||
|
||||
**修改命令最后的版本号** **v1.0.3**
|
||||
**修改命令最后的版本号** **v1.0.4**
|
||||
|
||||
```bash
|
||||
docker pull crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3
|
||||
docker pull crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4
|
||||
```
|
||||
|
||||
### 2. 停止并删除旧容器
|
||||
@@ -61,7 +60,7 @@ docker rm wiki-crawl-backend
|
||||
|
||||
### 3. 启动新容器 (Run) - 关键步骤
|
||||
|
||||
**修改命令最后的版本号** **v1.0.3**
|
||||
**修改命令最后的版本号** **v1.0.4**
|
||||
|
||||
**code**Bash
|
||||
|
||||
@@ -69,7 +68,7 @@ docker rm wiki-crawl-backend
|
||||
docker run -d --name wiki-crawl-backend \
|
||||
-e PYTHONUNBUFFERED=1 \
|
||||
-p 80:8000 \
|
||||
crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3
|
||||
crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4
|
||||
```
|
||||
|
||||
### 4. 验证与日志查看
|
||||
@@ -132,9 +131,9 @@ docker image prune -a -f
|
||||
|
||||
### 5. 那个超长的 URL
|
||||
|
||||
**crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3**
|
||||
**crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4**
|
||||
|
||||
* **Registry (仓库地址)**: **crpi-1rwd...aliyuncs.com** **-> 你的专属阿里云仓库服务器。**
|
||||
* **Namespace (命名空间)**: **qg-demo** **-> 你在仓库里划出的个人地盘。**
|
||||
* **Image Name (镜像名)**: **wiki-crawl-backend** **-> 这个项目的名字。**
|
||||
* **Tag (标签)**: **v1.0.3** **-> 相当于软件的版本号。如果不写 Tag,默认就是** **latest**。**生产环境强烈建议写明确的版本号**,方便回滚(比如 1.0.3 挂了,你可以立马用 1.0.2 启动)。
|
||||
* **Tag (标签)**: **v1.0.4** **-> 相当于软件的版本号。如果不写 Tag,默认就是** **latest**。**生产环境强烈建议写明确的版本号**,方便回滚(比如 1.0.3 挂了,你可以立马用 1.0.2 启动)。
|
||||
|
||||
@@ -17,6 +17,7 @@ dependencies = [
|
||||
"pinecone>=8.0.0",
|
||||
"pip>=25.3",
|
||||
"psycopg2-binary>=2.9.11",
|
||||
"pydantic-settings>=2.12.0",
|
||||
"pymilvus>=2.6.5",
|
||||
"qdrant-client==1.10.1",
|
||||
"redis>=7.1.0",
|
||||
|
||||
@@ -1,96 +1,167 @@
|
||||
import requests
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import sys
|
||||
|
||||
# 配置后端地址
|
||||
BASE_URL = "http://47.122.127.178"
|
||||
# ================= 配置区域 =================
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
|
||||
def log_res(name, response):
|
||||
print(f"\n=== 测试接口: {name} ===")
|
||||
if response.status_code == 200:
|
||||
res_json = response.json()
|
||||
print(f"状态: 成功 (HTTP 200)")
|
||||
print(f"返回数据: {json.dumps(res_json, indent=2, ensure_ascii=False)}")
|
||||
return res_json
|
||||
else:
|
||||
print(f"状态: 失败 (HTTP {response.status_code})")
|
||||
print(f"错误信息: {response.text}")
|
||||
return None
|
||||
# 使用 Dify 文档作为测试对象 (结构清晰,适合验证 Markdown 切分)
|
||||
TEST_URL = "https://docs.dify.ai/en/use-dify/knowledge/create-knowledge/import-text-data/readme"
|
||||
|
||||
def run_tests():
|
||||
# 测试数据准备
|
||||
test_root_url = f"https://example.com/wiki_{random.randint(1000, 9999)}"
|
||||
# 测试查询词 (确保能命中上面的页面)
|
||||
TEST_QUERY = "upload size limit"
|
||||
# ===========================================
|
||||
|
||||
# 1. 测试 /register
|
||||
print("步骤 1: 注册新任务...")
|
||||
res = requests.post(f"{BASE_URL}/register", json={"url": test_root_url})
|
||||
data = log_res("注册任务", res)
|
||||
if not data or data['code'] != 1: return
|
||||
task_id = data['data']['task_id']
|
||||
class Colors:
|
||||
HEADER = '\033[95m'
|
||||
OKBLUE = '\033[94m'
|
||||
OKGREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
|
||||
# 2. 测试 /add_urls
|
||||
print("\n步骤 2: 模拟爬虫发现了新链接,存入队列...")
|
||||
sub_urls = [
|
||||
f"{test_root_url}/page1",
|
||||
f"{test_root_url}/page2",
|
||||
f"{test_root_url}/page1" # 故意重复一个,测试后端去重
|
||||
]
|
||||
res = requests.post(f"{BASE_URL}/add_urls", json={
|
||||
"task_id": task_id,
|
||||
"urls": sub_urls
|
||||
})
|
||||
log_res("存入新链接", res)
|
||||
def log(step: str, msg: str, color=Colors.OKBLUE):
|
||||
print(f"{color}[{step}] {msg}{Colors.ENDC}")
|
||||
|
||||
# 3. 测试 /pending_urls
|
||||
print("\n步骤 3: 模拟爬虫节点获取待处理任务...")
|
||||
res = requests.post(f"{BASE_URL}/pending_urls", json={
|
||||
"task_id": task_id,
|
||||
"limit": 2
|
||||
})
|
||||
data = log_res("获取待处理URL", res)
|
||||
if not data or not data['data']['urls']:
|
||||
print("没有获取到待处理URL,停止后续测试")
|
||||
return
|
||||
def run_e2e_test():
|
||||
print(f"{Colors.HEADER}=== 开始 Wiki Crawler E2E 完整测试 ==={Colors.ENDC}")
|
||||
|
||||
target_url = data['data']['urls'][0]
|
||||
# 0. 后端健康检查
|
||||
try:
|
||||
requests.get(f"{BASE_URL}/docs", timeout=3)
|
||||
except Exception:
|
||||
log("FATAL", "无法连接后端,请确保 main.py 正在运行 (http://127.0.0.1:8000)", Colors.FAIL)
|
||||
sys.exit(1)
|
||||
|
||||
# 4. 测试 /save_results
|
||||
print("\n步骤 4: 模拟爬虫抓取完成,存入知识片段和向量...")
|
||||
# 模拟一个 1536 维的向量(已处理精度)
|
||||
mock_embedding = [round(random.uniform(-1, 1), 8) for _ in range(1536)]
|
||||
# ---------------------------------------------------------
|
||||
# Step 1: 地图式扫描 (Map)
|
||||
# ---------------------------------------------------------
|
||||
# log("STEP 1", f"注册任务并扫描链接: {TEST_URL}")
|
||||
|
||||
payload = {
|
||||
"task_id": task_id,
|
||||
"results": [
|
||||
{
|
||||
"source_url": target_url,
|
||||
"chunk_index": 0,
|
||||
"title": "测试页面标题 - 切片1",
|
||||
"content": "这是模拟抓取到的第一段网页内容...",
|
||||
"embedding": mock_embedding
|
||||
},
|
||||
{
|
||||
"source_url": target_url,
|
||||
"chunk_index": 1,
|
||||
"title": "测试页面标题 - 切片2",
|
||||
"content": "这是模拟抓取到的第二段网页内容...",
|
||||
"embedding": mock_embedding
|
||||
}
|
||||
]
|
||||
# task_id = None
|
||||
# try:
|
||||
# res = requests.post(f"{BASE_URL}/api/v2/crawler/map", json={"url": TEST_URL})
|
||||
# res_json = res.json()
|
||||
|
||||
# # 验证响应状态
|
||||
# if res_json.get('code') != 1:
|
||||
# log("FAIL", f"Map 接口返回错误: {res_json}", Colors.FAIL)
|
||||
# sys.exit(1)
|
||||
|
||||
# data = res_json['data']
|
||||
# task_id = data['task_id']
|
||||
# count = data.get('count', 0)
|
||||
|
||||
# log("SUCCESS", f"任务注册成功。Task ID: {task_id}, 待爬取链接数: {count}", Colors.OKGREEN)
|
||||
|
||||
# except Exception as e:
|
||||
# log("FAIL", f"请求异常: {e}", Colors.FAIL)
|
||||
# sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Step 2: 触发后台处理 (Process)
|
||||
# ---------------------------------------------------------
|
||||
# task_id = 6
|
||||
# log("STEP 2", f"触发后台处理 -> Task ID: {task_id}")
|
||||
|
||||
# try:
|
||||
# res = requests.post(
|
||||
# f"{BASE_URL}/api/v2/crawler/process",
|
||||
# json={"task_id": task_id, "batch_size": 5}
|
||||
# )
|
||||
# res_json = res.json()
|
||||
|
||||
# if res_json.get('code') == 1:
|
||||
# log("SUCCESS", "后台处理任务已启动...", Colors.OKGREEN)
|
||||
# else:
|
||||
# log("FAIL", f"启动失败: {res_json}", Colors.FAIL)
|
||||
# sys.exit(1)
|
||||
|
||||
# except Exception as e:
|
||||
# log("FAIL", f"请求异常: {e}", Colors.FAIL)
|
||||
# sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Step 3: 轮询搜索结果 (Polling)
|
||||
# ---------------------------------------------------------
|
||||
log("STEP 3", "轮询搜索接口,等待数据入库...")
|
||||
|
||||
max_retries = 12
|
||||
found_data = False
|
||||
search_results = []
|
||||
|
||||
for i in range(max_retries):
|
||||
print(f" ⏳ 第 {i+1}/{max_retries} 次尝试搜索...", end="\r")
|
||||
time.sleep(5) # 每次等待 5 秒,给爬虫和 Embedding 一点时间
|
||||
|
||||
try:
|
||||
# 调用 V2 智能搜索接口
|
||||
search_res = requests.post(
|
||||
f"{BASE_URL}/api/v2/search",
|
||||
json={
|
||||
"query": TEST_QUERY,
|
||||
"task_id": task_id,
|
||||
"limit": 3
|
||||
}
|
||||
)
|
||||
resp_json = search_res.json()
|
||||
|
||||
# 解析响应结构: {code: 1, msg: "...", data: {results: [...]}}
|
||||
if resp_json['code'] == 1:
|
||||
data_body = resp_json['data']
|
||||
# 兼容性检查:确保 results 存在且不为空
|
||||
if data_body and 'results' in data_body and len(data_body['results']) > 0:
|
||||
search_results = data_body['results']
|
||||
found_data = True
|
||||
print("") # 换行
|
||||
log("SUCCESS", f"✅ 成功搜索到 {len(search_results)} 条相关切片!", Colors.OKGREEN)
|
||||
break
|
||||
except Exception as e:
|
||||
# 忽略网络抖动,继续重试
|
||||
pass
|
||||
|
||||
if not found_data:
|
||||
print("")
|
||||
log("FAIL", "❌ 超时:未能在规定时间内搜索到数据。请检查后端日志是否有报错。", Colors.FAIL)
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Step 4: 验证 Phase 1.5 成果 (Meta Info)
|
||||
# ---------------------------------------------------------
|
||||
log("STEP 4", "验证结构化数据 (Phase 1.5 Check)")
|
||||
|
||||
first_result = search_results[0]
|
||||
|
||||
# 打印第一条结果用于人工确认
|
||||
print(f"\n{Colors.WARNING}--- 检索结果样本 ---{Colors.ENDC}")
|
||||
print(f"Title: {first_result.get('title')}")
|
||||
print(f"URL: {first_result.get('source_url')}")
|
||||
print(f"Meta: {json.dumps(first_result.get('meta_info', {}), ensure_ascii=False)}")
|
||||
print(f"Content Preview: {first_result.get('content')[:50]}...")
|
||||
print(f"{Colors.WARNING}----------------------{Colors.ENDC}\n")
|
||||
|
||||
# 自动化断言
|
||||
checks = {
|
||||
"Has Content": bool(first_result.get('content')),
|
||||
"Has Meta Info": 'meta_info' in first_result,
|
||||
"Has Header Path": 'header_path' in first_result.get('meta_info', {}),
|
||||
"Headers Dict Exists": 'headers' in first_result.get('meta_info', {})
|
||||
}
|
||||
res = requests.post(f"{BASE_URL}/save_results", json=payload)
|
||||
log_res("保存结果", res)
|
||||
# 5. 测试 /search
|
||||
print("\n步骤 5: 测试基于向量的搜索...")
|
||||
query = [round(random.uniform(-1, 1), 8) for _ in range(1536)]
|
||||
res = requests.post(f"{BASE_URL}/search", json={
|
||||
"task_id": None,
|
||||
"query_embedding": query,
|
||||
"limit": 5
|
||||
})
|
||||
log_res("基于向量的搜索", res)
|
||||
|
||||
print("\n✅ 所有 API 流程测试完成!")
|
||||
all_pass = True
|
||||
for name, passed in checks.items():
|
||||
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if passed else f"{Colors.FAIL}FAIL{Colors.ENDC}"
|
||||
print(f"检查项 [{name}]: {status}")
|
||||
if not passed:
|
||||
all_pass = False
|
||||
|
||||
if all_pass:
|
||||
meta = first_result['meta_info']
|
||||
print(f"\n{Colors.OKBLUE}🎉 测试通过!系统已具备 Phase 1.5 (结构化 RAG) 能力。{Colors.ENDC}")
|
||||
print(f"提取到的上下文路径: {Colors.HEADER}{meta.get('header_path', 'N/A')}{Colors.ENDC}")
|
||||
else:
|
||||
print(f"\n{Colors.FAIL}❌ 测试未完全通过:缺少必要的元数据字段。请检查 crawler_service.py 或 update_db.py。{Colors.ENDC}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
run_e2e_test()
|
||||
82
scripts/update_sql.py
Normal file
82
scripts/update_sql.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from backend.core.config import settings
|
||||
|
||||
def update_database_schema():
|
||||
"""
|
||||
数据库无损升级脚本
|
||||
"""
|
||||
print(f"🔌 连接数据库: {settings.DB_NAME}...")
|
||||
engine = create_engine(settings.DATABASE_URL)
|
||||
|
||||
commands = [
|
||||
# 1. 安全添加 meta_info 列 (旧数据会自动填充为 {})
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='knowledge_chunks' AND column_name='meta_info') THEN
|
||||
ALTER TABLE knowledge_chunks ADD COLUMN meta_info JSONB DEFAULT '{}';
|
||||
RAISE NOTICE '已添加 meta_info 列';
|
||||
END IF;
|
||||
END $$;
|
||||
""",
|
||||
|
||||
# 2. 安全添加 content_tsvector 列
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='knowledge_chunks' AND column_name='content_tsvector') THEN
|
||||
ALTER TABLE knowledge_chunks ADD COLUMN content_tsvector TSVECTOR;
|
||||
RAISE NOTICE '已添加 content_tsvector 列';
|
||||
END IF;
|
||||
END $$;
|
||||
""",
|
||||
|
||||
# 3. 创建索引 (不影响现有数据)
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunks_meta ON knowledge_chunks USING GIN (meta_info);",
|
||||
"CREATE INDEX IF NOT EXISTS idx_chunks_tsvector ON knowledge_chunks USING GIN (content_tsvector);",
|
||||
|
||||
# 4. 创建触发器函数 (用于新插入的数据)
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION chunks_tsvector_trigger() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
new.content_tsvector := to_tsvector('english', coalesce(new.title, '') || ' ' || new.content);
|
||||
return new;
|
||||
END
|
||||
$$ LANGUAGE plpgsql;
|
||||
""",
|
||||
|
||||
# 5. 绑定触发器
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'tsvectorupdate') THEN
|
||||
CREATE TRIGGER tsvectorupdate BEFORE INSERT OR UPDATE
|
||||
ON knowledge_chunks FOR EACH ROW EXECUTE PROCEDURE chunks_tsvector_trigger();
|
||||
END IF;
|
||||
END $$;
|
||||
""",
|
||||
|
||||
# 6. 【新增】回填旧数据
|
||||
# 让以前存的 task_id=6 的数据也能生成关键词索引
|
||||
"""
|
||||
UPDATE knowledge_chunks
|
||||
SET content_tsvector = to_tsvector('english', coalesce(title, '') || ' ' || content)
|
||||
WHERE content_tsvector IS NULL;
|
||||
"""
|
||||
]
|
||||
|
||||
with engine.begin() as conn:
|
||||
for cmd in commands:
|
||||
try:
|
||||
conn.execute(text(cmd))
|
||||
except Exception as e:
|
||||
print(f"⚠️ 执行警告 (通常可忽略): {e}")
|
||||
|
||||
print("✅ 数据库结构升级完成!旧数据已保留并兼容。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
update_database_schema()
|
||||
2
uv.lock
generated
2
uv.lock
generated
@@ -2305,6 +2305,7 @@ dependencies = [
|
||||
{ name = "pinecone" },
|
||||
{ name = "pip" },
|
||||
{ name = "psycopg2-binary" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pymilvus" },
|
||||
{ name = "qdrant-client" },
|
||||
{ name = "redis" },
|
||||
@@ -2326,6 +2327,7 @@ requires-dist = [
|
||||
{ name = "pinecone", specifier = ">=8.0.0" },
|
||||
{ name = "pip", specifier = ">=25.3" },
|
||||
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.12.0" },
|
||||
{ name = "pymilvus", specifier = ">=2.6.5" },
|
||||
{ name = "qdrant-client", specifier = "==1.10.1" },
|
||||
{ name = "redis", specifier = ">=7.1.0" },
|
||||
|
||||
Reference in New Issue
Block a user