变更项目架构,提高扩展性

This commit is contained in:
2026-01-13 01:37:26 +08:00
parent b9dbf1e8f7
commit 9190fee16f
22 changed files with 740 additions and 723 deletions

View File

@@ -1,26 +0,0 @@
import os
class Settings:
# 数据库配置
DB_USER: str = "postgres"
DB_PASS: str = "DXC_welcome001"
DB_HOST: str = "8.155.144.6"
DB_PORT: str = "25432"
DB_NAME: str = "wiki_crawler"
DASHSCOPE_API_KEY: str = "sk-8b091493de594c5e9eb42f12f1cc5805"
FIRECRAWL_API_KEY: str = "fc-8a2af3fb6a014a27a57dfbc728cb7365"
@property # property 方法,意义:将方法转换为属性,调用时不需要加括号
def DATABASE_URL(self) -> str:
url = f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
return url
def API_KEY(self, type: str) -> str:
if type == "dashscope":
return self.DASHSCOPE_API_KEY
elif type == "firecrawl":
return self.FIRECRAWL_API_KEY
else:
raise ValueError(f"Unknown API type: {type}")
settings = Settings()

24
backend/core/config.py Normal file
View File

@@ -0,0 +1,24 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""
系统配置类
自动读取环境变量或 .env 文件
"""
DB_USER: str
DB_PASS: str
DB_HOST: str
DB_PORT: str = "5432"
DB_NAME: str
DASHSCOPE_API_KEY: str
FIRECRAWL_API_KEY: str
# 配置:忽略多余的环境变量,指定编码
model_config = SettingsConfigDict(env_file=".env", extra="ignore", env_file_encoding='utf-8')
@property
def DATABASE_URL(self) -> str:
return f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
settings = Settings()

View File

@@ -1,14 +1,19 @@
from sqlalchemy import create_engine, MetaData, Table, event
from pgvector.sqlalchemy import Vector # 必须导入这个
from sqlalchemy import create_engine, MetaData, Table
from pgvector.sqlalchemy import Vector
from .config import settings
class Database:
"""
数据库单例类
负责初始化连接池并反射加载现有的表结构
"""
def __init__(self):
# 1. 创建引擎
# pool_pre_ping=True 用于解决数据库连接长时间空闲后断开的问题
self.engine = create_engine(settings.DATABASE_URL, pool_pre_ping=True)
# 2. 【核心修复】手动注册 vector 类型,让反射能识别它
# 这告诉 SQLAlchemy:如果在数据库里看到名为 "vector" 的类型,请使用 pgvector 库的 Vector 类来处理
# 2. 注册 pgvector 类型
# 这是为了让 SQLAlchemy 反射机制能识别数据库中的 'vector' 类型
self.engine.dialect.ischema_names['vector'] = Vector
self.metadata = MetaData()
@@ -19,14 +24,15 @@ class Database:
self._reflect_tables()
def _reflect_tables(self):
"""自动从数据库加载表定义"""
try:
# 自动从数据库加载表结构
# 因为上面注册了 ischema_names现在 chunks_table.c.embedding 就能被正确识别为 Vector 类型了
# autoload_with 会查询数据库元数据,自动填充 Column 信息
self.tasks = Table('crawl_tasks', self.metadata, autoload_with=self.engine)
self.queue = Table('crawl_queue', self.metadata, autoload_with=self.engine)
self.chunks = Table('knowledge_chunks', self.metadata, autoload_with=self.engine)
print("[INFO] Database tables reflected successfully.")
except Exception as e:
print(f"❌ 数据库表加载失败: {e}")
print(f"[ERROR] Failed to reflect tables: {e}")
# 全局
db_instance = Database()
# 全局数据库实
db = Database()

View File

@@ -1,131 +1,13 @@
# backend/main.py
from fastapi import FastAPI, APIRouter, BackgroundTasks
# 确保导入路径与你的文件名一致,如果文件名是 workflow.py 则用 workflow
from .services.crawler_sql_service import crawler_sql_service
from .services.automated_crawler import workflow
from .schemas import (
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
AutoMapRequest, AutoProcessRequest, TextSearchRequest
)
from .utils import make_response
from fastapi import FastAPI
from backend.routers import v1, v2
app = FastAPI(title="Wiki Crawler API")
# ==========================================
# 工具函数
# ==========================================
# ==========================================
# V1 Router: 原始的底层接口 (Manual Control)
# ==========================================
router_v1 = APIRouter()
@router_v1.post("/register")
async def register(req: RegisterRequest):
try:
# Service 返回: {'task_id': 1, 'is_new_task': True, 'msg': '...'}
res = crawler_sql_service.register_task(req.url)
# 使用 pop 将 msg 提取出来作为响应的 msg剩下的作为 data
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/add_urls")
async def add_urls(req: AddUrlsRequest):
try:
urls = req.urls_obj["urls"]
res = crawler_sql_service.add_urls(req.task_id, urls=urls)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/pending_urls")
async def pending_urls(req: PendingRequest):
try:
res = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
# 即使队列为空Service 也会返回 msg="Queue is empty"
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/save_results")
async def save_results(req: SaveResultsRequest):
try:
res = crawler_sql_service.save_results(req.task_id, req.results)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router_v1.post("/search")
async def search_v1(req: SearchRequest):
"""V1 搜索:客户端手动传向量"""
try:
vector = req.query_embedding['vector']
if not vector:
return make_response(2, "Vector is empty", None)
# Service 现在返回 {'results': [...], 'msg': 'Found ...'}
res = crawler_sql_service.search_knowledge(
query_embedding=vector,
task_id=req.task_id,
limit=req.limit
)
return make_response(1, res.pop("msg", "Search Done"), res)
except Exception as e:
return make_response(0, str(e))
# ==========================================
# V2 Router: 自动化工作流 (Automated Workflow)
# ==========================================
router_v2 = APIRouter()
@router_v2.post("/crawler/map")
async def auto_map(req: AutoMapRequest):
"""
[同步] 输入首页 URL自动调用 Firecrawl Map 并入库
"""
try:
# Workflow 返回: {'task_id':..., 'msg': 'Task mapped...', ...}
res = workflow.map_and_ingest(req.url)
return make_response(1, res.pop("msg", "Mapping Started"), res)
except Exception as e:
return make_response(0, str(e))
@router_v2.post("/crawler/process")
async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTasks):
"""
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
"""
try:
# 将耗时操作放入后台任务
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
# 因为是后台任务,无法立即获取 Service 的返回值 msg只能返回通用消息
return make_response(1, "Background processing started", {"task_id": req.task_id})
except Exception as e:
return make_response(0, str(e))
@router_v2.post("/search")
async def search_v2(req: TextSearchRequest):
"""
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
"""
try:
# Workflow 返回 {'results': [...], 'msg': '...'}
res = workflow.search_with_embedding(req.query, req.task_id, req.limit)
return make_response(1, res.pop("msg", "Search Success"), res)
except Exception as e:
return make_response(0, f"Search Failed: {str(e)}")
# ==========================================
# 挂载路由
# ==========================================
app.include_router(router_v1, prefix="/api/v1", tags=["V1 Manual API"])
app.include_router(router_v2, prefix="/api/v2", tags=["V2 Automated Workflow"])
app.include_router(v1.router)
app.include_router(v2.router)
if __name__ == "__main__":
import uvicorn
# 提示:运行方式 uv run backend/main.py 或 uvicorn backend.main:app --reload
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)

30
backend/routers/v1.py Normal file
View File

@@ -0,0 +1,30 @@
from fastapi import APIRouter
from backend.services.data_service import data_service
from backend.utils.common import make_response
from backend.schemas.schemas import RegisterRequest, AddUrlsRequest, PendingRequest, SearchRequest
router = APIRouter(prefix="/api/v1", tags=["V1 Manual"])
@router.post("/register")
async def register(req: RegisterRequest):
try:
res = data_service.register_task(req.url)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/add_urls")
async def add_urls(req: AddUrlsRequest):
try:
res = data_service.add_urls(req.task_id, req.urls_obj["urls"])
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/search")
async def search_manual(req: SearchRequest):
try:
res = data_service.search(req.query_embedding['vector'], req.task_id, req.limit)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))

30
backend/routers/v2.py Normal file
View File

@@ -0,0 +1,30 @@
from fastapi import APIRouter, BackgroundTasks
from backend.services.crawler_service import crawler_service
from backend.utils.common import make_response
from backend.schemas.schemas import AutoMapRequest, AutoProcessRequest, TextSearchRequest
router = APIRouter(prefix="/api/v2", tags=["V2 Automated"])
@router.post("/crawler/map")
async def auto_map(req: AutoMapRequest):
try:
res = crawler_service.map_site(req.url)
return make_response(1, res.pop("msg", "Started"), res)
except Exception as e:
return make_response(0, str(e))
@router.post("/crawler/process")
async def auto_process(req: AutoProcessRequest, bg_tasks: BackgroundTasks):
try:
bg_tasks.add_task(crawler_service.process_queue, req.task_id, req.batch_size)
return make_response(1, "Background processing started", {"task_id": req.task_id})
except Exception as e:
return make_response(0, str(e))
@router.post("/search")
async def search_smart(req: TextSearchRequest):
try:
res = crawler_service.search(req.query, req.task_id, req.limit)
return make_response(1, res.pop("msg", "Success"), res)
except Exception as e:
return make_response(0, str(e))

View File

@@ -1,201 +0,0 @@
import dashscope
from http import HTTPStatus
from firecrawl import FirecrawlApp
from langchain_text_splitters import RecursiveCharacterTextSplitter
from ..config import settings
from .crawler_sql_service import crawler_sql_service
# 初始化配置
dashscope.api_key = settings.DASHSCOPE_API_KEY
class AutomatedCrawler:
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
self.splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def _get_embedding(self, text: str):
"""内部方法:调用 Dashscope 生成向量"""
# 注意:此方法是内部辅助,出错返回 None由调用方处理状态
embedding = None
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
input=text,
dimension=1536
)
if resp.status_code == HTTPStatus.OK:
embedding = resp.output['embeddings'][0]['embedding']
else:
print(f"Embedding API Error: {resp}")
except Exception as e:
print(f"Embedding Exception: {e}")
return embedding
def map_and_ingest(self, start_url: str):
"""
V2 步骤1: 地图式扫描并入库
"""
print(f"[WorkFlow] Start mapping: {start_url}")
result = {}
try:
# 1. 在数据库注册任务
task_info = crawler_sql_service.register_task(start_url)
task_id = task_info['task_id']
is_new_task = task_info['is_new_task']
# 2. 调用 Firecrawl Map
if is_new_task:
map_result = self.firecrawl.map(start_url)
urls = []
# 兼容 firecrawl sdk 不同版本的返回结构
# 如果 map_result 是对象且有 links 属性
if hasattr(map_result, 'links'):
for link in map_result.links:
# 假设 link 是对象或字典,视具体 SDK 版本而定
# 如果 link 是字符串直接 append
if isinstance(link, str):
urls.append(link)
else:
urls.append(getattr(link, 'url', str(link)))
# 如果是字典
elif isinstance(map_result, dict):
urls = map_result.get('links', [])
print(f"[WorkFlow] Found {len(urls)} links")
# 3. 批量入库
res = {"msg": "No urls found to add"}
if urls:
res = crawler_sql_service.add_urls(task_id, urls)
result = {
"msg": "Task successfully mapped and URLs added",
"task_id": task_id,
"is_new_task": is_new_task,
"url_count": len(urls),
"map_detail": res
}
else:
result = {
"msg": "Task already exists, skipped mapping",
"task_id": task_id,
"is_new_task": False,
"url_count": 0,
"map_detail": {}
}
except Exception as e:
print(f"[WorkFlow] Map Error: {e}")
# 向上抛出异常,由 main.py 捕获并返回错误 Response
raise e
return result
def process_task_queue(self, task_id: int, limit: int = 10):
"""
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
"""
processed_count = 0
total_chunks_saved = 0
result = {}
# 1. 获取待处理 URL
pending = crawler_sql_service.get_pending_urls(task_id, limit)
urls = pending['urls']
if not urls:
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
else:
for url in urls:
try:
print(f"[WorkFlow] Processing: {url}")
# 2. 单页抓取
scrape_res = self.firecrawl.scrape(
url,
params={'formats': ['markdown'], 'onlyMainContent': True}
)
# 兼容 SDK 返回类型 (对象或字典)
content = ""
metadata = {}
if isinstance(scrape_res, dict):
content = scrape_res.get('markdown', '')
metadata = scrape_res.get('metadata', {})
else:
content = getattr(scrape_res, 'markdown', '')
metadata = getattr(scrape_res, 'metadata', {})
if not metadata and hasattr(scrape_res, 'metadata_dict'):
metadata = scrape_res.metadata_dict
title = metadata.get('title', url)
if not content:
print(f"[WorkFlow] Skip empty content: {url}")
continue
# 3. 切片
chunks = self.splitter.split_text(content)
results_to_save = []
# 4. 向量化
for idx, chunk_text in enumerate(chunks):
vector = self._get_embedding(chunk_text)
if vector:
results_to_save.append({
"source_url": url,
"chunk_index": idx,
"title": title,
"content": chunk_text,
"embedding": vector
})
# 5. 保存
if results_to_save:
save_res = crawler_sql_service.save_results(task_id, results_to_save)
processed_count += 1
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
except Exception as e:
print(f"[WorkFlow] Error processing {url}: {e}")
# 此处不抛出异常,以免打断整个批次的循环
# 实际生产建议在这里调用 service 将 url 标记为 failed
result = {
"msg": f"Batch processing complete. URLs processed: {processed_count}",
"processed_urls": processed_count,
"total_chunks_saved": total_chunks_saved
}
return result
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
"""
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
"""
result = {}
# 1. 获取向量
vector = self._get_embedding(query_text)
if not vector:
result = {
"msg": "Failed to generate embedding for query",
"results": []
}
else:
# 2. 执行搜索
# search_knowledge 现在已经返回带 msg 的字典了
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
return result
# 单例模式
workflow = AutomatedCrawler()

View File

@@ -0,0 +1,150 @@
# backend/services/crawler_service.py
from firecrawl import FirecrawlApp
from backend.core.config import settings
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
from backend.utils.text_process import text_processor
class CrawlerService:
"""
爬虫编排服务
协调 Firecrawl, LLM, 和 DataService
"""
def __init__(self):
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
def map_site(self, start_url: str):
print(f"[INFO] Mapping: {start_url}")
try:
# 1. 注册任务
task_res = data_service.register_task(start_url)
# 2. 无论是否新任务,都尝试把 start_url 加入队列
urls_to_add = [start_url]
# 3. 调用 Firecrawl Map
try:
map_res = self.firecrawl.map(start_url)
# 兼容不同 SDK 版本的返回 (对象 或 字典)
found_links = []
if isinstance(map_res, dict):
found_links = map_res.get('links', [])
else:
found_links = getattr(map_res, 'links', [])
for link in found_links:
# 链接可能是字符串,也可能是对象
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
urls_to_add.append(u)
print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.")
except Exception as e:
print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}")
# 4. 批量入库
if urls_to_add:
data_service.add_urls(task_res['task_id'], urls_to_add)
return {
"msg": "Mapped successfully",
"count": len(urls_to_add),
"task_id": task_res['task_id']
}
except Exception as e:
print(f"[ERROR] Map failed: {e}")
raise e
def process_queue(self, task_id: int, batch_size: int = 5):
"""
处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储
"""
pending = data_service.get_pending_urls(task_id, batch_size)
urls = pending['urls']
if not urls: return {"msg": "Queue empty"}
processed = 0
for url in urls:
try:
print(f"[INFO] Scraping: {url}")
# 调用 Firecrawl
scrape_res = self.firecrawl.scrape(
url,
formats=['markdown'],
only_main_content=True
)
# =====================================================
# 【核心修复】兼容字典和对象两种返回格式
# =====================================================
raw_md = ""
metadata = None
# 1. 提取 markdown 和 metadata 本体
if isinstance(scrape_res, dict):
raw_md = scrape_res.get('markdown', '')
metadata = scrape_res.get('metadata', {})
else:
# 如果是对象 (ScrapeResponse)
raw_md = getattr(scrape_res, 'markdown', '')
metadata = getattr(scrape_res, 'metadata', {})
# 2. 从 metadata 中安全提取 title
title = url # 默认用 url 当标题
if metadata:
if isinstance(metadata, dict):
# 如果 metadata 是字典
title = metadata.get('title', url)
else:
# 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方
title = getattr(metadata, 'title', url)
# =====================================================
if not raw_md:
print(f"[WARN] Empty content for {url}")
continue
# 1. 清洗
clean_md = text_processor.clean_markdown(raw_md)
if not clean_md: continue
# 2. 切分
chunks = text_processor.split_markdown(clean_md)
chunks_data = []
for i, chunk in enumerate(chunks):
content = chunk['content']
headers = chunk['metadata']
header_path = " > ".join(headers.values())
# 3. 向量化 (Title + Header + Content)
emb_input = f"{title}\n{header_path}\n{content}"
vector = llm_service.get_embedding(emb_input)
if vector:
chunks_data.append({
"index": i,
"content": content,
"embedding": vector,
"meta_info": {"header_path": header_path, "headers": headers}
})
# 4. 存储
if chunks_data:
data_service.save_chunks(task_id, url, title, chunks_data)
processed += 1
except Exception as e:
# 打印详细错误方便调试
print(f"[ERROR] Process URL {url} failed: {e}")
return {"msg": "Batch processed", "count": processed}
def search(self, query: str, task_id: int, limit: int):
vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []}
return data_service.search(vector, task_id, limit)
crawler_service = CrawlerService()

View File

@@ -1,230 +0,0 @@
from sqlalchemy import select, insert, update, and_
from ..database import db_instance
from ..utils import normalize_url
class CrawlerSqlService:
def __init__(self):
self.db = db_instance
def register_task(self, url: str):
"""完全使用库 API 实现的注册"""
clean_url = normalize_url(url)
result = {}
with self.db.engine.begin() as conn:
# 使用 select() API
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
result = {
"task_id": existing[0],
"is_new_task": False,
"msg": "Task already exists"
}
else:
# 使用 insert() API
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
result = {
"task_id": new_task[0],
"is_new_task": True,
"msg": "New task created successfully"
}
return result
def add_urls(self, task_id: int, urls: list[str]):
"""通用 API 实现的批量添加(含详细返回)"""
success_urls, skipped_urls, failed_urls = [], [], []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
# 检查队列中是否已存在该 URL
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if conn.execute(check_q).fetchone():
skipped_urls.append(clean_url)
continue
# 插入新 URL
conn.execute(insert(self.db.queue).values(
task_id=task_id, url=clean_url, status='pending'
))
success_urls.append(clean_url)
except Exception:
failed_urls.append(clean_url)
# 构造返回消息
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
return {
"success_urls": success_urls,
"skipped_urls": skipped_urls,
"failed_urls": failed_urls,
"msg": msg
}
def get_pending_urls(self, task_id: int, limit: int):
"""原子锁定 API 实现"""
result = {}
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
if urls:
upd = update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing')
conn.execute(upd)
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
else:
result = {"urls": [], "msg": "Queue is empty"}
return result
def save_results(self, task_id: int, results: list):
"""
保存同一 URL 的多个切片。
"""
if not results:
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
# 1. 基础信息提取
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
target_url = normalize_url(first_item.get('source_url'))
# 结果统计容器
inserted_chunks = []
updated_chunks = []
failed_chunks = []
is_page_update = False
with self.db.engine.begin() as conn:
# 2. 判断该 URL 是否已经有切片存在
check_page_stmt = select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
).limit(1)
if conn.execute(check_page_stmt).fetchone():
is_page_update = True
# 3. 逐个处理切片
for res in results:
data = res if isinstance(res, dict) else res.__dict__
c_idx = data.get('chunk_index')
try:
# 检查切片是否存在
find_chunk_stmt = select(self.db.chunks.c.id).where(
and_(
self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == target_url,
self.db.chunks.c.chunk_index == c_idx
)
)
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
if existing_chunk:
# 覆盖更新
upd_stmt = update(self.db.chunks).where(
self.db.chunks.c.id == existing_chunk[0]
).values(
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(upd_stmt)
updated_chunks.append(c_idx)
else:
# 插入新切片
ins_stmt = insert(self.db.chunks).values(
task_id=task_id,
source_url=target_url,
chunk_index=c_idx,
title=data.get('title'),
content=data.get('content'),
embedding=data.get('embedding')
)
conn.execute(ins_stmt)
inserted_chunks.append(c_idx)
except Exception as e:
print(f"Chunk {c_idx} failed: {e}")
failed_chunks.append(c_idx)
# 4. 更新队列状态
conn.execute(
update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
).values(status='completed')
)
# 构造返回
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
return {
"source_url": target_url,
"is_page_update": is_page_update,
"detail": {
"inserted_chunk_indexes": inserted_chunks,
"updated_chunk_indexes": updated_chunks,
"failed_chunk_indexes": failed_chunks
},
"counts": {
"inserted": len(inserted_chunks),
"updated": len(updated_chunks),
"failed": len(failed_chunks)
},
"msg": msg
}
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
"""
results = []
msg = ""
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.chunk_index
)
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
rows = conn.execute(stmt).fetchall()
for r in rows:
results.append({
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"chunk_index": r[4]
})
if results:
msg = f"Found {len(results)} matches"
else:
msg = "No matching content found"
return {"results": results, "msg": msg}
crawler_sql_service = CrawlerSqlService()

