Files
wiki_crawler/nodes/chunk_and_embedding.py
2025-12-23 00:36:49 +08:00

139 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import re
import requests
def text_cleaning(text: str) -> str:
"""
对文本进行清洗,移除多余空格、换行符等
"""
text = re.sub(r'\s+', ' ', text) # 替换多个空格为一个空格
text = text.strip() # 移除首尾空格
return text
def text_to_chunks(text: str):
chunk_size = 800
overlap = 100 # 100 字符重叠意思是每块文本之间有100个字符的重叠
step = chunk_size - overlap
chunks = []
text_len = len(text)
if text_len < 50:
chunks.append(text)
else:
start = 0
while start < text_len:
end = min(start + chunk_size, text_len)
chunk_content = text[start:end]
# 防止切出过短的碎片,或者是最后一块
if len(chunk_content) > 50 or start + step >= text_len:
chunks.append(chunk_content)
start += step
return chunks
def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]:
if not texts:
return []
MODEL_NAME = "text-embedding-v4"
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": MODEL_NAME,
"input": {"texts": texts},
"parameters": {"text_type": "document", "dimension": 1536}
}
try:
response = requests.post(url, headers=headers, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
if "output" in result and "embeddings" in result["output"]:
embeddings_list = result["output"]["embeddings"]
embeddings_list.sort(key=lambda x: x["text_index"]) # 按文本索引排序,确保顺序一致
# --- 核心修复:对每个浮点数保留 8 位小数,解决精度过高报错 ---
final_vectors = []
for item in embeddings_list:
# 将每个 float 限制在 8 位精度以内
rounded_vector = [round(float(val), 8) for val in item["embedding"]]
final_vectors.append(rounded_vector)
return final_vectors
else:
return [None] * len(texts)
except Exception as e:
print(f"Alibaba Embedding Error: {e}")
return [None] * len(texts)
def main(scrape_json: list, DASHSCOPE_API_KEY: str) -> dict:
"""
输入: res_json (Firecrawl结果), DASHSCOPE_API_KEY (阿里API Key)
"""
# --- 1. 解析 Firecrawl JSON (通用容错解析) ---
scrape_obj = scrape_json[0]
if not scrape_obj["success"]:
return {"results": []}
data = scrape_obj.get("data", [])
# 获取原始内容
text = data.get("markdown", "")
metadata = data.get("metadata", {})
warning = data.get("warning", "")
# =======================================================
# --- 2. 通用 Markdown 清洗 (Generic Cleaning) ---
# =======================================================
text = text_cleaning(text)
# --- 3. 安全切片 (Safe Chunking) ---
# 800 字符切片100 字符重叠
chunks = text_to_chunks(text)
# --- 4. 向量化 (Call Alibaba) ---
vectors = []
if chunks:
# 这里传入 DASHSCOPE_API_KEY
vectors = chunks_embedding(chunks, DASHSCOPE_API_KEY)
# 双重保险:确保向量列表长度一致
if len(vectors) != len(chunks):
vectors = [None] * len(chunks)
# --- 5. 构造 SQL 数据 ---
result_list = []
for idx, content in enumerate(chunks):
clean_content = content.strip() # 清洗首尾空白
if not clean_content: continue
result_list.append({
"source_url": metadata.get("sourceURL", ""),
"title": metadata.get("title", ""),
"content": clean_content,
"chunk_index": idx,
"embedding": vectors[idx]
})
'''
JSON 格式
results:
[
{
"source_url": "https://example.com",
"title": "Example Title",
"content": "Example chunk content",
"chunk_index": 0,
"embedding": [0.123, 0.456, ...]
},
...
]
'''
return {
"results": result_list,
}