Files
wiki_crawler/docs/t.md
2025-12-23 00:09:32 +08:00

150 lines
5.1 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
针对你希望实现“通用、支持全库或特定任务搜索”的需求,我重新设计并实现了搜索 API。
这一版代码继续采用 **SQLAlchemy Core API**实现了逻辑上的“存取分离”Dify 只管发向量,后端决定怎么搜。
### 1. 修改 `schemas.py`
增加搜索请求模型,将 `task_id` 设为可选(`Optional`),从而支持全局搜索。
```python
from pydantic import BaseModel
from typing import List, Optional
class SearchRequest(BaseModel):
# 如果不传 task_id则进行全库搜索
task_id: Optional[int] = None
query_embedding: List[float]
limit: Optional[int] = 5
```
### 2. 重新实现 `service.py` 中的搜索方法
利用 `pgvector``cosine_distance` 算子。注意这里使用了动态构建查询的技巧。
```python
from sqlalchemy import select, and_
from .database import db_instance
class CrawlerService:
def __init__(self):
self.db = db_instance
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
"""
高性能向量搜索方法
:param query_embedding: 问题的向量
:param task_id: 可选的任务ID不传则搜全表
:param limit: 返回结果数量
"""
with self.db.engine.connect() as conn:
# 1. 选择需要的字段
# 我们同时返回 task_id方便在全库搜索时知道来源哪个任务
stmt = select(
self.db.chunks.c.task_id,
self.db.chunks.c.source_url,
self.db.chunks.c.title,
self.db.chunks.c.content,
self.db.chunks.c.chunk_index
)
# 2. 动态添加过滤条件
if task_id is not None:
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
# 3. 按余弦距离排序1 - 余弦相似度)
# 距离越小,相似度越高
stmt = stmt.order_by(
self.db.chunks.c.embedding.cosine_distance(query_embedding)
).limit(limit)
# 4. 执行并解析结果
rows = conn.execute(stmt).fetchall()
results = []
for r in rows:
results.append({
"task_id": r[0],
"source_url": r[1],
"title": r[2],
"content": r[3],
"chunk_index": r[4]
})
return results
```
### 3. 重新实现 `main.py` 中的接口
增加了对向量维度的基础校验,确保系统健壮性。
```python
@app.post("/search")
async def search(req: SearchRequest):
"""
通用搜索接口:
支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。
"""
try:
# 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536
if not req.query_embedding or len(req.query_embedding) != 1536:
return make_response(
code=2,
msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
data=None
)
# 2. 调用业务类执行搜索
data = crawler_service.search_knowledge(
query_embedding=req.query_embedding,
task_id=req.task_id,
limit=req.limit
)
# 3. 统一返回
return make_response(
code=1,
msg="搜索完成",
data=data
)
except Exception as e:
# 记录日志并返回失败信息
print(f"搜索接口异常: {str(e)}")
return make_response(code=0, msg=f"搜索失败: {str(e)}")
```
### 4. 配合性能优化的 SQL 建议
如果你在 Dify 中运行一段时间后发现搜索变慢(数据达到万级或十万级),请在数据库执行以下语句开启 **HNSW 索引**。这是目前向量搜索性能最高、内存占用最平衡的方案。
```sql
-- 针对余弦距离优化搜索速度
-- m: 列表最大连接数ef_construction: 动态候选列表大小
-- 这两个参数决定了索引的精度和构建速度
CREATE INDEX idx_chunks_embedding_hnsw
ON knowledge_chunks
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);
```
### 5. 在 Dify 中这个节点的输出怎么用?
由于 `/search` 返回的是一个列表,在 Dify 中接一个 **“代码节点”** 处理成文本最方便:
```python
def main(search_data: list) -> dict:
# search_data 是从 HTTP 节点获得的 data 数组
context_list = []
for i, item in enumerate(search_data):
# 格式化每条资料,包含来源和内容
block = f"【资料{i+1}】来源: {item['source_url']}\n内容: {item['content']}"
context_list.append(block)
# 用换行符连接所有资料
return {
"final_context": "\n\n".join(context_list)
}
```
最后把这个 `final_context` 塞进 LLM 节点的 Prompt 即可。这样的设计确保了你的 Dify 流程非常干净:**输入 -> 转向量 -> 搜后端 -> 出答案**。