新增RAG测试脚本

This commit is contained in:
2026-01-13 10:37:19 +08:00
parent d5ee00d404
commit e5ac2dde03
5 changed files with 386 additions and 21 deletions

View File

@@ -142,9 +142,41 @@ class CrawlerService:
return {"msg": "Batch processed", "count": processed} return {"msg": "Batch processed", "count": processed}
def search(self, query: str, task_id, limit: int): def search(self, query: str, task_id, return_num: int):
"""
全链路搜索:向量生成 -> 混合检索(粗排) -> 重排序(精排)
"""
# 1. 生成查询向量
vector = llm_service.get_embedding(query) vector = llm_service.get_embedding(query)
if not vector: return {"msg": "Embedding failed", "results": []} if not vector: return {"msg": "Embedding failed", "results": []}
return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit)
# 2. 计算粗排召回数量
# 逻辑:至少召回 50 个,如果用户要很多,则召回 10 倍
coarse_limit = return_num * 10 if return_num * 10 > 50 else 50
# 3. 执行混合检索 (粗排)
coarse_results = data_service.search(
query_text=query,
query_vector=vector,
task_id=task_id,
candidates_num=coarse_limit # 使用计算出的粗排数量
)
candidates = coarse_results.get('results', [])
if not candidates:
return {"msg": "No documents found", "results": []}
# 4. 执行重排序 (精排)
final_results = llm_service.rerank(
query=query,
documents=candidates,
top_n=return_num # 最终返回用户需要的数量
)
return {
"results": final_results,
"msg": f"Reranked {len(final_results)} from {len(candidates)} candidates"
}
crawler_service = CrawlerService() crawler_service = CrawlerService()

View File

@@ -91,24 +91,20 @@ class DataService:
return {"msg": f"Saved {count} chunks", "count": count} return {"msg": f"Saved {count} chunks", "count": count}
def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5): def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 5):
""" """
Phase 2: 混合检索 (Hybrid Search) Phase 2: 混合检索 (Hybrid Search)
综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword)
""" """
# 向量格式清洗
if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist()
if query_vector and isinstance(query_vector, list) and len(query_vector) > 0:
if isinstance(query_vector[0], list): query_vector = query_vector[0]
results = [] results = []
with self.db.engine.connect() as conn: with self.db.engine.connect() as conn:
# 定义混合检索的 SQL 逻辑
# 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法)
keyword_query = func.websearch_to_tsquery('english', query_text) keyword_query = func.websearch_to_tsquery('english', query_text)
vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector)) vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector))
keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query) keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query)
# 综合打分列: 0.7 * Vector + 0.3 * Keyword
# coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0
final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score") final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score")
stmt = select( stmt = select(
@@ -123,8 +119,8 @@ class DataService:
if task_id: if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id) stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 按综合分数倒序 # 使用 candidates_num 控制召回数量
stmt = stmt.order_by(desc("score")).limit(limit) stmt = stmt.order_by(desc("score")).limit(candidates_num)
try: try:
rows = conn.execute(stmt).fetchall() rows = conn.execute(stmt).fetchall()
@@ -141,23 +137,19 @@ class DataService:
] ]
except Exception as e: except Exception as e:
print(f"[ERROR] Hybrid search failed: {e}") print(f"[ERROR] Hybrid search failed: {e}")
return self._fallback_vector_search(query_vector, task_id, limit) return self._fallback_vector_search(query_vector, task_id, candidates_num)
return {"results": results, "msg": f"Hybrid found {len(results)}"} return {"results": results, "msg": f"Hybrid found {len(results)}"}
def _fallback_vector_search(self, vector, task_id, limit): def _fallback_vector_search(self, vector, task_id, limit):
"""降级兜底:纯向量搜索"""
print("[WARN] Fallback to pure vector search") print("[WARN] Fallback to pure vector search")
with self.db.engine.connect() as conn: with self.db.engine.connect() as conn:
stmt = select( stmt = select(
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
self.db.chunks.c.content, self.db.chunks.c.meta_info self.db.chunks.c.content, self.db.chunks.c.meta_info
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit) ).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
if task_id: if task_id:
stmt = stmt.where(self.db.chunks.c.task_id == task_id) stmt = stmt.where(self.db.chunks.c.task_id == task_id)
rows = conn.execute(stmt).fetchall() rows = conn.execute(stmt).fetchall()
return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"} return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"}
data_service = DataService() data_service = DataService()

