完成RAG测试
This commit is contained in:
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/rag_benchmark/__init__.py
Normal file
0
tests/rag_benchmark/__init__.py
Normal file
BIN
tests/rag_benchmark/benchmark_report.png
Normal file
BIN
tests/rag_benchmark/benchmark_report.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 100 KiB |
86
tests/rag_benchmark/dataset.json
Normal file
86
tests/rag_benchmark/dataset.json
Normal file
@@ -0,0 +1,86 @@
|
||||
[
|
||||
{
|
||||
"id": 1,
|
||||
"type": "core_function",
|
||||
"query": "What is the difference between /scrape and /map endpoints?",
|
||||
"ground_truth": "/map is used to crawl a website and retrieve all URLs, while /scrape is used to extract content from a specific URL.",
|
||||
"keywords": ["URL", "content", "specific", "retrieve"]
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"type": "new_feature",
|
||||
"query": "What is the Deep Research feature?",
|
||||
"ground_truth": "Deep Research is an alpha feature allowing agents to perform iterative research tasks.",
|
||||
"keywords": ["alpha", "iterative", "research", "agent"]
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"type": "integration",
|
||||
"query": "How can I integrate Firecrawl with ChatGPT?",
|
||||
"ground_truth": "Firecrawl can be integrated via MCP (Model Context Protocol).",
|
||||
"keywords": ["MCP", "Model Context Protocol", "setup"]
|
||||
},
|
||||
{
|
||||
"id": 4,
|
||||
"type": "multilingual_zh",
|
||||
"query": "如何进行私有化部署 (Self-host)?",
|
||||
"ground_truth": "你需要使用 Docker Compose 进行部署,文档位于 /self-host/quick-start/docker-compose。",
|
||||
"keywords": ["Docker", "Compose", "self-host", "deploy"]
|
||||
},
|
||||
{
|
||||
"id": 5,
|
||||
"type": "api_detail",
|
||||
"query": "What parameters are available for the /extract endpoint?",
|
||||
"ground_truth": "The extract endpoint allows defining a schema for structured data extraction.",
|
||||
"keywords": ["schema", "structured", "prompt"]
|
||||
},
|
||||
{
|
||||
"id": 6,
|
||||
"type": "numeric",
|
||||
"query": "How do credits work for the scrape endpoint?",
|
||||
"ground_truth": "Specific credit usage details are in the /credits endpoint documentation (usually 1 credit per page for basic scrape).",
|
||||
"keywords": ["credit", "usage", "cost"]
|
||||
},
|
||||
{
|
||||
"id": 7,
|
||||
"type": "negative_test",
|
||||
"query": "Does Firecrawl support scraping video content from YouTube?",
|
||||
"ground_truth": "The documentation does not mention video scraping support.",
|
||||
"keywords": []
|
||||
},
|
||||
{
|
||||
"id": 8,
|
||||
"type": "advanced",
|
||||
"query": "How to use batch scrape?",
|
||||
"ground_truth": "Use the /batch/scrape endpoint to submit multiple URLs at once.",
|
||||
"keywords": ["batch", "multiple", "URLs"]
|
||||
},
|
||||
{
|
||||
"id": 9,
|
||||
"type": "automation",
|
||||
"query": "Is there an n8n integration guide?",
|
||||
"ground_truth": "Yes, there is a workflow automation guide for n8n.",
|
||||
"keywords": ["n8n", "workflow", "automation"]
|
||||
},
|
||||
{
|
||||
"id": 10,
|
||||
"type": "security",
|
||||
"query": "Where can I find information about webhook security?",
|
||||
"ground_truth": "Information is available in the Webhooks Security section.",
|
||||
"keywords": ["webhook", "security", "signature"]
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"type": "cross_lingual_trap",
|
||||
"query": "Explain the crawl features in French.",
|
||||
"ground_truth": "The system should ideally retrieve the French document (/fr/features/crawl) and answer in French.",
|
||||
"keywords": ["fonctionnalités", "crawl", "fr"]
|
||||
},
|
||||
{
|
||||
"id": 12,
|
||||
"type": "api_history",
|
||||
"query": "How to check historical token usage?",
|
||||
"ground_truth": "Use the /token-usage-historical endpoint.",
|
||||
"keywords": ["token", "usage", "historical"]
|
||||
}
|
||||
]
|
||||
74
tests/rag_benchmark/evaluator.py
Normal file
74
tests/rag_benchmark/evaluator.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import json
|
||||
import logging
|
||||
from backend.services.llm_service import llm_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RAGEvaluator:
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
def calculate_retrieval_metrics(self, retrieved_docs, dataset_item):
|
||||
"""
|
||||
计算检索阶段指标: Keyword Recall (关键词覆盖率)
|
||||
检查 dataset 中的 keywords 有多少出现在了 retrieved_docs 的 content 中
|
||||
"""
|
||||
required_keywords = dataset_item.get("keywords", [])
|
||||
if not required_keywords:
|
||||
return {"keyword_recall": 1.0, "hit": True} # 没有关键词要求,默认算对
|
||||
|
||||
# 将所有检索到的文本拼接并转小写
|
||||
full_context = " ".join([doc['content'] for doc in retrieved_docs]).lower()
|
||||
|
||||
found_count = 0
|
||||
for kw in required_keywords:
|
||||
if kw.lower() in full_context:
|
||||
found_count += 1
|
||||
|
||||
recall = found_count / len(required_keywords)
|
||||
|
||||
return {
|
||||
"keyword_recall": recall,
|
||||
# 只要召回率大于 0 就认为 Hit 了一部分;
|
||||
# 严格一点可以要求 recall > 0.5,这里我们设定只要沾边就算 Hit
|
||||
"hit": recall > 0
|
||||
}
|
||||
|
||||
def evaluate_generation_quality(self, question, generated_answer, ground_truth_answer, q_type):
|
||||
"""
|
||||
使用 LLM 作为裁判,评估生成质量 (1-5分)
|
||||
"""
|
||||
prompt = f"""
|
||||
你是一名RAG系统的自动化测试裁判。请根据以下信息对“系统回答”进行评分(1-5分)。
|
||||
|
||||
【测试类型】: {q_type}
|
||||
【用户问题】: {question}
|
||||
【标准答案 (Ground Truth)】: {ground_truth_answer}
|
||||
【系统回答】: {generated_answer}
|
||||
|
||||
评分标准:
|
||||
- 5分: 含义与标准答案完全一致,逻辑正确,无幻觉。
|
||||
- 4分: 核心意思正确,但缺少部分细节或废话较多。
|
||||
- 3分: 回答了一部分正确信息,但有遗漏或轻微错误。
|
||||
- 2分: 包含大量错误信息或严重答非所问。
|
||||
- 1分: 完全错误,或产生了严重幻觉(例如在负向测试中编造了不存在的功能)。
|
||||
|
||||
注意:对于"negative_test"(负向测试),如果标准答案是“不支持/文档未提及”,而系统回答诚实地说“未找到相关信息”或“不支持”,应给满分。
|
||||
|
||||
请仅返回JSON格式: {{"score": 5, "reason": "理由..."}}
|
||||
"""
|
||||
|
||||
try:
|
||||
# 使用 system_prompt 强制约束格式
|
||||
result_str = self.llm.chat(prompt, system_prompt="你是一个只输出JSON的评测机器人。")
|
||||
|
||||
# 清洗 Markdown 格式 (```json ... ```)
|
||||
if "```" in result_str:
|
||||
result_str = result_str.split("```json")[-1].split("```")[0].strip()
|
||||
|
||||
eval_result = json.loads(result_str)
|
||||
return eval_result
|
||||
except Exception as e:
|
||||
logger.error(f"Eval LLM failed: {e}")
|
||||
# 降级处理
|
||||
return {"score": 0, "reason": "Evaluation Script Error"}
|
||||
168
tests/rag_benchmark/run_benchmark.py
Normal file
168
tests/rag_benchmark/run_benchmark.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from tabulate import tabulate # 需要 pip install tabulate
|
||||
|
||||
# 路径 Hack: 确保能导入 backend 模块
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
|
||||
|
||||
from backend.services.data_service import data_service
|
||||
from backend.services.llm_service import llm_service
|
||||
from tests.rag_benchmark.evaluator import RAGEvaluator
|
||||
|
||||
# ================= 配置区 =================
|
||||
# 请填入你数据库中真实存在的、包含爬取数据的 task_id
|
||||
TEST_TASK_ID = 19
|
||||
# ========================================
|
||||
|
||||
def run_experiment(config_name, dataset, retrieve_func, generate_func):
|
||||
"""
|
||||
运行单组实验并收集详细指标
|
||||
"""
|
||||
print(f"\n🚀 开始测试配置: [ {config_name} ]")
|
||||
evaluator = RAGEvaluator()
|
||||
results = []
|
||||
|
||||
total_latency = 0
|
||||
|
||||
# 用于分类统计 (例如 core_function 得分多少, negative_test 得分多少)
|
||||
category_stats = defaultdict(lambda: {"count": 0, "score_sum": 0, "recall_sum": 0})
|
||||
|
||||
for item in dataset:
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 检索 (Retrieval)
|
||||
retrieved_docs = retrieve_func(item['query'])
|
||||
|
||||
# 2. 生成 (Generation)
|
||||
# 构造 Context,如果没有检索到内容,给一个提示
|
||||
if retrieved_docs:
|
||||
context_str = "\n---\n".join([f"Source: {d.get('source_url', 'unknown')}\nContent: {d['content']}" for d in retrieved_docs])
|
||||
else:
|
||||
context_str = "没有检索到任何相关文档。"
|
||||
|
||||
answer = generate_func(item['query'], context_str)
|
||||
|
||||
latency = time.time() - start_time
|
||||
total_latency += latency
|
||||
|
||||
# 3. 评测 (Evaluation)
|
||||
# 计算关键词召回率
|
||||
retrieval_metric = evaluator.calculate_retrieval_metrics(retrieved_docs, item)
|
||||
# 计算LLM回答质量
|
||||
gen_eval = evaluator.evaluate_generation_quality(
|
||||
item['query'], answer, item['ground_truth'], item['type']
|
||||
)
|
||||
|
||||
# 记录单条结果
|
||||
row = {
|
||||
"id": item['id'],
|
||||
"type": item['type'],
|
||||
"query": item['query'],
|
||||
"recall": retrieval_metric['keyword_recall'],
|
||||
"score": gen_eval['score'],
|
||||
"reason": gen_eval.get('reason', '')[:50] + "...", # 截断显示
|
||||
"latency": latency
|
||||
}
|
||||
results.append(row)
|
||||
|
||||
# 累加分类统计
|
||||
cat = item['type']
|
||||
category_stats[cat]["count"] += 1
|
||||
category_stats[cat]["score_sum"] += gen_eval['score']
|
||||
category_stats[cat]["recall_sum"] += retrieval_metric['keyword_recall']
|
||||
|
||||
# 实时打印进度 (简洁版)
|
||||
status_icon = "✅" if gen_eval['score'] >= 4 else "⚠️" if gen_eval['score'] >= 3 else "❌"
|
||||
print(f" {status_icon} ID:{item['id']} [{item['type'][:10]}] Score:{gen_eval['score']} | Recall:{retrieval_metric['keyword_recall']:.1f}")
|
||||
|
||||
# --- 汇总本轮实验数据 ---
|
||||
avg_score = sum(r['score'] for r in results) / len(results)
|
||||
avg_recall = sum(r['recall'] for r in results) / len(results)
|
||||
avg_latency = total_latency / len(results)
|
||||
|
||||
# 格式化分类报告
|
||||
cat_report = []
|
||||
for cat, data in category_stats.items():
|
||||
cat_report.append(f"{cat}: {data['score_sum']/data['count']:.1f}分")
|
||||
|
||||
return {
|
||||
"Config": config_name,
|
||||
"Avg Score (1-5)": f"{avg_score:.2f}",
|
||||
"Avg Recall": f"{avg_recall:.2%}",
|
||||
"Avg Latency": f"{avg_latency:.3f}s",
|
||||
"Weakest Category": min(category_stats, key=lambda k: category_stats[k]['score_sum']/category_stats[k]['count'])
|
||||
}
|
||||
|
||||
def main():
|
||||
# 1. 加载数据集
|
||||
dataset_path = os.path.join(os.path.dirname(__file__), 'dataset.json')
|
||||
if not os.path.exists(dataset_path):
|
||||
print("Error: dataset.json not found.")
|
||||
return
|
||||
|
||||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
|
||||
print(f"载入 {len(dataset)} 条测试用例,准备开始横向评测...")
|
||||
|
||||
# 2. 定义实验变量 (检索函数 + 生成函数)
|
||||
|
||||
# === Exp A: 纯关键词 (模拟传统搜索) ===
|
||||
def retrieve_keyword(query):
|
||||
# vector_weight=0 强制使用 SQL TSVector
|
||||
# 注意: 需要传递一个假向量给接口占位
|
||||
dummy_vec = [0.0] * 1536
|
||||
res = data_service.search(query, dummy_vec, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)
|
||||
return res['results']
|
||||
|
||||
# === Exp B: 纯向量 (语义检索) ===
|
||||
def retrieve_vector(query):
|
||||
vec = llm_service.get_embedding(query)
|
||||
# vector_weight=1 忽略关键词匹配
|
||||
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)
|
||||
return res['results']
|
||||
|
||||
# === Exp C: 混合检索 (Hybrid) ===
|
||||
def retrieve_hybrid(query):
|
||||
vec = llm_service.get_embedding(query)
|
||||
# 默认 0.7 向量 + 0.3 关键词
|
||||
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)
|
||||
return res['results']
|
||||
|
||||
# === Exp D: 混合 + 重排序 (Rerank) ===
|
||||
def retrieve_rerank(query):
|
||||
vec = llm_service.get_embedding(query)
|
||||
# 1. 扩大召回 (Top 30)
|
||||
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)
|
||||
initial_docs = res['results']
|
||||
# 2. 精排 (Top 5)
|
||||
reranked = llm_service.rerank(query, initial_docs, top_n=5)
|
||||
return reranked
|
||||
|
||||
# === 通用生成函数 ===
|
||||
def generate_answer(query, context):
|
||||
system_prompt = "你是一个智能助手。请严格根据提供的上下文回答用户问题。如果上下文中没有答案,请直接说'未找到相关信息'。"
|
||||
prompt = f"参考上下文:\n{context}\n\n用户问题:{query}"
|
||||
return llm_service.chat(prompt, system_prompt=system_prompt)
|
||||
|
||||
# 3. 运行所有实验
|
||||
final_report = []
|
||||
|
||||
final_report.append(run_experiment("1. Keyword Only (BM25)", dataset, retrieve_keyword, generate_answer))
|
||||
final_report.append(run_experiment("2. Vector Only", dataset, retrieve_vector, generate_answer))
|
||||
final_report.append(run_experiment("3. Hybrid (Base)", dataset, retrieve_hybrid, generate_answer))
|
||||
final_report.append(run_experiment("4. Hybrid + Rerank", dataset, retrieve_rerank, generate_answer))
|
||||
|
||||
# 4. 输出最终报表
|
||||
print("\n\n📊 ================= 最终横向对比报告 (Final Report) ================= 📊")
|
||||
print(tabulate(final_report, headers="keys", tablefmt="github"))
|
||||
print("\n💡 解读建议:")
|
||||
print("1. 如果 'Avg Recall' 低,说明切片(Chunking)或检索算法找不到资料。")
|
||||
print("2. 如果 Recall 高但 'Avg Score' 低,说明 LLM 产生了幻觉或 Prompt 没写好。")
|
||||
print("3. 'Weakest Category' 帮你发现短板(如多语言或负向测试)。")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
208
tests/rag_benchmark/visual_benchmark.py
Normal file
208
tests/rag_benchmark/visual_benchmark.py
Normal file
@@ -0,0 +1,208 @@
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
from tqdm import tqdm
|
||||
|
||||
# 路径 Hack: 确保能导入 backend
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
# 直接导入服务类 (Direct Call)
|
||||
from backend.services.data_service import data_service
|
||||
from backend.services.llm_service import llm_service
|
||||
|
||||
# ================= 配置区 =================
|
||||
TEST_TASK_ID = 19 # 请修改为真实的 Task ID
|
||||
DATASET_PATH = os.path.join(current_dir, 'dataset.json')
|
||||
OUTPUT_IMG = os.path.join(current_dir, 'benchmark_report.png')
|
||||
# ========================================
|
||||
|
||||
class RAGEvaluator:
|
||||
"""评测工具类:负责计算召回率和调用LLM打分"""
|
||||
def __init__(self):
|
||||
self.llm = llm_service
|
||||
|
||||
def calculate_recall(self, retrieved_docs, keywords):
|
||||
"""计算关键词召回率"""
|
||||
if not keywords: return 1.0 # 无关键词要求的题目默认为满分
|
||||
|
||||
full_text = " ".join([d['content'] for d in retrieved_docs]).lower()
|
||||
hit_count = sum(1 for k in keywords if k.lower() in full_text)
|
||||
return hit_count / len(keywords)
|
||||
|
||||
def judge_answer(self, query, answer, ground_truth):
|
||||
"""调用 LLM 给生成结果打分 (1-5)"""
|
||||
prompt = f"""
|
||||
作为 RAG 评测员,请对【系统回答】打分 (1-5)。
|
||||
用户问题: {query}
|
||||
标准答案: {ground_truth}
|
||||
系统回答: {answer}
|
||||
|
||||
标准:
|
||||
5: 含义完全一致,无幻觉。
|
||||
3: 包含核心信息,但有遗漏。
|
||||
1: 错误或严重幻觉。
|
||||
|
||||
只返回数字 (1, 2, 3, 4, 5)。
|
||||
"""
|
||||
try:
|
||||
# 这里调用你在 llm_service 中新增的 chat 方法
|
||||
res = self.llm.chat(prompt)
|
||||
# 简单的清洗逻辑,提取数字
|
||||
score = int(''.join(filter(str.isdigit, res)))
|
||||
return min(max(score, 1), 5) # 限制在 1-5
|
||||
except:
|
||||
return 1 # 失败保底 1 分
|
||||
|
||||
class Visualizer:
|
||||
"""绘图工具类"""
|
||||
def plot_dashboard(self, df):
|
||||
# 设置风格
|
||||
sns.set_theme(style="whitegrid")
|
||||
# 解决中文显示问题 (如果环境支持 SimHei 则用中文,否则用英文)
|
||||
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial', 'DejaVu Sans']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
fig = plt.figure(figsize=(18, 10))
|
||||
gs = fig.add_gridspec(2, 2)
|
||||
|
||||
# Chart 1: 总体指标对比 (Bar Chart)
|
||||
ax1 = fig.add_subplot(gs[0, 0])
|
||||
# 将数据变形为长格式以便绘图
|
||||
df_summary = df.groupby('config')[['score', 'recall']].mean().reset_index()
|
||||
df_melt = df_summary.melt(id_vars='config', var_name='Metric', value_name='Value')
|
||||
# 将 Recall 归一化到 0-5 方便同图显示,或者分开轴。这里简单处理:Recall * 5
|
||||
df_melt.loc[df_melt['Metric'] == 'recall', 'Value'] *= 5
|
||||
|
||||
sns.barplot(data=df_melt, x='config', y='Value', hue='Metric', ax=ax1, palette="viridis")
|
||||
ax1.set_title('Overall Performance (Score & Recall)', fontsize=14, fontweight='bold')
|
||||
ax1.set_ylabel('Score (1-5) / Recall (x5)')
|
||||
ax1.set_ylim(0, 5.5)
|
||||
for container in ax1.containers:
|
||||
ax1.bar_label(container, fmt='%.1f')
|
||||
|
||||
# Chart 2: 延迟 vs 质量 (Scatter Plot)
|
||||
ax2 = fig.add_subplot(gs[0, 1])
|
||||
df_latency = df.groupby('config')[['score', 'latency']].mean().reset_index()
|
||||
sns.scatterplot(data=df_latency, x='latency', y='score', hue='config', s=200, ax=ax2, palette="deep")
|
||||
|
||||
# 添加标签
|
||||
for i in range(df_latency.shape[0]):
|
||||
ax2.text(
|
||||
df_latency.latency[i]+0.05,
|
||||
df_latency.score[i],
|
||||
df_latency.config[i],
|
||||
fontsize=10
|
||||
)
|
||||
ax2.set_title('Trade-off: Latency vs Quality', fontsize=14, fontweight='bold')
|
||||
ax2.set_xlabel('Avg Latency (seconds)')
|
||||
ax2.set_ylabel('Avg Quality Score (1-5)')
|
||||
|
||||
# Chart 3: 类别热力图 (Heatmap) - 你的 Weakest Category 可视化
|
||||
ax3 = fig.add_subplot(gs[1, :]) # 占用下方整行
|
||||
pivot_data = df.pivot_table(index='config', columns='type', values='score', aggfunc='mean')
|
||||
sns.heatmap(pivot_data, annot=True, cmap="RdYlGn", center=3, fmt=".1f", ax=ax3, linewidths=.5)
|
||||
ax3.set_title('Category Breakdown (Find the Weakest Link)', fontsize=14, fontweight='bold')
|
||||
ax3.set_xlabel('')
|
||||
ax3.set_ylabel('')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(OUTPUT_IMG)
|
||||
print(f"\n📊 报表已生成: {OUTPUT_IMG}")
|
||||
|
||||
def main():
|
||||
# 1. 加载数据
|
||||
if not os.path.exists(DATASET_PATH):
|
||||
print("Dataset not found!")
|
||||
return
|
||||
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
||||
dataset = json.load(f)
|
||||
|
||||
# 2. 定义实验配置 (Direct Call)
|
||||
configs = [
|
||||
{
|
||||
"name": "1. BM25 (Keyword)",
|
||||
"retriever": lambda q: data_service.search(q, [0.0]*1536, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)['results'],
|
||||
"rerank": False
|
||||
},
|
||||
{
|
||||
"name": "2. Vector Only",
|
||||
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)['results'],
|
||||
"rerank": False
|
||||
},
|
||||
{
|
||||
"name": "3. Hybrid (Base)",
|
||||
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)['results'],
|
||||
"rerank": False
|
||||
},
|
||||
{
|
||||
"name": "4. Hybrid + Rerank",
|
||||
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)['results'], # 召回 Top 30
|
||||
"rerank": True
|
||||
}
|
||||
]
|
||||
|
||||
evaluator = RAGEvaluator()
|
||||
all_results = []
|
||||
|
||||
print("🚀 开始自动化评测 (Visualization Mode)...")
|
||||
|
||||
# 3. 循环执行 (双重循环:配置 -> 数据)
|
||||
# 使用 tqdm 显示总进度
|
||||
total_steps = len(configs) * len(dataset)
|
||||
pbar = tqdm(total=total_steps, desc="Running Experiments")
|
||||
|
||||
for cfg in configs:
|
||||
for item in dataset:
|
||||
pbar.set_description(f"Testing {cfg['name']}")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# A. 检索
|
||||
docs = cfg['retriever'](item['query'])
|
||||
|
||||
# B. Rerank (如果在配置里开启)
|
||||
if cfg['rerank']:
|
||||
docs = llm_service.rerank(item['query'], docs, top_n=5)
|
||||
|
||||
# C. 生成
|
||||
context = "\n".join([d['content'] for d in docs]) if docs else ""
|
||||
if not context:
|
||||
answer = "未找到相关信息"
|
||||
else:
|
||||
prompt = f"Context:\n{context}\n\nQuestion: {item['query']}"
|
||||
answer = llm_service.chat(prompt) # 调用生成接口
|
||||
|
||||
latency = time.time() - start_time
|
||||
|
||||
# D. 评测指标
|
||||
recall = evaluator.calculate_recall(docs, item.get('keywords', []))
|
||||
score = evaluator.judge_answer(item['query'], answer, item['ground_truth'])
|
||||
|
||||
# E. 收集数据
|
||||
all_results.append({
|
||||
"config": cfg['name'],
|
||||
"id": item['id'],
|
||||
"type": item['type'], # 类别字段
|
||||
"recall": recall,
|
||||
"score": score,
|
||||
"latency": latency
|
||||
})
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
# 4. 数据处理与绘图
|
||||
df = pd.DataFrame(all_results)
|
||||
viz = Visualizer()
|
||||
viz.plot_dashboard(df)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user