Compare commits

...

2 Commits

Author SHA1 Message Date
d5ee00d404 完成混合检索 2026-01-13 02:23:27 +08:00
9190fee16f 变更项目架构,提高扩展性 2026-01-13 01:37:26 +08:00
22 changed files with 792 additions and 723 deletions

7
.env.example Normal file
View File

@@ -0,0 +1,7 @@
DB_USER=postgres
DB_PASS=
DB_HOST=localhost
DB_PORT=5432
DB_NAME=wiki_crawler
DASHSCOPE_API_KEY=
FIRECRAWL_API_KEY=

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
__pycache__/
.venv
wiki_backend.tar
.env

View File

@@ -1,26 +0,0 @@
import os
class Settings:
# 数据库配置
DB_USER: str = "postgres"
DB_PASS: str = "DXC_welcome001"
DB_HOST: str = "8.155.144.6"
DB_PORT: str = "25432"
DB_NAME: str = "wiki_crawler"
DASHSCOPE_API_KEY: str = "sk-8b091493de594c5e9eb42f12f1cc5805"
FIRECRAWL_API_KEY: str = "fc-8a2af3fb6a014a27a57dfbc728cb7365"
@property # property 方法,意义:将方法转换为属性,调用时不需要加括号
def DATABASE_URL(self) -> str:
url = f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
return url
def API_KEY(self, type: str) -> str:
if type == "dashscope":
return self.DASHSCOPE_API_KEY
elif type == "firecrawl":
return self.FIRECRAWL_API_KEY
else:
raise ValueError(f"Unknown API type: {type}")
settings = Settings()

24
backend/core/config.py Normal file
View File

