变更项目架构,提高扩展性
This commit is contained in:
@@ -1,201 +0,0 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from firecrawl import FirecrawlApp
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from ..config import settings
|
||||
from .crawler_sql_service import crawler_sql_service
|
||||
|
||||
# 初始化配置
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
class AutomatedCrawler:
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
self.splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=500,
|
||||
chunk_overlap=100,
|
||||
separators=["\n\n", "\n", "。", "!", "?", " ", ""]
|
||||
)
|
||||
|
||||
def _get_embedding(self, text: str):
|
||||
"""内部方法:调用 Dashscope 生成向量"""
|
||||
# 注意:此方法是内部辅助,出错返回 None,由调用方处理状态
|
||||
embedding = None
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v3, # 确认你的模型版本
|
||||
input=text,
|
||||
dimension=1536
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
embedding = resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"Embedding API Error: {resp}")
|
||||
except Exception as e:
|
||||
print(f"Embedding Exception: {e}")
|
||||
|
||||
return embedding
|
||||
|
||||
def map_and_ingest(self, start_url: str):
|
||||
"""
|
||||
V2 步骤1: 地图式扫描并入库
|
||||
"""
|
||||
print(f"[WorkFlow] Start mapping: {start_url}")
|
||||
result = {}
|
||||
|
||||
try:
|
||||
# 1. 在数据库注册任务
|
||||
task_info = crawler_sql_service.register_task(start_url)
|
||||
task_id = task_info['task_id']
|
||||
is_new_task = task_info['is_new_task']
|
||||
|
||||
# 2. 调用 Firecrawl Map
|
||||
if is_new_task:
|
||||
map_result = self.firecrawl.map(start_url)
|
||||
|
||||
urls = []
|
||||
# 兼容 firecrawl sdk 不同版本的返回结构
|
||||
# 如果 map_result 是对象且有 links 属性
|
||||
if hasattr(map_result, 'links'):
|
||||
for link in map_result.links:
|
||||
# 假设 link 是对象或字典,视具体 SDK 版本而定
|
||||
# 如果 link 是字符串直接 append
|
||||
if isinstance(link, str):
|
||||
urls.append(link)
|
||||
else:
|
||||
urls.append(getattr(link, 'url', str(link)))
|
||||
# 如果是字典
|
||||
elif isinstance(map_result, dict):
|
||||
urls = map_result.get('links', [])
|
||||
|
||||
print(f"[WorkFlow] Found {len(urls)} links")
|
||||
|
||||
# 3. 批量入库
|
||||
res = {"msg": "No urls found to add"}
|
||||
if urls:
|
||||
res = crawler_sql_service.add_urls(task_id, urls)
|
||||
|
||||
result = {
|
||||
"msg": "Task successfully mapped and URLs added",
|
||||
"task_id": task_id,
|
||||
"is_new_task": is_new_task,
|
||||
"url_count": len(urls),
|
||||
"map_detail": res
|
||||
}
|
||||
else:
|
||||
result = {
|
||||
"msg": "Task already exists, skipped mapping",
|
||||
"task_id": task_id,
|
||||
"is_new_task": False,
|
||||
"url_count": 0,
|
||||
"map_detail": {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Map Error: {e}")
|
||||
# 向上抛出异常,由 main.py 捕获并返回错误 Response
|
||||
raise e
|
||||
|
||||
return result
|
||||
|
||||
def process_task_queue(self, task_id: int, limit: int = 10):
|
||||
"""
|
||||
V2 步骤2: 消费队列 -> 抓取 -> 切片 -> 向量化 -> 存储
|
||||
"""
|
||||
processed_count = 0
|
||||
total_chunks_saved = 0
|
||||
result = {}
|
||||
|
||||
# 1. 获取待处理 URL
|
||||
pending = crawler_sql_service.get_pending_urls(task_id, limit)
|
||||
urls = pending['urls']
|
||||
|
||||
if not urls:
|
||||
result = {"msg": "Queue is empty, no processing needed", "processed_count": 0}
|
||||
else:
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[WorkFlow] Processing: {url}")
|
||||
# 2. 单页抓取
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url,
|
||||
params={'formats': ['markdown'], 'onlyMainContent': True}
|
||||
)
|
||||
|
||||
# 兼容 SDK 返回类型 (对象或字典)
|
||||
content = ""
|
||||
metadata = {}
|
||||
|
||||
if isinstance(scrape_res, dict):
|
||||
content = scrape_res.get('markdown', '')
|
||||
metadata = scrape_res.get('metadata', {})
|
||||
else:
|
||||
content = getattr(scrape_res, 'markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {})
|
||||
if not metadata and hasattr(scrape_res, 'metadata_dict'):
|
||||
metadata = scrape_res.metadata_dict
|
||||
|
||||
title = metadata.get('title', url)
|
||||
|
||||
if not content:
|
||||
print(f"[WorkFlow] Skip empty content: {url}")
|
||||
continue
|
||||
|
||||
# 3. 切片
|
||||
chunks = self.splitter.split_text(content)
|
||||
results_to_save = []
|
||||
|
||||
# 4. 向量化
|
||||
for idx, chunk_text in enumerate(chunks):
|
||||
vector = self._get_embedding(chunk_text)
|
||||
if vector:
|
||||
results_to_save.append({
|
||||
"source_url": url,
|
||||
"chunk_index": idx,
|
||||
"title": title,
|
||||
"content": chunk_text,
|
||||
"embedding": vector
|
||||
})
|
||||
|
||||
# 5. 保存
|
||||
if results_to_save:
|
||||
save_res = crawler_sql_service.save_results(task_id, results_to_save)
|
||||
processed_count += 1
|
||||
total_chunks_saved += save_res['counts']['inserted'] + save_res['counts']['updated']
|
||||
|
||||
except Exception as e:
|
||||
print(f"[WorkFlow] Error processing {url}: {e}")
|
||||
# 此处不抛出异常,以免打断整个批次的循环
|
||||
# 实际生产建议在这里调用 service 将 url 标记为 failed
|
||||
|
||||
result = {
|
||||
"msg": f"Batch processing complete. URLs processed: {processed_count}",
|
||||
"processed_urls": processed_count,
|
||||
"total_chunks_saved": total_chunks_saved
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def search_with_embedding(self, query_text: str, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
V2 搜索: 输入文本 -> 自动转向量 -> 搜索数据库
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 1. 获取向量
|
||||
vector = self._get_embedding(query_text)
|
||||
|
||||
if not vector:
|
||||
result = {
|
||||
"msg": "Failed to generate embedding for query",
|
||||
"results": []
|
||||
}
|
||||
else:
|
||||
# 2. 执行搜索
|
||||
# search_knowledge 现在已经返回带 msg 的字典了
|
||||
result = crawler_sql_service.search_knowledge(vector, task_id, limit)
|
||||
|
||||
return result
|
||||
|
||||
# 单例模式
|
||||
workflow = AutomatedCrawler()
|
||||
150
backend/services/crawler_service.py
Normal file
150
backend/services/crawler_service.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# backend/services/crawler_service.py
|
||||
from firecrawl import FirecrawlApp
|
||||
from backend.core.config import settings
|
||||
from backend.services.data_service import data_service
|
||||
from backend.services.llm_service import llm_service
|
||||
from backend.utils.text_process import text_processor
|
||||
|
||||
class CrawlerService:
|
||||
"""
|
||||
爬虫编排服务
|
||||
协调 Firecrawl, LLM, 和 DataService
|
||||
"""
|
||||
def __init__(self):
|
||||
self.firecrawl = FirecrawlApp(api_key=settings.FIRECRAWL_API_KEY)
|
||||
|
||||
def map_site(self, start_url: str):
|
||||
print(f"[INFO] Mapping: {start_url}")
|
||||
try:
|
||||
# 1. 注册任务
|
||||
task_res = data_service.register_task(start_url)
|
||||
|
||||
# 2. 无论是否新任务,都尝试把 start_url 加入队列
|
||||
urls_to_add = [start_url]
|
||||
|
||||
# 3. 调用 Firecrawl Map
|
||||
try:
|
||||
map_res = self.firecrawl.map(start_url)
|
||||
|
||||
# 兼容不同 SDK 版本的返回 (对象 或 字典)
|
||||
found_links = []
|
||||
if isinstance(map_res, dict):
|
||||
found_links = map_res.get('links', [])
|
||||
else:
|
||||
found_links = getattr(map_res, 'links', [])
|
||||
|
||||
for link in found_links:
|
||||
# 链接可能是字符串,也可能是对象
|
||||
u = link if isinstance(link, str) else getattr(link, 'url', str(link))
|
||||
urls_to_add.append(u)
|
||||
|
||||
print(f"[INFO] Firecrawl Map found {len(found_links)} sub-links.")
|
||||
except Exception as e:
|
||||
print(f"[WARN] Firecrawl Map warning (proceeding with seed only): {e}")
|
||||
|
||||
# 4. 批量入库
|
||||
if urls_to_add:
|
||||
data_service.add_urls(task_res['task_id'], urls_to_add)
|
||||
|
||||
return {
|
||||
"msg": "Mapped successfully",
|
||||
"count": len(urls_to_add),
|
||||
"task_id": task_res['task_id']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Map failed: {e}")
|
||||
raise e
|
||||
|
||||
def process_queue(self, task_id: int, batch_size: int = 5):
|
||||
"""
|
||||
处理队列:爬取 -> 清洗 -> 切分 -> 向量化 -> 存储
|
||||
"""
|
||||
pending = data_service.get_pending_urls(task_id, batch_size)
|
||||
urls = pending['urls']
|
||||
if not urls: return {"msg": "Queue empty"}
|
||||
|
||||
processed = 0
|
||||
for url in urls:
|
||||
try:
|
||||
print(f"[INFO] Scraping: {url}")
|
||||
|
||||
# 调用 Firecrawl
|
||||
scrape_res = self.firecrawl.scrape(
|
||||
url,
|
||||
formats=['markdown'],
|
||||
only_main_content=True
|
||||
)
|
||||
|
||||
# =====================================================
|
||||
# 【核心修复】兼容字典和对象两种返回格式
|
||||
# =====================================================
|
||||
raw_md = ""
|
||||
metadata = None
|
||||
|
||||
# 1. 提取 markdown 和 metadata 本体
|
||||
if isinstance(scrape_res, dict):
|
||||
raw_md = scrape_res.get('markdown', '')
|
||||
metadata = scrape_res.get('metadata', {})
|
||||
else:
|
||||
# 如果是对象 (ScrapeResponse)
|
||||
raw_md = getattr(scrape_res, 'markdown', '')
|
||||
metadata = getattr(scrape_res, 'metadata', {})
|
||||
|
||||
# 2. 从 metadata 中安全提取 title
|
||||
title = url # 默认用 url 当标题
|
||||
if metadata:
|
||||
if isinstance(metadata, dict):
|
||||
# 如果 metadata 是字典
|
||||
title = metadata.get('title', url)
|
||||
else:
|
||||
# 如果 metadata 是对象 (DocumentMetadata) -> 这里就是你报错的地方
|
||||
title = getattr(metadata, 'title', url)
|
||||
|
||||
# =====================================================
|
||||
|
||||
if not raw_md:
|
||||
print(f"[WARN] Empty content for {url}")
|
||||
continue
|
||||
|
||||
# 1. 清洗
|
||||
clean_md = text_processor.clean_markdown(raw_md)
|
||||
if not clean_md: continue
|
||||
|
||||
# 2. 切分
|
||||
chunks = text_processor.split_markdown(clean_md)
|
||||
|
||||
chunks_data = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
content = chunk['content']
|
||||
headers = chunk['metadata']
|
||||
header_path = " > ".join(headers.values())
|
||||
|
||||
# 3. 向量化 (Title + Header + Content)
|
||||
emb_input = f"{title}\n{header_path}\n{content}"
|
||||
vector = llm_service.get_embedding(emb_input)
|
||||
|
||||
if vector:
|
||||
chunks_data.append({
|
||||
"index": i,
|
||||
"content": content,
|
||||
"embedding": vector,
|
||||
"meta_info": {"header_path": header_path, "headers": headers}
|
||||
})
|
||||
|
||||
# 4. 存储
|
||||
if chunks_data:
|
||||
data_service.save_chunks(task_id, url, title, chunks_data)
|
||||
processed += 1
|
||||
except Exception as e:
|
||||
# 打印详细错误方便调试
|
||||
print(f"[ERROR] Process URL {url} failed: {e}")
|
||||
|
||||
return {"msg": "Batch processed", "count": processed}
|
||||
|
||||
def search(self, query: str, task_id: int, limit: int):
|
||||
vector = llm_service.get_embedding(query)
|
||||
if not vector: return {"msg": "Embedding failed", "results": []}
|
||||
return data_service.search(vector, task_id, limit)
|
||||
|
||||
crawler_service = CrawlerService()
|
||||
@@ -1,230 +0,0 @@
|
||||
from sqlalchemy import select, insert, update, and_
|
||||
from ..database import db_instance
|
||||
from ..utils import normalize_url
|
||||
|
||||
class CrawlerSqlService:
|
||||
def __init__(self):
|
||||
self.db = db_instance
|
||||
|
||||
def register_task(self, url: str):
|
||||
"""完全使用库 API 实现的注册"""
|
||||
clean_url = normalize_url(url)
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 使用 select() API
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
result = {
|
||||
"task_id": existing[0],
|
||||
"is_new_task": False,
|
||||
"msg": "Task already exists"
|
||||
}
|
||||
else:
|
||||
# 使用 insert() API
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
result = {
|
||||
"task_id": new_task[0],
|
||||
"is_new_task": True,
|
||||
"msg": "New task created successfully"
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
"""通用 API 实现的批量添加(含详细返回)"""
|
||||
success_urls, skipped_urls, failed_urls = [], [], []
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
# 检查队列中是否已存在该 URL
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if conn.execute(check_q).fetchone():
|
||||
skipped_urls.append(clean_url)
|
||||
continue
|
||||
|
||||
# 插入新 URL
|
||||
conn.execute(insert(self.db.queue).values(
|
||||
task_id=task_id, url=clean_url, status='pending'
|
||||
))
|
||||
success_urls.append(clean_url)
|
||||
except Exception:
|
||||
failed_urls.append(clean_url)
|
||||
|
||||
# 构造返回消息
|
||||
msg = f"Added {len(success_urls)} urls, skipped {len(skipped_urls)}, failed {len(failed_urls)}"
|
||||
|
||||
return {
|
||||
"success_urls": success_urls,
|
||||
"skipped_urls": skipped_urls,
|
||||
"failed_urls": failed_urls,
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
"""原子锁定 API 实现"""
|
||||
result = {}
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
|
||||
if urls:
|
||||
upd = update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing')
|
||||
conn.execute(upd)
|
||||
result = {"urls": urls, "msg": f"Fetched {len(urls)} pending urls"}
|
||||
else:
|
||||
result = {"urls": [], "msg": "Queue is empty"}
|
||||
|
||||
return result
|
||||
|
||||
def save_results(self, task_id: int, results: list):
|
||||
"""
|
||||
保存同一 URL 的多个切片。
|
||||
"""
|
||||
if not results:
|
||||
return {"msg": "No data provided to save", "counts": {"inserted": 0, "updated": 0, "failed": 0}}
|
||||
|
||||
# 1. 基础信息提取
|
||||
first_item = results[0] if isinstance(results[0], dict) else results[0].__dict__
|
||||
target_url = normalize_url(first_item.get('source_url'))
|
||||
|
||||
# 结果统计容器
|
||||
inserted_chunks = []
|
||||
updated_chunks = []
|
||||
failed_chunks = []
|
||||
is_page_update = False
|
||||
|
||||
with self.db.engine.begin() as conn:
|
||||
# 2. 判断该 URL 是否已经有切片存在
|
||||
check_page_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id, self.db.chunks.c.source_url == target_url)
|
||||
).limit(1)
|
||||
if conn.execute(check_page_stmt).fetchone():
|
||||
is_page_update = True
|
||||
|
||||
# 3. 逐个处理切片
|
||||
for res in results:
|
||||
data = res if isinstance(res, dict) else res.__dict__
|
||||
c_idx = data.get('chunk_index')
|
||||
|
||||
try:
|
||||
# 检查切片是否存在
|
||||
find_chunk_stmt = select(self.db.chunks.c.id).where(
|
||||
and_(
|
||||
self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == target_url,
|
||||
self.db.chunks.c.chunk_index == c_idx
|
||||
)
|
||||
)
|
||||
existing_chunk = conn.execute(find_chunk_stmt).fetchone()
|
||||
|
||||
if existing_chunk:
|
||||
# 覆盖更新
|
||||
upd_stmt = update(self.db.chunks).where(
|
||||
self.db.chunks.c.id == existing_chunk[0]
|
||||
).values(
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(upd_stmt)
|
||||
updated_chunks.append(c_idx)
|
||||
else:
|
||||
# 插入新切片
|
||||
ins_stmt = insert(self.db.chunks).values(
|
||||
task_id=task_id,
|
||||
source_url=target_url,
|
||||
chunk_index=c_idx,
|
||||
title=data.get('title'),
|
||||
content=data.get('content'),
|
||||
embedding=data.get('embedding')
|
||||
)
|
||||
conn.execute(ins_stmt)
|
||||
inserted_chunks.append(c_idx)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Chunk {c_idx} failed: {e}")
|
||||
failed_chunks.append(c_idx)
|
||||
|
||||
# 4. 更新队列状态
|
||||
conn.execute(
|
||||
update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == target_url)
|
||||
).values(status='completed')
|
||||
)
|
||||
|
||||
# 构造返回
|
||||
msg = f"Saved results for {target_url}. Inserted: {len(inserted_chunks)}, Updated: {len(updated_chunks)}"
|
||||
|
||||
return {
|
||||
"source_url": target_url,
|
||||
"is_page_update": is_page_update,
|
||||
"detail": {
|
||||
"inserted_chunk_indexes": inserted_chunks,
|
||||
"updated_chunk_indexes": updated_chunks,
|
||||
"failed_chunk_indexes": failed_chunks
|
||||
},
|
||||
"counts": {
|
||||
"inserted": len(inserted_chunks),
|
||||
"updated": len(updated_chunks),
|
||||
"failed": len(failed_chunks)
|
||||
},
|
||||
"msg": msg
|
||||
}
|
||||
|
||||
def search_knowledge(self, query_embedding: list, task_id: int = None, limit: int = 5):
|
||||
"""
|
||||
高性能向量搜索方法
|
||||
"""
|
||||
results = []
|
||||
msg = ""
|
||||
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id,
|
||||
self.db.chunks.c.source_url,
|
||||
self.db.chunks.c.title,
|
||||
self.db.chunks.c.content,
|
||||
self.db.chunks.c.chunk_index
|
||||
)
|
||||
|
||||
if task_id is not None:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
stmt = stmt.order_by(
|
||||
self.db.chunks.c.embedding.cosine_distance(query_embedding)
|
||||
).limit(limit)
|
||||
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
|
||||
for r in rows:
|
||||
results.append({
|
||||
"task_id": r[0],
|
||||
"source_url": r[1],
|
||||
"title": r[2],
|
||||
"content": r[3],
|
||||
"chunk_index": r[4]
|
||||
})
|
||||
|
||||
if results:
|
||||
msg = f"Found {len(results)} matches"
|
||||
else:
|
||||
msg = "No matching content found"
|
||||
|
||||
return {"results": results, "msg": msg}
|
||||
|
||||
|
||||
crawler_sql_service = CrawlerSqlService()
|
||||
111
backend/services/data_service.py
Normal file
111
backend/services/data_service.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from sqlalchemy import select, insert, update, and_
|
||||
from backend.core.database import db
|
||||
from backend.utils.common import normalize_url
|
||||
|
||||
class DataService:
|
||||
"""
|
||||
数据持久化服务层
|
||||
只负责数据库 CRUD 操作,不包含外部 API 调用逻辑
|
||||
"""
|
||||
def __init__(self):
|
||||
self.db = db
|
||||
|
||||
def register_task(self, url: str):
|
||||
clean_url = normalize_url(url)
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.tasks.c.id).where(self.db.tasks.c.root_url == clean_url)
|
||||
existing = conn.execute(query).fetchone()
|
||||
|
||||
if existing:
|
||||
return {"task_id": existing[0], "is_new_task": False, "msg": "Task already exists"}
|
||||
else:
|
||||
stmt = insert(self.db.tasks).values(root_url=clean_url).returning(self.db.tasks.c.id)
|
||||
new_task = conn.execute(stmt).fetchone()
|
||||
return {"task_id": new_task[0], "is_new_task": True, "msg": "New task created"}
|
||||
|
||||
def add_urls(self, task_id: int, urls: list[str]):
|
||||
success_urls = []
|
||||
with self.db.engine.begin() as conn:
|
||||
for url in urls:
|
||||
clean_url = normalize_url(url)
|
||||
try:
|
||||
check_q = select(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
)
|
||||
if not conn.execute(check_q).fetchone():
|
||||
conn.execute(insert(self.db.queue).values(task_id=task_id, url=clean_url, status='pending'))
|
||||
success_urls.append(clean_url)
|
||||
except Exception:
|
||||
pass
|
||||
return {"msg": f"Added {len(success_urls)} new urls"}
|
||||
|
||||
def get_pending_urls(self, task_id: int, limit: int):
|
||||
with self.db.engine.begin() as conn:
|
||||
query = select(self.db.queue.c.url).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.status == 'pending')
|
||||
).limit(limit)
|
||||
urls = [r[0] for r in conn.execute(query).fetchall()]
|
||||
|
||||
if urls:
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url.in_(urls))
|
||||
).values(status='processing'))
|
||||
|
||||
return {"urls": urls, "msg": "Fetched pending urls"}
|
||||
|
||||
def save_chunks(self, task_id: int, source_url: str, title: str, chunks_data: list):
|
||||
"""
|
||||
保存切片数据 (Phase 1.5: 支持 meta_info)
|
||||
"""
|
||||
clean_url = normalize_url(source_url)
|
||||
count = 0
|
||||
with self.db.engine.begin() as conn:
|
||||
for item in chunks_data:
|
||||
# item 结构: {'index': int, 'content': str, 'embedding': list, 'meta_info': dict}
|
||||
idx = item['index']
|
||||
meta = item.get('meta_info', {})
|
||||
|
||||
# 检查是否存在
|
||||
existing = conn.execute(select(self.db.chunks.c.id).where(
|
||||
and_(self.db.chunks.c.task_id == task_id,
|
||||
self.db.chunks.c.source_url == clean_url,
|
||||
self.db.chunks.c.chunk_index == idx)
|
||||
)).fetchone()
|
||||
|
||||
values = {
|
||||
"task_id": task_id, "source_url": clean_url, "chunk_index": idx,
|
||||
"title": title, "content": item['content'], "embedding": item['embedding'],
|
||||
"meta_info": meta
|
||||
}
|
||||
|
||||
if existing:
|
||||
conn.execute(update(self.db.chunks).where(self.db.chunks.c.id == existing[0]).values(**values))
|
||||
else:
|
||||
conn.execute(insert(self.db.chunks).values(**values))
|
||||
count += 1
|
||||
|
||||
# 标记队列完成
|
||||
conn.execute(update(self.db.queue).where(
|
||||
and_(self.db.queue.c.task_id == task_id, self.db.queue.c.url == clean_url)
|
||||
).values(status='completed'))
|
||||
|
||||
return {"msg": f"Saved {count} chunks", "count": count}
|
||||
|
||||
def search(self, vector: list, task_id: int = None, limit: int = 5):
|
||||
with self.db.engine.connect() as conn:
|
||||
stmt = select(
|
||||
self.db.chunks.c.task_id, self.db.chunks.c.source_url, self.db.chunks.c.title,
|
||||
self.db.chunks.c.content, self.db.chunks.c.meta_info
|
||||
).order_by(self.db.chunks.c.embedding.cosine_distance(vector)).limit(limit)
|
||||
|
||||
if task_id:
|
||||
stmt = stmt.where(self.db.chunks.c.task_id == task_id)
|
||||
|
||||
rows = conn.execute(stmt).fetchall()
|
||||
results = [
|
||||
{"task_id": r[0], "source_url": r[1], "title": r[2], "content": r[3], "meta_info": r[4]}
|
||||
for r in rows
|
||||
]
|
||||
return {"results": results, "msg": f"Found {len(results)}"}
|
||||
|
||||
data_service = DataService()
|
||||
30
backend/services/llm_service.py
Normal file
30
backend/services/llm_service.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import dashscope
|
||||
from http import HTTPStatus
|
||||
from backend.core.config import settings
|
||||
|
||||
class LLMService:
|
||||
"""
|
||||
LLM 服务封装层
|
||||
负责与 DashScope 或其他模型供应商交互
|
||||
"""
|
||||
def __init__(self):
|
||||
dashscope.api_key = settings.DASHSCOPE_API_KEY
|
||||
|
||||
def get_embedding(self, text: str, dimension: int = 1536):
|
||||
"""生成文本向量"""
|
||||
try:
|
||||
resp = dashscope.TextEmbedding.call(
|
||||
model=dashscope.TextEmbedding.Models.text_embedding_v4,
|
||||
input=text,
|
||||
dimension=dimension
|
||||
)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return resp.output['embeddings'][0]['embedding']
|
||||
else:
|
||||
print(f"[ERROR] Embedding API Error: {resp}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Embedding Exception: {e}")
|
||||
return None
|
||||
|
||||
llm_service = LLMService()
|
||||
Reference in New Issue
Block a user