185 lines
5.9 KiB
Python
185 lines
5.9 KiB
Python
|
|
import json
|
|||
|
|
import re
|
|||
|
|
import requests
|
|||
|
|
def embedding_alibaba(texts: list[str], api_key: str) -> list[list[float]]:
|
|||
|
|
"""
|
|||
|
|
调用阿里百炼 (DashScope) Embedding API
|
|||
|
|
文档参考: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details
|
|||
|
|
"""
|
|||
|
|
if not texts:
|
|||
|
|
return []
|
|||
|
|
|
|||
|
|
# 配置模型名称,阿里目前主力是 v2 和 v3
|
|||
|
|
# 如果后续阿里发布了 v4,直接在这里改字符串即可
|
|||
|
|
MODEL_NAME = "text-embedding-v4"
|
|||
|
|
|
|||
|
|
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
|||
|
|
|
|||
|
|
headers = {
|
|||
|
|
"Authorization": f"Bearer {api_key}",
|
|||
|
|
"Content-Type": "application/json"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"model": MODEL_NAME,
|
|||
|
|
"input": {
|
|||
|
|
"texts": texts
|
|||
|
|
},
|
|||
|
|
"parameters": {
|
|||
|
|
"text_type": "document",
|
|||
|
|
"dimension": 1536
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
|||
|
|
response.raise_for_status()
|
|||
|
|
result = response.json()
|
|||
|
|
|
|||
|
|
# 阿里 API 返回结构:
|
|||
|
|
# {
|
|||
|
|
# "output": {
|
|||
|
|
# "embeddings": [
|
|||
|
|
# { "embedding": [...], "text_index": 0 },
|
|||
|
|
# { "embedding": [...], "text_index": 1 }
|
|||
|
|
# ]
|
|||
|
|
# },
|
|||
|
|
# "usage": ...
|
|||
|
|
# }
|
|||
|
|
|
|||
|
|
if "output" in result and "embeddings" in result["output"]:
|
|||
|
|
# 确保按 text_index 排序,防止乱序
|
|||
|
|
embeddings_list = result["output"]["embeddings"]
|
|||
|
|
embeddings_list.sort(key=lambda x: x["text_index"])
|
|||
|
|
return [item["embedding"] for item in embeddings_list]
|
|||
|
|
else:
|
|||
|
|
print(f"Alibaba API Response Format Warning: {result}")
|
|||
|
|
return [None] * len(texts)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"Alibaba Embedding Error: {e}")
|
|||
|
|
# 出错时返回 None 列表,确保流程不中断
|
|||
|
|
return [None] * len(texts)
|
|||
|
|
|
|||
|
|
def main(res_json: list, DASHSCOPE_API_KEY: str) -> dict:
|
|||
|
|
"""
|
|||
|
|
输入: res_json (Firecrawl结果), DASHSCOPE_API_KEY (阿里API Key)
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
# --- 1. 解析 Firecrawl JSON (通用容错解析) ---
|
|||
|
|
try:
|
|||
|
|
raw_data = res_json
|
|||
|
|
if isinstance(raw_data, str):
|
|||
|
|
try: raw_data = json.loads(raw_data)
|
|||
|
|
except: pass
|
|||
|
|
|
|||
|
|
data_list = []
|
|||
|
|
if isinstance(raw_data, dict) and 'res_json' in raw_data: data_list = raw_data['res_json']
|
|||
|
|
elif isinstance(raw_data, list): data_list = raw_data
|
|||
|
|
else: data_list = [raw_data]
|
|||
|
|
|
|||
|
|
if not data_list or not isinstance(data_list, list): return {"sql_values": "[]"}
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
first_result = data_list[0]
|
|||
|
|
if not isinstance(first_result, dict): return {"sql_values": "[]"}
|
|||
|
|
|
|||
|
|
data_obj = first_result.get("data", {})
|
|||
|
|
metadata = data_obj.get("metadata", {})
|
|||
|
|
|
|||
|
|
# 获取原始内容
|
|||
|
|
text = data_obj.get("markdown", "")
|
|||
|
|
title = metadata.get("title", "No Title")
|
|||
|
|
url = metadata.get("sourceURL", metadata.get("url", ""))
|
|||
|
|
|
|||
|
|
if not text: return {"sql_values": "[]"}
|
|||
|
|
except IndexError: return {"sql_values": "[]"}
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
return {"sql_values": "[]", "error": f"Parse Error: {str(e)}"}
|
|||
|
|
|
|||
|
|
# =======================================================
|
|||
|
|
# --- 2. 通用 Markdown 清洗 (Generic Cleaning) ---
|
|||
|
|
# =======================================================
|
|||
|
|
|
|||
|
|
# 2.1 移除 Markdown 图片 () -> 也就是删掉图片
|
|||
|
|
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
|||
|
|
|
|||
|
|
# 2.2 移除 Markdown 链接格式,保留文本 ([text](url) -> text)
|
|||
|
|
text = re.sub(r'\[(.*?)\]\(.*?\)', r'\1', text)
|
|||
|
|
|
|||
|
|
# 2.3 移除 HTML 标签 (简单的防噪)
|
|||
|
|
text = re.sub(r'<[^>]+>', '', text)
|
|||
|
|
|
|||
|
|
# 2.4 清洗特殊字符和零宽空格
|
|||
|
|
text = text.replace('\u200b', '')
|
|||
|
|
|
|||
|
|
# 2.5 压缩空行 (通用逻辑)
|
|||
|
|
# 将连续的换行符(3个以上)替换为2个,保持段落感但去除大片空白
|
|||
|
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
|||
|
|
|
|||
|
|
# 2.6 去除首尾空白
|
|||
|
|
text = text.strip()
|
|||
|
|
|
|||
|
|
# --- 3. 安全切片 (Safe Chunking) ---
|
|||
|
|
# 800 字符切片,100 字符重叠
|
|||
|
|
chunk_size = 800
|
|||
|
|
overlap = 100
|
|||
|
|
step = chunk_size - overlap
|
|||
|
|
|
|||
|
|
chunks = []
|
|||
|
|
text_len = len(text)
|
|||
|
|
|
|||
|
|
if text_len < 50:
|
|||
|
|
chunks.append(text)
|
|||
|
|
else:
|
|||
|
|
start = 0
|
|||
|
|
while start < text_len:
|
|||
|
|
end = min(start + chunk_size, text_len)
|
|||
|
|
chunk_content = text[start:end]
|
|||
|
|
|
|||
|
|
# 防止切出过短的碎片,或者是最后一块
|
|||
|
|
if len(chunk_content) > 50 or start + step >= text_len:
|
|||
|
|
chunks.append(chunk_content)
|
|||
|
|
|
|||
|
|
start += step
|
|||
|
|
|
|||
|
|
# --- 4. 向量化 (Call Alibaba) ---
|
|||
|
|
vectors = []
|
|||
|
|
if chunks:
|
|||
|
|
# 这里传入 DASHSCOPE_API_KEY
|
|||
|
|
vectors = embedding_alibaba(chunks, DASHSCOPE_API_KEY)
|
|||
|
|
|
|||
|
|
# 双重保险:确保向量列表长度一致
|
|||
|
|
if len(vectors) != len(chunks):
|
|||
|
|
vectors = [None] * len(chunks)
|
|||
|
|
|
|||
|
|
# --- 5. 构造 SQL 数据 ---
|
|||
|
|
result_list = []
|
|||
|
|
# 简单的 SQL 转义,防止单引号报错
|
|||
|
|
safe_title = str(title).replace("'", "''")
|
|||
|
|
|
|||
|
|
for idx, content in enumerate(chunks):
|
|||
|
|
clean_content = content.strip()
|
|||
|
|
if not clean_content: continue
|
|||
|
|
|
|||
|
|
result_list.append({
|
|||
|
|
"url": url,
|
|||
|
|
"title": safe_title,
|
|||
|
|
"content": clean_content.replace("'", "''"),
|
|||
|
|
"chunk_index": idx,
|
|||
|
|
"embedding": vectors[idx]
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"sql_values": json.dumps(result_list)
|
|||
|
|
}
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
key = "sk-8b091493de594c5e9eb42f12f1cc5805"
|
|||
|
|
import json
|
|||
|
|
with open("anyscript\wiki_crawler\chunk.json", "r", encoding="utf-8") as f:
|
|||
|
|
data = json.load(f)
|
|||
|
|
res = main(data, key)
|
|||
|
|
|
|||
|
|
result = json.loads(res["sql_values"])
|
|||
|
|
print(result)
|