View File

@@ -0,0 +1,111 @@
from sqlalchemy import select, insert, update, and_
from backend.core.database import db
from backend.utils.common import normalize_url
class DataService:
"""
数据持久化服务层
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
"""
def __init__(self):
self.db = db
def register_task(self, url: str):
clean_url = normalize_url(url)
with self.db.engine.begin() as conn:
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
existing = conn.execute(query).fetchone()
if existing:
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
else:
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
new_task = conn.execute(stmt).fetchone()
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
def add_urls(self, task_id: int, urls: list[str]):
success_urls = []
with self.db.engine.begin() as conn:
for url in urls:
clean_url = normalize_url(url)
try:
check_q = select(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
)
if not conn.execute(check_q).fetchone():
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
success_urls.append(clean_url)
except Exception:
pass
return {"msg": f"Added {len(success_urls)} new urls"}
def get_pending_urls(self, task_id: int, limit: int):
with self.db.engine.begin() as conn:
query = select(self.db.queue.c.url).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
).limit(limit)
urls = [r[0] for r in conn.execute(query).fetchall()]
if urls:
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
).values(status='processing'))
return {"urls": urls, "msg": "Fetched pending urls"}
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
"""
保存切片数据 (Phase 1.5: 支持 meta_info)
"""
clean_url = normalize_url(source_url)
count = 0
with self.db.engine.begin() as conn:
for item in chunks_data:
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
idx = item['index']
meta = item.get('meta_info', {})
# 检查是否存在
existing = conn.execute(select(self.db.chunks.c.id).where(
and_(self.db.chunks.c.task_id == task_id,
self.db.chunks.c.source_url == clean_url,
self.db.chunks.c.chunk_index == idx)
)).fetchone()
values = {
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
"title": title, "content": item['content'], "embedding": item['embedding'],
"meta_info": meta
}
if existing:
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
else:
conn.execute(insert(self.db.chunks).values(**values))
count += 1
# 标记队列完成
conn.execute(update(self.db.queue).where(
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
).values(status='completed'))
return {"msg": f"Saved {count} chunks", "count": count}
def search(self, vector: list, task_id: int = None, limit: int = 5):
with self.db.engine.connect() as conn:
stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall()
results = [
{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]}
for r in rows
]
return {"results": results, "msg": f"Found {len(results)}"}
data_service = DataService()

