mcp调试完成
This commit is contained in:
@@ -6,6 +6,15 @@
|
|||||||
|
|
||||||
完成wiki网页爬取和向量化与知识库查找
|
完成wiki网页爬取和向量化与知识库查找
|
||||||
|
|
||||||
|
mcp调试命令
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npx @modelcontextprotocol/inspector uv run backend/mcp_server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
需要nodejs环境
|
||||||
|
打开页面后在Environment Variables中添加`PYTHONIOENCODING = utf-8`来防止编码问题(视具体情况而定,如果可以正常运行,也可以不加)
|
||||||
|
|
||||||
## 当前状况
|
## 当前状况
|
||||||
|
|
||||||
1. chunk分段逻辑:根据返回的markdown进行分割,按照#、##进行标题的分类,增加JSONB格式字段meta_info,有下面两个字段,分别可以用于数据库查询和LLM上下文认知资料来源
|
1. chunk分段逻辑:根据返回的markdown进行分割,按照#、##进行标题的分类,增加JSONB格式字段meta_info,有下面两个字段,分别可以用于数据库查询和LLM上下文认知资料来源
|
||||||
|
|||||||
@@ -1,27 +1,26 @@
|
|||||||
|
import os
|
||||||
|
from typing import ClassVar # <--- 1. 导入这个
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
import logging
|
|
||||||
|
|
||||||
# 获取当前模块的专用 Logger
|
|
||||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
"""
|
|
||||||
系统配置类
|
|
||||||
自动读取环境变量或 .env 文件
|
|
||||||
"""
|
|
||||||
CANDIDATE_NUM: int = 10
|
|
||||||
|
|
||||||
DB_USER: str
|
DB_USER: str
|
||||||
DB_PASS: str
|
DB_PASS: str
|
||||||
DB_HOST: str
|
DB_HOST: str
|
||||||
DB_PORT: str = "5432"
|
DB_PORT: str = "5432"
|
||||||
DB_NAME: str
|
DB_NAME: str
|
||||||
|
|
||||||
DASHSCOPE_API_KEY: str
|
DASHSCOPE_API_KEY: str
|
||||||
FIRECRAWL_API_KEY: str
|
FIRECRAWL_API_KEY: str
|
||||||
|
|
||||||
# 配置:忽略多余的环境变量,指定编码
|
CANDIDATE_NUM: int = 50
|
||||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore", env_file_encoding='utf-8')
|
|
||||||
|
# =========================================================
|
||||||
|
# 【核心修复】加上 ClassVar 类型注解
|
||||||
|
# =========================================================
|
||||||
|
BASE_DIR: ClassVar[str] = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
ENV_PATH: ClassVar[str] = os.path.join(BASE_DIR, ".env")
|
||||||
|
|
||||||
|
# 使用绝对路径加载
|
||||||
|
model_config = SettingsConfigDict(env_file=ENV_PATH, extra="ignore")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def DATABASE_URL(self) -> str:
|
def DATABASE_URL(self) -> str:
|
||||||
|
|||||||
@@ -1,64 +1,145 @@
|
|||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import logging
|
||||||
# 路径兼容
|
from typing import Optional # 确保引入 Optional
|
||||||
|
import threading
|
||||||
|
# 1. 路径兼容 (确保能找到 backend 包)
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
from backend.core.logger import setup_logging
|
|
||||||
|
|
||||||
# 初始化日志 (写在 FastMCP 初始化之前)
|
|
||||||
setup_logging()
|
|
||||||
from mcp.server.fastmcp import FastMCP
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from backend.core.logger import setup_logging
|
||||||
from backend.services.crawler_service import crawler_service
|
from backend.services.crawler_service import crawler_service
|
||||||
|
|
||||||
|
# 2. 初始化日志 (必须走 stderr)
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger("mcp_server")
|
||||||
|
|
||||||
|
# 3. 初始化 MCP 服务
|
||||||
mcp = FastMCP("WikiCrawler-V3")
|
mcp = FastMCP("WikiCrawler-V3")
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def kb_add_website(url: str) -> str:
|
async def kb_add_website(url: str) -> str:
|
||||||
"""[Admin] Add a website map task."""
|
"""
|
||||||
|
[Admin] Input a URL to map and register a task.
|
||||||
|
This is the first step to add a knowledge base.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The root URL of the website (e.g., https://docs.firecrawl.dev).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Task ID and count of found links.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
res = crawler_service.map_site(url)
|
res = crawler_service.map_site(url)
|
||||||
return f"Task Registered. ID: {res['task_id']}, Links Found: {res['count']}"
|
return f"Task Registered. ID: {res['task_id']}, Links Found: {res['count']}, Is New: {res['is_new']}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error(f"Add website failed: {e}", exc_info=True)
|
||||||
return f"Error: {e}"
|
return f"Error: {e}"
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def kb_check_status(task_id: int) -> str:
|
async def kb_check_status(task_id: int) -> str:
|
||||||
"""[Monitor] Check detailed progress and active threads."""
|
"""
|
||||||
|
[Monitor] Check detailed progress and active threads.
|
||||||
|
Use this to see if the crawler is still running or finished.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task to check.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A formatted report including progress stats and currently crawling URLs.
|
||||||
|
"""
|
||||||
data = crawler_service.get_task_status(task_id)
|
data = crawler_service.get_task_status(task_id)
|
||||||
if not data: return "Task not found."
|
if not data: return "Task not found."
|
||||||
|
|
||||||
s = data['stats']
|
s = data['stats']
|
||||||
threads = data['active_threads']
|
threads = data['active_threads']
|
||||||
|
|
||||||
|
# 格式化输出给 LLM 阅读
|
||||||
report = (
|
report = (
|
||||||
|
f"--- Task {task_id} Status ---\n"
|
||||||
|
f"Root URL: {data['root_url']}\n"
|
||||||
f"Progress: {s['completed']}/{s['total']} (Pending: {s['pending']})\n"
|
f"Progress: {s['completed']}/{s['total']} (Pending: {s['pending']})\n"
|
||||||
f"Active Threads: {len(threads)}\n"
|
f"Active Threads (Running): {len(threads)}\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if threads:
|
if threads:
|
||||||
report += "Currently Crawling:\n" + "\n".join([f"- {t}" for t in threads[:5]])
|
report += "Currently Crawling:\n" + "\n".join([f"- {t}" for t in threads[:5]])
|
||||||
|
if len(threads) > 5:
|
||||||
|
report += f"\n... and {len(threads)-5} more."
|
||||||
|
|
||||||
return report
|
return report
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def kb_run_crawler(task_id: int, batch_size: int = 5) -> str:
|
async def kb_run_crawler(task_id: int, batch_size: int = 20) -> str:
|
||||||
"""[Action] Trigger crawler batch."""
|
"""
|
||||||
# MCP 同步调用以获得反馈
|
[Action] Trigger the crawler in BACKGROUND mode.
|
||||||
res = crawler_service.process_queue_concurrent(task_id, batch_size)
|
This returns immediately, so you can use 'kb_check_status' to monitor progress.
|
||||||
return f"Batch Finished. Count: {res.get('count', 0)}"
|
|
||||||
|
Args:
|
||||||
|
task_id: The ID of the task.
|
||||||
|
batch_size: Number of URLs to process (suggest 10-20).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Status message confirming start.
|
||||||
|
"""
|
||||||
|
# 定义一个在后台跑的包装函数
|
||||||
|
def background_task():
|
||||||
|
try:
|
||||||
|
logger.info(f"Background batch started for Task {task_id}")
|
||||||
|
# 这里是阻塞操作,但它现在跑在独立线程里
|
||||||
|
crawler_service.process_queue_concurrent(task_id, batch_size)
|
||||||
|
logger.info(f"Background batch finished for Task {task_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Background task failed: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 2. 创建并启动线程
|
||||||
|
thread = threading.Thread(target=background_task)
|
||||||
|
thread.daemon = True # 设置为守护线程,防止主程序退出时卡死
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# 3. 立即返回,不等待爬取结束
|
||||||
|
return f"🚀 Background crawler started for Task {task_id} (Batch Size: {batch_size}). You can now check status."
|
||||||
|
|
||||||
|
|
||||||
@mcp.tool()
|
@mcp.tool()
|
||||||
async def kb_search(query: str, task_id: int = None) -> str:
|
async def kb_search(query: str, task_id: Optional[int] = None, limit: int = 5) -> str:
|
||||||
"""[User] Search knowledge base."""
|
"""
|
||||||
res = crawler_service.search(query, task_id, 5)
|
[User] Search knowledge base with Hybrid Search & Rerank.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The user's question or search keywords.
|
||||||
|
task_id: (Optional) Limit search to a specific task ID.
|
||||||
|
limit: (Optional) Number of results to return (default 5).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Ranked content blocks with source paths.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
res = crawler_service.search(query, task_id, limit)
|
||||||
results = res.get('results', [])
|
results = res.get('results', [])
|
||||||
if not results: return "No results."
|
|
||||||
|
if not results: return "No results found."
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
for i, r in enumerate(results):
|
for i, r in enumerate(results):
|
||||||
score_display = f"{r['score']:.4f}" + (" (Reranked)" if r.get('reranked') else "")
|
score_display = f"{r['score']:.4f}" + (" (Reranked)" if r.get('reranked') else "")
|
||||||
meta = r.get('meta_info', {})
|
meta = r.get('meta_info', {})
|
||||||
path = meta.get('header_path', 'Root')
|
path = meta.get('header_path', 'Root')
|
||||||
output.append(f"[{i+1}] Score: {score_display}\nPath: {path}\nContent: {r['content'][:200]}...")
|
|
||||||
|
# 格式化单个结果块
|
||||||
|
block = (
|
||||||
|
f"[{i+1}] Score: {score_display}\n"
|
||||||
|
f"Path: {path}\n"
|
||||||
|
f"Content: {r['content'][:300]}..." # 限制长度防止 Context 溢出
|
||||||
|
)
|
||||||
|
output.append(block)
|
||||||
|
|
||||||
return "\n\n".join(output)
|
return "\n\n".join(output)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Search failed: {e}", exc_info=True)
|
||||||
|
return f"Search Error: {e}"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# 启动 MCP 服务
|
||||||
mcp.run()
|
mcp.run()
|
||||||
@@ -1,41 +1,64 @@
|
|||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import threading
|
import threading
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Optional, Union
|
||||||
|
|
||||||
from firecrawl import FirecrawlApp
|
from firecrawl import FirecrawlApp
|
||||||
from backend.core.config import settings
|
from backend.core.config import settings
|
||||||
from backend.services.data_service import data_service
|
from backend.services.data_service import data_service
|
||||||
from backend.services.llm_service import llm_service
|
from backend.services.llm_service import llm_service
|
||||||
from backend.utils.text_process import text_processor
|
from backend.utils.text_process import text_processor
|
||||||
import logging
|
|
||||||
|
|
||||||
# 获取当前模块的专用 Logger
|
# 获取当前模块的专用 Logger
|
||||||
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class CrawlerService:
|
class CrawlerService:
|
||||||
|
"""
|
||||||
|
爬虫业务服务层 (Crawler Service)
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 协调外部 API (Firecrawl) 和内部服务 (DataService, LLMService)。
|
||||||
|
2. 管理多线程爬取任务及其状态。
|
||||||
|
3. 提供统一的搜索入口 (混合检索 + Rerank)。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||||
self.max_workers = 5
|
self.max_workers = 5 # 线程池最大并发数
|
||||||
|
|
||||||
# [新增] 内存状态追踪
|
# 内存状态追踪: { task_id: set([url1, url2]) }
|
||||||
# 结构: { task_id: { url: "status_desc" } }
|
self._active_workers: Dict[int, set] = {}
|
||||||
self._active_workers = {}
|
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
def _track_start(self, task_id, url):
|
def _track_start(self, task_id: int, url: str):
|
||||||
"""开始追踪某个URL"""
|
"""[Internal] 标记某个URL开始处理"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if task_id not in self._active_workers:
|
if task_id not in self._active_workers:
|
||||||
self._active_workers[task_id] = set()
|
self._active_workers[task_id] = set()
|
||||||
self._active_workers[task_id].add(url)
|
self._active_workers[task_id].add(url)
|
||||||
|
|
||||||
def _track_end(self, task_id, url):
|
def _track_end(self, task_id: int, url: str):
|
||||||
"""结束追踪某个URL"""
|
"""[Internal] 标记某个URL处理结束"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if task_id in self._active_workers:
|
if task_id in self._active_workers:
|
||||||
self._active_workers[task_id].discard(url)
|
self._active_workers[task_id].discard(url)
|
||||||
|
|
||||||
def get_task_status(self, task_id: int):
|
def get_task_status(self, task_id: int) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
[综合监控] 获取全量状态 = 数据库统计 + 实时线程列表
|
获取任务的实时综合状态。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (int): 任务 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含数据库统计和实时线程信息的字典。如果任务不存在返回 None。
|
||||||
|
结构示例:
|
||||||
|
{
|
||||||
|
"root_url": "https://example.com",
|
||||||
|
"stats": {"pending": 10, "processing": 2, "completed": 5, "failed": 0},
|
||||||
|
"active_threads": ["https://example.com/page1"],
|
||||||
|
"active_thread_count": 1
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
# 1. 获取数据库层面的统计 (宏观)
|
# 1. 获取数据库层面的统计 (宏观)
|
||||||
db_data = data_service.get_task_monitor_data(task_id)
|
db_data = data_service.get_task_monitor_data(task_id)
|
||||||
@@ -45,19 +68,33 @@ class CrawlerService:
|
|||||||
# 2. 获取内存层面的活跃线程 (微观)
|
# 2. 获取内存层面的活跃线程 (微观)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
active_urls = list(self._active_workers.get(task_id, []))
|
active_urls = list(self._active_workers.get(task_id, []))
|
||||||
# 输出情况
|
|
||||||
|
# 日志输出当前状态
|
||||||
logger.info(f"Task {task_id} active threads: {active_urls}")
|
logger.info(f"Task {task_id} active threads: {active_urls}")
|
||||||
logger.info(f"Task {task_id} stats: {db_data['db_stats']}") # 打印数据库统计信息
|
logger.info(f"Task {task_id} stats: {db_data['db_stats']}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"root_url": db_data["root_url"],
|
"root_url": db_data["root_url"],
|
||||||
"stats": db_data["db_stats"], # Pending, Completed, Failed 等
|
"stats": db_data["db_stats"],
|
||||||
"active_threads": active_urls, # 当前 CPU/网络 正在处理的 URL
|
"active_threads": active_urls,
|
||||||
"active_thread_count": len(active_urls)
|
"active_thread_count": len(active_urls)
|
||||||
}
|
}
|
||||||
|
|
||||||
def map_site(self, start_url: str):
|
def map_site(self, start_url: str) -> Dict[str, Any]:
|
||||||
"""阶段1:站点地图扫描"""
|
"""
|
||||||
|
第一阶段:站点地图扫描 (Map)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
start_url (str): 目标网站的根 URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含任务 ID 和发现链接数的字典。
|
||||||
|
{
|
||||||
|
"task_id": 123,
|
||||||
|
"count": 50,
|
||||||
|
"is_new": True
|
||||||
|
}
|
||||||
|
"""
|
||||||
logger.info(f"Mapping: {start_url}")
|
logger.info(f"Mapping: {start_url}")
|
||||||
try:
|
try:
|
||||||
task_res = data_service.register_task(start_url)
|
task_res = data_service.register_task(start_url)
|
||||||
@@ -75,7 +112,9 @@ class CrawlerService:
|
|||||||
# 新任务执行 Map
|
# 新任务执行 Map
|
||||||
try:
|
try:
|
||||||
map_res = self.firecrawl.map(start_url)
|
map_res = self.firecrawl.map(start_url)
|
||||||
|
# 兼容不同版本的 SDK 返回结构
|
||||||
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
|
found_links = map_res.get('links', []) if isinstance(map_res, dict) else getattr(map_res, 'links', [])
|
||||||
|
|
||||||
for link in found_links:
|
for link in found_links:
|
||||||
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
||||||
urls_to_add.append(u)
|
urls_to_add.append(u)
|
||||||
@@ -96,7 +135,7 @@ class CrawlerService:
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
def _process_single_url(self, task_id: int, url: str):
|
def _process_single_url(self, task_id: int, url: str):
|
||||||
"""[Worker] 单个 URL 处理线程"""
|
"""[Internal Worker] 单个 URL 处理线程逻辑"""
|
||||||
# 1. 内存标记:开始
|
# 1. 内存标记:开始
|
||||||
self._track_start(task_id, url)
|
self._track_start(task_id, url)
|
||||||
logger.info(f"[THREAD START] {url}")
|
logger.info(f"[THREAD START] {url}")
|
||||||
@@ -107,6 +146,7 @@ class CrawlerService:
|
|||||||
url, formats=['markdown'], only_main_content=True
|
url, formats=['markdown'], only_main_content=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 兼容性提取
|
||||||
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
|
raw_md = getattr(scrape_res, 'markdown', '') if not isinstance(scrape_res, dict) else scrape_res.get('markdown', '')
|
||||||
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
|
metadata = getattr(scrape_res, 'metadata', {}) if not isinstance(scrape_res, dict) else scrape_res.get('metadata', {})
|
||||||
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
|
title = getattr(metadata, 'title', url) if not isinstance(metadata, dict) else metadata.get('title', url)
|
||||||
@@ -124,6 +164,7 @@ class CrawlerService:
|
|||||||
headers = chunk['metadata']
|
headers = chunk['metadata']
|
||||||
path = " > ".join(headers.values())
|
path = " > ".join(headers.values())
|
||||||
emb_input = f"{title}\n{path}\n{chunk['content']}"
|
emb_input = f"{title}\n{path}\n{chunk['content']}"
|
||||||
|
|
||||||
vector = llm_service.get_embedding(emb_input)
|
vector = llm_service.get_embedding(emb_input)
|
||||||
if vector:
|
if vector:
|
||||||
chunks_data.append({
|
chunks_data.append({
|
||||||
@@ -145,34 +186,72 @@ class CrawlerService:
|
|||||||
# 5. 内存标记:结束 (无论成功失败都要移除)
|
# 5. 内存标记:结束 (无论成功失败都要移除)
|
||||||
self._track_end(task_id, url)
|
self._track_end(task_id, url)
|
||||||
|
|
||||||
def process_queue_concurrent(self, task_id: int, batch_size: int = 10):
|
def process_queue_concurrent(self, task_id: int, batch_size: int = 10) -> Dict[str, Any]:
|
||||||
"""阶段2:多线程并发处理"""
|
"""
|
||||||
|
第二阶段:多线程并发处理 (Process)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_id (int): 任务 ID
|
||||||
|
batch_size (int): 本次批次处理的 URL 数量(会分配给线程池并发执行)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 处理结果概览
|
||||||
|
{
|
||||||
|
"msg": "Batch completed",
|
||||||
|
"count": 10
|
||||||
|
}
|
||||||
|
"""
|
||||||
urls = data_service.get_pending_urls(task_id, limit=batch_size)
|
urls = data_service.get_pending_urls(task_id, limit=batch_size)
|
||||||
if not urls: return {"msg": "No pending urls"}
|
if not urls: return {"msg": "No pending urls", "count": 0}
|
||||||
|
|
||||||
logger.info(f"Batch started: {len(urls)} urls")
|
logger.info(f"Batch started: {len(urls)} urls")
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||||
# 提交任务到线程池
|
# 提交任务到线程池
|
||||||
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
|
futures = {executor.submit(self._process_single_url, task_id, url): url for url in urls}
|
||||||
# 等待完成
|
# 等待完成 (阻塞直到所有线程结束)
|
||||||
concurrent.futures.wait(futures)
|
concurrent.futures.wait(futures)
|
||||||
|
|
||||||
return {"msg": "Batch completed", "count": len(urls)}
|
return {"msg": "Batch completed", "count": len(urls)}
|
||||||
|
|
||||||
def search(self, query: str, task_id, return_num: int):
|
def search(self, query: str, task_id: Optional[int], return_num: int) -> Dict[str, Any]:
|
||||||
"""阶段3:搜索"""
|
"""
|
||||||
|
第三阶段:智能搜索 (Search)
|
||||||
|
流程:用户问题 -> Embedding -> 数据库混合检索(粗排) -> Rerank模型(精排) -> 结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): 用户问题
|
||||||
|
task_id (Optional[int]): 指定搜索的任务 ID,None 为全库搜索
|
||||||
|
return_num (int): 最终返回给用户的条数 (Top K)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 搜索结果列表
|
||||||
|
{
|
||||||
|
"results": [
|
||||||
|
{"content": "...", "score": 0.98, "meta_info": {...}},
|
||||||
|
...
|
||||||
|
]
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# 1. 生成向量
|
||||||
vector = llm_service.get_embedding(query)
|
vector = llm_service.get_embedding(query)
|
||||||
if not vector: return {"msg": "Embedding failed", "results": []}
|
if not vector: return {"msg": "Embedding failed", "results": []}
|
||||||
|
|
||||||
|
# 2. 数据库粗排 (召回 10 倍数量或至少 50 条)
|
||||||
coarse_limit = min(return_num * 10, 100)
|
coarse_limit = min(return_num * 10, 100)
|
||||||
coarse_limit = max(coarse_limit, 50)
|
coarse_limit = max(coarse_limit, 50)
|
||||||
|
|
||||||
coarse_res = data_service.search(query, vector, task_id, coarse_limit)
|
coarse_res = data_service.search(
|
||||||
|
query_text=query,
|
||||||
|
query_vector=vector,
|
||||||
|
task_id=task_id,
|
||||||
|
candidates_num=coarse_limit
|
||||||
|
)
|
||||||
candidates = coarse_res.get('results', [])
|
candidates = coarse_res.get('results', [])
|
||||||
|
|
||||||
if not candidates: return {"results": []}
|
if not candidates: return {"results": []}
|
||||||
|
|
||||||
|
# 3. LLM 精排 (Rerank)
|
||||||
final_res = llm_service.rerank(query, candidates, return_num)
|
final_res = llm_service.rerank(query, candidates, return_num)
|
||||||
return {"results": final_res}
|
return {"results": final_res}
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ def normalize_url(url: str) -> str:
|
|||||||
2. 移除 fragment (#后面的内容)
|
2. 移除 fragment (#后面的内容)
|
||||||
3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面)
|
3. 移除 query 参数 (视业务需求而定,这里假设不同 query 是同一页面)
|
||||||
4. 移除尾部斜杠
|
4. 移除尾部斜杠
|
||||||
|
示例:
|
||||||
|
"https://www.example.com/path/" -> "https://www.example.com/path"
|
||||||
|
"https://www.example.com/path?query=1" -> "https://www.example.com/path"
|
||||||
"""
|
"""
|
||||||
if not url:
|
if not url:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
Reference in New Issue
Block a user