From e5ac2dde03f6a502f0b54e66ee82a52b7ad44513 Mon Sep 17 00:00:00 2001 From: QingGang Date: Tue, 13 Jan 2026 10:37:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9ERAG=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/services/crawler_service.py | 36 +++++- backend/services/data_service.py | 24 ++-- backend/services/llm_service.py | 69 +++++++++- scripts/evaluate_rag.py | 192 ++++++++++++++++++++++++++++ scripts/test_dataset.json | 86 +++++++++++++ 5 files changed, 386 insertions(+), 21 deletions(-) create mode 100644 scripts/evaluate_rag.py create mode 100644 scripts/test_dataset.json diff --git a/backend/services/crawler_service.py b/backend/services/crawler_service.py index 3eed083..70172af 100644 --- a/backend/services/crawler_service.py +++ b/backend/services/crawler_service.py @@ -142,9 +142,41 @@ class CrawlerService: return {"msg": "Batch processed", "count": processed} - def search(self, query: str, task_id, limit: int): + def search(self, query: str, task_id, return_num: int): + """ + 全链路搜索:向量生成 -> 混合检索(粗排) -> 重排序(精排) + """ + # 1. 生成查询向量 vector = llm_service.get_embedding(query) if not vector: return {"msg": "Embedding failed", "results": []} - return data_service.search(query_text=query, query_vector=vector, task_id=task_id, limit=limit) + # 2. 计算粗排召回数量 + # 逻辑:至少召回 50 个,如果用户要很多,则召回 10 倍 + coarse_limit = return_num * 10 if return_num * 10 > 50 else 50 + + # 3. 执行混合检索 (粗排) + coarse_results = data_service.search( + query_text=query, + query_vector=vector, + task_id=task_id, + candidates_num=coarse_limit # 使用计算出的粗排数量 + ) + + candidates = coarse_results.get('results', []) + + if not candidates: + return {"msg": "No documents found", "results": []} + + # 4. 执行重排序 (精排) + final_results = llm_service.rerank( + query=query, + documents=candidates, + top_n=return_num # 最终返回用户需要的数量 + ) + + return { + "results": final_results, + "msg": f"Reranked {len(final_results)} from {len(candidates)} candidates" + } + crawler_service = CrawlerService() \ No newline at end of file diff --git a/backend/services/data_service.py b/backend/services/data_service.py index 1e75700..cbaac1a 100644 --- a/backend/services/data_service.py +++ b/backend/services/data_service.py @@ -91,24 +91,20 @@ class DataService: return {"msg": f"Saved {count} chunks", "count": count} - def search(self, query_text: str, query_vector: list, task_id = None, limit: int = 5): + def search(self, query_text: str, query_vector: list, task_id=None, candidates_num: int = 5): """ Phase 2: 混合检索 (Hybrid Search) - 综合 向量相似度 (Semantic) 和 关键词匹配度 (Keyword) """ + # 向量格式清洗 + if hasattr(query_vector, 'tolist'): query_vector = query_vector.tolist() + if query_vector and isinstance(query_vector, list) and len(query_vector) > 0: + if isinstance(query_vector[0], list): query_vector = query_vector[0] results = [] with self.db.engine.connect() as conn: - # 定义混合检索的 SQL 逻辑 - - # 使用 websearch_to_tsquery 处理用户输入 (支持 "firecrawl or dify" 这种语法) keyword_query = func.websearch_to_tsquery('english', query_text) - vector_score = (1 - self.db.chunks.c.embedding.cosine_distance(query_vector)) keyword_score = func.ts_rank(self.db.chunks.c.content_tsvector, keyword_query) - - # 综合打分列: 0.7 * Vector + 0.3 * Keyword - # coalesce 确保如果关键词得分为 NULL (无匹配),则视为 0 final_score = (vector_score * 0.7 + func.coalesce(keyword_score, 0) * 0.3).label("score") stmt = select( @@ -123,8 +119,8 @@ class DataService: if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id) - # 按综合分数倒序 - stmt = stmt.order_by(desc("score")).limit(limit) + # 使用 candidates_num 控制召回数量 + stmt = stmt.order_by(desc("score")).limit(candidates_num) try: rows = conn.execute(stmt).fetchall() @@ -141,23 +137,19 @@ class DataService: ] except Exception as e: print(f"[ERROR] Hybrid search failed: {e}") - return self._fallback_vector_search(query_vector, task_id, limit) + return self._fallback_vector_search(query_vector, task_id, candidates_num) return {"results": results, "msg": f"Hybrid found {len(results)}"} def _fallback_vector_search(self, vector, task_id, limit): - """降级兜底:纯向量搜索""" print("[WARN] Fallback to pure vector search") with self.db.engine.connect() as conn: stmt = select( self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title, self.db.chunks.c.content, self.db.chunks.c.meta_info ).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit) - if task_id: stmt = stmt.where(self.db.chunks.c.task_id == task_id) - rows = conn.execute(stmt).fetchall() return {"results": [{"content": r[3], "meta_info": r[4]} for r in rows], "msg": "Fallback found"} - data_service = DataService() \ No newline at end of file diff --git a/backend/services/llm_service.py b/backend/services/llm_service.py index f309e32..3279f72 100644 --- a/backend/services/llm_service.py +++ b/backend/services/llm_service.py @@ -5,16 +5,16 @@ from backend.core.config import settings class LLMService: """ LLM 服务封装层 - 负责与 DashScope 或其他模型供应商交互 + 负责与 DashScope (通义千问/GTE) 交互,包括 Embedding 和 Rerank """ def __init__(self): dashscope.api_key = settings.DASHSCOPE_API_KEY def get_embedding(self, text: str, dimension: int = 1536): - """生成文本向量""" + """生成文本向量 (Bi-Encoder)""" try: resp = dashscope.TextEmbedding.call( - model=dashscope.TextEmbedding.Models.text_embedding_v4, + model=dashscope.TextEmbedding.Models.text_embedding_v4, # 或 v4,视你的数据库维度而定 input=text, dimension=dimension ) @@ -27,4 +27,67 @@ class LLMService: print(f"[ERROR] Embedding Exception: {e}") return None + def rerank(self, query: str, documents: list, top_n: int = 5): + """ + 执行重排序 (Cross-Encoder) + + Args: + query: 用户问题 + documents: 粗排召回的切片列表 (List[dict]),必须包含 'content' 字段 + top_n: 最终返回多少个结果 + + Returns: + List[dict]: 排序后并截取 Top N 的文档列表,包含新的 'score' + """ + if not documents: + return [] + + # 1. 准备输入数据 + # Rerank API 需要纯文本列表,但我们需要保留 documents 里的 meta_info 和 id + # 所以我们提取 content 给 API,拿到 index 后再映射回去 + doc_contents = [doc.get('content', '') for doc in documents] + + # 如果文档太多(比如超过 100 个),建议先截断,避免 API 超时或报错 + if len(doc_contents) > 50: + doc_contents = doc_contents[:50] + documents = documents[:50] + + try: + # 2. 调用 DashScope GTE-Rerank + resp = dashscope.TextReRank.call( + model='gte-rerank', + query=query, + documents=doc_contents, + top_n=top_n, + return_documents=False # 我们只需要索引和分数,不需要它把文本再传回来 + ) + + if resp.status_code == HTTPStatus.OK: + # 3. 结果重组 + # API 返回结构示例: output.results = [{'index': 2, 'relevance_score': 0.98}, {'index': 0, ...}] + reranked_results = [] + + for item in resp.output.results: + # 根据 API 返回的 index 找到原始文档对象 + original_doc = documents[item.index] + + # 更新分数为 Rerank 的精准分数 (通常是 0~1 之间的置信度) + original_doc['score'] = item.relevance_score + + # 标记来源,方便调试知道这是 Rerank 过的 + original_doc['reranked'] = True + + reranked_results.append(original_doc) + + return reranked_results + else: + print(f"[ERROR] Rerank API Error: {resp}") + # 降级策略:如果 Rerank 挂了,直接返回粗排的前 N 个 + return documents[:top_n] + + except Exception as e: + print(f"[ERROR] Rerank Exception: {e}") + # 降级策略 + return documents[:top_n] + llm_service = LLMService() \ No newline at end of file diff --git a/scripts/evaluate_rag.py b/scripts/evaluate_rag.py new file mode 100644 index 0000000..2d78885 --- /dev/null +++ b/scripts/evaluate_rag.py @@ -0,0 +1,192 @@ +import sys +import os +import json +import requests +import time +import numpy as np +from time import sleep + +# 将项目根目录加入路径 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from backend.core.config import settings + +# ================= ⚙️ 配置区域 ================= +BASE_URL = "http://127.0.0.1:8000" +TASK_ID = 19 # ⚠️ 请修改为你实际爬取数据的 Task ID +# 自动适配操作系统路径 +TEST_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_dataset.json") +# ============================================== + +class Colors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + +def get_rag_results(query): + """ + 调用搜索接口并记录耗时 + """ + start_ts = time.time() + try: + # 调用 V2 接口,该接口内部已集成 混合检索 -> Rerank + res = requests.post( + f"{BASE_URL}/api/v2/search", + json={"query": query, "task_id": TASK_ID, "limit": 5}, # 获取 Top 5 + timeout=15 + ) + latency = (time.time() - start_ts) * 1000 # ms + + if res.status_code != 200: + print(f"{Colors.FAIL}❌ API Error {res.status_code}: {res.text}{Colors.ENDC}") + return [], 0 + + res_json = res.json() + chunks = res_json.get('data', {}).get('results', []) + return chunks, latency + except Exception as e: + print(f"{Colors.FAIL}❌ 请求异常: {e}{Colors.ENDC}") + return [], 0 + +def check_hit(content, keywords): + """ + 检查切片相关性 (Relevance Check) + 使用关键词匹配作为 Ground Truth 的轻量级验证。 + """ + if not keywords: return True # 拒答题或开放性题目跳过关键词检查 + if not content: return False + + content_lower = content.lower() + for k in keywords: + if k.lower() in content_lower: + return True + return False + +def run_evaluation(): + # 1. 加载测试集 + if not os.path.exists(TEST_FILE): + print(f"{Colors.FAIL}❌ 找不到测试文件: {TEST_FILE}{Colors.ENDC}") + print("请确保 scripts/test_dataset.json 文件存在。") + return + + with open(TEST_FILE, 'r', encoding='utf-8') as f: + dataset = json.load(f) + + print(f"{Colors.HEADER}🚀 开始全维度量化评测 (Task ID: {TASK_ID}){Colors.ENDC}") + print(f"📄 测试集包含 {len(dataset)} 个样本\n") + + # === 统计容器 === + metrics = { + "p_at_1": [], # Precision@1: 正确答案排第1 + "hit_at_5": [], # HitRate@5: 正确答案在前5 + "mrr": [], # MRR: 倒数排名分数 + "latency": [] # 耗时 + } + + # === 开始循环测试 === + for i, item in enumerate(dataset): + query = item['query'] + print(f"📝 Case {i+1}: {Colors.BOLD}{query}{Colors.ENDC}") + + # 执行检索 + chunks, latency = get_rag_results(query) + metrics['latency'].append(latency) + + # 计算单次指标 + is_hit_at_5 = 0 + p_at_1 = 0 + reciprocal_rank = 0.0 + hit_position = -1 + hit_chunk = None + + # 遍历 Top 5 结果 + for idx, chunk in enumerate(chunks): + if check_hit(chunk['content'], item['keywords']): + # 命中! + is_hit_at_5 = 1 + hit_position = idx + reciprocal_rank = 1.0 / (idx + 1) + hit_chunk = chunk + + # 如果是第1个就命中了 + if idx == 0: + p_at_1 = 1 + + # 找到即停止 (MRR计算只需知道第一个正确答案的位置) + break + + # 记录指标 + metrics['p_at_1'].append(p_at_1) + metrics['hit_at_5'].append(is_hit_at_5) + metrics['mrr'].append(reciprocal_rank) + + # 打印单行结果 + if is_hit_at_5: + rank_display = f"Rank {hit_position + 1}" + color = Colors.OKGREEN if hit_position == 0 else Colors.OKCYAN + source = hit_chunk.get('source_url', 'Unknown') + + # 跨语言污染检查 (简单规则) + warning = "" + if "/es/" in source and "Spanish" not in query: warning = f"{Colors.WARNING}[跨语言风险]{Colors.ENDC}" + elif "/zh/" in source and "如何" not in query and "什么" not in query: warning = f"{Colors.WARNING}[跨语言风险]{Colors.ENDC}" + + print(f" {color}✅ 命中 ({rank_display}){Colors.ENDC} | MRR: {reciprocal_rank:.2f} | 耗时: {latency:.0f}ms {warning}") + else: + print(f" {Colors.FAIL}❌ 未命中{Colors.ENDC} | 预期关键词: {item['keywords']}") + + # 稍微间隔,避免触发 API 频率限制 + sleep(0.1) + + # === 最终计算 === + count = len(dataset) + if count == 0: return + + avg_p1 = np.mean(metrics['p_at_1']) * 100 + avg_hit5 = np.mean(metrics['hit_at_5']) * 100 + avg_mrr = np.mean(metrics['mrr']) + avg_latency = np.mean(metrics['latency']) + p95_latency = np.percentile(metrics['latency'], 95) + + print("\n" + "="*60) + print(f"{Colors.HEADER}📊 最终量化评估报告 (Evaluation Report){Colors.ENDC}") + print("="*60) + + # 1. Precision@1 (最关键指标) + print(f"🥇 {Colors.BOLD}Precision@1 (首位精确率): {avg_p1:.1f}%{Colors.ENDC}") + print(f" - 意义: 用户能否直接得到正确答案。引入 Rerank 后此项应显著提高。") + + # 2. Hit Rate / Recall@5 + print(f"🥈 Hit Rate@5 (前五召回率): {avg_hit5:.1f}%") + print(f" - 意义: 数据库是否真的包含答案。如果此项低,说明爬虫没爬全或混合检索漏了。") + + # 3. MRR + print(f"🥉 MRR (平均倒数排名): {avg_mrr:.3f} / 1.0") + + # 4. Latency + print(f"⚡ Avg Latency (平均耗时): {avg_latency:.0f} ms") + print(f"⚡ P95 Latency (95%分位): {p95_latency:.0f} ms") + print("="*60) + + # === 智能诊断 === + print(f"{Colors.HEADER}💡 诊断建议:{Colors.ENDC}") + + if avg_p1 < avg_hit5: + gap = avg_hit5 - avg_p1 + print(f" • {Colors.WARNING}排序优化空间大{Colors.ENDC}: 召回了但没排第一的情况占 {gap:.1f}%。") + print(" -> 你的 Rerank 模型生效了吗?或者 Rerank 的 Top N 截断是否太早?") + elif avg_p1 > 80: + print(f" • {Colors.OKGREEN}排序效果优秀{Colors.ENDC}: 绝大多数正确答案都排在第一位。") + + if avg_hit5 < 50: + print(f" • {Colors.FAIL}召回率过低{Colors.ENDC}: 可能是测试集关键词太生僻,或者 TS_RANK 权重过低。") + + if avg_latency > 2000: + print(f" • {Colors.WARNING}系统响应慢{Colors.ENDC}: 2秒以上。检查是否因为 Rerank 文档过多(建议 <= 50个)。") + +if __name__ == "__main__": + run_evaluation() \ No newline at end of file diff --git a/scripts/test_dataset.json b/scripts/test_dataset.json new file mode 100644 index 0000000..9a7846a --- /dev/null +++ b/scripts/test_dataset.json @@ -0,0 +1,86 @@ +[ + { + "id": 1, + "type": "core_function", + "query": "What is the difference between /scrape and /map endpoints?", + "ground_truth": "/map is used to crawl a website and retrieve all URLs, while /scrape is used to extract content from a specific URL.", + "keywords": ["URL", "content", "specific", "retrieve"] + }, + { + "id": 2, + "type": "new_feature", + "query": "What is the Deep Research feature?", + "ground_truth": "Deep Research is an alpha feature allowing agents to perform iterative research tasks.", + "keywords": ["alpha", "iterative", "research", "agent"] + }, + { + "id": 3, + "type": "integration", + "query": "How can I integrate Firecrawl with ChatGPT?", + "ground_truth": "Firecrawl can be integrated via MCP (Model Context Protocol).", + "keywords": ["MCP", "Model Context Protocol", "setup"] + }, + { + "id": 4, + "type": "multilingual_zh", + "query": "如何进行私有化部署 (Self-host)?", + "ground_truth": "你需要使用 Docker Compose 进行部署,文档位于 /self-host/quick-start/docker-compose。", + "keywords": ["Docker", "Compose", "self-host", "deploy"] + }, + { + "id": 5, + "type": "api_detail", + "query": "What parameters are available for the /extract endpoint?", + "ground_truth": "The extract endpoint allows defining a schema for structured data extraction.", + "keywords": ["schema", "structured", "prompt"] + }, + { + "id": 6, + "type": "numeric", + "query": "How do credits work for the scrape endpoint?", + "ground_truth": "Specific credit usage details are in the /credits endpoint documentation (usually 1 credit per page for basic scrape).", + "keywords": ["credit", "usage", "cost"] + }, + { + "id": 7, + "type": "negative_test", + "query": "Does Firecrawl support scraping video content from YouTube?", + "ground_truth": "The documentation does not mention video scraping support.", + "keywords": [] + }, + { + "id": 8, + "type": "advanced", + "query": "How to use batch scrape?", + "ground_truth": "Use the /batch/scrape endpoint to submit multiple URLs at once.", + "keywords": ["batch", "multiple", "URLs"] + }, + { + "id": 9, + "type": "automation", + "query": "Is there an n8n integration guide?", + "ground_truth": "Yes, there is a workflow automation guide for n8n.", + "keywords": ["n8n", "workflow", "automation"] + }, + { + "id": 10, + "type": "security", + "query": "Where can I find information about webhook security?", + "ground_truth": "Information is available in the Webhooks Security section.", + "keywords": ["webhook", "security", "signature"] + }, + { + "id": 11, + "type": "cross_lingual_trap", + "query": "Explain the crawl features in French.", + "ground_truth": "The system should ideally retrieve the French document (/fr/features/crawl) and answer in French.", + "keywords": ["fonctionnalités", "crawl", "fr"] + }, + { + "id": 12, + "type": "api_history", + "query": "How to check historical token usage?", + "ground_truth": "Use the /token-usage-historical endpoint.", + "keywords": ["token", "usage", "historical"] + } +] \ No newline at end of file