208 lines
7.8 KiB
Python
208 lines
7.8 KiB
Python
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
import time
|
|||
|
|
import json
|
|||
|
|
import pandas as pd
|
|||
|
|
import seaborn as sns
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
from tqdm import tqdm
|
|||
|
|
|
|||
|
|
# 路径 Hack: 确保能导入 backend
|
|||
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|||
|
|
project_root = os.path.abspath(os.path.join(current_dir, '../../'))
|
|||
|
|
if project_root not in sys.path:
|
|||
|
|
sys.path.insert(0, project_root)
|
|||
|
|
|
|||
|
|
# 直接导入服务类 (Direct Call)
|
|||
|
|
from backend.services.data_service import data_service
|
|||
|
|
from backend.services.llm_service import llm_service
|
|||
|
|
|
|||
|
|
# ================= 配置区 =================
|
|||
|
|
TEST_TASK_ID = 19 # 请修改为真实的 Task ID
|
|||
|
|
DATASET_PATH = os.path.join(current_dir, 'dataset.json')
|
|||
|
|
OUTPUT_IMG = os.path.join(current_dir, 'benchmark_report.png')
|
|||
|
|
# ========================================
|
|||
|
|
|
|||
|
|
class RAGEvaluator:
|
|||
|
|
"""评测工具类:负责计算召回率和调用LLM打分"""
|
|||
|
|
def __init__(self):
|
|||
|
|
self.llm = llm_service
|
|||
|
|
|
|||
|
|
def calculate_recall(self, retrieved_docs, keywords):
|
|||
|
|
"""计算关键词召回率"""
|
|||
|
|
if not keywords: return 1.0 # 无关键词要求的题目默认为满分
|
|||
|
|
|
|||
|
|
full_text = " ".join([d['content'] for d in retrieved_docs]).lower()
|
|||
|
|
hit_count = sum(1 for k in keywords if k.lower() in full_text)
|
|||
|
|
return hit_count / len(keywords)
|
|||
|
|
|
|||
|
|
def judge_answer(self, query, answer, ground_truth):
|
|||
|
|
"""调用 LLM 给生成结果打分 (1-5)"""
|
|||
|
|
prompt = f"""
|
|||
|
|
作为 RAG 评测员,请对【系统回答】打分 (1-5)。
|
|||
|
|
用户问题: {query}
|
|||
|
|
标准答案: {ground_truth}
|
|||
|
|
系统回答: {answer}
|
|||
|
|
|
|||
|
|
标准:
|
|||
|
|
5: 含义完全一致,无幻觉。
|
|||
|
|
3: 包含核心信息,但有遗漏。
|
|||
|
|
1: 错误或严重幻觉。
|
|||
|
|
|
|||
|
|
只返回数字 (1, 2, 3, 4, 5)。
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
# 这里调用你在 llm_service 中新增的 chat 方法
|
|||
|
|
res = self.llm.chat(prompt)
|
|||
|
|
# 简单的清洗逻辑,提取数字
|
|||
|
|
score = int(''.join(filter(str.isdigit, res)))
|
|||
|
|
return min(max(score, 1), 5) # 限制在 1-5
|
|||
|
|
except:
|
|||
|
|
return 1 # 失败保底 1 分
|
|||
|
|
|
|||
|
|
class Visualizer:
|
|||
|
|
"""绘图工具类"""
|
|||
|
|
def plot_dashboard(self, df):
|
|||
|
|
# 设置风格
|
|||
|
|
sns.set_theme(style="whitegrid")
|
|||
|
|
# 解决中文显示问题 (如果环境支持 SimHei 则用中文,否则用英文)
|
|||
|
|
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial', 'DejaVu Sans']
|
|||
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|||
|
|
|
|||
|
|
fig = plt.figure(figsize=(18, 10))
|
|||
|
|
gs = fig.add_gridspec(2, 2)
|
|||
|
|
|
|||
|
|
# Chart 1: 总体指标对比 (Bar Chart)
|
|||
|
|
ax1 = fig.add_subplot(gs[0, 0])
|
|||
|
|
# 将数据变形为长格式以便绘图
|
|||
|
|
df_summary = df.groupby('config')[['score', 'recall']].mean().reset_index()
|
|||
|
|
df_melt = df_summary.melt(id_vars='config', var_name='Metric', value_name='Value')
|
|||
|
|
# 将 Recall 归一化到 0-5 方便同图显示,或者分开轴。这里简单处理:Recall * 5
|
|||
|
|
df_melt.loc[df_melt['Metric'] == 'recall', 'Value'] *= 5
|
|||
|
|
|
|||
|
|
sns.barplot(data=df_melt, x='config', y='Value', hue='Metric', ax=ax1, palette="viridis")
|
|||
|
|
ax1.set_title('Overall Performance (Score & Recall)', fontsize=14, fontweight='bold')
|
|||
|
|
ax1.set_ylabel('Score (1-5) / Recall (x5)')
|
|||
|
|
ax1.set_ylim(0, 5.5)
|
|||
|
|
for container in ax1.containers:
|
|||
|
|
ax1.bar_label(container, fmt='%.1f')
|
|||
|
|
|
|||
|
|
# Chart 2: 延迟 vs 质量 (Scatter Plot)
|
|||
|
|
ax2 = fig.add_subplot(gs[0, 1])
|
|||
|
|
df_latency = df.groupby('config')[['score', 'latency']].mean().reset_index()
|
|||
|
|
sns.scatterplot(data=df_latency, x='latency', y='score', hue='config', s=200, ax=ax2, palette="deep")
|
|||
|
|
|
|||
|
|
# 添加标签
|
|||
|
|
for i in range(df_latency.shape[0]):
|
|||
|
|
ax2.text(
|
|||
|
|
df_latency.latency[i]+0.05,
|
|||
|
|
df_latency.score[i],
|
|||
|
|
df_latency.config[i],
|
|||
|
|
fontsize=10
|
|||
|
|
)
|
|||
|
|
ax2.set_title('Trade-off: Latency vs Quality', fontsize=14, fontweight='bold')
|
|||
|
|
ax2.set_xlabel('Avg Latency (seconds)')
|
|||
|
|
ax2.set_ylabel('Avg Quality Score (1-5)')
|
|||
|
|
|
|||
|
|
# Chart 3: 类别热力图 (Heatmap) - 你的 Weakest Category 可视化
|
|||
|
|
ax3 = fig.add_subplot(gs[1, :]) # 占用下方整行
|
|||
|
|
pivot_data = df.pivot_table(index='config', columns='type', values='score', aggfunc='mean')
|
|||
|
|
sns.heatmap(pivot_data, annot=True, cmap="RdYlGn", center=3, fmt=".1f", ax=ax3, linewidths=.5)
|
|||
|
|
ax3.set_title('Category Breakdown (Find the Weakest Link)', fontsize=14, fontweight='bold')
|
|||
|
|
ax3.set_xlabel('')
|
|||
|
|
ax3.set_ylabel('')
|
|||
|
|
|
|||
|
|
plt.tight_layout()
|
|||
|
|
plt.savefig(OUTPUT_IMG)
|
|||
|
|
print(f"\n📊 报表已生成: {OUTPUT_IMG}")
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# 1. 加载数据
|
|||
|
|
if not os.path.exists(DATASET_PATH):
|
|||
|
|
print("Dataset not found!")
|
|||
|
|
return
|
|||
|
|
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
|
|||
|
|
dataset = json.load(f)
|
|||
|
|
|
|||
|
|
# 2. 定义实验配置 (Direct Call)
|
|||
|
|
configs = [
|
|||
|
|
{
|
|||
|
|
"name": "1. BM25 (Keyword)",
|
|||
|
|
"retriever": lambda q: data_service.search(q, [0.0]*1536, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)['results'],
|
|||
|
|
"rerank": False
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "2. Vector Only",
|
|||
|
|
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)['results'],
|
|||
|
|
"rerank": False
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "3. Hybrid (Base)",
|
|||
|
|
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)['results'],
|
|||
|
|
"rerank": False
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "4. Hybrid + Rerank",
|
|||
|
|
"retriever": lambda q: data_service.search(q, llm_service.get_embedding(q), task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)['results'], # 召回 Top 30
|
|||
|
|
"rerank": True
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
evaluator = RAGEvaluator()
|
|||
|
|
all_results = []
|
|||
|
|
|
|||
|
|
print("🚀 开始自动化评测 (Visualization Mode)...")
|
|||
|
|
|
|||
|
|
# 3. 循环执行 (双重循环:配置 -> 数据)
|
|||
|
|
# 使用 tqdm 显示总进度
|
|||
|
|
total_steps = len(configs) * len(dataset)
|
|||
|
|
pbar = tqdm(total=total_steps, desc="Running Experiments")
|
|||
|
|
|
|||
|
|
for cfg in configs:
|
|||
|
|
for item in dataset:
|
|||
|
|
pbar.set_description(f"Testing {cfg['name']}")
|
|||
|
|
|
|||
|
|
start_time = time.time()
|
|||
|
|
|
|||
|
|
# A. 检索
|
|||
|
|
docs = cfg['retriever'](item['query'])
|
|||
|
|
|
|||
|
|
# B. Rerank (如果在配置里开启)
|
|||
|
|
if cfg['rerank']:
|
|||
|
|
docs = llm_service.rerank(item['query'], docs, top_n=5)
|
|||
|
|
|
|||
|
|
# C. 生成
|
|||
|
|
context = "\n".join([d['content'] for d in docs]) if docs else ""
|
|||
|
|
if not context:
|
|||
|
|
answer = "未找到相关信息"
|
|||
|
|
else:
|
|||
|
|
prompt = f"Context:\n{context}\n\nQuestion: {item['query']}"
|
|||
|
|
answer = llm_service.chat(prompt) # 调用生成接口
|
|||
|
|
|
|||
|
|
latency = time.time() - start_time
|
|||
|
|
|
|||
|
|
# D. 评测指标
|
|||
|
|
recall = evaluator.calculate_recall(docs, item.get('keywords', []))
|
|||
|
|
score = evaluator.judge_answer(item['query'], answer, item['ground_truth'])
|
|||
|
|
|
|||
|
|
# E. 收集数据
|
|||
|
|
all_results.append({
|
|||
|
|
"config": cfg['name'],
|
|||
|
|
"id": item['id'],
|
|||
|
|
"type": item['type'], # 类别字段
|
|||
|
|
"recall": recall,
|
|||
|
|
"score": score,
|
|||
|
|
"latency": latency
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
pbar.update(1)
|
|||
|
|
|
|||
|
|
pbar.close()
|
|||
|
|
|
|||
|
|
# 4. 数据处理与绘图
|
|||
|
|
df = pd.DataFrame(all_results)
|
|||
|
|
viz = Visualizer()
|
|||
|
|
viz.plot_dashboard(df)
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|