Files
wiki_crawler/backend/services/llm_service.py
2026-01-13 10:37:19 +08:00

93 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import dashscope
from http import HTTPStatus
from backend.core.config import settings
class LLMService:
"""
LLM 服务封装层
负责与 DashScope (通义千问/GTE) 交互,包括 Embedding 和 Rerank
"""
def __init__(self):
dashscope.api_key = settings.DASHSCOPE_API_KEY
def get_embedding(self, text: str, dimension: int = 1536):
"""生成文本向量 (Bi-Encoder)"""
try:
resp = dashscope.TextEmbedding.call(
model=dashscope.TextEmbedding.Models.text_embedding_v4, # 或 v4视你的数据库维度而定
input=text,
dimension=dimension
)
if resp.status_code == HTTPStatus.OK:
return resp.output['embeddings'][0]['embedding']
else:
print(f"[ERROR] Embedding API Error: {resp}")
return None
except Exception as e:
print(f"[ERROR] Embedding Exception: {e}")
return None
def rerank(self, query: str, documents: list, top_n: int = 5):
"""
执行重排序 (Cross-Encoder)
Args:
query: 用户问题
documents: 粗排召回的切片列表 (List[dict]),必须包含 'content' 字段
top_n: 最终返回多少个结果
Returns:
List[dict]: 排序后并截取 Top N 的文档列表,包含新的 'score'
"""
if not documents:
return []
# 1. 准备输入数据
# Rerank API 需要纯文本列表,但我们需要保留 documents 里的 meta_info 和 id
# 所以我们提取 content 给 API拿到 index 后再映射回去
doc_contents = [doc.get('content', '') for doc in documents]
# 如果文档太多(比如超过 100 个),建议先截断,避免 API 超时或报错
if len(doc_contents) > 50:
doc_contents = doc_contents[:50]
documents = documents[:50]
try:
# 2. 调用 DashScope GTE-Rerank
resp = dashscope.TextReRank.call(
model='gte-rerank',
query=query,
documents=doc_contents,
top_n=top_n,
return_documents=False # 我们只需要索引和分数,不需要它把文本再传回来
)
if resp.status_code == HTTPStatus.OK:
# 3. 结果重组
# API 返回结构示例: output.results = [{'index': 2, 'relevance_score': 0.98}, {'index': 0, ...}]
reranked_results = []
for item in resp.output.results:
# 根据 API 返回的 index 找到原始文档对象
original_doc = documents[item.index]
# 更新分数为 Rerank 的精准分数 (通常是 0~1 之间的置信度)
original_doc['score'] = item.relevance_score
# 标记来源,方便调试知道这是 Rerank 过的
original_doc['reranked'] = True
reranked_results.append(original_doc)
return reranked_results
else:
print(f"[ERROR] Rerank API Error: {resp}")
# 降级策略:如果 Rerank 挂了,直接返回粗排的前 N 个
return documents[:top_n]
except Exception as e:
print(f"[ERROR] Rerank Exception: {e}")
# 降级策略
return documents[:top_n]
llm_service = LLMService()