@@ -0,0 +1,24 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""
系统配置类
自动读取环境变量或 .env 文件
"""
DB_USER: str
DB_PASS: str
DB_HOST: str
DB_PORT: str = "5432"
DB_NAME: str
DASHSCOPE_API_KEY: str
FIRECRAWL_API_KEY: str
# 配置:忽略多余的环境变量,指定编码
model_config = SettingsConfigDict(env_file=".env", extra="ignore", env_file_encoding='utf-8')
@property
def DATABASE_URL(self) -> str:
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
settings = Settings()

View File

@@ -1,14 +1,19 @@
from sqlalchemy import create_engine, MetaData, Table, event
from pgvector.sqlalchemy import Vector # 必须导入这个
from sqlalchemy import create_engine, MetaData, Table
from pgvector.sqlalchemy import Vector
from .config import settings
class Database:
"""
数据库单例类
负责初始化连接池并反射加载现有的表结构
"""
def __init__(self):
# 1. 创建引擎
# pool_pre_ping=True 用于解决数据库连接长时间空闲后断开的问题
self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
# 2. 【核心修复】手动注册 vector 类型,让反射能识别它
# 这告诉 SQLAlchemy:如果在数据库里看到名为 "vector" 的类型,请使用 pgvector 库的 Vector 类来处理
# 2. 注册 pgvector 类型
# 这是为了让 SQLAlchemy 反射机制能识别数据库中的 'vector' 类型
self.engine.dialect.ischema_names['vector'] = Vector
self.metadata = MetaData()
@@ -19,14 +24,15 @@ class Database:
self._reflect_tables()
def _reflect_tables(self):
"""自动从数据库加载表定义"""
try:
# 自动从数据库加载表结构
# 因为上面注册了 ischema_names现在 chunks_table.c.embedding 就能被正确识别为 Vector 类型了
# autoload_with 会查询数据库元数据,自动填充 Column 信息
self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine)
self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine)
self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine)
print("[INFO] Database tables reflected successfully.")
except Exception as e:
print(f"❌ 数据库表加载失败: {e}")
print(f"[ERROR] Failed to reflect tables: {e}")
# 全局
db_instance = Database()
# 全局数据库实
db = Database()

View File

@@ -1,131 +1,13 @@
# backend/main.py
from fastapi import FastAPI, APIRouter, BackgroundTasks
# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow
from .services.crawler_sql_service import crawler_sql_service
from .services.automated_crawler import workflow
from .schemas import (
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
AutoMapRequest, AutoProcessRequest, TextSearchRequest
)
from .utils import make_response
from fastapi import FastAPI
from backend.routers import v1, v2
app = FastAPI(title="Wiki Crawler API")
# ==========================================
# 工具函数
# ==========================================
# ==========================================
# V1 Router: 原始的底层接口 (Manual Control)
# ==========================================
router_v1 = APIRouter()
@router_v1.post("/register")
async def register(req: RegisterRequest):
try:
# Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'}
res = crawler_sql_service.register_task(req.url)
# 使用 pop 将 msg 提取出来作为响应的 msg剩下的作为 data
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/add_urls")
async def add_urls(req: AddUrlsRequest):
try:
urls = req.urls_obj["urls"]
res = crawler_sql_service.add_urls(req.task_id, urls=urls)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/pending_urls")
async def pending_urls(req: PendingRequest):
try:
res = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
# 即使队列为空Service 也会返回 msg="Queue is empty"
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/save_results")
async def save_results(req: SaveResultsRequest):
try:
res = crawler_sql_service.save_results(req.task_id, req.results)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/search")
async def search_v1(req: SearchRequest):
"""V1 搜索:客户端手动传向量"""
try:
vector = req.query_embedding['vector']
if not vector:
return make_response(2, "Vector is empty", None)
# Service 现在返回 {'results': [...], 'msg': 'Found ...'}
res = crawler_sql_service.search_knowledge(
query_embedding=vector,
task_id=req.task_id,
limit=req.limit
)
return make_response(1, res.pop("msg", "Search Done"), res)
except Exception as e:
return make_response(0, str(e))
# ==========================================
# V2 Router: 自动化工作流 (Automated Workflow)
# ==========================================
router_v2 = APIRouter()
@router_v2.post("/crawler/map")
async def auto_map(req: AutoMapRequest):
"""
[同步] 输入首页 URL自动调用 Firecrawl Map 并入库
"""
try:
# Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...}
res = workflow.map_and_ingest(req.url)
return make_response(1, res.pop("msg", "Mapping Started"), res)
except Exception as e:
return make_response(0, str(e))
@router_v2.post("/crawler/process")
async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTasks):
"""
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
"""
try:
# 将耗时操作放入后台任务
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
# 因为是后台任务,无法立即获取 Service 的返回值 msg只能返回通用消息
return make_response(1, "Background processing started", {"task_id": req.task_id})
except Exception as e:
return make_response(0, str(e))
@router_v2.post("/search")
async def search_v2(req: TextSearchRequest):
"""
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
"""
try:
# Workflow 返回 {'results': [...], 'msg': '...'}
res = workflow.search_with_embedding(req.query, req.task_id, req.limit)
return make_response(1, res.pop("msg", "Search Success"), res)
except Exception as e:
return make_response(0, f"Search Failed: {str(e)}")
# ==========================================
# 挂载路由
# ==========================================
app.include_router(router_v1, prefix="/api/v1", tags=["V1 Manual API"])
app.include_router(router_v2, prefix="/api/v2", tags=["V2 Automated Workflow"])
app.include_router(v1.router)
app.include_router(v2.router)
if __name__ == "__main__":
import uvicorn
# 提示:运行方式 uv run backend/main.py 或 uvicorn backend.main:app --reload
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)

30
backend/routers/v1.py Normal file
View File

@@ -0,0 +1,30 @@
from fastapi import APIRouter
from backend.services.data_service import data_service
from backend.utils.common import make_response
from backend.schemas.schemas import RegisterRequest, AddUrlsRequest, PendingRequest, SearchRequest
router = APIRouter(prefix="/api/v1", tags=["V1 Manual"])
@router.post("/register")
async def register(req: RegisterRequest):
try:
res = data_service.register_task(req.url)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/add_urls")
async def add_urls(req: AddUrlsRequest):
try:
res = data_service.add_urls(req.task_id, req.urls_obj["urls"])
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/search")
async def search_manual(req: SearchRequest):
try:
res = data_service.search(req.query_embedding['vector'], req.task_id, req.limit)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))

30
backend/routers/v2.py Normal file
View File

@@ -0,0 +1,30 @@
from fastapi import APIRouter, BackgroundTasks
from backend.services.crawler_service import crawler_service
from backend.utils.common import make_response
from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest
router = APIRouter(prefix="/api/v2", tags=["V2 Automated"])
@router.post("/crawler/map")
async def auto_map(req: AutoMapRequest):
try:
res = crawler_service.map_site(req.url)
return make_response(1, res.pop("msg", "Started"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/crawler/process")
async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks):
try:
bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size)
return make_response(1, "Background processing started", {"task_id": req.task_id})
except Exception as e:
return make_response(0, str(e))
@router.post("/search")
async def search_smart(req: TextSearchRequest):
try:
res = crawler_service.search(req.query, req.task_id, req.limit)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))

View File

@@ -1,201 +0,0 @@
import dashscope
from http import HTTPStatus
from firecrawl import FirecrawlApp
from langchain_text_splitters import RecursiveCharacterTextSplitter
from ..config import settings
from .crawler_sql_service import crawler_sql_service
# 初始化配置
dashscope.api_key = settings.DASHSCOPE_API_KEY
class AutomatedCrawler:
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def _get_embedding(self, text: str):
"""内部方法:调用 Dashscope 生成向量"""
# 注意:此方法是内部辅助,出错返回 None由调用方处理状态
embedding = None
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
input=text,
dimension=1536
)
if resp.status_code == HTTPStatus.OK:
embedding = resp.output['embeddings'][0]['embedding']
else:
print(f"Embedding API Error: {resp}")
except Exception as e:
print(f"Embedding Exception: {e}")
return embedding
def map_and_ingest(self, start_url: str):
"""
V2 步骤1: 地图式扫描并入库
"""
print(f"[WorkFlow] Start mapping: {start_url}")
result = {}
try:
# 1. 在数据库注册任务
task_info = crawler_sql_service.register_task(start_url)
task_id = task_info['task_id']
is_new_task = task_info['is_new_task']
# 2. 调用 Firecrawl Map
if is_new_task:
map_result = self.firecrawl.map(start_url)
urls = []
# 兼容 firecrawl sdk 不同版本的返回结构
# 如果 map_result 是对象且有 links 属性
if hasattr(map_result, 'links'):
for link in map_result.links:
# 假设 link 是对象或字典,视具体 SDK 版本而定
# 如果 link 是字符串直接 append
if isinstance(link, str):
urls.append(link)
else:
urls.append(getattr(link, 'url', str(link)))
# 如果是字典
elif isinstance(map_result, dict):
urls = map_result.get('links', [])
print(f"[WorkFlow] Found {len(urls)} links")
# 3. 批量入库
res = {"msg": "No urls found to add"}
if urls:
res = crawler_sql_service.add_urls(task_id, urls)
result = {
"msg": "Task successfully mapped and URLs added",
"task_id": task_id,
"is_new_task": is_new_task,
"url_count": len(urls),
"map_detail": res
}
else:
result = {
"msg": "Task already exists, skipped mapping",
"task_id": task_id,
"is_new_task": False,
"url_count": 0,
"map_detail": {}
}
except Exception as e:
print(f"[WorkFlow] Map Error: {e}")
# 向上抛出异常,由 main.py 捕获并返回错误 Response
raise e
return result
def process_task_queue(self, task_id: int, limit: int = 10):
"""
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
"""
processed_count = 0
total_chunks_saved = 0
result = {}
# 1. 获取待处理 URL
pending = crawler_sql_service.get_pending_urls(task_id, limit)
urls = pending['urls']
if not urls:
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
else:
for url in urls:
try:
print(f"[WorkFlow] Processing: {url}")
# 2. 单页抓取
scrape_res = self.firecrawl.scrape(
url,
params={'formats': ['markdown'], 'onlyMainContent': True}
)
# 兼容 SDK 返回类型 (对象或字典)
content = ""
metadata = {}
if isinstance(scrape_res, dict):
content = scrape_res.get('markdown', '')
metadata = scrape_res.get('metadata', {})
else:
content = getattr(scrape_res, 'markdown', '')
metadata = getattr(scrape_res, 'metadata', {})
if not metadata and hasattr(scrape_res, 'metadata_dict'):
metadata = scrape_res.metadata_dict
title = metadata.get('title', url)
if not content:
print(f"[WorkFlow] Skip empty content: {url}")
continue
# 3. 切片
chunks = self.splitter.split_text(content)
results_to_save = []
# 4. 向量化
for idx, chunk_text in enumerate(chunks):
vector = self._get_embedding(chunk_text)
if vector:
results_to_save.append({
"source_url": url,
"chunk_index": idx,
"title": title,
"content": chunk_text,
"embedding": vector
})
# 5. 保存
if results_to_save:
save_res = crawler_sql_service.save_results(task_id, results_to_save)
processed_count += 1
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
except Exception as e:
print(f"[WorkFlow] Error processing {url}: {e}")
# 此处不抛出异常,以免打断整个批次的循环
# 实际生产建议在这里调用 service 将 url 标记为 failed
result = {
"msg": f"Batch processing complete. URLs processed: {processed_count}",
"processed_urls": processed_count,
"total_chunks_saved": total_chunks_saved
}
return result
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
"""
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
"""
result = {}
# 1. 获取向量
vector = self._get_embedding(query_text)
if not vector:
result = {
"msg": "Failed to generate embedding for query",
"results": []
}
else:
# 2. 执行搜索
# search_knowledge 现在已经返回带 msg 的字典了
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
return result
# 单例模式
workflow = AutomatedCrawler()

View File

@@ -0,0 +1,150 @@
# backend/services/crawler_service.py
from firecrawl import FirecrawlApp
from backend.core.config import settings
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
from backend.utils.text_process import text_processor
class CrawlerService:
"""
爬虫编排服务
协调 Firecrawl, LLM, 和 DataService
"""
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
def map_site(self, start_url: str):
print(f"[INFO] Mapping: {start_url}")
try:
# 1. 注册任务
task_res = data_service.register_task(start_url)
# 2. 无论是否新任务,都尝试把 start_url 加入队列
urls_to_add = [start_url]
# 3. 调用 Firecrawl Map
try:
map_res = self.firecrawl.map(start_url)
# 兼容不同 SDK 版本的返回 (对象 或 字典)
found_links = []
if isinstance(map_res, dict):
found_links = map_res.get('links', [])
else:
found_links = getattr(map_res, 'links', [])
for link in found_links:
# 链接可能是字符串,也可能是对象
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
urls_to_add.append(u)
print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.")
except Exception as e:
print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}")
# 4. 批量入库
if urls_to_add:
data_service.add_urls(task_res['task_id'], urls_to_add)
return {
"msg": "Mapped successfully",
"count": len(urls_to_add),
"task_id": task_res['task_id']
}
except Exception as e:
print(f"[ERROR] Map failed: {e}")
raise e
def process_queue(self, task_id: int, batch_size: int = 5):
"""
处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储
"""
pending = data_service.get_pending_urls(task_id, batch_size)
urls = pending['urls']
if not urls: return {"msg": "Queue empty"}
processed = 0
for url in urls:
try:
print(f"[INFO] Scraping: {url}")
# 调用 Firecrawl
scrape_res = self.firecrawl.scrape(
url,
formats=['markdown'],
only_main_content=True
)
# =====================================================
# 【核心修复】兼容字典和对象两种返回格式
# =====================================================
raw_md = ""
metadata = None
# 1. 提取 markdown 和 metadata 本体
if isinstance(scrape_res, dict):
raw_md = scrape_res.get('markdown', '')
metadata = scrape_res.get('metadata', {})
else:
# 如果是对象 (ScrapeResponse)
raw_md = getattr(scrape_res, 'markdown', '')
metadata = getattr(scrape_res, 'metadata', {})
# 2. 从 metadata 中安全提取 title
title = url # 默认用 url 当标题
if metadata:
if isinstance(metadata, dict):
# 如果 metadata 是字典
title = metadata.get('title', url)
else:
# 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方
title = getattr(metadata, 'title', url)
# =====================================================
if not raw_md:
print(f"[WARN] Empty content for {url}")
continue
# 1. 清洗
clean_md = text_processor.clean_markdown(raw_md)
if not clean_md: continue
# 2. 切分
chunks = text_processor.split_markdown(clean_md)
chunks_data = []
for i, chunk in enumerate(chunks):
content = chunk['content']
headers = chunk['metadata']
header_path = " > ".join(headers.values())
# 3. 向量化 (Title + Header + Content)
emb_input = f"{title}\n{header_path}\n{content}"
vector = llm_service.get_embedding(emb_input)
if vector:
chunks_data.append({
"index": i,
"content": content,
"embedding": vector,
"meta_info": {"header_path": header_path, "headers": headers}
})
# 4. 存储
if chunks_data:
data_service.save_chunks(task_id, url, title, chunks_data)
processed += 1
except Exception as e:
# 打印详细错误方便调试
print(f"[ERROR] Process URL {url} failed: {e}")
return {"msg": "Batch processed", "count": processed}
def search(self, query: str, task_id, limit: int):
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit)
crawler_service = CrawlerService()

View File

@@ -1,230 +0,0 @@
from sqlalchemy import select, insert, update, and_
from ..database import db_instance
from ..utils import normalize_url
class CrawlerSqlService:
def __init__(self):
self.db = db_instance
def register_task(self, url: str):
"""完全使用库 API 实现的注册"""
clean_url = normalize_url(url)
result = {}
with self.db.engine.begin() as conn:
# 使用 select() API
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
result = {
"task_id": existing[0],
"is_new_task": False,
"msg": "Task already exists"
}
else:
# 使用 insert() API
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
result = {
"task_id": new_task[0],
"is_new_task": True,
"msg": "New task created successfully"
}
return result
def add_urls(self, task_id: int, urls: list[str]):
"""通用 API 实现的批量添加(含详细返回)"""
success_urls, skipped_urls, failed_urls = [], [], []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
# 检查队列中是否已存在该 URL
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if conn.execute(check_q).fetchone():
skipped_urls.append(clean_url)
continue
# 插入新 URL
conn.execute(insert(self.db.queue).values(
task_id=task_id, url=clean_url, status='pending'
))
success_urls.append(clean_url)
except Exception:
failed_urls.append(clean_url)
# 构造返回消息
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
return {
"success_urls": success_urls,
"skipped_urls": skipped_urls,
"failed_urls": failed_urls,
"msg": msg
}
def get_pending_urls(self, task_id: int, limit: int):
"""原子锁定 API 实现"""
result = {}
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
if urls:
upd = update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
conn.execute(upd)
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
else:
result = {"urls": [], "msg": "Queue is empty"}
return result
def save_results(self, task_id: int, results: list):
"""
保存同一 URL 的多个切片。
"""
if not results:
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
# 1. 基础信息提取
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
target_url = normalize_url(first_item.get('source_url'))
# 结果统计容器
inserted_chunks = []
updated_chunks = []
failed_chunks = []
is_page_update = False
with self.db.engine.begin() as conn:
# 2. 判断该 URL 是否已经有切片存在
check_page_stmt = select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
).limit(1)
if conn.execute(check_page_stmt).fetchone():
is_page_update = True
# 3. 逐个处理切片
for res in results:
data = res if isinstance(res, dict) else res.__dict__
c_idx = data.get('chunk_index')
try:
# 检查切片是否存在
find_chunk_stmt = select(self.db.chunks.c.id).where(
and_(
self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == target_url,
self.db.chunks.c.chunk_index == c_idx
)
)
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
if existing_chunk:
# 覆盖更新
upd_stmt = update(self.db.chunks).where(
self.db.chunks.c.id == existing_chunk[0]
).values(
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(upd_stmt)
updated_chunks.append(c_idx)
else:
# 插入新切片
ins_stmt = insert(self.db.chunks).values(
task_id=task_id,
source_url=target_url,
chunk_index=c_idx,
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(ins_stmt)
inserted_chunks.append(c_idx)
except Exception as e:
print(f"Chunk {c_idx} failed: {e}")
failed_chunks.append(c_idx)
# 4. 更新队列状态
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
).values(status='completed')
)
# 构造返回
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
return {
"source_url": target_url,
"is_page_update": is_page_update,
"detail": {
"inserted_chunk_indexes": inserted_chunks,
"updated_chunk_indexes": updated_chunks,
"failed_chunk_indexes": failed_chunks
},
"counts": {
"inserted": len(inserted_chunks),
"updated": len(updated_chunks),
"failed": len(failed_chunks)
},
"msg": msg
}
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
"""
results = []
msg = ""
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.chunk_index
)
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
rows = conn.execute(stmt).fetchall()
for r in rows:
results.append({
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"chunk_index": r[4]
})
if results:
msg = f"Found {len(results)} matches"
else:
msg = "No matching content found"
return {"results": results, "msg": msg}
crawler_sql_service = CrawlerSqlService()

