import sys import os import time import json import pandas as pd import seaborn as sns import matplotlib.pyplot as plt from tqdm import tqdm # 路径 Hack: 确保能导入 backend current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(current_dir, '../../')) if project_root not in sys.path: sys.path.insert(0, project_root) # 直接导入服务类 (Direct Call) from backend.services.data_service import data_service from backend.services.llm_service import llm_service # ================= 配置区 ================= TEST_TASK_ID = 19 # 请修改为真实的 Task ID DATASET_PATH = os.path.join(current_dir, 'dataset.json') OUTPUT_IMG = os.path.join(current_dir, 'benchmark_report.png') # ======================================== class RAGEvaluator: """评测工具类:负责计算召回率和调用LLM打分""" def __init__(self): self.llm = llm_service def calculate_recall(self, retrieved_docs, keywords): """计算关键词召回率""" if not keywords: return 1.0 # 无关键词要求的题目默认为满分 full_text = " ".join([d['content'] for d in retrieved_docs]).lower() hit_count = sum(1 for k in keywords if k.lower() in full_text) return hit_count / len(keywords) def judge_answer(self, query, answer, ground_truth): """调用 LLM 给生成结果打分 (1-5)""" prompt = f""" 作为 RAG 评测员,请对【系统回答】打分 (1-5)。 用户问题: {query} 标准答案: {ground_truth} 系统回答: {answer} 标准: 5: 含义完全一致,无幻觉。 3: 包含核心信息,但有遗漏。 1: 错误或严重幻觉。 只返回数字 (1, 2, 3, 4, 5)。 """ try: # 这里调用你在 llm_service 中新增的 chat 方法 res = self.llm.chat(prompt) # 简单的清洗逻辑,提取数字 score = int(''.join(filter(str.isdigit, res))) return min(max(score, 1), 5) # 限制在 1-5 except: return 1 # 失败保底 1 分 class Visualizer: """绘图工具类""" def plot_dashboard(self, df): # 设置风格 sns.set_theme(style="whitegrid") # 解决中文显示问题 (如果环境支持 SimHei 则用中文,否则用英文) plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial', 'DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False fig = plt.figure(figsize=(18, 10)) gs = fig.add_gridspec(2, 2) # Chart 1: 总体指标对比 (Bar Chart) ax1 = fig.add_subplot(gs[0, 0]) # 将数据变形为长格式以便绘图 df_summary = df.groupby('config')[['score', 'recall']].mean().reset_index() df_melt = df_summary.melt(id_vars='config', var_name='Metric', value_name='Value') # 将 Recall 归一化到 0-5 方便同图显示,或者分开轴。这里简单处理:Recall * 5 df_melt.loc[df_melt['Metric'] == 'recall', 'Value'] *= 5 sns.barplot(data=df_melt, x='config', y='Value', hue='Metric', ax=ax1, palette="viridis") ax1.set_title('Overall Performance (Score & Recall)', fontsize=14, fontweight='bold') ax1.set_ylabel('Score (1-5) / Recall (x5)') ax1.set_ylim(0, 5.5) for container in ax1.containers: ax1.bar_label(container, fmt='%.1f') # Chart 2: 延迟 vs 质量 (Scatter Plot) ax2 = fig.add_subplot(gs[0, 1]) df_latency = df.groupby('config')[['score', 'latency']].mean().reset_index() sns.scatterplot(data=df_latency, x='latency', y='score', hue='config', s=200, ax=ax2, palette="deep") # 添加标签 for i in range(df_latency.shape[0]): ax2.text( df_latency.latency[i]+0.05, df_latency.score[i], df_latency.config[i], fontsize=10 ) ax2.set_title('Trade-off: Latency vs Quality', fontsize=14, fontweight='bold') ax2.set_xlabel('Avg Latency (seconds)') ax2.set_ylabel('Avg Quality Score (1-5)') # Chart 3: 类别热力图 (Heatmap) - 你的 Weakest Category 可视化 ax3 = fig.add_subplot(gs[1, :]) # 占用下方整行 pivot_data = df.pivot_table(index='config', columns='type', values='score', aggfunc='mean') sns.heatmap(pivot_data, annot=True, cmap="RdYlGn", center=3, fmt=".1f", ax=ax3, linewidths=.5) ax3.set_title('Category Breakdown (Find the Weakest Link)', fontsize=14, fontweight='bold') ax3.set_xlabel('') ax3.set_ylabel('') plt.tight_layout() plt.savefig(OUTPUT_IMG) print(f"\n📊 报表已生成: {OUTPUT_IMG}") def main(): # 1. 加载数据 if not os.path.exists(DATASET_PATH): print("Dataset not found!") return with open(DATASET_PATH, 'r', encoding='utf-8') as f: dataset = json.load(f) # 2. 定义实验配置 (Direct Call) configs = [ { "name": "1. BM25 (Keyword)", "retriever": lambda q: data_service.search(q, [0.0]*1536, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)['results'], "rerank": False }, { "name": "2. Vector Only", "retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)['results'], "rerank": False }, { "name": "3. Hybrid (Base)", "retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)['results'], "rerank": False }, { "name": "4. Hybrid + Rerank", "retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)['results'], # 召回 Top 30 "rerank": True } ] evaluator = RAGEvaluator() all_results = [] print("🚀 开始自动化评测 (Visualization Mode)...") # 3. 循环执行 (双重循环:配置 -> 数据) # 使用 tqdm 显示总进度 total_steps = len(configs) * len(dataset) pbar = tqdm(total=total_steps, desc="Running Experiments") for cfg in configs: for item in dataset: pbar.set_description(f"Testing {cfg['name']}") start_time = time.time() # A. 检索 docs = cfg['retriever'](item['query']) # B. Rerank (如果在配置里开启) if cfg['rerank']: docs = llm_service.rerank(item['query'], docs, top_n=5) # C. 生成 context = "\n".join([d['content'] for d in docs]) if docs else "" if not context: answer = "未找到相关信息" else: prompt = f"Context:\n{context}\n\nQuestion: {item['query']}" answer = llm_service.chat(prompt) # 调用生成接口 latency = time.time() - start_time # D. 评测指标 recall = evaluator.calculate_recall(docs, item.get('keywords', [])) score = evaluator.judge_answer(item['query'], answer, item['ground_truth']) # E. 收集数据 all_results.append({ "config": cfg['name'], "id": item['id'], "type": item['type'], # 类别字段 "recall": recall, "score": score, "latency": latency }) pbar.update(1) pbar.close() # 4. 数据处理与绘图 df = pd.DataFrame(all_results) viz = Visualizer() viz.plot_dashboard(df) if __name__ == "__main__": main()