feat: 添加搜索工具功能(Tavily Search)
- 新增 SearchToolConfig 模型:支持搜索工具配置 - Agent 增加 tools 字段:可配置可用工具列表 - 后台管理增加搜索工具配置页面 - Agent 管理增加工具启用开关 - 网页端增加搜索工具禁用复选框 - WebSocket chat 处理增加搜索调用逻辑 - 默认配置 Tavily Search API
This commit is contained in:
178
main_v2.py
178
main_v2.py
@@ -18,11 +18,11 @@ import os
|
||||
from models_v2 import (
|
||||
init_db, get_db, SessionLocal,
|
||||
User, Conversation, Message, SystemConfig,
|
||||
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping,
|
||||
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping, SearchToolConfig,
|
||||
init_default_data
|
||||
)
|
||||
from services.llm_service import llm_service
|
||||
from services.agent_service import AgentService, LLMProviderService, ChannelService
|
||||
from services.agent_service import AgentService, LLMProviderService, ChannelService, SearchToolService
|
||||
from services.conversation_service import ConversationService
|
||||
|
||||
# 配置日志
|
||||
@@ -204,6 +204,7 @@ async def get_agents(db: Session = Depends(get_db)):
|
||||
"thinking_prompt": a.thinking_prompt,
|
||||
"thinking_prefix": a.thinking_prefix,
|
||||
"thinking_suffix": a.thinking_suffix,
|
||||
"tools": a.tools or [], # 工具列表
|
||||
"max_history": a.max_history,
|
||||
"temperature_override": a.temperature_override,
|
||||
"is_default": a.is_default,
|
||||
@@ -418,6 +419,128 @@ async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)):
|
||||
return {"success": success}
|
||||
|
||||
|
||||
# ==================== 搜索工具 API ====================
|
||||
|
||||
@app.get("/api/v2/search-tools")
|
||||
async def get_search_tools(db: Session = Depends(get_db)):
|
||||
"""获取所有搜索工具配置"""
|
||||
service = SearchToolService(db)
|
||||
configs = service.get_all_configs()
|
||||
|
||||
return {
|
||||
"configs": [
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"provider": c.provider,
|
||||
"api_key": c.api_key,
|
||||
"api_base": c.api_base,
|
||||
"max_results": c.max_results,
|
||||
"include_raw_content": c.include_raw_content,
|
||||
"search_depth": c.search_depth,
|
||||
"is_active": c.is_active,
|
||||
"is_default": c.is_default
|
||||
}
|
||||
for c in configs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/v2/search-tools")
|
||||
async def create_search_tool(data: dict, db: Session = Depends(get_db)):
|
||||
"""创建搜索工具配置"""
|
||||
service = SearchToolService(db)
|
||||
|
||||
# 如果设为默认,更新其他
|
||||
if data.get('is_default'):
|
||||
service.set_default_config(0) # 先清除所有默认
|
||||
|
||||
config = service.create_config(data)
|
||||
return {"success": True, "config": {"id": config.id, "name": config.name}}
|
||||
|
||||
|
||||
@app.put("/api/v2/search-tools/{config_id}")
|
||||
async def update_search_tool(config_id: int, data: dict, db: Session = Depends(get_db)):
|
||||
"""更新搜索工具配置"""
|
||||
service = SearchToolService(db)
|
||||
|
||||
if data.get('is_default'):
|
||||
service.set_default_config(config_id)
|
||||
|
||||
config = service.update_config(config_id, data)
|
||||
|
||||
if not config:
|
||||
return {"success": False, "message": "配置不存在"}
|
||||
|
||||
return {"success": True, "config": {"id": config.id, "name": config.name}}
|
||||
|
||||
|
||||
@app.delete("/api/v2/search-tools/{config_id}")
|
||||
async def delete_search_tool(config_id: int, db: Session = Depends(get_db)):
|
||||
"""删除搜索工具配置"""
|
||||
service = SearchToolService(db)
|
||||
success = service.delete_config(config_id)
|
||||
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@app.post("/api/v2/search-tools/{config_id}/default")
|
||||
async def set_search_tool_default(config_id: int, db: Session = Depends(get_db)):
|
||||
"""设置默认搜索工具"""
|
||||
service = SearchToolService(db)
|
||||
success = service.set_default_config(config_id)
|
||||
|
||||
return {"success": success}
|
||||
|
||||
|
||||
@app.post("/api/v2/search")
|
||||
async def perform_search(data: dict, db: Session = Depends(get_db)):
|
||||
"""执行搜索(供前端或Agent调用)"""
|
||||
import httpx
|
||||
|
||||
query = data.get('query')
|
||||
if not query:
|
||||
return {"success": False, "message": "缺少搜索关键词"}
|
||||
|
||||
# 获取搜索工具配置
|
||||
service = SearchToolService(db)
|
||||
config_id = data.get('config_id')
|
||||
|
||||
if config_id:
|
||||
config = service.get_config(config_id)
|
||||
else:
|
||||
config = service.get_default_config()
|
||||
|
||||
if not config or not config.api_key:
|
||||
return {"success": False, "message": "未配置搜索工具"}
|
||||
|
||||
# Tavily Search API
|
||||
if config.provider == 'tavily':
|
||||
try:
|
||||
tavily_url = "https://api.tavily.com/search"
|
||||
payload = {
|
||||
"api_key": config.api_key,
|
||||
"query": query,
|
||||
"max_results": config.max_results,
|
||||
"include_raw_content": config.include_raw_content,
|
||||
"search_depth": config.search_depth
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
response = await client.post(tavily_url, json=payload)
|
||||
result = response.json()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"results": result.get("results", []),
|
||||
"query": query
|
||||
}
|
||||
except Exception as e:
|
||||
return {"success": False, "message": str(e)}
|
||||
|
||||
return {"success": False, "message": "不支持的搜索提供商"}
|
||||
|
||||
|
||||
# ==================== 对话 API(保留原有) ====================
|
||||
|
||||
@app.get("/api/conversations")
|
||||
@@ -610,6 +733,7 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
||||
conversation_id = data.get("conversation_id")
|
||||
enable_thinking = data.get("enable_thinking", True)
|
||||
agent_id_override = data.get("agent_id")
|
||||
disabled_tools = data.get("disabled_tools", []) # 禁用的工具列表
|
||||
|
||||
if agent_id_override:
|
||||
agent = agent_service.get_agent(agent_id_override)
|
||||
@@ -619,6 +743,48 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
||||
if not message.strip():
|
||||
continue
|
||||
|
||||
# 获取Agent配置
|
||||
agent_config = agent_service.get_agent_config(current_agent_id)
|
||||
agent_tools = agent_config.get('agent', {}).get('tools', [])
|
||||
|
||||
# 检查是否需要执行搜索
|
||||
search_context = None
|
||||
if 'search' in agent_tools and 'search' not in disabled_tools:
|
||||
# 使用关键词检测:如果消息包含搜索相关关键词,执行搜索
|
||||
search_keywords = ['搜索', '查找', '查询', '最新', '新闻', '新闻', 'weather', '天气', '股价', '股票', '汇率', 'what is', 'what are', 'find', 'search', 'look up']
|
||||
should_search = any(kw in message.lower() for kw in search_keywords)
|
||||
|
||||
if should_search:
|
||||
# 执行搜索
|
||||
search_service = SearchToolService(db)
|
||||
search_config = search_service.get_default_config()
|
||||
|
||||
if search_config and search_config.api_key:
|
||||
import httpx
|
||||
try:
|
||||
logger.info(f"执行搜索: query={message}")
|
||||
tavily_url = "https://api.tavily.com/search"
|
||||
payload = {
|
||||
"api_key": search_config.api_key,
|
||||
"query": message,
|
||||
"max_results": search_config.max_results,
|
||||
"search_depth": search_config.search_depth
|
||||
}
|
||||
|
||||
# 同步调用(简化处理)
|
||||
with httpx.Client(timeout=30) as client:
|
||||
resp = client.post(tavily_url, json=payload)
|
||||
search_result = resp.json()
|
||||
|
||||
if search_result.get("results"):
|
||||
# 构建搜索上下文
|
||||
search_context = "\n\n【搜索结果】\n"
|
||||
for i, r in enumerate(search_result["results"][:5], 1):
|
||||
search_context += f"{i}. {r.get('title', 'N/A')}\n {r.get('content', r.get('snippet', 'N/A'))[:200]}\n 来源: {r.get('url', 'N/A')}\n"
|
||||
logger.info(f"搜索完成: {len(search_result['results'])} 条结果")
|
||||
except Exception as e:
|
||||
logger.error(f"搜索失败: {e}")
|
||||
|
||||
# 获取或创建会话
|
||||
if conversation_id:
|
||||
conversation = conv_service.get_conversation(conversation_id)
|
||||
@@ -664,6 +830,14 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
||||
# 获取对话历史
|
||||
history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20))
|
||||
|
||||
# 如果有搜索结果,添加到消息中
|
||||
if search_context:
|
||||
# 在系统提示中添加搜索结果说明
|
||||
modified_system_prompt = agent_config['agent'].get('system_prompt', '') + "\n\n如果提供了搜索结果,请基于搜索结果回答用户问题,并注明信息来源。"
|
||||
agent_config['agent']['system_prompt'] = modified_system_prompt
|
||||
# 将搜索结果作为系统消息添加到历史
|
||||
history.append({"role": "system", "content": f"以下是搜索到的相关信息,请参考这些内容回答用户问题:{search_context}"})
|
||||
|
||||
# 使用非流式调用LLM(简化版本,确保稳定)
|
||||
try:
|
||||
# 调用LLM(非流式)
|
||||
|
||||
Reference in New Issue
Block a user