直接新建后端把所有逻辑做完了
This commit is contained in:
@@ -8,9 +8,19 @@ class Settings:
|
||||
DB_PORT: str = "25432"
|
||||
DB_NAME: str = "wiki_crawler"
|
||||
|
||||
@property
|
||||
DASHSCOPE_API_KEY: str = "sk-8b091493de594c5e9eb42f12f1cc5805"
|
||||
FIRECRAWL_API_KEY: str = "fc-8a2af3fb6a014a27a57dfbc728cb7365"
|
||||
@property # property 方法,意义:将方法转换为属性,调用时不需要加括号
|
||||
def DATABASE_URL(self) -> str:
|
||||
url = f"postgresql+psycopg2://{self.DB_USER}:{self.DB_PASS}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
|
||||
return url
|
||||
|
||||
def API_KEY(self, type: str) -> str:
|
||||
if type == "dashscope":
|
||||
return self.DASHSCOPE_API_KEY
|
||||
elif type == "firecrawl":
|
||||
return self.FIRECRAWL_API_KEY
|
||||
else:
|
||||
raise ValueError(f"Unknown API type: {type}")
|
||||
|
||||
settings = Settings()
|
||||
116
backend/main.py
116
backend/main.py
@@ -1,79 +1,119 @@
|
||||
from fastapi import FastAPI
|
||||
from .service import crawler_service
|
||||
from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest
|
||||
# backend/main.py
|
||||
from fastapi import FastAPI, APIRouter, BackgroundTasks
|
||||
from .service import crawler_sql_service
|
||||
from .workflow import workflow
|
||||
from .schemas import (
|
||||
RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest,
|
||||
AutoMapRequest, AutoProcessRequest, TextSearchRequest
|
||||
)
|
||||
from .utils import make_response
|
||||
|
||||
app = FastAPI(title="Wiki Crawler API")
|
||||
|
||||
@app.post("/register")
|
||||
# ==========================================
|
||||
# V1 Router: 原始的底层接口 (Manual Control)
|
||||
# ==========================================
|
||||
router_v1 = APIRouter()
|
||||
|
||||
@router_v1.post("/register")
|
||||
async def register(req: RegisterRequest):
|
||||
try:
|
||||
data = crawler_service.register_task(req.url)
|
||||
data = crawler_sql_service.register_task(req.url)
|
||||
return make_response(1, "Success", data)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@app.post("/add_urls")
|
||||
@router_v1.post("/add_urls")
|
||||
async def add_urls(req: AddUrlsRequest):
|
||||
try:
|
||||
urls = req.urls_obj["urls"]
|
||||
data = crawler_service.add_urls(req.task_id, urls=urls)
|
||||
data = crawler_sql_service.add_urls(req.task_id, urls=urls)
|
||||
return make_response(1, "Success", data)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@app.post("/pending_urls")
|
||||
@router_v1.post("/pending_urls")
|
||||
async def pending_urls(req: PendingRequest):
|
||||
try:
|
||||
data = crawler_service.get_pending_urls(req.task_id, req.limit)
|
||||
data = crawler_sql_service.get_pending_urls(req.task_id, req.limit)
|
||||
msg = "Success" if data["urls"] else "Queue Empty"
|
||||
return make_response(1, msg, data)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@app.post("/save_results")
|
||||
@router_v1.post("/save_results")
|
||||
async def save_results(req: SaveResultsRequest):
|
||||
try:
|
||||
data = crawler_service.save_results(req.task_id, req.results)
|
||||
data = crawler_sql_service.save_results(req.task_id, req.results)
|
||||
return make_response(1, "Success", data)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@app.post("/search")
|
||||
async def search(req: SearchRequest):
|
||||
"""
|
||||
通用搜索接口:
|
||||
支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。
|
||||
"""
|
||||
@router_v1.post("/search")
|
||||
async def search_v1(req: SearchRequest):
|
||||
"""V1 搜索:需要客户端自己传向量"""
|
||||
try:
|
||||
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536)
|
||||
vector = req.query_embedding['vector']
|
||||
# 注意:这里需要确认你数据库的向量维度。TextEmbedding V3 可能是 1024,V2 是 1536。
|
||||
# 请根据你的 PGVector 设置进行匹配。
|
||||
if not vector:
|
||||
return make_response(2, "Vector is empty", None)
|
||||
|
||||
if not vector or len(vector) != 1536:
|
||||
return make_response(
|
||||
code=2,
|
||||
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
|
||||
data=None
|
||||
)
|
||||
|
||||
# 2. 调用业务类执行搜索
|
||||
data = crawler_service.search_knowledge(
|
||||
data = crawler_sql_service.search_knowledge(
|
||||
query_embedding=vector,
|
||||
task_id=req.task_id,
|
||||
limit=req.limit
|
||||
)
|
||||
|
||||
# 3. 统一返回
|
||||
return make_response(
|
||||
code=1,
|
||||
msg="搜索完成",
|
||||
data=data
|
||||
)
|
||||
|
||||
return make_response(1, "Search Done", data)
|
||||
except Exception as e:
|
||||
# 记录日志并返回失败信息
|
||||
print(f"搜索接口异常: {str(e)}")
|
||||
return make_response(code=0, msg=f"搜索失败: {str(e)}")
|
||||
return make_response(0, str(e))
|
||||
|
||||
|
||||
# ==========================================
|
||||
# V2 Router: 自动化工作流 (Automated Workflow)
|
||||
# ==========================================
|
||||
router_v2 = APIRouter()
|
||||
|
||||
@router_v2.post("/auto/map")
|
||||
async def auto_map(req: AutoMapRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
[异步] 输入首页 URL,自动调用 Firecrawl Map 并入库
|
||||
"""
|
||||
# 也可以放入 background_tasks,但 map 通常比较快,这里演示同步返回任务ID
|
||||
try:
|
||||
# 为了不阻塞主线程,如果 map 很慢,建议放入 background_tasks
|
||||
# 这里为了能立刻看到 task_id,先同步调用 (Firecrawl Map 比较快)
|
||||
data = workflow.map_and_ingest(req.url)
|
||||
return make_response(1, "Mapping Started", data)
|
||||
except Exception as e:
|
||||
return make_response(0, str(e))
|
||||
|
||||
@router_v2.post("/auto/process")
|
||||
async def auto_process(req: AutoProcessRequest, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
[异步] 触发后台任务:消费队列 -> 抓取 -> Embedding -> 入库
|
||||
"""
|
||||
# 将耗时操作放入后台任务
|
||||
background_tasks.add_task(workflow.process_task_queue, req.task_id, req.batch_size)
|
||||
return make_response(1, "Processing started in background", {"task_id": req.task_id})
|
||||
|
||||
@router_v2.post("/search")
|
||||
async def search_v2(req: TextSearchRequest):
|
||||
"""
|
||||
[智能] 输入自然语言文本 -> 后端转向量 -> 搜索
|
||||
"""
|
||||
try:
|
||||
data = workflow.search_with_embedding(req.query, req.task_id, req.limit)
|
||||
return make_response(1, "Search Success", data)
|
||||
except Exception as e:
|
||||
return make_response(0, f"Search Failed: {str(e)}")
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 挂载路由
|
||||
# ==========================================
|
||||
app.include_router(router_v1, prefix="/api/v1", tags=["V1 Manual API"])
|
||||
app.include_router(router_v2, prefix="/api/v2", tags=["V2 Automated Workflow"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from typing import Optional, List, Any
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
url: str
|
||||
@@ -28,4 +29,22 @@ class SearchRequest(BaseModel):
|
||||
# 如果不传 task_id,则进行全库搜索
|
||||
task_id: Optional[int] = None
|
||||
query_embedding: dict
|
||||
limit: Optional[int] = 5
|
||||
|
||||
|
||||
|
||||
|
||||
# ... (保留原有的 Schema: RegisterRequest, AddUrlsRequest 等) ...
|
||||
|
||||
# === V2 New Schemas ===
|
||||
class AutoMapRequest(BaseModel):
|
||||
url: str
|
||||
|
||||
class AutoProcessRequest(BaseModel):
|
||||
task_id: int
|
||||
batch_size: Optional[int] = 5
|
||||
|
||||
class TextSearchRequest(BaseModel):
|
||||
query: str # 用户直接传文字,不需要传向量了
|
||||
task_id: Optional[int] = None
|
||||
limit: Optional[int] = 5
|
||||
@@ -3,7 +3,7 @@ from sqlalchemy import select, insert, update, delete, and_
|
||||
from .database import db_instance
|
||||
from .utils import normalize_url
|
||||
|
||||
class CrawlerService:
|
||||
class CrawlerSqlService:
|
||||
def __init__(self):
|
||||
self.db = db_instance
|
||||
|
||||
@@ -202,4 +202,4 @@ class CrawlerService:
|
||||
return results
|
||||
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
crawler_sql_service = CrawlerSqlService()
|
||||
@@ -1,6 +1,24 @@
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
from sqlalchemy import create_engine, MetaData, Table, select, update, and_
|
||||
# backend/llm_service.py
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from .config import settings
|
||||
|
||||
# 初始化 Dashscope
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
def get_embeddings(texts: list[str]):
|
||||
"""调用通义千问 embedding 模型"""
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 或其他模型
|
||||
input=texts
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return [item['embedding'] for item in resp.output['embeddings']]
|
||||
else:
|
||||
print(f"Embedding Error: {resp}")
|
||||
return []
|
||||
def normalize_url(url: str) -> str:
|
||||
if not url: return ""
|
||||
url = url.strip()
|
||||
|
||||
140
backend/workflow.py
Normal file
140
backend/workflow.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# backend/workflow.py
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from .config import settings
|
||||
from .service import crawler_sql_service
|
||||
|
||||
# 初始化配置
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
class AutomatedCrawler:
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
# 文本切分器:每块500字符,重叠100字符,保证语义连贯
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||
)
|
||||
|
||||
def _get_embedding(self, text: str):
|
||||
"""调用 Dashscope 生成向量"""
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v4,
|
||||
input=text,
|
||||
dimension=1536
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
# 阿里 text-embedding-v3 默认维度是 1024,v2 是 1536
|
||||
# 如果你的数据库是 1536 维,请使用 text_embedding_v2 或调整参数
|
||||
return resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"Embedding API Error: {resp}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Embedding Exception: {e}")
|
||||
return None
|
||||
|
||||
def map_and_ingest(self, start_url: str):
|
||||
"""
|
||||
V2 步骤1: 地图式扫描并入库
|
||||
"""
|
||||
print(f"[WorkFlow] Start mapping: {start_url}")
|
||||
|
||||
# 1. 在数据库注册任务
|
||||
task_info = crawler_sql_service.register_task(start_url)
|
||||
task_id = task_info['task_id']
|
||||
|
||||
try:
|
||||
# 2. 调用 Firecrawl Map
|
||||
map_result = self.firecrawl.map(start_url)
|
||||
|
||||
urls=[]
|
||||
for link in map_result.links:
|
||||
urls.append(link.url)
|
||||
|
||||
print(f"[WorkFlow] Found {len(urls)} links")
|
||||
|
||||
# 3. 批量入库 (状态为 pending)
|
||||
if urls:
|
||||
res = crawler_sql_service.add_urls(task_id, urls)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"map_result": res
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Map Error: {e}")
|
||||
raise e
|
||||
|
||||
def process_task_queue(self, task_id: int, limit: int = 10):
|
||||
"""
|
||||
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
||||
"""
|
||||
# 1. 获取待处理 URL
|
||||
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
||||
urls = pending['urls']
|
||||
|
||||
if not urls:
|
||||
return {"msg": "No pending urls", "count": 0}
|
||||
|
||||
processed_count = 0
|
||||
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[WorkFlow] Processing: {url}")
|
||||
# 2. 单页抓取 (Markdown)
|
||||
scrape_res = self.firecrawl.scrape(url, formats=['markdown'], only_main_content=True)
|
||||
|
||||
content = scrape_res.markdown
|
||||
|
||||
metadata = scrape_res.metadata_dict
|
||||
title = metadata.get('title', url)
|
||||
|
||||
if not content:
|
||||
print(f"[WorkFlow] Skip empty content: {url}")
|
||||
continue
|
||||
|
||||
# 3. 切片
|
||||
chunks = self.splitter.split_text(content)
|
||||
|
||||
results_to_save = []
|
||||
|
||||
# 4. 逐个切片向量化 (可以优化为批量调用 embedding API)
|
||||
for idx, chunk_text in enumerate(chunks):
|
||||
vector = self._get_embedding(chunk_text)
|
||||
if vector:
|
||||
results_to_save.append({
|
||||
"source_url": url,
|
||||
"chunk_index": idx,
|
||||
"title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": vector
|
||||
})
|
||||
|
||||
# 5. 保存结果
|
||||
if results_to_save:
|
||||
crawler_sql_service.save_results(task_id, results_to_save)
|
||||
processed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Error processing {url}: {e}")
|
||||
# 实际生产中这里应该把 URL 状态改为 'failed'
|
||||
|
||||
return {"processed_urls": processed_count, "total_chunks": "dynamic"}
|
||||
|
||||
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
||||
"""
|
||||
vector = self._get_embedding(query_text)
|
||||
if not vector:
|
||||
raise Exception("Failed to generate embedding for query")
|
||||
|
||||
return crawler_sql_service.search_knowledge(vector, task_id, limit)
|
||||
|
||||
# 单例模式
|
||||
workflow = AutomatedCrawler()
|
||||
Reference in New Issue
Block a user