View File

@@ -5,16 +5,16 @@ from backend.core.config import settings
class LLMService: class LLMService:
""" """
LLM 服务封装层 LLM 服务封装层
负责与 DashScope 或其他模型供应商交互 负责与 DashScope (通义千问/GTE) 交互,包括 Embedding 和 Rerank
""" """
def __init__(self): def __init__(self):
dashscope.api_key = settings.DASHSCOPE_API_KEY dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embedding(self, text: str, dimension: int = 1536): def get_embedding(self, text: str, dimension: int = 1536):
"""生成文本向量""" """生成文本向量 (Bi-Encoder)"""
try: try:
resp = dashscope.TextEmbedding.call( resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4, model=dashscope.TextEmbedding.Models.text_embedding_v4, # 或 v4视你的数据库维度而定
input=text, input=text,
dimension=dimension dimension=dimension
) )
@@ -27,4 +27,67 @@ class LLMService:
print(f"[ERROR] Embedding Exception: {e}") print(f"[ERROR] Embedding Exception: {e}")
return None return None
def rerank(self, query: str, documents: list, top_n: int = 5):
"""
执行重排序 (Cross-Encoder)
Args:
query: 用户问题
documents: 粗排召回的切片列表 (List[dict]),必须包含 'content' 字段
top_n: 最终返回多少个结果
Returns:
List[dict]: 排序后并截取 Top N 的文档列表,包含新的 'score'
"""
if not documents:
return []
# 1. 准备输入数据
# Rerank API 需要纯文本列表,但我们需要保留 documents 里的 meta_info 和 id
# 所以我们提取 content 给 API拿到 index 后再映射回去
doc_contents = [doc.get('content', '') for doc in documents]
# 如果文档太多(比如超过 100 个),建议先截断,避免 API 超时或报错
if len(doc_contents) > 50:
doc_contents = doc_contents[:50]
documents = documents[:50]
try:
# 2. 调用 DashScope GTE-Rerank
resp = dashscope.TextReRank.call(
model='gte-rerank',
query=query,
documents=doc_contents,
top_n=top_n,
return_documents=False # 我们只需要索引和分数,不需要它把文本再传回来
)
if resp.status_code == HTTPStatus.OK:
# 3. 结果重组
# API 返回结构示例: output.results = [{'index': 2, 'relevance_score': 0.98}, {'index': 0, ...}]
reranked_results = []
for item in resp.output.results:
# 根据 API 返回的 index 找到原始文档对象
original_doc = documents[item.index]
# 更新分数为 Rerank 的精准分数 (通常是 0~1 之间的置信度)
original_doc['score'] = item.relevance_score
# 标记来源,方便调试知道这是 Rerank 过的
original_doc['reranked'] = True
reranked_results.append(original_doc)
return reranked_results
else:
print(f"[ERROR] Rerank API Error: {resp}")
# 降级策略:如果 Rerank 挂了,直接返回粗排的前 N 个
return documents[:top_n]
except Exception as e:
print(f"[ERROR] Rerank Exception: {e}")
# 降级策略
return documents[:top_n]
llm_service = LLMService() llm_service = LLMService()

192
scripts/evaluate_rag.py Normal file
View File

