Files
wiki_crawler/tests/rag_benchmark/visual_benchmark.py
2026-01-27 01:41:45 +08:00

208 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import time
import json
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
# 路径 Hack: 确保能导入 backend
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
if project_root not in sys.path:
sys.path.insert(0, project_root)
# 直接导入服务类 (Direct Call)
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
# ================= 配置区 =================
TEST_TASK_ID = 19 # 请修改为真实的 Task ID
DATASET_PATH = os.path.join(current_dir, 'dataset.json')
OUTPUT_IMG = os.path.join(current_dir, 'benchmark_report.png')
# ========================================
class RAGEvaluator:
"""评测工具类负责计算召回率和调用LLM打分"""
def __init__(self):
self.llm = llm_service
def calculate_recall(self, retrieved_docs, keywords):
"""计算关键词召回率"""
if not keywords: return 1.0 # 无关键词要求的题目默认为满分
full_text = " ".join([d['content'] for d in retrieved_docs]).lower()
hit_count = sum(1 for k in keywords if k.lower() in full_text)
return hit_count / len(keywords)
def judge_answer(self, query, answer, ground_truth):
"""调用 LLM 给生成结果打分 (1-5)"""
prompt = f"""
作为 RAG 评测员,请对【系统回答】打分 (1-5)。
用户问题: {query}
标准答案: {ground_truth}
系统回答: {answer}
标准:
5: 含义完全一致,无幻觉。
3: 包含核心信息,但有遗漏。
1: 错误或严重幻觉。
只返回数字 (1, 2, 3, 4, 5)。
"""
try:
# 这里调用你在 llm_service 中新增的 chat 方法
res = self.llm.chat(prompt)
# 简单的清洗逻辑,提取数字
score = int(''.join(filter(str.isdigit, res)))
return min(max(score, 1), 5) # 限制在 1-5
except:
return 1 # 失败保底 1 分
class Visualizer:
"""绘图工具类"""
def plot_dashboard(self, df):
# 设置风格
sns.set_theme(style="whitegrid")
# 解决中文显示问题 (如果环境支持 SimHei 则用中文,否则用英文)
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False
fig = plt.figure(figsize=(18, 10))
gs = fig.add_gridspec(2, 2)
# Chart 1: 总体指标对比 (Bar Chart)
ax1 = fig.add_subplot(gs[0, 0])
# 将数据变形为长格式以便绘图
df_summary = df.groupby('config')[['score', 'recall']].mean().reset_index()
df_melt = df_summary.melt(id_vars='config', var_name='Metric', value_name='Value')
# 将 Recall 归一化到 0-5 方便同图显示或者分开轴。这里简单处理Recall * 5
df_melt.loc[df_melt['Metric'] == 'recall', 'Value'] *= 5
sns.barplot(data=df_melt, x='config', y='Value', hue='Metric', ax=ax1, palette="viridis")
ax1.set_title('Overall Performance (Score & Recall)', fontsize=14, fontweight='bold')
ax1.set_ylabel('Score (1-5) / Recall (x5)')
ax1.set_ylim(0, 5.5)
for container in ax1.containers:
ax1.bar_label(container, fmt='%.1f')
# Chart 2: 延迟 vs 质量 (Scatter Plot)
ax2 = fig.add_subplot(gs[0, 1])
df_latency = df.groupby('config')[['score', 'latency']].mean().reset_index()
sns.scatterplot(data=df_latency, x='latency', y='score', hue='config', s=200, ax=ax2, palette="deep")
# 添加标签
for i in range(df_latency.shape[0]):
ax2.text(
df_latency.latency[i]+0.05,
df_latency.score[i],
df_latency.config[i],
fontsize=10
)
ax2.set_title('Trade-off: Latency vs Quality', fontsize=14, fontweight='bold')
ax2.set_xlabel('Avg Latency (seconds)')
ax2.set_ylabel('Avg Quality Score (1-5)')
# Chart 3: 类别热力图 (Heatmap) - 你的 Weakest Category 可视化
ax3 = fig.add_subplot(gs[1, :]) # 占用下方整行
pivot_data = df.pivot_table(index='config', columns='type', values='score', aggfunc='mean')
sns.heatmap(pivot_data, annot=True, cmap="RdYlGn", center=3, fmt=".1f", ax=ax3, linewidths=.5)
ax3.set_title('Category Breakdown (Find the Weakest Link)', fontsize=14, fontweight='bold')
ax3.set_xlabel('')
ax3.set_ylabel('')
plt.tight_layout()
plt.savefig(OUTPUT_IMG)
print(f"\n📊 报表已生成: {OUTPUT_IMG}")
def main():
# 1. 加载数据
if not os.path.exists(DATASET_PATH):
print("Dataset not found!")
return
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
dataset = json.load(f)
# 2. 定义实验配置 (Direct Call)
configs = [
{
"name": "1. BM25 (Keyword)",
"retriever": lambda q: data_service.search(q, [0.0]*1536, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)['results'],
"rerank": False
},
{
"name": "2. Vector Only",
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)['results'],
"rerank": False
},
{
"name": "3. Hybrid (Base)",
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)['results'],
"rerank": False
},
{
"name": "4. Hybrid + Rerank",
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)['results'], # 召回 Top 30
"rerank": True
}
]
evaluator = RAGEvaluator()
all_results = []
print("🚀 开始自动化评测 (Visualization Mode)...")
# 3. 循环执行 (双重循环:配置 -> 数据)
# 使用 tqdm 显示总进度
total_steps = len(configs) * len(dataset)
pbar = tqdm(total=total_steps, desc="Running Experiments")
for cfg in configs:
for item in dataset:
pbar.set_description(f"Testing {cfg['name']}")
start_time = time.time()
# A. 检索
docs = cfg['retriever'](item['query'])
# B. Rerank (如果在配置里开启)
if cfg['rerank']:
docs = llm_service.rerank(item['query'], docs, top_n=5)
# C. 生成
context = "\n".join([d['content'] for d in docs]) if docs else ""
if not context:
answer = "未找到相关信息"
else:
prompt = f"Context:\n{context}\n\nQuestion: {item['query']}"
answer = llm_service.chat(prompt) # 调用生成接口
latency = time.time() - start_time
# D. 评测指标
recall = evaluator.calculate_recall(docs, item.get('keywords', []))
score = evaluator.judge_answer(item['query'], answer, item['ground_truth'])
# E. 收集数据
all_results.append({
"config": cfg['name'],
"id": item['id'],
"type": item['type'], # 类别字段
"recall": recall,
"score": score,
"latency": latency
})
pbar.update(1)
pbar.close()
# 4. 数据处理与绘图
df = pd.DataFrame(all_results)
viz = Visualizer()
viz.plot_dashboard(df)
if __name__ == "__main__":
main()