View File

@@ -0,0 +1,30 @@
import dashscope
from http import HTTPStatus
from backend.core.config import settings
class LLMService:
"""
LLM 服务封装层
负责与 DashScope 或其他模型供应商交互
"""
def __init__(self):
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embedding(self, text: str, dimension: int = 1536):
"""生成文本向量"""
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4,
input=text,
dimension=dimension
)
if resp.status_code == HTTPStatus.OK:
return resp.output['embeddings'][0]['embedding']
else:
print(f"[ERROR] Embedding API Error: {resp}")
return None
except Exception as e:
print(f"[ERROR] Embedding Exception: {e}")
return None
llm_service = LLMService()

View File

@@ -1,39 +0,0 @@
from urllib.parse import urlparse, urlunparse
from sqlalchemy import create_engine, MetaData, Table, select, update, and_
# backend/llm_service.py
import dashscope
from http import HTTPStatus
from .config import settings
# 初始化 Dashscope
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embeddings(texts: list[str]):
"""调用通义千问 embedding 模型"""
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 或其他模型
input=texts
)
if resp.status_code == HTTPStatus.OK:
return [item['embedding'] for item in resp.output['embeddings']]
else:
print(f"Embedding Error: {resp}")
return []
def normalize_url(url: str) -> str:
if not url: return ""
url = url.strip()
parsed = urlparse(url)
scheme = parsed.scheme.lower()
netloc = parsed.netloc.lower()
path = parsed.path.rstrip('/')
if not path: path = ""
return urlunparse((scheme, netloc, path, parsed.params, parsed.query, ""))
def make_response(code: int, msg: str = "Success", data: any = None):
"""
统一响应格式
:param code: 1 成功, 0 失败, 其他自定义
:param msg: 提示信息
:param data: 返回数据
"""
return {"code": code, "msg": msg, "data": data}