View File

@@ -0,0 +1,163 @@
from sqlalchemy import select, insert, update, and_, text, func, desc
from backend.core.database import db
from backend.utils.common import normalize_url
class DataService:
"""
数据持久化服务层
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
"""
def __init__(self):
self.db = db
def register_task(self, url: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
else:
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
def add_urls(self, task_id: int, urls: list[str]):
success_urls = []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if not conn.execute(check_q).fetchone():
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
success_urls.append(clean_url)
except Exception:
pass
return {"msg": f"Added {len(success_urls)} new urls"}
def get_pending_urls(self, task_id: int, limit: int):
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
if urls:
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing'))
return {"urls": urls, "msg": "Fetched pending urls"}
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
"""
保存切片数据 (Phase 1.5: 支持 meta_info)
"""
clean_url = normalize_url(source_url)
count = 0
with self.db.engine.begin() as conn:
for item in chunks_data:
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
idx = item['index']
meta = item.get('meta_info', {})
# 检查是否存在
existing = conn.execute(select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == clean_url,
self.db.chunks.c.chunk_index == idx)
)).fetchone()
values = {
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
"title": title, "content": item['content'], "embedding": item['embedding'],
"meta_info": meta
}
if existing:
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
else:
conn.execute(insert(self.db.chunks).values(**values))
count += 1
# 标记队列完成
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
return {"msg": f"Saved {count} chunks", "count": count}
def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5):
"""
Phase 2: 混合检索 (Hybrid Search)
综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword)
"""
results = []
with self.db.engine.connect() as conn:
# 定义混合检索的 SQL 逻辑
# 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法)
keyword_query = func.websearch_to_tsquery('english', query_text)
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
# 综合打分列: 0.7 * Vector + 0.3 * Keyword
# coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.meta_info,
final_score
)
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 按综合分数倒序
stmt = stmt.order_by(desc("score")).limit(limit)
try:
rows = conn.execute(stmt).fetchall()
results = [
{
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"meta_info": r[4],
"score": float(r[5])
}
for r in rows
]
except Exception as e:
print(f"[ERROR] Hybrid search failed: {e}")
return self._fallback_vector_search(query_vector, task_id, limit)
return {"results": results, "msg": f"Hybrid found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit):
"""降级兜底:纯向量搜索"""
print("[WARN] Fallback to pure vector search")
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall()
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
data_service = DataService()

View File

@@ -0,0 +1,30 @@
import dashscope
from http import HTTPStatus
from backend.core.config import settings
class LLMService:
"""
LLM 服务封装层
负责与 DashScope 或其他模型供应商交互
"""
def __init__(self):
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embedding(self, text: str, dimension: int = 1536):
"""生成文本向量"""
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4,
input=text,
dimension=dimension
)
if resp.status_code == HTTPStatus.OK:
return resp.output['embeddings'][0]['embedding']
else:
print(f"[ERROR] Embedding API Error: {resp}")
return None
except Exception as e:
print(f"[ERROR] Embedding Exception: {e}")
return None
llm_service = LLMService()

