新增RAG测试脚本
This commit is contained in:
192
scripts/evaluate_rag.py
Normal file
192
scripts/evaluate_rag.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import sys
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import time
|
||||
import numpy as np
|
||||
from time import sleep
|
||||
|
||||
# 将项目根目录加入路径
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
from backend.core.config import settings
|
||||
|
||||
# ================= ⚙️ 配置区域 =================
|
||||
BASE_URL = "http://127.0.0.1:8000"
|
||||
TASK_ID = 19 # ⚠️ 请修改为你实际爬取数据的 Task ID
|
||||
# 自动适配操作系统路径
|
||||
TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_dataset.json")
|
||||
# ==============================================
|
||||
|
||||
class Colors:
|
||||
HEADER = '\033[95m'
|
||||
OKBLUE = '\033[94m'
|
||||
OKCYAN = '\033[96m'
|
||||
OKGREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
|
||||
def get_rag_results(query):
|
||||
"""
|
||||
调用搜索接口并记录耗时
|
||||
"""
|
||||
start_ts = time.time()
|
||||
try:
|
||||
# 调用 V2 接口,该接口内部已集成 混合检索 -> Rerank
|
||||
res = requests.post(
|
||||
f"{BASE_URL}/api/v2/search",
|
||||
json={"query": query, "task_id": TASK_ID, "limit": 5}, # 获取 Top 5
|
||||
timeout=15
|
||||
)
|
||||
latency = (time.time() - start_ts) * 1000 # ms
|
||||
|
||||
if res.status_code != 200:
|
||||
print(f"{Colors.FAIL}❌ API Error {res.status_code}: {res.text}{Colors.ENDC}")
|
||||
return [], 0
|
||||
|
||||
res_json = res.json()
|
||||
chunks = res_json.get('data', {}).get('results', [])
|
||||
return chunks, latency
|
||||
except Exception as e:
|
||||
print(f"{Colors.FAIL}❌ 请求异常: {e}{Colors.ENDC}")
|
||||
return [], 0
|
||||
|
||||
def check_hit(content, keywords):
|
||||
"""
|
||||
检查切片相关性 (Relevance Check)
|
||||
使用关键词匹配作为 Ground Truth 的轻量级验证。
|
||||
"""
|
||||
if not keywords: return True # 拒答题或开放性题目跳过关键词检查
|
||||
if not content: return False
|
||||
|
||||
content_lower = content.lower()
|
||||
for k in keywords:
|
||||
if k.lower() in content_lower:
|
||||
return True
|
||||
return False
|
||||
|
||||
def run_evaluation():
|
||||
# 1. 加载测试集
|
||||
if not os.path.exists(TEST_FILE):
|
||||
print(f"{Colors.FAIL}❌ 找不到测试文件: {TEST_FILE}{Colors.ENDC}")
|
||||
print("请确保 scripts/test_dataset.json 文件存在。")
|
||||
return
|
||||
|
||||
with open(TEST_FILE, 'r', encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
|
||||
print(f"{Colors.HEADER}🚀 开始全维度量化评测 (Task ID: {TASK_ID}){Colors.ENDC}")
|
||||
print(f"📄 测试集包含 {len(dataset)} 个样本\n")
|
||||
|
||||
# === 统计容器 ===
|
||||
metrics = {
|
||||
"p_at_1": [], # Precision@1: 正确答案排第1
|
||||
"hit_at_5": [], # HitRate@5: 正确答案在前5
|
||||
"mrr": [], # MRR: 倒数排名分数
|
||||
"latency": [] # 耗时
|
||||
}
|
||||
|
||||
# === 开始循环测试 ===
|
||||
for i, item in enumerate(dataset):
|
||||
query = item['query']
|
||||
print(f"📝 Case {i+1}: {Colors.BOLD}{query}{Colors.ENDC}")
|
||||
|
||||
# 执行检索
|
||||
chunks, latency = get_rag_results(query)
|
||||
metrics['latency'].append(latency)
|
||||
|
||||
# 计算单次指标
|
||||
is_hit_at_5 = 0
|
||||
p_at_1 = 0
|
||||
reciprocal_rank = 0.0
|
||||
hit_position = -1
|
||||
hit_chunk = None
|
||||
|
||||
# 遍历 Top 5 结果
|
||||
for idx, chunk in enumerate(chunks):
|
||||
if check_hit(chunk['content'], item['keywords']):
|
||||
# 命中!
|
||||
is_hit_at_5 = 1
|
||||
hit_position = idx
|
||||
reciprocal_rank = 1.0 / (idx + 1)
|
||||
hit_chunk = chunk
|
||||
|
||||
# 如果是第1个就命中了
|
||||
if idx == 0:
|
||||
p_at_1 = 1
|
||||
|
||||
# 找到即停止 (MRR计算只需知道第一个正确答案的位置)
|
||||
break
|
||||
|
||||
# 记录指标
|
||||
metrics['p_at_1'].append(p_at_1)
|
||||
metrics['hit_at_5'].append(is_hit_at_5)
|
||||
metrics['mrr'].append(reciprocal_rank)
|
||||
|
||||
# 打印单行结果
|
||||
if is_hit_at_5:
|
||||
rank_display = f"Rank {hit_position + 1}"
|
||||
color = Colors.OKGREEN if hit_position == 0 else Colors.OKCYAN
|
||||
source = hit_chunk.get('source_url', 'Unknown')
|
||||
|
||||
# 跨语言污染检查 (简单规则)
|
||||
warning = ""
|
||||
if "/es/" in source and "Spanish" not in query: warning = f"{Colors.WARNING}[跨语言风险]{Colors.ENDC}"
|
||||
elif "/zh/" in source and "如何" not in query and "什么" not in query: warning = f"{Colors.WARNING}[跨语言风险]{Colors.ENDC}"
|
||||
|
||||
print(f" {color}✅ 命中 ({rank_display}){Colors.ENDC} | MRR: {reciprocal_rank:.2f} | 耗时: {latency:.0f}ms {warning}")
|
||||
else:
|
||||
print(f" {Colors.FAIL}❌ 未命中{Colors.ENDC} | 预期关键词: {item['keywords']}")
|
||||
|
||||
# 稍微间隔,避免触发 API 频率限制
|
||||
sleep(0.1)
|
||||
|
||||
# === 最终计算 ===
|
||||
count = len(dataset)
|
||||
if count == 0: return
|
||||
|
||||
avg_p1 = np.mean(metrics['p_at_1']) * 100
|
||||
avg_hit5 = np.mean(metrics['hit_at_5']) * 100
|
||||
avg_mrr = np.mean(metrics['mrr'])
|
||||
avg_latency = np.mean(metrics['latency'])
|
||||
p95_latency = np.percentile(metrics['latency'], 95)
|
||||
|
||||
print("\n" + "="*60)
|
||||
print(f"{Colors.HEADER}📊 最终量化评估报告 (Evaluation Report){Colors.ENDC}")
|
||||
print("="*60)
|
||||
|
||||
# 1. Precision@1 (最关键指标)
|
||||
print(f"🥇 {Colors.BOLD}Precision@1 (首位精确率): {avg_p1:.1f}%{Colors.ENDC}")
|
||||
print(f" - 意义: 用户能否直接得到正确答案。引入 Rerank 后此项应显著提高。")
|
||||
|
||||
# 2. Hit Rate / Recall@5
|
||||
print(f"🥈 Hit Rate@5 (前五召回率): {avg_hit5:.1f}%")
|
||||
print(f" - 意义: 数据库是否真的包含答案。如果此项低,说明爬虫没爬全或混合检索漏了。")
|
||||
|
||||
# 3. MRR
|
||||
print(f"🥉 MRR (平均倒数排名): {avg_mrr:.3f} / 1.0")
|
||||
|
||||
# 4. Latency
|
||||
print(f"⚡ Avg Latency (平均耗时): {avg_latency:.0f} ms")
|
||||
print(f"⚡ P95 Latency (95%分位): {p95_latency:.0f} ms")
|
||||
print("="*60)
|
||||
|
||||
# === 智能诊断 ===
|
||||
print(f"{Colors.HEADER}💡 诊断建议:{Colors.ENDC}")
|
||||
|
||||
if avg_p1 < avg_hit5:
|
||||
gap = avg_hit5 - avg_p1
|
||||
print(f" • {Colors.WARNING}排序优化空间大{Colors.ENDC}: 召回了但没排第一的情况占 {gap:.1f}%。")
|
||||
print(" -> 你的 Rerank 模型生效了吗?或者 Rerank 的 Top N 截断是否太早?")
|
||||
elif avg_p1 > 80:
|
||||
print(f" • {Colors.OKGREEN}排序效果优秀{Colors.ENDC}: 绝大多数正确答案都排在第一位。")
|
||||
|
||||
if avg_hit5 < 50:
|
||||
print(f" • {Colors.FAIL}召回率过低{Colors.ENDC}: 可能是测试集关键词太生僻,或者 TS_RANK 权重过低。")
|
||||
|
||||
if avg_latency > 2000:
|
||||
print(f" • {Colors.WARNING}系统响应慢{Colors.ENDC}: 2秒以上。检查是否因为 Rerank 文档过多(建议 <= 50个)。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_evaluation()
|
||||
86
scripts/test_dataset.json
Normal file
86
scripts/test_dataset.json
Normal file
@@ -0,0 +1,86 @@
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"type": "core_function",
|
||||
"query": "What is the difference between /scrape and /map endpoints?",
|
||||
"ground_truth": "/map is used to crawl a website and retrieve all URLs, while /scrape is used to extract content from a specific URL.",
|
||||
"keywords": ["URL", "content", "specific", "retrieve"]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "new_feature",
|
||||
"query": "What is the Deep Research feature?",
|
||||
"ground_truth": "Deep Research is an alpha feature allowing agents to perform iterative research tasks.",
|
||||
"keywords": ["alpha", "iterative", "research", "agent"]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "integration",
|
||||
"query": "How can I integrate Firecrawl with ChatGPT?",
|
||||
"ground_truth": "Firecrawl can be integrated via MCP (Model Context Protocol).",
|
||||
"keywords": ["MCP", "Model Context Protocol", "setup"]
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "multilingual_zh",
|
||||
"query": "如何进行私有化部署 (Self-host)?",
|
||||
"ground_truth": "你需要使用 Docker Compose 进行部署,文档位于 /self-host/quick-start/docker-compose。",
|
||||
"keywords": ["Docker", "Compose", "self-host", "deploy"]
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "api_detail",
|
||||
"query": "What parameters are available for the /extract endpoint?",
|
||||
"ground_truth": "The extract endpoint allows defining a schema for structured data extraction.",
|
||||
"keywords": ["schema", "structured", "prompt"]
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "numeric",
|
||||
"query": "How do credits work for the scrape endpoint?",
|
||||
"ground_truth": "Specific credit usage details are in the /credits endpoint documentation (usually 1 credit per page for basic scrape).",
|
||||
"keywords": ["credit", "usage", "cost"]
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "negative_test",
|
||||
"query": "Does Firecrawl support scraping video content from YouTube?",
|
||||
"ground_truth": "The documentation does not mention video scraping support.",
|
||||
"keywords": []
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "advanced",
|
||||
"query": "How to use batch scrape?",
|
||||
"ground_truth": "Use the /batch/scrape endpoint to submit multiple URLs at once.",
|
||||
"keywords": ["batch", "multiple", "URLs"]
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "automation",
|
||||
"query": "Is there an n8n integration guide?",
|
||||
"ground_truth": "Yes, there is a workflow automation guide for n8n.",
|
||||
"keywords": ["n8n", "workflow", "automation"]
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"type": "security",
|
||||
"query": "Where can I find information about webhook security?",
|
||||
"ground_truth": "Information is available in the Webhooks Security section.",
|
||||
"keywords": ["webhook", "security", "signature"]
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "cross_lingual_trap",
|
||||
"query": "Explain the crawl features in French.",
|
||||
"ground_truth": "The system should ideally retrieve the French document (/fr/features/crawl) and answer in French.",
|
||||
"keywords": ["fonctionnalités", "crawl", "fr"]
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"type": "api_history",
|
||||
"query": "How to check historical token usage?",
|
||||
"ground_truth": "Use the /token-usage-historical endpoint.",
|
||||
"keywords": ["token", "usage", "historical"]
|
||||
}
|
||||
]
|
||||
Reference in New Issue
Block a user