aaa
This commit is contained in:
185
scripts/chunk.py
Normal file
185
scripts/chunk.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
import re
|
||||
import requests
|
||||
def embedding_alibaba(texts: list[str], api_key: str) -> list[list[float]]:
|
||||
"""
|
||||
调用阿里百炼 (DashScope) Embedding API
|
||||
文档参考: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
# 配置模型名称,阿里目前主力是 v2 和 v3
|
||||
# 如果后续阿里发布了 v4,直接在这里改字符串即可
|
||||
MODEL_NAME = "text-embedding-v4"
|
||||
|
||||
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"input": {
|
||||
"texts": texts
|
||||
},
|
||||
"parameters": {
|
||||
"text_type": "document",
|
||||
"dimension": 1536
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# 阿里 API 返回结构:
|
||||
# {
|
||||
# "output": {
|
||||
# "embeddings": [
|
||||
# { "embedding": [...], "text_index": 0 },
|
||||
# { "embedding": [...], "text_index": 1 }
|
||||
# ]
|
||||
# },
|
||||
# "usage": ...
|
||||
# }
|
||||
|
||||
if "output" in result and "embeddings" in result["output"]:
|
||||
# 确保按 text_index 排序,防止乱序
|
||||
embeddings_list = result["output"]["embeddings"]
|
||||
embeddings_list.sort(key=lambda x: x["text_index"])
|
||||
return [item["embedding"] for item in embeddings_list]
|
||||
else:
|
||||
print(f"Alibaba API Response Format Warning: {result}")
|
||||
return [None] * len(texts)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Alibaba Embedding Error: {e}")
|
||||
# 出错时返回 None 列表,确保流程不中断
|
||||
return [None] * len(texts)
|
||||
|
||||
def main(res_json: list, DASHSCOPE_API_KEY: str) -> dict:
|
||||
"""
|
||||
输入: res_json (Firecrawl结果), DASHSCOPE_API_KEY (阿里API Key)
|
||||
"""
|
||||
|
||||
# --- 1. 解析 Firecrawl JSON (通用容错解析) ---
|
||||
try:
|
||||
raw_data = res_json
|
||||
if isinstance(raw_data, str):
|
||||
try: raw_data = json.loads(raw_data)
|
||||
except: pass
|
||||
|
||||
data_list = []
|
||||
if isinstance(raw_data, dict) and 'res_json' in raw_data: data_list = raw_data['res_json']
|
||||
elif isinstance(raw_data, list): data_list = raw_data
|
||||
else: data_list = [raw_data]
|
||||
|
||||
if not data_list or not isinstance(data_list, list): return {"sql_values": "[]"}
|
||||
|
||||
try:
|
||||
first_result = data_list[0]
|
||||
if not isinstance(first_result, dict): return {"sql_values": "[]"}
|
||||
|
||||
data_obj = first_result.get("data", {})
|
||||
metadata = data_obj.get("metadata", {})
|
||||
|
||||
# 获取原始内容
|
||||
text = data_obj.get("markdown", "")
|
||||
title = metadata.get("title", "No Title")
|
||||
url = metadata.get("sourceURL", metadata.get("url", ""))
|
||||
|
||||
if not text: return {"sql_values": "[]"}
|
||||
except IndexError: return {"sql_values": "[]"}
|
||||
|
||||
except Exception as e:
|
||||
return {"sql_values": "[]", "error": f"Parse Error: {str(e)}"}
|
||||
|
||||
# =======================================================
|
||||
# --- 2. 通用 Markdown 清洗 (Generic Cleaning) ---
|
||||
# =======================================================
|
||||
|
||||
# 2.1 移除 Markdown 图片 () -> 也就是删掉图片
|
||||
text = re.sub(r'!\[.*?\]\(.*?\)', '', text)
|
||||
|
||||
# 2.2 移除 Markdown 链接格式,保留文本 ([text](url) -> text)
|
||||
text = re.sub(r'\[(.*?)\]\(.*?\)', r'\1', text)
|
||||
|
||||
# 2.3 移除 HTML 标签 (简单的防噪)
|
||||
text = re.sub(r'<[^>]+>', '', text)
|
||||
|
||||
# 2.4 清洗特殊字符和零宽空格
|
||||
text = text.replace('\u200b', '')
|
||||
|
||||
# 2.5 压缩空行 (通用逻辑)
|
||||
# 将连续的换行符(3个以上)替换为2个,保持段落感但去除大片空白
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
|
||||
# 2.6 去除首尾空白
|
||||
text = text.strip()
|
||||
|
||||
# --- 3. 安全切片 (Safe Chunking) ---
|
||||
# 800 字符切片,100 字符重叠
|
||||
chunk_size = 800
|
||||
overlap = 100
|
||||
step = chunk_size - overlap
|
||||
|
||||
chunks = []
|
||||
text_len = len(text)
|
||||
|
||||
if text_len < 50:
|
||||
chunks.append(text)
|
||||
else:
|
||||
start = 0
|
||||
while start < text_len:
|
||||
end = min(start + chunk_size, text_len)
|
||||
chunk_content = text[start:end]
|
||||
|
||||
# 防止切出过短的碎片,或者是最后一块
|
||||
if len(chunk_content) > 50 or start + step >= text_len:
|
||||
chunks.append(chunk_content)
|
||||
|
||||
start += step
|
||||
|
||||
# --- 4. 向量化 (Call Alibaba) ---
|
||||
vectors = []
|
||||
if chunks:
|
||||
# 这里传入 DASHSCOPE_API_KEY
|
||||
vectors = embedding_alibaba(chunks, DASHSCOPE_API_KEY)
|
||||
|
||||
# 双重保险:确保向量列表长度一致
|
||||
if len(vectors) != len(chunks):
|
||||
vectors = [None] * len(chunks)
|
||||
|
||||
# --- 5. 构造 SQL 数据 ---
|
||||
result_list = []
|
||||
# 简单的 SQL 转义,防止单引号报错
|
||||
safe_title = str(title).replace("'", "''")
|
||||
|
||||
for idx, content in enumerate(chunks):
|
||||
clean_content = content.strip()
|
||||
if not clean_content: continue
|
||||
|
||||
result_list.append({
|
||||
"url": url,
|
||||
"title": safe_title,
|
||||
"content": clean_content.replace("'", "''"),
|
||||
"chunk_index": idx,
|
||||
"embedding": vectors[idx]
|
||||
})
|
||||
|
||||
return {
|
||||
"sql_values": json.dumps(result_list)
|
||||
}
|
||||
if __name__ == "__main__":
|
||||
key = "sk-8b091493de594c5e9eb42f12f1cc5805"
|
||||
import json
|
||||
with open("anyscript\wiki_crawler\chunk.json", "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
res = main(data, key)
|
||||
|
||||
result = json.loads(res["sql_values"])
|
||||
print(result)
|
||||
Reference in New Issue
Block a user