View File

@@ -1,39 +0,0 @@
from urllib.parse import urlparse, urlunparse
from sqlalchemy import create_engine, MetaData, Table, select, update, and_
# backend/llm_service.py
import dashscope
from http import HTTPStatus
from .config import settings
# 初始化 Dashscope
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embeddings(texts: list[str]):
"""调用通义千问 embedding 模型"""
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 或其他模型
input=texts
)
if resp.status_code == HTTPStatus.OK:
return [item['embedding'] for item in resp.output['embeddings']]
else:
print(f"Embedding Error: {resp}")
return []
def normalize_url(url: str) -> str:
if not url: return ""
url = url.strip()
parsed = urlparse(url)
scheme = parsed.scheme.lower()
netloc = parsed.netloc.lower()
path = parsed.path.rstrip('/')
if not path: path = ""
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
def make_response(code: int, msg: str = "Success", data: any = None):
"""
统一响应格式
:param code: 1 成功, 0 失败, 其他自定义
:param msg: 提示信息
:param data: 返回数据
"""
return {"code": code, "msg": msg, "data": data}

26
backend/utils/common.py Normal file
View File

@@ -0,0 +1,26 @@
from urllib.parse import urlparse, urlunparse
def make_response(code: int, msg: str = "Success", data: any = None):
"""统一 API 响应格式封装"""
return {"code": code, "msg": msg, "data": data}
def normalize_url(url: str) -> str:
"""
URL 标准化处理
1. 去除首尾空格
2. 移除 fragment (#后面的内容)
3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面)
4. 移除尾部斜杠
"""
if not url:
return ""
parsed = urlparse(url.strip())
# 重新组合scheme, netloc, path, params, query, fragment
# 这里我们只保留 scheme, netloc, path
clean_path = parsed.path.rstrip('/')
# 构造新的 parsed 对象 (param, query, fragment 置空)
new_parsed = parsed._replace(path=clean_path, params='', query='', fragment='')
return urlunparse(new_parsed)

