Files
wiki_crawler/tests/rag_benchmark/run_benchmark.py
2026-01-27 01:41:45 +08:00

168 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import sys
import os
import time
import json
from collections import defaultdict
from tabulate import tabulate # 需要 pip install tabulate
# 路径 Hack: 确保能导入 backend 模块
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
from backend.services.data_service import data_service
from backend.services.llm_service import llm_service
from tests.rag_benchmark.evaluator import RAGEvaluator
# ================= 配置区 =================
# 请填入你数据库中真实存在的、包含爬取数据的 task_id
TEST_TASK_ID = 19
# ========================================
def run_experiment(config_name, dataset, retrieve_func, generate_func):
"""
运行单组实验并收集详细指标
"""
print(f"\n🚀 开始测试配置: [ {config_name} ]")
evaluator = RAGEvaluator()
results = []
total_latency = 0
# 用于分类统计 (例如 core_function 得分多少, negative_test 得分多少)
category_stats = defaultdict(lambda: {"count": 0, "score_sum": 0, "recall_sum": 0})
for item in dataset:
start_time = time.time()
# 1. 检索 (Retrieval)
retrieved_docs = retrieve_func(item['query'])
# 2. 生成 (Generation)
# 构造 Context如果没有检索到内容给一个提示
if retrieved_docs:
context_str = "\n---\n".join([f"Source: {d.get('source_url', 'unknown')}\nContent: {d['content']}" for d in retrieved_docs])
else:
context_str = "没有检索到任何相关文档。"
answer = generate_func(item['query'], context_str)
latency = time.time() - start_time
total_latency += latency
# 3. 评测 (Evaluation)
# 计算关键词召回率
retrieval_metric = evaluator.calculate_retrieval_metrics(retrieved_docs, item)
# 计算LLM回答质量
gen_eval = evaluator.evaluate_generation_quality(
item['query'], answer, item['ground_truth'], item['type']
)
# 记录单条结果
row = {
"id": item['id'],
"type": item['type'],
"query": item['query'],
"recall": retrieval_metric['keyword_recall'],
"score": gen_eval['score'],
"reason": gen_eval.get('reason', '')[:50] + "...", # 截断显示
"latency": latency
}
results.append(row)
# 累加分类统计
cat = item['type']
category_stats[cat]["count"] += 1
category_stats[cat]["score_sum"] += gen_eval['score']
category_stats[cat]["recall_sum"] += retrieval_metric['keyword_recall']
# 实时打印进度 (简洁版)
status_icon = "" if gen_eval['score'] >= 4 else "⚠️" if gen_eval['score'] >= 3 else ""
print(f" {status_icon} ID:{item['id']} [{item['type'][:10]}] Score:{gen_eval['score']} | Recall:{retrieval_metric['keyword_recall']:.1f}")
# --- 汇总本轮实验数据 ---
avg_score = sum(r['score'] for r in results) / len(results)
avg_recall = sum(r['recall'] for r in results) / len(results)
avg_latency = total_latency / len(results)
# 格式化分类报告
cat_report = []
for cat, data in category_stats.items():
cat_report.append(f"{cat}: {data['score_sum']/data['count']:.1f}")
return {
"Config": config_name,
"Avg Score (1-5)": f"{avg_score:.2f}",
"Avg Recall": f"{avg_recall:.2%}",
"Avg Latency": f"{avg_latency:.3f}s",
"Weakest Category": min(category_stats, key=lambda k: category_stats[k]['score_sum']/category_stats[k]['count'])
}
def main():
# 1. 加载数据集
dataset_path = os.path.join(os.path.dirname(__file__), 'dataset.json')
if not os.path.exists(dataset_path):
print("Error: dataset.json not found.")
return
with open(dataset_path, 'r', encoding='utf-8') as f:
dataset = json.load(f)
print(f"载入 {len(dataset)} 条测试用例,准备开始横向评测...")
# 2. 定义实验变量 (检索函数 + 生成函数)
# === Exp A: 纯关键词 (模拟传统搜索) ===
def retrieve_keyword(query):
# vector_weight=0 强制使用 SQL TSVector
# 注意: 需要传递一个假向量给接口占位
dummy_vec = [0.0] * 1536
res = data_service.search(query, dummy_vec, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)
return res['results']
# === Exp B: 纯向量 (语义检索) ===
def retrieve_vector(query):
vec = llm_service.get_embedding(query)
# vector_weight=1 忽略关键词匹配
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)
return res['results']
# === Exp C: 混合检索 (Hybrid) ===
def retrieve_hybrid(query):
vec = llm_service.get_embedding(query)
# 默认 0.7 向量 + 0.3 关键词
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)
return res['results']
# === Exp D: 混合 + 重排序 (Rerank) ===
def retrieve_rerank(query):
vec = llm_service.get_embedding(query)
# 1. 扩大召回 (Top 30)
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)
initial_docs = res['results']
# 2. 精排 (Top 5)
reranked = llm_service.rerank(query, initial_docs, top_n=5)
return reranked
# === 通用生成函数 ===
def generate_answer(query, context):
system_prompt = "你是一个智能助手。请严格根据提供的上下文回答用户问题。如果上下文中没有答案,请直接说'未找到相关信息'"
prompt = f"参考上下文:\n{context}\n\n用户问题:{query}"
return llm_service.chat(prompt, system_prompt=system_prompt)
# 3. 运行所有实验
final_report = []
final_report.append(run_experiment("1. Keyword Only (BM25)", dataset, retrieve_keyword, generate_answer))
final_report.append(run_experiment("2. Vector Only", dataset, retrieve_vector, generate_answer))
final_report.append(run_experiment("3. Hybrid (Base)", dataset, retrieve_hybrid, generate_answer))
final_report.append(run_experiment("4. Hybrid + Rerank", dataset, retrieve_rerank, generate_answer))
# 4. 输出最终报表
print("\n\n📊 ================= 最终横向对比报告 (Final Report) ================= 📊")
print(tabulate(final_report, headers="keys", tablefmt="github"))
print("\n💡 解读建议:")
print("1. 如果 'Avg Recall' 低,说明切片(Chunking)或检索算法找不到资料。")
print("2. 如果 Recall 高但 'Avg Score' 低,说明 LLM 产生了幻觉或 Prompt 没写好。")
print("3. 'Weakest Category' 帮你发现短板(如多语言或负向测试)。")
if __name__ == "__main__":
main()