87 lines
2.8 KiB
Python
87 lines
2.8 KiB
Python
|
|
import requests
|
|||
|
|
import json
|
|||
|
|
import random
|
|||
|
|
|
|||
|
|
# 配置后端地址
|
|||
|
|
BASE_URL = "http://127.0.0.1:8000"
|
|||
|
|
|
|||
|
|
def log_res(name, response):
|
|||
|
|
print(f"\n=== 测试接口: {name} ===")
|
|||
|
|
if response.status_code == 200:
|
|||
|
|
res_json = response.json()
|
|||
|
|
print(f"状态: 成功 (HTTP 200)")
|
|||
|
|
print(f"返回数据: {json.dumps(res_json, indent=2, ensure_ascii=False)}")
|
|||
|
|
return res_json
|
|||
|
|
else:
|
|||
|
|
print(f"状态: 失败 (HTTP {response.status_code})")
|
|||
|
|
print(f"错误信息: {response.text}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def run_tests():
|
|||
|
|
# 测试数据准备
|
|||
|
|
test_root_url = f"https://example.com/wiki_{random.randint(1000, 9999)}"
|
|||
|
|
|
|||
|
|
# 1. 测试 /register
|
|||
|
|
print("步骤 1: 注册新任务...")
|
|||
|
|
res = requests.post(f"{BASE_URL}/register", json={"url": test_root_url})
|
|||
|
|
data = log_res("注册任务", res)
|
|||
|
|
if not data or data['code'] != 1: return
|
|||
|
|
task_id = data['data']['task_id']
|
|||
|
|
|
|||
|
|
# 2. 测试 /add_urls
|
|||
|
|
print("\n步骤 2: 模拟爬虫发现了新链接,存入队列...")
|
|||
|
|
sub_urls = [
|
|||
|
|
f"{test_root_url}/page1",
|
|||
|
|
f"{test_root_url}/page2",
|
|||
|
|
f"{test_root_url}/page1" # 故意重复一个,测试后端去重
|
|||
|
|
]
|
|||
|
|
res = requests.post(f"{BASE_URL}/add_urls", json={
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"urls": sub_urls
|
|||
|
|
})
|
|||
|
|
log_res("存入新链接", res)
|
|||
|
|
|
|||
|
|
# 3. 测试 /pending_urls
|
|||
|
|
print("\n步骤 3: 模拟爬虫节点获取待处理任务...")
|
|||
|
|
res = requests.post(f"{BASE_URL}/pending_urls", json={
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"limit": 2
|
|||
|
|
})
|
|||
|
|
data = log_res("获取待处理URL", res)
|
|||
|
|
if not data or not data['data']['urls']:
|
|||
|
|
print("没有获取到待处理URL,停止后续测试")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
target_url = data['data']['urls'][0]
|
|||
|
|
|
|||
|
|
# 4. 测试 /save_results
|
|||
|
|
print("\n步骤 4: 模拟爬虫抓取完成,存入知识片段和向量...")
|
|||
|
|
# 模拟一个 1536 维的向量(已处理精度)
|
|||
|
|
mock_embedding = [round(random.uniform(-1, 1), 8) for _ in range(1536)]
|
|||
|
|
|
|||
|
|
payload = {
|
|||
|
|
"task_id": task_id,
|
|||
|
|
"results": [
|
|||
|
|
{
|
|||
|
|
"source_url": target_url,
|
|||
|
|
"chunk_index": 0,
|
|||
|
|
"title": "测试页面标题 - 切片1",
|
|||
|
|
"content": "这是模拟抓取到的第一段网页内容...",
|
|||
|
|
"embedding": mock_embedding
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"source_url": target_url,
|
|||
|
|
"chunk_index": 1,
|
|||
|
|
"title": "测试页面标题 - 切片2",
|
|||
|
|
"content": "这是模拟抓取到的第二段网页内容...",
|
|||
|
|
"embedding": mock_embedding
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
}
|
|||
|
|
res = requests.post(f"{BASE_URL}/save_results", json=payload)
|
|||
|
|
log_res("保存结果", res)
|
|||
|
|
|
|||
|
|
print("\n✅ 所有 API 流程测试完成!")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
run_tests()
|