View File

@@ -0,0 +1,61 @@
import re
from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
class TextProcessor:
"""文本处理工具类:负责 Markdown 清洗和切分"""
def __init__(self):
# 基于 Markdown 标题的语义切分器
self.md_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "h1"),
("##", "h2"),
("###", "h3"),
],
strip_headers=False
)
# 备用的字符切分器
self.char_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def clean_markdown(self, text: str) -> str:
"""清洗 Markdown 中的网页噪音"""
if not text: return ""
# 去除 'Skip to main content'
text = re.sub(r'\[Skip to main content\].*?\n', '', text, flags=re.IGNORECASE)
# 去除页脚导航 (Previous / Next)
text = re.sub(r'\[Previous\].*?\[Next\].*', '', text, flags=re.DOTALL | re.IGNORECASE)
return text.strip()
def split_markdown(self, text: str):
"""执行切分策略:先按标题切,过长则按字符切"""
md_chunks = self.md_splitter.split_text(text)
final_chunks = []
for chunk in md_chunks:
# chunk.page_content 是文本
# chunk.metadata 是标题层级
if len(chunk.page_content) > 1000:
sub_texts = self.char_splitter.split_text(chunk.page_content)
for sub in sub_texts:
final_chunks.append({
"content": sub,
"metadata": chunk.metadata
})
else:
final_chunks.append({
"content": chunk.page_content,
"metadata": chunk.metadata
})
return final_chunks
# 单例工具
text_processor = TextProcessor()

