168 lines
6.8 KiB
Python
168 lines
6.8 KiB
Python
import sys
|
||
import os
|
||
import time
|
||
import json
|
||
from collections import defaultdict
|
||
from tabulate import tabulate # 需要 pip install tabulate
|
||
|
||
# 路径 Hack: 确保能导入 backend 模块
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../../'))
|
||
|
||
from backend.services.data_service import data_service
|
||
from backend.services.llm_service import llm_service
|
||
from tests.rag_benchmark.evaluator import RAGEvaluator
|
||
|
||
# ================= 配置区 =================
|
||
# 请填入你数据库中真实存在的、包含爬取数据的 task_id
|
||
TEST_TASK_ID = 19
|
||
# ========================================
|
||
|
||
def run_experiment(config_name, dataset, retrieve_func, generate_func):
|
||
"""
|
||
运行单组实验并收集详细指标
|
||
"""
|
||
print(f"\n🚀 开始测试配置: [ {config_name} ]")
|
||
evaluator = RAGEvaluator()
|
||
results = []
|
||
|
||
total_latency = 0
|
||
|
||
# 用于分类统计 (例如 core_function 得分多少, negative_test 得分多少)
|
||
category_stats = defaultdict(lambda: {"count": 0, "score_sum": 0, "recall_sum": 0})
|
||
|
||
for item in dataset:
|
||
start_time = time.time()
|
||
|
||
# 1. 检索 (Retrieval)
|
||
retrieved_docs = retrieve_func(item['query'])
|
||
|
||
# 2. 生成 (Generation)
|
||
# 构造 Context,如果没有检索到内容,给一个提示
|
||
if retrieved_docs:
|
||
context_str = "\n---\n".join([f"Source: {d.get('source_url', 'unknown')}\nContent: {d['content']}" for d in retrieved_docs])
|
||
else:
|
||
context_str = "没有检索到任何相关文档。"
|
||
|
||
answer = generate_func(item['query'], context_str)
|
||
|
||
latency = time.time() - start_time
|
||
total_latency += latency
|
||
|
||
# 3. 评测 (Evaluation)
|
||
# 计算关键词召回率
|
||
retrieval_metric = evaluator.calculate_retrieval_metrics(retrieved_docs, item)
|
||
# 计算LLM回答质量
|
||
gen_eval = evaluator.evaluate_generation_quality(
|
||
item['query'], answer, item['ground_truth'], item['type']
|
||
)
|
||
|
||
# 记录单条结果
|
||
row = {
|
||
"id": item['id'],
|
||
"type": item['type'],
|
||
"query": item['query'],
|
||
"recall": retrieval_metric['keyword_recall'],
|
||
"score": gen_eval['score'],
|
||
"reason": gen_eval.get('reason', '')[:50] + "...", # 截断显示
|
||
"latency": latency
|
||
}
|
||
results.append(row)
|
||
|
||
# 累加分类统计
|
||
cat = item['type']
|
||
category_stats[cat]["count"] += 1
|
||
category_stats[cat]["score_sum"] += gen_eval['score']
|
||
category_stats[cat]["recall_sum"] += retrieval_metric['keyword_recall']
|
||
|
||
# 实时打印进度 (简洁版)
|
||
status_icon = "✅" if gen_eval['score'] >= 4 else "⚠️" if gen_eval['score'] >= 3 else "❌"
|
||
print(f" {status_icon} ID:{item['id']} [{item['type'][:10]}] Score:{gen_eval['score']} | Recall:{retrieval_metric['keyword_recall']:.1f}")
|
||
|
||
# --- 汇总本轮实验数据 ---
|
||
avg_score = sum(r['score'] for r in results) / len(results)
|
||
avg_recall = sum(r['recall'] for r in results) / len(results)
|
||
avg_latency = total_latency / len(results)
|
||
|
||
# 格式化分类报告
|
||
cat_report = []
|
||
for cat, data in category_stats.items():
|
||
cat_report.append(f"{cat}: {data['score_sum']/data['count']:.1f}分")
|
||
|
||
return {
|
||
"Config": config_name,
|
||
"Avg Score (1-5)": f"{avg_score:.2f}",
|
||
"Avg Recall": f"{avg_recall:.2%}",
|
||
"Avg Latency": f"{avg_latency:.3f}s",
|
||
"Weakest Category": min(category_stats, key=lambda k: category_stats[k]['score_sum']/category_stats[k]['count'])
|
||
}
|
||
|
||
def main():
|
||
# 1. 加载数据集
|
||
dataset_path = os.path.join(os.path.dirname(__file__), 'dataset.json')
|
||
if not os.path.exists(dataset_path):
|
||
print("Error: dataset.json not found.")
|
||
return
|
||
|
||
with open(dataset_path, 'r', encoding='utf-8') as f:
|
||
dataset = json.load(f)
|
||
|
||
print(f"载入 {len(dataset)} 条测试用例,准备开始横向评测...")
|
||
|
||
# 2. 定义实验变量 (检索函数 + 生成函数)
|
||
|
||
# === Exp A: 纯关键词 (模拟传统搜索) ===
|
||
def retrieve_keyword(query):
|
||
# vector_weight=0 强制使用 SQL TSVector
|
||
# 注意: 需要传递一个假向量给接口占位
|
||
dummy_vec = [0.0] * 1536
|
||
res = data_service.search(query, dummy_vec, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5)
|
||
return res['results']
|
||
|
||
# === Exp B: 纯向量 (语义检索) ===
|
||
def retrieve_vector(query):
|
||
vec = llm_service.get_embedding(query)
|
||
# vector_weight=1 忽略关键词匹配
|
||
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5)
|
||
return res['results']
|
||
|
||
# === Exp C: 混合检索 (Hybrid) ===
|
||
def retrieve_hybrid(query):
|
||
vec = llm_service.get_embedding(query)
|
||
# 默认 0.7 向量 + 0.3 关键词
|
||
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5)
|
||
return res['results']
|
||
|
||
# === Exp D: 混合 + 重排序 (Rerank) ===
|
||
def retrieve_rerank(query):
|
||
vec = llm_service.get_embedding(query)
|
||
# 1. 扩大召回 (Top 30)
|
||
res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30)
|
||
initial_docs = res['results']
|
||
# 2. 精排 (Top 5)
|
||
reranked = llm_service.rerank(query, initial_docs, top_n=5)
|
||
return reranked
|
||
|
||
# === 通用生成函数 ===
|
||
def generate_answer(query, context):
|
||
system_prompt = "你是一个智能助手。请严格根据提供的上下文回答用户问题。如果上下文中没有答案,请直接说'未找到相关信息'。"
|
||
prompt = f"参考上下文:\n{context}\n\n用户问题:{query}"
|
||
return llm_service.chat(prompt, system_prompt=system_prompt)
|
||
|
||
# 3. 运行所有实验
|
||
final_report = []
|
||
|
||
final_report.append(run_experiment("1. Keyword Only (BM25)", dataset, retrieve_keyword, generate_answer))
|
||
final_report.append(run_experiment("2. Vector Only", dataset, retrieve_vector, generate_answer))
|
||
final_report.append(run_experiment("3. Hybrid (Base)", dataset, retrieve_hybrid, generate_answer))
|
||
final_report.append(run_experiment("4. Hybrid + Rerank", dataset, retrieve_rerank, generate_answer))
|
||
|
||
# 4. 输出最终报表
|
||
print("\n\n📊 ================= 最终横向对比报告 (Final Report) ================= 📊")
|
||
print(tabulate(final_report, headers="keys", tablefmt="github"))
|
||
print("\n💡 解读建议:")
|
||
print("1. 如果 'Avg Recall' 低,说明切片(Chunking)或检索算法找不到资料。")
|
||
print("2. 如果 Recall 高但 'Avg Score' 低,说明 LLM 产生了幻觉或 Prompt 没写好。")
|
||
print("3. 'Weakest Category' 帮你发现短板(如多语言或负向测试)。")
|
||
|
||
if __name__ == "__main__":
|
||
main() |