Files
wiki_crawler/backend/services/llm_service.py

120 lines
4.7 KiB
Python
Raw Normal View History

2026-01-13 01:37:26 +08:00
import dashscope
from http import HTTPStatus
from backend.core.config import settings
import logging
2026-01-13 01:37:26 +08:00
# 获取当前模块的专用 Logger
# __name__ 会自动识别为 "backend.services.crawler_service" 这样的路径
logger = logging.getLogger(__name__)
2026-01-13 01:37:26 +08:00
class LLMService:
"""
LLM 服务封装层
2026-01-13 10:37:19 +08:00
负责与 DashScope (通义千问/GTE) 交互包括 Embedding Rerank
2026-01-13 01:37:26 +08:00
"""
def __init__(self):
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embedding(self, text: str, dimension: int = 1536):
2026-01-13 10:37:19 +08:00
"""生成文本向量 (Bi-Encoder)"""
2026-01-13 01:37:26 +08:00
try:
resp = dashscope.TextEmbedding.call(
2026-01-13 10:37:19 +08:00
model=dashscope.TextEmbedding.Models.text_embedding_v4, # 或 v4视你的数据库维度而定
2026-01-13 01:37:26 +08:00
input=text,
dimension=dimension
)
if resp.status_code == HTTPStatus.OK:
return resp.output['embeddings'][0]['embedding']
else:
logger.error(f"Embedding API Error: {resp}")
2026-01-13 01:37:26 +08:00
return None
except Exception as e:
logger.error(f"Embedding Exception: {e}")
2026-01-13 01:37:26 +08:00
return None
2026-01-13 10:37:19 +08:00
def rerank(self, query: str, documents: list, top_n: int = 5):
"""
执行重排序 (Cross-Encoder)
Args:
query: 用户问题
documents: 粗排召回的切片列表 (List[dict])必须包含 'content' 字段
top_n: 最终返回多少个结果
Returns:
List[dict]: 排序后并截取 Top N 的文档列表包含新的 'score'
"""
if not documents:
return []
# 1. 准备输入数据
# Rerank API 需要纯文本列表,但我们需要保留 documents 里的 meta_info 和 id
# 所以我们提取 content 给 API拿到 index 后再映射回去
doc_contents = [doc.get('content', '') for doc in documents]
# 如果文档太多(比如超过 100 个),建议先截断,避免 API 超时或报错
if len(doc_contents) > 50:
doc_contents = doc_contents[:50]
documents = documents[:50]
try:
# 2. 调用 DashScope GTE-Rerank
resp = dashscope.TextReRank.call(
model='gte-rerank',
query=query,
documents=doc_contents,
top_n=top_n,
return_documents=False # 我们只需要索引和分数,不需要它把文本再传回来
)
if resp.status_code == HTTPStatus.OK:
# 3. 结果重组
# API 返回结构示例: output.results = [{'index': 2, 'relevance_score': 0.98}, {'index': 0, ...}]
reranked_results = []
for item in resp.output.results:
# 根据 API 返回的 index 找到原始文档对象
original_doc = documents[item.index]
# 更新分数为 Rerank 的精准分数 (通常是 0~1 之间的置信度)
original_doc['score'] = item.relevance_score
# 标记来源,方便调试知道这是 Rerank 过的
original_doc['reranked'] = True
reranked_results.append(original_doc)
return reranked_results
else:
logger.error(f"Rerank API Error: {resp}")
2026-01-13 10:37:19 +08:00
# 降级策略:如果 Rerank 挂了,直接返回粗排的前 N 个
return documents[:top_n]
except Exception as e:
logger.error(f"Rerank Exception: {e}")
2026-01-13 10:37:19 +08:00
# 降级策略
return documents[:top_n]
2026-01-27 01:41:45 +08:00
def chat(self, prompt: str, system_prompt: str = None, model: str = "qwen-turbo") -> str:
"""
[新增] 通用对话生成接口用于RAG的最终回答或作为测试裁判(Judge)
"""
messages = []
if system_prompt:
messages.append({'role': 'system', 'content': system_prompt})
messages.append({'role': 'user', 'content': prompt})
2026-01-13 10:37:19 +08:00
2026-01-27 01:41:45 +08:00
try:
resp = dashscope.Generation.call(
model=model,
messages=messages,
result_format='message'
)
if resp.status_code == HTTPStatus.OK:
return resp.output.choices[0].message.content
else:
logger.error(f"Chat API Error: {resp}")
return "Error generating response."
except Exception as e:
logger.error(f"Chat Exception: {e}")
return "Error generating response."
2026-01-13 01:37:26 +08:00
llm_service = LLMService()