View File

@@ -1,4 +1,3 @@
# Wiki Crawler Backend 部署操作手册
## 核心配置信息 (每次只需修改这里)
@@ -7,7 +6,7 @@
| **字段** | **当前值 (示例)** | **说明** | **每次要改吗?** |
| -------------------- | ---------------------------------- | ------------------------------ | ---------------------- |
| **Version** | **v1.0.3** | **镜像的版本标签 (Tag)** | **是 (必须改)** |
| **Version** | **v1.0.4** | **镜像的版本标签 (Tag)** | **是 (必须改)** |
| **Image Name** | **wiki-crawl-backend** | **镜像/容器的名字** | **否 (固定)** |
| **Namespace** | **qg-demo** | **阿里云命名空间** | **否 (固定)** |
| **Registry** | **crpi-1rwd6fvain6t49g2...** | **阿里云仓库地址** | **否 (固定)** |
@@ -20,18 +19,18 @@
### 1. 构建镜像 (Build)
**修改命令最后的版本号** **v1.0.3**
**修改命令最后的版本号** **v1.0.4**
```powershell
docker build -t crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3 .
docker build -t crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4 .
```
### 2. 推送镜像 (Push)
**修改命令最后的版本号** **v1.0.3**
**修改命令最后的版本号** **v1.0.4**
```powershell
docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3
docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4
```
> **成功标准:** **看到进度条走完,且最后显示** **Pushed**。
@@ -44,10 +43,10 @@ docker push crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/w
### 1. 拉取新镜像 (Pull)
**修改命令最后的版本号** **v1.0.3**
**修改命令最后的版本号** **v1.0.4**
```bash
docker pull crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3
docker pull crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4
```
### 2. 停止并删除旧容器
@@ -61,7 +60,7 @@ docker rm wiki-crawl-backend
### 3. 启动新容器 (Run) - 关键步骤
**修改命令最后的版本号** **v1.0.3**
**修改命令最后的版本号** **v1.0.4**
**code**Bash
@@ -69,7 +68,7 @@ docker rm wiki-crawl-backend
docker run -d --name wiki-crawl-backend \
-e PYTHONUNBUFFERED=1 \
-p 80:8000 \
crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3
crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4
```
### 4. 验证与日志查看
@@ -132,9 +131,9 @@ docker image prune -a -f
### 5. 那个超长的 URL
**crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.3**
**crpi-1rwd6fvain6t49g2.cn-hangzhou.personal.cr.aliyuncs.com/qg-demo/wiki-crawl-backend:v1.0.4**
* **Registry (仓库地址)**: **crpi-1rwd...aliyuncs.com** **-> 你的专属阿里云仓库服务器。**
* **Namespace (命名空间)**: **qg-demo** **-> 你在仓库里划出的个人地盘。**
* **Image Name (镜像名)**: **wiki-crawl-backend** **-> 这个项目的名字。**
* **Tag (标签)**: **v1.0.3** **-> 相当于软件的版本号。如果不写 Tag默认就是** **latest**。**生产环境强烈建议写明确的版本号**,方便回滚(比如 1.0.3 挂了,你可以立马用 1.0.2 启动)。
* **Tag (标签)**: **v1.0.4** **-> 相当于软件的版本号。如果不写 Tag默认就是** **latest**。**生产环境强烈建议写明确的版本号**,方便回滚(比如 1.0.3 挂了,你可以立马用 1.0.2 启动)。

View File

@@ -17,6 +17,7 @@ dependencies = [
"pinecone>=8.0.0",
"pip>=25.3",
"psycopg2-binary>=2.9.11",
"pydantic-settings>=2.12.0",
"pymilvus>=2.6.5",
"qdrant-client==1.10.1",
"redis>=7.1.0",

View File

@@ -1,96 +1,167 @@
import requests
import time
import json
import random
import sys
# 配置后端地址
BASE_URL = "http://47.122.127.178"
# ================= 配置区域 =================
BASE_URL = "http://127.0.0.1:8000"
def log_res(name, response):
print(f"\n=== 测试接口: {name} ===")
if response.status_code == 200:
res_json = response.json()
print(f"状态: 成功 (HTTP 200)")
print(f"返回数据: {json.dumps(res_json, indent=2, ensure_ascii=False)}")
return res_json
else:
print(f"状态: 失败 (HTTP {response.status_code})")
print(f"错误信息: {response.text}")
return None
# 使用 Dify 文档作为测试对象 (结构清晰,适合验证 Markdown 切分)
TEST_URL = "https://docs.dify.ai/en/use-dify/knowledge/create-knowledge/import-text-data/readme"
def run_tests():
# 测试数据准备
test_root_url = f"https://example.com/wiki_{random.randint(1000, 9999)}"
# 测试查询词 (确保能命中上面的页面)
TEST_QUERY = "upload size limit"
# ===========================================
class Colors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
def log(step: str, msg: str, color=Colors.OKBLUE):
print(f"{color}[{step}] {msg}{Colors.ENDC}")
def run_e2e_test():
print(f"{Colors.HEADER}=== 开始 Wiki Crawler E2E 完整测试 ==={Colors.ENDC}")
# 0. 后端健康检查
try:
requests.get(f"{BASE_URL}/docs", timeout=3)
except Exception:
log("FATAL", "无法连接后端,请确保 main.py 正在运行 (http://127.0.0.1:8000)", Colors.FAIL)
sys.exit(1)
# ---------------------------------------------------------
# Step 1: 地图式扫描 (Map)
# ---------------------------------------------------------
# log("STEP 1", f"注册任务并扫描链接: {TEST_URL}")
# 1. 测试 /register
print("步骤 1: 注册新任务...")
res = requests.post(f"{BASE_URL}/register", json={"url": test_root_url})
data = log_res("注册任务", res)
if not data or data['code'] != 1: return
task_id = data['data']['task_id']
# task_id = None
# try:
# res = requests.post(f"{BASE_URL}/api/v2/crawler/map", json={"url": TEST_URL})
# res_json = res.json()
# # 验证响应状态
# if res_json.get('code') != 1:
# log("FAIL", f"Map 接口返回错误: {res_json}", Colors.FAIL)
# sys.exit(1)
# data = res_json['data']
# task_id = data['task_id']
# count = data.get('count', 0)
# log("SUCCESS", f"任务注册成功。Task ID: {task_id}, 待爬取链接数: {count}", Colors.OKGREEN)
# except Exception as e:
# log("FAIL", f"请求异常: {e}", Colors.FAIL)
# sys.exit(1)
# 2. 测试 /add_urls
print("\n步骤 2: 模拟爬虫发现了新链接,存入队列...")
sub_urls = [
f"{test_root_url}/page1",
f"{test_root_url}/page2",
f"{test_root_url}/page1" # 故意重复一个,测试后端去重
]
res = requests.post(f"{BASE_URL}/add_urls", json={
"task_id": task_id,
"urls": sub_urls
})
log_res("存入新链接", res)
# 3. 测试 /pending_urls
print("\n步骤 3: 模拟爬虫节点获取待处理任务...")
res = requests.post(f"{BASE_URL}/pending_urls", json={
"task_id": task_id,
"limit": 2
})
data = log_res("获取待处理URL", res)
if not data or not data['data']['urls']:
print("没有获取到待处理URL停止后续测试")
return
# ---------------------------------------------------------
# Step 2: 触发后台处理 (Process)
# ---------------------------------------------------------
# task_id = 6
# log("STEP 2", f"触发后台处理 -> Task ID: {task_id}")
target_url = data['data']['urls'][0]
# try:
# res = requests.post(
# f"{BASE_URL}/api/v2/crawler/process",
# json={"task_id": task_id, "batch_size": 5}
# )
# res_json = res.json()
# if res_json.get('code') == 1:
# log("SUCCESS", "后台处理任务已启动...", Colors.OKGREEN)
# else:
# log("FAIL", f"启动失败: {res_json}", Colors.FAIL)
# sys.exit(1)
# except Exception as e:
# log("FAIL", f"请求异常: {e}", Colors.FAIL)
# sys.exit(1)
# 4. 测试 /save_results
print("\n步骤 4: 模拟爬虫抓取完成,存入知识片段和向量...")
# 模拟一个 1536 维的向量(已处理精度)
mock_embedding = [round(random.uniform(-1, 1), 8) for _ in range(1536)]
# ---------------------------------------------------------
# Step 3: 轮询搜索结果 (Polling)
# ---------------------------------------------------------
log("STEP 3", "轮询搜索接口,等待数据入库...")
task_id = 6
max_retries = 12
found_data = False
search_results = []
for i in range(max_retries):
print(f" ⏳ 第 {i+1}/{max_retries} 次尝试搜索...", end="\r")
time.sleep(5) # 每次等待 5 秒,给爬虫和 Embedding 一点时间
try:
# 调用 V2 智能搜索接口
search_res = requests.post(
f"{BASE_URL}/api/v2/search",
json={
"query": TEST_QUERY,
"task_id": task_id,
"limit": 3
}
)
resp_json = search_res.json()
# 解析响应结构: {code: 1, msg: "...", data: {results: [...]}}
if resp_json['code'] == 1:
data_body = resp_json['data']
# 兼容性检查:确保 results 存在且不为空
if data_body and 'results' in data_body and len(data_body['results']) > 0:
search_results = data_body['results']
found_data = True
print("") # 换行
log("SUCCESS", f"✅ 成功搜索到 {len(search_results)} 条相关切片!", Colors.OKGREEN)
break
except Exception as e:
# 忽略网络抖动,继续重试
pass
if not found_data:
print("")
log("FAIL", "❌ 超时:未能在规定时间内搜索到数据。请检查后端日志是否有报错。", Colors.FAIL)
sys.exit(1)
# ---------------------------------------------------------
# Step 4: 验证 Phase 1.5 成果 (Meta Info)
# ---------------------------------------------------------
log("STEP 4", "验证结构化数据 (Phase 1.5 Check)")
payload = {
"task_id": task_id,
"results": [
{
"source_url": target_url,
"chunk_index": 0,
"title": "测试页面标题 - 切片1",
"content": "这是模拟抓取到的第一段网页内容...",
"embedding": mock_embedding
},
{
"source_url": target_url,
"chunk_index": 1,
"title": "测试页面标题 - 切片2",
"content": "这是模拟抓取到的第二段网页内容...",
"embedding": mock_embedding
}
]
first_result = search_results[0]
# 打印第一条结果用于人工确认
print(f"\n{Colors.WARNING}--- 检索结果样本 ---{Colors.ENDC}")
print(f"Title: {first_result.get('title')}")
print(f"URL: {first_result.get('source_url')}")
print(f"Meta: {json.dumps(first_result.get('meta_info', {}), ensure_ascii=False)}")
print(f"Content Preview: {first_result.get('content')[:50]}...")
print(f"{Colors.WARNING}----------------------{Colors.ENDC}\n")
# 自动化断言
checks = {
"Has Content": bool(first_result.get('content')),
"Has Meta Info": 'meta_info' in first_result,
"Has Header Path": 'header_path' in first_result.get('meta_info', {}),
"Headers Dict Exists": 'headers' in first_result.get('meta_info', {})
}
res = requests.post(f"{BASE_URL}/save_results", json=payload)
log_res("保存结果", res)
# 5. 测试 /search
print("\n步骤 5: 测试基于向量的搜索...")
query = [round(random.uniform(-1, 1), 8) for _ in range(1536)]
res = requests.post(f"{BASE_URL}/search", json={
"task_id": None,
"query_embedding": query,
"limit": 5
})
log_res("基于向量的搜索", res)
print("\n✅ 所有 API 流程测试完成!")
all_pass = True
for name, passed in checks.items():
status = f"{Colors.OKGREEN}PASS{Colors.ENDC}" if passed else f"{Colors.FAIL}FAIL{Colors.ENDC}"
print(f"检查项 [{name}]: {status}")
if not passed:
all_pass = False
if all_pass:
meta = first_result['meta_info']
print(f"\n{Colors.OKBLUE}🎉 测试通过!系统已具备 Phase 1.5 (结构化 RAG) 能力。{Colors.ENDC}")
print(f"提取到的上下文路径: {Colors.HEADER}{meta.get('header_path', 'N/A')}{Colors.ENDC}")
else:
print(f"\n{Colors.FAIL}❌ 测试未完全通过:缺少必要的元数据字段。请检查 crawler_service.py 或 update_db.py。{Colors.ENDC}")
if __name__ == "__main__":
run_tests()
run_e2e_test()

82
scripts/update_sql.py Normal file
View File

@@ -0,0 +1,82 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from sqlalchemy import create_engine, text
from backend.core.config import settings
def update_database_schema():
"""
数据库无损升级脚本
"""
print(f"🔌 连接数据库: {settings.DB_NAME}...")
engine = create_engine(settings.DATABASE_URL)
commands = [
# 1. 安全添加 meta_info 列 (旧数据会自动填充为 {})
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='knowledge_chunks' AND column_name='meta_info') THEN
ALTER TABLE knowledge_chunks ADD COLUMN meta_info JSONB DEFAULT '{}';
RAISE NOTICE '已添加 meta_info 列';
END IF;
END $$;
""",
# 2. 安全添加 content_tsvector 列
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM information_schema.columns WHERE table_name='knowledge_chunks' AND column_name='content_tsvector') THEN
ALTER TABLE knowledge_chunks ADD COLUMN content_tsvector TSVECTOR;
RAISE NOTICE '已添加 content_tsvector 列';
END IF;
END $$;
""",
# 3. 创建索引 (不影响现有数据)
"CREATE INDEX IF NOT EXISTS idx_chunks_meta ON knowledge_chunks USING GIN (meta_info);",
"CREATE INDEX IF NOT EXISTS idx_chunks_tsvector ON knowledge_chunks USING GIN (content_tsvector);",
# 4. 创建触发器函数 (用于新插入的数据)
"""
CREATE OR REPLACE FUNCTION chunks_tsvector_trigger() RETURNS trigger AS $$
BEGIN
new.content_tsvector := to_tsvector('english', coalesce(new.title, '') || ' ' || new.content);
return new;
END
$$ LANGUAGE plpgsql;
""",
# 5. 绑定触发器
"""
DO $$
BEGIN
IF NOT EXISTS (SELECT 1 FROM pg_trigger WHERE tgname = 'tsvectorupdate') THEN
CREATE TRIGGER tsvectorupdate BEFORE INSERT OR UPDATE
ON knowledge_chunks FOR EACH ROW EXECUTE PROCEDURE chunks_tsvector_trigger();
END IF;
END $$;
""",
# 6. 【新增】回填旧数据
# 让以前存的 task_id=6 的数据也能生成关键词索引
"""
UPDATE knowledge_chunks
SET content_tsvector = to_tsvector('english', coalesce(title, '') || ' ' || content)
WHERE content_tsvector IS NULL;
"""
]
with engine.begin() as conn:
for cmd in commands:
try:
conn.execute(text(cmd))
except Exception as e:
print(f"⚠️ 执行警告 (通常可忽略): {e}")
print("✅ 数据库结构升级完成!旧数据已保留并兼容。")
if __name__ == "__main__":
update_database_schema()

2
uv.lock generated
View File

@@ -2305,6 +2305,7 @@ dependencies = [
{ name = "pinecone" },
{ name = "pip" },
{ name = "psycopg2-binary" },
{ name = "pydantic-settings" },
{ name = "pymilvus" },
{ name = "qdrant-client" },
{ name = "redis" },
@@ -2326,6 +2327,7 @@ requires-dist = [
{ name = "pinecone", specifier = ">=8.0.0" },
{ name = "pip", specifier = ">=25.3" },
{ name = "psycopg2-binary", specifier = ">=2.9.11" },
{ name = "pydantic-settings", specifier = ">=2.12.0" },
{ name = "pymilvus", specifier = ">=2.6.5" },
{ name = "qdrant-client", specifier = "==1.10.1" },
{ name = "redis", specifier = ">=7.1.0" },