import sys import os import time import json from collections import defaultdict from tabulate import tabulate # 需要 pip install tabulate # 路径 Hack: 确保能导入 backend 模块 sys.path.append(os.path.join(os.path.dirname(__file__), '../../../')) from backend.services.data_service import data_service from backend.services.llm_service import llm_service from tests.rag_benchmark.evaluator import RAGEvaluator # ================= 配置区 ================= # 请填入你数据库中真实存在的、包含爬取数据的 task_id TEST_TASK_ID = 19 # ======================================== def run_experiment(config_name, dataset, retrieve_func, generate_func): """ 运行单组实验并收集详细指标 """ print(f"\n🚀 开始测试配置: [ {config_name} ]") evaluator = RAGEvaluator() results = [] total_latency = 0 # 用于分类统计 (例如 core_function 得分多少, negative_test 得分多少) category_stats = defaultdict(lambda: {"count": 0, "score_sum": 0, "recall_sum": 0}) for item in dataset: start_time = time.time() # 1. 检索 (Retrieval) retrieved_docs = retrieve_func(item['query']) # 2. 生成 (Generation) # 构造 Context,如果没有检索到内容,给一个提示 if retrieved_docs: context_str = "\n---\n".join([f"Source: {d.get('source_url', 'unknown')}\nContent: {d['content']}" for d in retrieved_docs]) else: context_str = "没有检索到任何相关文档。" answer = generate_func(item['query'], context_str) latency = time.time() - start_time total_latency += latency # 3. 评测 (Evaluation) # 计算关键词召回率 retrieval_metric = evaluator.calculate_retrieval_metrics(retrieved_docs, item) # 计算LLM回答质量 gen_eval = evaluator.evaluate_generation_quality( item['query'], answer, item['ground_truth'], item['type'] ) # 记录单条结果 row = { "id": item['id'], "type": item['type'], "query": item['query'], "recall": retrieval_metric['keyword_recall'], "score": gen_eval['score'], "reason": gen_eval.get('reason', '')[:50] + "...", # 截断显示 "latency": latency } results.append(row) # 累加分类统计 cat = item['type'] category_stats[cat]["count"] += 1 category_stats[cat]["score_sum"] += gen_eval['score'] category_stats[cat]["recall_sum"] += retrieval_metric['keyword_recall'] # 实时打印进度 (简洁版) status_icon = "✅" if gen_eval['score'] >= 4 else "⚠️" if gen_eval['score'] >= 3 else "❌" print(f" {status_icon} ID:{item['id']} [{item['type'][:10]}] Score:{gen_eval['score']} | Recall:{retrieval_metric['keyword_recall']:.1f}") # --- 汇总本轮实验数据 --- avg_score = sum(r['score'] for r in results) / len(results) avg_recall = sum(r['recall'] for r in results) / len(results) avg_latency = total_latency / len(results) # 格式化分类报告 cat_report = [] for cat, data in category_stats.items(): cat_report.append(f"{cat}: {data['score_sum']/data['count']:.1f}分") return { "Config": config_name, "Avg Score (1-5)": f"{avg_score:.2f}", "Avg Recall": f"{avg_recall:.2%}", "Avg Latency": f"{avg_latency:.3f}s", "Weakest Category": min(category_stats, key=lambda k: category_stats[k]['score_sum']/category_stats[k]['count']) } def main(): # 1. 加载数据集 dataset_path = os.path.join(os.path.dirname(__file__), 'dataset.json') if not os.path.exists(dataset_path): print("Error: dataset.json not found.") return with open(dataset_path, 'r', encoding='utf-8') as f: dataset = json.load(f) print(f"载入 {len(dataset)} 条测试用例,准备开始横向评测...") # 2. 定义实验变量 (检索函数 + 生成函数) # === Exp A: 纯关键词 (模拟传统搜索) === def retrieve_keyword(query): # vector_weight=0 强制使用 SQL TSVector # 注意: 需要传递一个假向量给接口占位 dummy_vec = [0.0] * 1536 res = data_service.search(query, dummy_vec, task_id=TEST_TASK_ID, vector_weight=0.0, candidates_num=5) return res['results'] # === Exp B: 纯向量 (语义检索) === def retrieve_vector(query): vec = llm_service.get_embedding(query) # vector_weight=1 忽略关键词匹配 res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=1.0, candidates_num=5) return res['results'] # === Exp C: 混合检索 (Hybrid) === def retrieve_hybrid(query): vec = llm_service.get_embedding(query) # 默认 0.7 向量 + 0.3 关键词 res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=5) return res['results'] # === Exp D: 混合 + 重排序 (Rerank) === def retrieve_rerank(query): vec = llm_service.get_embedding(query) # 1. 扩大召回 (Top 30) res = data_service.search(query, vec, task_id=TEST_TASK_ID, vector_weight=0.7, candidates_num=30) initial_docs = res['results'] # 2. 精排 (Top 5) reranked = llm_service.rerank(query, initial_docs, top_n=5) return reranked # === 通用生成函数 === def generate_answer(query, context): system_prompt = "你是一个智能助手。请严格根据提供的上下文回答用户问题。如果上下文中没有答案,请直接说'未找到相关信息'。" prompt = f"参考上下文:\n{context}\n\n用户问题:{query}" return llm_service.chat(prompt, system_prompt=system_prompt) # 3. 运行所有实验 final_report = [] final_report.append(run_experiment("1. Keyword Only (BM25)", dataset, retrieve_keyword, generate_answer)) final_report.append(run_experiment("2. Vector Only", dataset, retrieve_vector, generate_answer)) final_report.append(run_experiment("3. Hybrid (Base)", dataset, retrieve_hybrid, generate_answer)) final_report.append(run_experiment("4. Hybrid + Rerank", dataset, retrieve_rerank, generate_answer)) # 4. 输出最终报表 print("\n\n📊 ================= 最终横向对比报告 (Final Report) ================= 📊") print(tabulate(final_report, headers="keys", tablefmt="github")) print("\n💡 解读建议:") print("1. 如果 'Avg Recall' 低,说明切片(Chunking)或检索算法找不到资料。") print("2. 如果 Recall 高但 'Avg Score' 低,说明 LLM 产生了幻觉或 Prompt 没写好。") print("3. 'Weakest Category' 帮你发现短板(如多语言或负向测试)。") if __name__ == "__main__": main()