@@ -0,0 +1,192 @@
import sys
import os
import json
import requests
import time
import numpy as np
from time import sleep
# 将项目根目录加入路径
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from backend.core.config import settings
# ================= ⚙️ 配置区域 =================
BASE_URL = "http://127.0.0.1:8000"
TASK_ID = 19 # ⚠️ 请修改为你实际爬取数据的 Task ID
# 自动适配操作系统路径
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_dataset.json")
# ==============================================
class Colors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
def get_rag_results(query):
"""
调用搜索接口并记录耗时
"""
start_ts = time.time()
try:
# 调用 V2 接口,该接口内部已集成 混合检索 -> Rerank
res = requests.post(
f"{BASE_URL}/api/v2/search",
json={"query": query, "task_id": TASK_ID, "limit": 5}, # 获取 Top 5
timeout=15
)
latency = (time.time() - start_ts) * 1000 # ms
if res.status_code != 200:
print(f"{Colors.FAIL}❌ API Error {res.status_code}: {res.text}{Colors.ENDC}")
return [], 0
res_json = res.json()
chunks = res_json.get('data', {}).get('results', [])
return chunks, latency
except Exception as e:
print(f"{Colors.FAIL}❌ 请求异常: {e}{Colors.ENDC}")
return [], 0
def check_hit(content, keywords):
"""
检查切片相关性 (Relevance Check)
使用关键词匹配作为 Ground Truth 的轻量级验证。
"""
if not keywords: return True # 拒答题或开放性题目跳过关键词检查
if not content: return False
content_lower = content.lower()
for k in keywords:
if k.lower() in content_lower:
return True
return False
def run_evaluation():
# 1. 加载测试集
if not os.path.exists(TEST_FILE):
print(f"{Colors.FAIL}❌ 找不到测试文件: {TEST_FILE}{Colors.ENDC}")
print("请确保 scripts/test_dataset.json 文件存在。")
return
with open(TEST_FILE, 'r', encoding='utf-8') as f:
dataset = json.load(f)
print(f"{Colors.HEADER}🚀 开始全维度量化评测 (Task ID: {TASK_ID}){Colors.ENDC}")
print(f"📄 测试集包含 {len(dataset)} 个样本\n")
# === 统计容器 ===
metrics = {
"p_at_1": [], # Precision@1: 正确答案排第1
"hit_at_5": [], # HitRate@5: 正确答案在前5
"mrr": [], # MRR: 倒数排名分数
"latency": [] # 耗时
}
# === 开始循环测试 ===
for i, item in enumerate(dataset):
query = item['query']
print(f"📝 Case {i+1}: {Colors.BOLD}{query}{Colors.ENDC}")
# 执行检索
chunks, latency = get_rag_results(query)
metrics['latency'].append(latency)
# 计算单次指标
is_hit_at_5 = 0
p_at_1 = 0
reciprocal_rank = 0.0
hit_position = -1
hit_chunk = None
# 遍历 Top 5 结果
for idx, chunk in enumerate(chunks):
if check_hit(chunk['content'], item['keywords']):
# 命中!
is_hit_at_5 = 1
hit_position = idx
reciprocal_rank = 1.0 / (idx + 1)
hit_chunk = chunk
# 如果是第1个就命中了
if idx == 0:
p_at_1 = 1
# 找到即停止 (MRR计算只需知道第一个正确答案的位置)
break
# 记录指标
metrics['p_at_1'].append(p_at_1)
metrics['hit_at_5'].append(is_hit_at_5)
metrics['mrr'].append(reciprocal_rank)
# 打印单行结果
if is_hit_at_5:
rank_display = f"Rank {hit_position + 1}"
color = Colors.OKGREEN if hit_position == 0 else Colors.OKCYAN
source = hit_chunk.get('source_url', 'Unknown')
# 跨语言污染检查 (简单规则)
warning = ""
if "/es/" in source and "Spanish" not in query: warning = f"{Colors.WARNING}[跨语言风险]{Colors.ENDC}"
elif "/zh/" in source and "如何" not in query and "什么" not in query: warning = f"{Colors.WARNING}[跨语言风险]{Colors.ENDC}"
print(f" {color}✅ 命中 ({rank_display}){Colors.ENDC} | MRR: {reciprocal_rank:.2f} | 耗时: {latency:.0f}ms {warning}")
else:
print(f" {Colors.FAIL}❌ 未命中{Colors.ENDC} | 预期关键词: {item['keywords']}")
# 稍微间隔,避免触发 API 频率限制
sleep(0.1)
# === 最终计算 ===
count = len(dataset)
if count == 0: return
avg_p1 = np.mean(metrics['p_at_1']) * 100
avg_hit5 = np.mean(metrics['hit_at_5']) * 100
avg_mrr = np.mean(metrics['mrr'])
avg_latency = np.mean(metrics['latency'])
p95_latency = np.percentile(metrics['latency'], 95)
print("\n" + "="*60)
print(f"{Colors.HEADER}📊 最终量化评估报告 (Evaluation Report){Colors.ENDC}")
print("="*60)
# 1. Precision@1 (最关键指标)
print(f"🥇 {Colors.BOLD}Precision@1 (首位精确率): {avg_p1:.1f}%{Colors.ENDC}")
print(f" - 意义: 用户能否直接得到正确答案。引入 Rerank 后此项应显著提高。")
# 2. Hit Rate / Recall@5
print(f"🥈 Hit Rate@5 (前五召回率): {avg_hit5:.1f}%")
print(f" - 意义: 数据库是否真的包含答案。如果此项低,说明爬虫没爬全或混合检索漏了。")
# 3. MRR
print(f"🥉 MRR (平均倒数排名): {avg_mrr:.3f} / 1.0")
# 4. Latency
print(f"⚡ Avg Latency (平均耗时): {avg_latency:.0f} ms")
print(f"⚡ P95 Latency (95%分位): {p95_latency:.0f} ms")
print("="*60)
# === 智能诊断 ===
print(f"{Colors.HEADER}💡 诊断建议:{Colors.ENDC}")
if avg_p1 < avg_hit5:
gap = avg_hit5 - avg_p1
print(f"{Colors.WARNING}排序优化空间大{Colors.ENDC}: 召回了但没排第一的情况占 {gap:.1f}%。")
print(" -> 你的 Rerank 模型生效了吗?或者 Rerank 的 Top N 截断是否太早?")
elif avg_p1 > 80:
print(f"{Colors.OKGREEN}排序效果优秀{Colors.ENDC}: 绝大多数正确答案都排在第一位。")
if avg_hit5 < 50:
print(f"{Colors.FAIL}召回率过低{Colors.ENDC}: 可能是测试集关键词太生僻,或者 TS_RANK 权重过低。")
if avg_latency > 2000:
print(f"{Colors.WARNING}系统响应慢{Colors.ENDC}: 2秒以上。检查是否因为 Rerank 文档过多(建议 <= 50个)。")
if __name__ == "__main__":
run_evaluation()

