185 lines
5.9 KiB
Python
185 lines
5.9 KiB
Python
import json
|
||
import re
|
||
import requests
|
||
def embedding_alibaba(texts: list[str], api_key: str) -> list[list[float]]:
|
||
"""
|
||
调用阿里百炼 (DashScope) Embedding API
|
||
文档参考: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details
|
||
"""
|
||
if not texts:
|
||
return []
|
||
|
||
# 配置模型名称,阿里目前主力是 v2 和 v3
|
||
# 如果后续阿里发布了 v4,直接在这里改字符串即可
|
||
MODEL_NAME = "text-embedding-v4"
|
||
|
||
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
|
||
payload = {
|
||
"model": MODEL_NAME,
|
||
"input": {
|
||
"texts": texts
|
||
},
|
||
"parameters": {
|
||
"text_type": "document",
|
||
"dimension": 1536
|
||
}
|
||
}
|
||
|
||
try:
|
||
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||
response.raise_for_status()
|
||
result = response.json()
|
||
|
||
# 阿里 API 返回结构:
|
||
# {
|
||
# "output": {
|
||
# "embeddings": [
|
||
# { "embedding": [...], "text_index": 0 },
|
||
# { "embedding": [...], "text_index": 1 }
|
||
# ]
|
||
# },
|
||
# "usage": ...
|
||
# }
|
||
|
||
if "output" in result and "embeddings" in result["output"]:
|
||
# 确保按 text_index 排序,防止乱序
|
||
embeddings_list = result["output"]["embeddings"]
|
||
embeddings_list.sort(key=lambda x: x["text_index"])
|
||
return [item["embedding"] for item in embeddings_list]
|
||
else:
|
||
print(f"Alibaba API Response Format Warning: {result}")
|
||
return [None] * len(texts)
|
||
|
||
except Exception as e:
|
||
print(f"Alibaba Embedding Error: {e}")
|
||
# 出错时返回 None 列表,确保流程不中断
|
||
return [None] * len(texts)
|
||
|
||
def main(res_json: list, DASHSCOPE_API_KEY: str) -> dict:
|
||
"""
|
||
输入: res_json (Firecrawl结果), DASHSCOPE_API_KEY (阿里API Key)
|
||
"""
|
||
|
||
# --- 1. 解析 Firecrawl JSON (通用容错解析) ---
|
||
try:
|
||
raw_data = res_json
|
||
if isinstance(raw_data, str):
|
||
try: raw_data = json.loads(raw_data)
|
||
except: pass
|
||
|
||
data_list = []
|
||
if isinstance(raw_data, dict) and 'res_json' in raw_data: data_list = raw_data['res_json']
|
||
elif isinstance(raw_data, list): data_list = raw_data
|
||
else: data_list = [raw_data]
|
||
|
||
if not data_list or not isinstance(data_list, list): return {"sql_values": "[]"}
|
||
|
||
try:
|
||
first_result = data_list[0]
|
||
if not isinstance(first_result, dict): return {"sql_values": "[]"}
|
||
|
||
data_obj = first_result.get("data", {})
|
||
metadata = data_obj.get("metadata", {})
|
||
|
||
# 获取原始内容
|
||
text = data_obj.get("markdown", "")
|
||
title = metadata.get("title", "No Title")
|
||
url = metadata.get("sourceURL", metadata.get("url", ""))
|
||
|
||
if not text: return {"sql_values": "[]"}
|
||
except IndexError: return {"sql_values": "[]"}
|
||
|
||
except Exception as e:
|
||
return {"sql_values": "[]", "error": f"Parse Error: {str(e)}"}
|
||
|
||
# =======================================================
|
||
# --- 2. 通用 Markdown 清洗 (Generic Cleaning) ---
|
||
# =======================================================
|
||
|
||
# 2.1 移除 Markdown 图片 () -> 也就是删掉图片
|
||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||
|
||
# 2.2 移除 Markdown 链接格式,保留文本 ([text](url) -> text)
|
||
text = re.sub(r'\[(.*?)\]\(.*?\)', r'\1', text)
|
||
|
||
# 2.3 移除 HTML 标签 (简单的防噪)
|
||
text = re.sub(r'<[^>]+>', '', text)
|
||
|
||
# 2.4 清洗特殊字符和零宽空格
|
||
text = text.replace('\u200b', '')
|
||
|
||
# 2.5 压缩空行 (通用逻辑)
|
||
# 将连续的换行符(3个以上)替换为2个,保持段落感但去除大片空白
|
||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||
|
||
# 2.6 去除首尾空白
|
||
text = text.strip()
|
||
|
||
# --- 3. 安全切片 (Safe Chunking) ---
|
||
# 800 字符切片,100 字符重叠
|
||
chunk_size = 800
|
||
overlap = 100
|
||
step = chunk_size - overlap
|
||
|
||
chunks = []
|
||
text_len = len(text)
|
||
|
||
if text_len < 50:
|
||
chunks.append(text)
|
||
else:
|
||
start = 0
|
||
while start < text_len:
|
||
end = min(start + chunk_size, text_len)
|
||
chunk_content = text[start:end]
|
||
|
||
# 防止切出过短的碎片,或者是最后一块
|
||
if len(chunk_content) > 50 or start + step >= text_len:
|
||
chunks.append(chunk_content)
|
||
|
||
start += step
|
||
|
||
# --- 4. 向量化 (Call Alibaba) ---
|
||
vectors = []
|
||
if chunks:
|
||
# 这里传入 DASHSCOPE_API_KEY
|
||
vectors = embedding_alibaba(chunks, DASHSCOPE_API_KEY)
|
||
|
||
# 双重保险:确保向量列表长度一致
|
||
if len(vectors) != len(chunks):
|
||
vectors = [None] * len(chunks)
|
||
|
||
# --- 5. 构造 SQL 数据 ---
|
||
result_list = []
|
||
# 简单的 SQL 转义,防止单引号报错
|
||
safe_title = str(title).replace("'", "''")
|
||
|
||
for idx, content in enumerate(chunks):
|
||
clean_content = content.strip()
|
||
if not clean_content: continue
|
||
|
||
result_list.append({
|
||
"url": url,
|
||
"title": safe_title,
|
||
"content": clean_content.replace("'", "''"),
|
||
"chunk_index": idx,
|
||
"embedding": vectors[idx]
|
||
})
|
||
|
||
return {
|
||
"sql_values": json.dumps(result_list)
|
||
}
|
||
if __name__ == "__main__":
|
||
key = "sk-8b091493de594c5e9eb42f12f1cc5805"
|
||
import json
|
||
with open("anyscript\wiki_crawler\chunk.json", "r", encoding="utf-8") as f:
|
||
data = json.load(f)
|
||
res = main(data, key)
|
||
|
||
result = json.loads(res["sql_values"])
|
||
print(result) |