Files
wiki_crawler/docs/t.md
2025-12-23 00:09:32 +08:00

5.1 KiB
Raw Blame History

针对你希望实现“通用、支持全库或特定任务搜索”的需求,我重新设计并实现了搜索 API。

这一版代码继续采用 SQLAlchemy Core API实现了逻辑上的“存取分离”Dify 只管发向量,后端决定怎么搜。

1. 修改 schemas.py

增加搜索请求模型,将 task_id 设为可选(Optional),从而支持全局搜索。

from pydantic import BaseModel
from typing import List, Optional

class SearchRequest(BaseModel):
    # 如果不传 task_id则进行全库搜索
    task_id: Optional[int] = None
    query_embedding: List[float]
    limit: Optional[int] = 5

2. 重新实现 service.py 中的搜索方法

利用 pgvectorcosine_distance 算子。注意这里使用了动态构建查询的技巧。

from sqlalchemy import select, and_
from .database import db_instance

class CrawlerService:
    def __init__(self):
        self.db = db_instance

    def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
        """
        高性能向量搜索方法
        :param query_embedding: 问题的向量
        :param task_id: 可选的任务ID不传则搜全表
        :param limit: 返回结果数量
        """
        with self.db.engine.connect() as conn:
            # 1. 选择需要的字段
            # 我们同时返回 task_id方便在全库搜索时知道来源哪个任务
            stmt = select(
                self.db.chunks.c.task_id,
                self.db.chunks.c.source_url,
                self.db.chunks.c.title,
                self.db.chunks.c.content,
                self.db.chunks.c.chunk_index
            )

            # 2. 动态添加过滤条件
            if task_id is not None:
                stmt = stmt.where(self.db.chunks.c.task_id == task_id)

            # 3. 按余弦距离排序1 - 余弦相似度)
            # 距离越小,相似度越高
            stmt = stmt.order_by(
                self.db.chunks.c.embedding.cosine_distance(query_embedding)
            ).limit(limit)

            # 4. 执行并解析结果
            rows = conn.execute(stmt).fetchall()
          
            results = []
            for r in rows:
                results.append({
                    "task_id": r[0],
                    "source_url": r[1],
                    "title": r[2],
                    "content": r[3],
                    "chunk_index": r[4]
                })
          
            return results

3. 重新实现 main.py 中的接口

增加了对向量维度的基础校验,确保系统健壮性。

@app.post("/search")
async def search(req: SearchRequest):
    """
    通用搜索接口:
    支持基于 task_id 的局部搜索,也支持不传 task_id 的全库搜索。
    """
    try:
        # 1. 基础校验:确保向量不为空且维度正确(阿里 v4 模型通常为 1536
        if not req.query_embedding or len(req.query_embedding) != 1536:
            return make_response(
                code=2, 
                msg=f"向量维度错误。期望 1536, 实际收到 {len(req.query_embedding) if req.query_embedding else 0}",
                data=None
            )

        # 2. 调用业务类执行搜索
        data = crawler_service.search_knowledge(
            query_embedding=req.query_embedding,
            task_id=req.task_id,
            limit=req.limit
        )

        # 3. 统一返回
        return make_response(
            code=1, 
            msg="搜索完成", 
            data=data
        )

    except Exception as e:
        # 记录日志并返回失败信息
        print(f"搜索接口异常: {str(e)}")
        return make_response(code=0, msg=f"搜索失败: {str(e)}")

4. 配合性能优化的 SQL 建议

如果你在 Dify 中运行一段时间后发现搜索变慢(数据达到万级或十万级),请在数据库执行以下语句开启 HNSW 索引。这是目前向量搜索性能最高、内存占用最平衡的方案。

-- 针对余弦距离优化搜索速度
-- m: 列表最大连接数ef_construction: 动态候选列表大小
-- 这两个参数决定了索引的精度和构建速度
CREATE INDEX idx_chunks_embedding_hnsw 
ON knowledge_chunks 
USING hnsw (embedding vector_cosine_ops) 
WITH (m = 16, ef_construction = 64);

5. 在 Dify 中这个节点的输出怎么用?

由于 /search 返回的是一个列表,在 Dify 中接一个 “代码节点” 处理成文本最方便:

def main(search_data: list) -> dict:
    # search_data 是从 HTTP 节点获得的 data 数组
    context_list = []
    for i, item in enumerate(search_data):
        # 格式化每条资料,包含来源和内容
        block = f"【资料{i+1}】来源: {item['source_url']}\n内容: {item['content']}"
        context_list.append(block)
  
    # 用换行符连接所有资料
    return {
        "final_context": "\n\n".join(context_list)
    }

最后把这个 final_context 塞进 LLM 节点的 Prompt 即可。这样的设计确保了你的 Dify 流程非常干净:输入 -> 转向量 -> 搜后端 -> 出答案