Files
wiki_crawler/scripts/chunk.py
2025-12-19 00:52:32 +08:00

185 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import re
import requests
def embedding_alibaba(texts: list[str], api_key: str) -> list[list[float]]:
"""
调用阿里百炼 (DashScope) Embedding API
文档参考: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details
"""
if not texts:
return []
# 配置模型名称,阿里目前主力是 v2 和 v3
# 如果后续阿里发布了 v4直接在这里改字符串即可
MODEL_NAME = "text-embedding-v4"
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": MODEL_NAME,
"input": {
"texts": texts
},
"parameters": {
"text_type": "document",
"dimension": 1536
}
}
try:
response = requests.post(url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
# 阿里 API 返回结构:
# {
# "output": {
# "embeddings": [
# { "embedding": [...], "text_index": 0 },
# { "embedding": [...], "text_index": 1 }
# ]
# },
# "usage": ...
# }
if "output" in result and "embeddings" in result["output"]:
# 确保按 text_index 排序,防止乱序
embeddings_list = result["output"]["embeddings"]
embeddings_list.sort(key=lambda x: x["text_index"])
return [item["embedding"] for item in embeddings_list]
else:
print(f"Alibaba API Response Format Warning: {result}")
return [None] * len(texts)
except Exception as e:
print(f"Alibaba Embedding Error: {e}")
# 出错时返回 None 列表,确保流程不中断
return [None] * len(texts)
def main(res_json: list, DASHSCOPE_API_KEY: str) -> dict:
"""
输入: res_json (Firecrawl结果), DASHSCOPE_API_KEY (阿里API Key)
"""
# --- 1. 解析 Firecrawl JSON (通用容错解析) ---
try:
raw_data = res_json
if isinstance(raw_data, str):
try: raw_data = json.loads(raw_data)
except: pass
data_list = []
if isinstance(raw_data, dict) and 'res_json' in raw_data: data_list = raw_data['res_json']
elif isinstance(raw_data, list): data_list = raw_data
else: data_list = [raw_data]
if not data_list or not isinstance(data_list, list): return {"sql_values": "[]"}
try:
first_result = data_list[0]
if not isinstance(first_result, dict): return {"sql_values": "[]"}
data_obj = first_result.get("data", {})
metadata = data_obj.get("metadata", {})
# 获取原始内容
text = data_obj.get("markdown", "")
title = metadata.get("title", "No Title")
url = metadata.get("sourceURL", metadata.get("url", ""))
if not text: return {"sql_values": "[]"}
except IndexError: return {"sql_values": "[]"}
except Exception as e:
return {"sql_values": "[]", "error": f"Parse Error: {str(e)}"}
# =======================================================
# --- 2. 通用 Markdown 清洗 (Generic Cleaning) ---
# =======================================================
# 2.1 移除 Markdown 图片 (![alt](url)) -> 也就是删掉图片
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
# 2.2 移除 Markdown 链接格式,保留文本 ([text](url) -> text)
text = re.sub(r'\[(.*?)\]\(.*?\)', r'\1', text)
# 2.3 移除 HTML 标签 (简单的防噪)
text = re.sub(r'<[^>]+>', '', text)
# 2.4 清洗特殊字符和零宽空格
text = text.replace('\u200b', '')
# 2.5 压缩空行 (通用逻辑)
# 将连续的换行符(3个以上)替换为2个保持段落感但去除大片空白
text = re.sub(r'\n{3,}', '\n\n', text)
# 2.6 去除首尾空白
text = text.strip()
# --- 3. 安全切片 (Safe Chunking) ---
# 800 字符切片100 字符重叠
chunk_size = 800
overlap = 100
step = chunk_size - overlap
chunks = []
text_len = len(text)
if text_len < 50:
chunks.append(text)
else:
start = 0
while start < text_len:
end = min(start + chunk_size, text_len)
chunk_content = text[start:end]
# 防止切出过短的碎片,或者是最后一块
if len(chunk_content) > 50 or start + step >= text_len:
chunks.append(chunk_content)
start += step
# --- 4. 向量化 (Call Alibaba) ---
vectors = []
if chunks:
# 这里传入 DASHSCOPE_API_KEY
vectors = embedding_alibaba(chunks, DASHSCOPE_API_KEY)
# 双重保险:确保向量列表长度一致
if len(vectors) != len(chunks):
vectors = [None] * len(chunks)
# --- 5. 构造 SQL 数据 ---
result_list = []
# 简单的 SQL 转义,防止单引号报错
safe_title = str(title).replace("'", "''")
for idx, content in enumerate(chunks):
clean_content = content.strip()
if not clean_content: continue
result_list.append({
"url": url,
"title": safe_title,
"content": clean_content.replace("'", "''"),
"chunk_index": idx,
"embedding": vectors[idx]
})
return {
"sql_values": json.dumps(result_list)
}
if __name__ == "__main__":
key = "sk-8b091493de594c5e9eb42f12f1cc5805"
import json
with open("anyscript\wiki_crawler\chunk.json", "r", encoding="utf-8") as f:
data = json.load(f)
res = main(data, key)
result = json.loads(res["sql_values"])
print(result)