26
backend/utils/common.py Normal file
View File

@@ -0,0 +1,26 @@
from urllib.parse import urlparse, urlunparse
def make_response(code: int, msg: str = "Success", data: any = None):
"""统一 API 响应格式封装"""
return {"code": code, "msg": msg, "data": data}
def normalize_url(url: str) -> str:
"""
URL 标准化处理
1. 去除首尾空格
2. 移除 fragment (#后面的内容)
3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面)
4. 移除尾部斜杠
"""
if not url:
return ""
parsed = urlparse(url.strip())
# 重新组合scheme, netloc, path, params, query, fragment
# 这里我们只保留 scheme, netloc, path
clean_path = parsed.path.rstrip('/')
# 构造新的 parsed 对象 (param, query, fragment 置空)
new_parsed = parsed._replace(path=clean_path, params='', query='', fragment='')
return urlunparse(new_parsed)

View File

@@ -0,0 +1,61 @@
import re
from langchain_text_splitters import RecursiveCharacterTextSplitter, MarkdownHeaderTextSplitter
class TextProcessor:
"""文本处理工具类:负责 Markdown 清洗和切分"""
def __init__(self):
# 基于 Markdown 标题的语义切分器
self.md_splitter = MarkdownHeaderTextSplitter(
headers_to_split_on=[
("#", "h1"),
("##", "h2"),
("###", "h3"),
],
strip_headers=False
)
# 备用的字符切分器
self.char_splitter = RecursiveCharacterTextSplitter(
chunk_size=800,
chunk_overlap=100,
separators=["\n\n", "\n", "", "", "", " ", ""]
)
def clean_markdown(self, text: str) -> str:
"""清洗 Markdown 中的网页噪音"""
if not text: return ""
# 去除 'Skip to main content'
text = re.sub(r'\[Skip to main content\].*?\n', '', text, flags=re.IGNORECASE)
# 去除页脚导航 (Previous / Next)
text = re.sub(r'\[Previous\].*?\[Next\].*', '', text, flags=re.DOTALL | re.IGNORECASE)
return text.strip()
def split_markdown(self, text: str):
"""执行切分策略:先按标题切,过长则按字符切"""
md_chunks = self.md_splitter.split_text(text)
final_chunks = []
for chunk in md_chunks:
# chunk.page_content 是文本
# chunk.metadata 是标题层级
if len(chunk.page_content) > 1000:
sub_texts = self.char_splitter.split_text(chunk.page_content)
for sub in sub_texts:
final_chunks.append({
"content": sub,
"metadata": chunk.metadata
})
else:
final_chunks.append({
"content": chunk.page_content,
"metadata": chunk.metadata
})
return final_chunks
# 单例工具
text_processor = TextProcessor()