86
scripts/test_dataset.json Normal file
View File

@@ -0,0 +1,86 @@
[
{
"id": 1,
"type": "core_function",
"query": "What is the difference between /scrape and /map endpoints?",
"ground_truth": "/map is used to crawl a website and retrieve all URLs, while /scrape is used to extract content from a specific URL.",
"keywords": ["URL", "content", "specific", "retrieve"]
},
{
"id": 2,
"type": "new_feature",
"query": "What is the Deep Research feature?",
"ground_truth": "Deep Research is an alpha feature allowing agents to perform iterative research tasks.",
"keywords": ["alpha", "iterative", "research", "agent"]
},
{
"id": 3,
"type": "integration",
"query": "How can I integrate Firecrawl with ChatGPT?",
"ground_truth": "Firecrawl can be integrated via MCP (Model Context Protocol).",
"keywords": ["MCP", "Model Context Protocol", "setup"]
},
{
"id": 4,
"type": "multilingual_zh",
"query": "如何进行私有化部署 (Self-host)",
"ground_truth": "你需要使用 Docker Compose 进行部署,文档位于 /self-host/quick-start/docker-compose。",
"keywords": ["Docker", "Compose", "self-host", "deploy"]
},
{
"id": 5,
"type": "api_detail",
"query": "What parameters are available for the /extract endpoint?",
"ground_truth": "The extract endpoint allows defining a schema for structured data extraction.",
"keywords": ["schema", "structured", "prompt"]
},
{
"id": 6,
"type": "numeric",
"query": "How do credits work for the scrape endpoint?",
"ground_truth": "Specific credit usage details are in the /credits endpoint documentation (usually 1 credit per page for basic scrape).",
"keywords": ["credit", "usage", "cost"]
},
{
"id": 7,
"type": "negative_test",
"query": "Does Firecrawl support scraping video content from YouTube?",
"ground_truth": "The documentation does not mention video scraping support.",
"keywords": []
},
{
"id": 8,
"type": "advanced",
"query": "How to use batch scrape?",
"ground_truth": "Use the /batch/scrape endpoint to submit multiple URLs at once.",
"keywords": ["batch", "multiple", "URLs"]
},
{
"id": 9,
"type": "automation",
"query": "Is there an n8n integration guide?",
"ground_truth": "Yes, there is a workflow automation guide for n8n.",
"keywords": ["n8n", "workflow", "automation"]
},
{
"id": 10,
"type": "security",
"query": "Where can I find information about webhook security?",
"ground_truth": "Information is available in the Webhooks Security section.",
"keywords": ["webhook", "security", "signature"]
},
{
"id": 11,
"type": "cross_lingual_trap",
"query": "Explain the crawl features in French.",
"ground_truth": "The system should ideally retrieve the French document (/fr/features/crawl) and answer in French.",
"keywords": ["fonctionnalités", "crawl", "fr"]
},
{
"id": 12,
"type": "api_history",
"query": "How to check historical token usage?",
"ground_truth": "Use the /token-usage-historical endpoint.",
"keywords": ["token", "usage", "historical"]
}
]