添加search方法
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from .service import crawler_service
|
from .service import crawler_service
|
||||||
from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest
|
from .schemas import RegisterRequest, PendingRequest, SaveResultsRequest, AddUrlsRequest, SearchRequest
|
||||||
from .utils import make_response
|
from .utils import make_response
|
||||||
|
|
||||||
app = FastAPI(title="Wiki Crawler API")
|
app = FastAPI(title="Wiki Crawler API")
|
||||||
@@ -37,7 +37,41 @@ async def save_results(req: SaveResultsRequest):
|
|||||||
return make_response(1, "Success", data)
|
return make_response(1, "Success", data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return make_response(0, str(e))
|
return make_response(0, str(e))
|
||||||
|
|
||||||
|
@app.post("/search")
|
||||||
|
async def search(req: SearchRequest):
|
||||||
|
"""
|
||||||
|
通用搜索接口:
|
||||||
|
支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536)
|
||||||
|
if not req.query_embedding or len(req.query_embedding) != 1536:
|
||||||
|
return make_response(
|
||||||
|
code=2,
|
||||||
|
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
|
||||||
|
data=None
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 调用业务类执行搜索
|
||||||
|
data = crawler_service.search_knowledge(
|
||||||
|
query_embedding=req.query_embedding,
|
||||||
|
task_id=req.task_id,
|
||||||
|
limit=req.limit
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 统一返回
|
||||||
|
return make_response(
|
||||||
|
code=1,
|
||||||
|
msg="搜索完成",
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# 记录日志并返回失败信息
|
||||||
|
print(f"搜索接口异常: {str(e)}")
|
||||||
|
return make_response(code=0, msg=f"搜索失败: {str(e)}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
|
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True)
|
||||||
@@ -22,4 +22,10 @@ class CrawlResult(BaseModel):
|
|||||||
|
|
||||||
class SaveResultsRequest(BaseModel):
|
class SaveResultsRequest(BaseModel):
|
||||||
task_id: int
|
task_id: int
|
||||||
results: List[CrawlResult]
|
results: List[CrawlResult]
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
# 如果不传 task_id,则进行全库搜索
|
||||||
|
task_id: Optional[int] = None
|
||||||
|
query_embedding: List[float]
|
||||||
|
limit: Optional[int] = 5
|
||||||
@@ -156,6 +156,50 @@ class CrawlerService:
|
|||||||
"failed": len(failed_chunks)
|
"failed": len(failed_chunks)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||||
|
"""
|
||||||
|
高性能向量搜索方法
|
||||||
|
:param query_embedding: 问题的向量
|
||||||
|
:param task_id: 可选的任务ID,不传则搜全表
|
||||||
|
:param limit: 返回结果数量
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
with self.db.engine.connect() as conn:
|
||||||
|
# 1. 选择需要的字段
|
||||||
|
# 我们同时返回 task_id,方便在全库搜索时知道来源哪个任务
|
||||||
|
stmt = select(
|
||||||
|
self.db.chunks.c.task_id,
|
||||||
|
self.db.chunks.c.source_url,
|
||||||
|
self.db.chunks.c.title,
|
||||||
|
self.db.chunks.c.content,
|
||||||
|
self.db.chunks.c.chunk_index
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 动态添加过滤条件
|
||||||
|
if task_id is not None:
|
||||||
|
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||||
|
|
||||||
|
# 3. 按余弦距离排序(1 - 余弦相似度)
|
||||||
|
# 距离越小,相似度越高
|
||||||
|
stmt = stmt.order_by(
|
||||||
|
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||||
|
).limit(limit)
|
||||||
|
|
||||||
|
# 4. 执行并解析结果
|
||||||
|
rows = conn.execute(stmt).fetchall()
|
||||||
|
|
||||||
|
results = []
|
||||||
|
for r in rows:
|
||||||
|
results.append({
|
||||||
|
"task_id": r[0],
|
||||||
|
"source_url": r[1],
|
||||||
|
"title": r[2],
|
||||||
|
"content": r[3],
|
||||||
|
"chunk_index": r[4]
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
crawler_service = CrawlerService()
|
crawler_service = CrawlerService()
|
||||||
@@ -57,7 +57,7 @@ def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]:
|
|||||||
|
|
||||||
if "output" in result and "embeddings" in result["output"]:
|
if "output" in result and "embeddings" in result["output"]:
|
||||||
embeddings_list = result["output"]["embeddings"]
|
embeddings_list = result["output"]["embeddings"]
|
||||||
embeddings_list.sort(key=lambda x: x["text_index"])
|
embeddings_list.sort(key=lambda x: x["text_index"]) # 按文本索引排序,确保顺序一致
|
||||||
|
|
||||||
# --- 核心修复:对每个浮点数保留 8 位小数,解决精度过高报错 ---
|
# --- 核心修复:对每个浮点数保留 8 位小数,解决精度过高报错 ---
|
||||||
final_vectors = []
|
final_vectors = []
|
||||||
|
|||||||
46
nodes/embedding.py
Normal file
46
nodes/embedding.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
def chunks_embedding(texts: list[str], api_key: str) -> list[list[float]]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
MODEL_NAME = "text-embedding-v4"
|
||||||
|
url = "https://dashscope.aliyuncs.com/api/v1/services/embeddings/text-embedding/text-embedding"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
payload = {
|
||||||
|
"model": MODEL_NAME,
|
||||||
|
"input": {"texts": texts},
|
||||||
|
"parameters": {"text_type": "document", "dimension": 1536}
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||||
|
response.raise_for_status()
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
if "output" in result and "embeddings" in result["output"]:
|
||||||
|
embeddings_list = result["output"]["embeddings"]
|
||||||
|
embeddings_list.sort(key=lambda x: x["text_index"])
|
||||||
|
|
||||||
|
# --- 核心修复:对每个浮点数保留 8 位小数,解决精度过高报错 ---
|
||||||
|
final_vectors = []
|
||||||
|
for item in embeddings_list:
|
||||||
|
# 将每个 float 限制在 8 位精度以内
|
||||||
|
rounded_vector = [round(float(val), 8) for val in item["embedding"]]
|
||||||
|
final_vectors.append(rounded_vector)
|
||||||
|
return final_vectors
|
||||||
|
else:
|
||||||
|
return [None] * len(texts)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Alibaba Embedding Error: {e}")
|
||||||
|
return [None] * len(texts)
|
||||||
|
|
||||||
|
def main(text: str, api_key: str):
|
||||||
|
|
||||||
|
vector = chunks_embedding([text], api_key)[0]
|
||||||
|
return {
|
||||||
|
'vector': vector
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user