From 0c9bfca346873652e1eb996ec57b0ca1ba248893 Mon Sep 17 00:00:00 2001 From: hubian <908234780@qq.com> Date: Mon, 13 Apr 2026 13:26:43 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E6=90=9C=E7=B4=A2?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E5=8A=9F=E8=83=BD=EF=BC=88Tavily=20Search?= =?UTF-8?q?=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 SearchToolConfig 模型:支持搜索工具配置 - Agent 增加 tools 字段:可配置可用工具列表 - 后台管理增加搜索工具配置页面 - Agent 管理增加工具启用开关 - 网页端增加搜索工具禁用复选框 - WebSocket chat 处理增加搜索调用逻辑 - 默认配置 Tavily Search API --- main_v2.py | 178 +++++++++++++++++++++++++++++++++- models_v2.py | 42 ++++++++ services/agent_service.py | 88 ++++++++++++++++- templates/admin_v2/index.html | 143 ++++++++++++++++++++++++++- templates/index.html | 21 +++- 5 files changed, 466 insertions(+), 6 deletions(-) diff --git a/main_v2.py b/main_v2.py index 4fc996f..b10607a 100644 --- a/main_v2.py +++ b/main_v2.py @@ -18,11 +18,11 @@ import os from models_v2 import ( init_db, get_db, SessionLocal, User, Conversation, Message, SystemConfig, - LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping, + LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping, SearchToolConfig, init_default_data ) from services.llm_service import llm_service -from services.agent_service import AgentService, LLMProviderService, ChannelService +from services.agent_service import AgentService, LLMProviderService, ChannelService, SearchToolService from services.conversation_service import ConversationService # 配置日志 @@ -204,6 +204,7 @@ async def get_agents(db: Session = Depends(get_db)): "thinking_prompt": a.thinking_prompt, "thinking_prefix": a.thinking_prefix, "thinking_suffix": a.thinking_suffix, + "tools": a.tools or [], # 工具列表 "max_history": a.max_history, "temperature_override": a.temperature_override, "is_default": a.is_default, @@ -418,6 +419,128 @@ async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)): return {"success": success} +# ==================== 搜索工具 API ==================== + +@app.get("/api/v2/search-tools") +async def get_search_tools(db: Session = Depends(get_db)): + """获取所有搜索工具配置""" + service = SearchToolService(db) + configs = service.get_all_configs() + + return { + "configs": [ + { + "id": c.id, + "name": c.name, + "provider": c.provider, + "api_key": c.api_key, + "api_base": c.api_base, + "max_results": c.max_results, + "include_raw_content": c.include_raw_content, + "search_depth": c.search_depth, + "is_active": c.is_active, + "is_default": c.is_default + } + for c in configs + ] + } + + +@app.post("/api/v2/search-tools") +async def create_search_tool(data: dict, db: Session = Depends(get_db)): + """创建搜索工具配置""" + service = SearchToolService(db) + + # 如果设为默认,更新其他 + if data.get('is_default'): + service.set_default_config(0) # 先清除所有默认 + + config = service.create_config(data) + return {"success": True, "config": {"id": config.id, "name": config.name}} + + +@app.put("/api/v2/search-tools/{config_id}") +async def update_search_tool(config_id: int, data: dict, db: Session = Depends(get_db)): + """更新搜索工具配置""" + service = SearchToolService(db) + + if data.get('is_default'): + service.set_default_config(config_id) + + config = service.update_config(config_id, data) + + if not config: + return {"success": False, "message": "配置不存在"} + + return {"success": True, "config": {"id": config.id, "name": config.name}} + + +@app.delete("/api/v2/search-tools/{config_id}") +async def delete_search_tool(config_id: int, db: Session = Depends(get_db)): + """删除搜索工具配置""" + service = SearchToolService(db) + success = service.delete_config(config_id) + + return {"success": success} + + +@app.post("/api/v2/search-tools/{config_id}/default") +async def set_search_tool_default(config_id: int, db: Session = Depends(get_db)): + """设置默认搜索工具""" + service = SearchToolService(db) + success = service.set_default_config(config_id) + + return {"success": success} + + +@app.post("/api/v2/search") +async def perform_search(data: dict, db: Session = Depends(get_db)): + """执行搜索(供前端或Agent调用)""" + import httpx + + query = data.get('query') + if not query: + return {"success": False, "message": "缺少搜索关键词"} + + # 获取搜索工具配置 + service = SearchToolService(db) + config_id = data.get('config_id') + + if config_id: + config = service.get_config(config_id) + else: + config = service.get_default_config() + + if not config or not config.api_key: + return {"success": False, "message": "未配置搜索工具"} + + # Tavily Search API + if config.provider == 'tavily': + try: + tavily_url = "https://api.tavily.com/search" + payload = { + "api_key": config.api_key, + "query": query, + "max_results": config.max_results, + "include_raw_content": config.include_raw_content, + "search_depth": config.search_depth + } + + async with httpx.AsyncClient(timeout=30) as client: + response = await client.post(tavily_url, json=payload) + result = response.json() + + return { + "success": True, + "results": result.get("results", []), + "query": query + } + except Exception as e: + return {"success": False, "message": str(e)} + + return {"success": False, "message": "不支持的搜索提供商"} + + # ==================== 对话 API(保留原有) ==================== @app.get("/api/conversations") @@ -610,6 +733,7 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str): conversation_id = data.get("conversation_id") enable_thinking = data.get("enable_thinking", True) agent_id_override = data.get("agent_id") + disabled_tools = data.get("disabled_tools", []) # 禁用的工具列表 if agent_id_override: agent = agent_service.get_agent(agent_id_override) @@ -619,6 +743,48 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str): if not message.strip(): continue + # 获取Agent配置 + agent_config = agent_service.get_agent_config(current_agent_id) + agent_tools = agent_config.get('agent', {}).get('tools', []) + + # 检查是否需要执行搜索 + search_context = None + if 'search' in agent_tools and 'search' not in disabled_tools: + # 使用关键词检测:如果消息包含搜索相关关键词,执行搜索 + search_keywords = ['搜索', '查找', '查询', '最新', '新闻', '新闻', 'weather', '天气', '股价', '股票', '汇率', 'what is', 'what are', 'find', 'search', 'look up'] + should_search = any(kw in message.lower() for kw in search_keywords) + + if should_search: + # 执行搜索 + search_service = SearchToolService(db) + search_config = search_service.get_default_config() + + if search_config and search_config.api_key: + import httpx + try: + logger.info(f"执行搜索: query={message}") + tavily_url = "https://api.tavily.com/search" + payload = { + "api_key": search_config.api_key, + "query": message, + "max_results": search_config.max_results, + "search_depth": search_config.search_depth + } + + # 同步调用(简化处理) + with httpx.Client(timeout=30) as client: + resp = client.post(tavily_url, json=payload) + search_result = resp.json() + + if search_result.get("results"): + # 构建搜索上下文 + search_context = "\n\n【搜索结果】\n" + for i, r in enumerate(search_result["results"][:5], 1): + search_context += f"{i}. {r.get('title', 'N/A')}\n {r.get('content', r.get('snippet', 'N/A'))[:200]}\n 来源: {r.get('url', 'N/A')}\n" + logger.info(f"搜索完成: {len(search_result['results'])} 条结果") + except Exception as e: + logger.error(f"搜索失败: {e}") + # 获取或创建会话 if conversation_id: conversation = conv_service.get_conversation(conversation_id) @@ -664,6 +830,14 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str): # 获取对话历史 history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20)) + # 如果有搜索结果,添加到消息中 + if search_context: + # 在系统提示中添加搜索结果说明 + modified_system_prompt = agent_config['agent'].get('system_prompt', '') + "\n\n如果提供了搜索结果,请基于搜索结果回答用户问题,并注明信息来源。" + agent_config['agent']['system_prompt'] = modified_system_prompt + # 将搜索结果作为系统消息添加到历史 + history.append({"role": "system", "content": f"以下是搜索到的相关信息,请参考这些内容回答用户问题:{search_context}"}) + # 使用非流式调用LLM(简化版本,确保稳定) try: # 调用LLM(非流式) diff --git a/models_v2.py b/models_v2.py index 3d29dfd..1809653 100644 --- a/models_v2.py +++ b/models_v2.py @@ -58,6 +58,9 @@ class Agent(Base): name = Column(String(100), unique=True, index=True) # Agent名称 display_name = Column(String(100)) # 显示名称 + # 工具配置 + tools = Column(JSON, default=list) # 可用工具列表 ["search", "calculator", ...] + # 大模型配置 llm_provider_id = Column(Integer, ForeignKey('llm_providers.id')) model_override = Column(String(100), nullable=True) # 覆盖Provider默认模型 @@ -224,6 +227,33 @@ class MatrixRoomMapping(Base): created_at = Column(DateTime, default=datetime.utcnow) +# ==================== 搜索工具配置 ==================== + +class SearchToolConfig(Base): + """搜索工具配置(Tavily等)""" + __tablename__ = 'search_tool_config' + + id = Column(Integer, primary_key=True, index=True) + name = Column(String(100)) # 工具名称,如 "Tavily Search" + provider = Column(String(50)) # 提供商:tavily, google, bing + + # API配置 + api_key = Column(String(200)) # API密钥 + api_base = Column(String(200), nullable=True) # API地址(可选) + + # 搜索参数 + max_results = Column(Integer, default=5) # 最大返回结果数 + include_raw_content = Column(Boolean, default=False) # 是否包含原始内容 + search_depth = Column(String(20), default='basic') # basic 或 advanced + + # 状态 + is_active = Column(Boolean, default=True) + is_default = Column(Boolean, default=False) # 是否为默认搜索工具 + + created_at = Column(DateTime, default=datetime.utcnow) + updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow) + + # ==================== 系统配置(保留) ==================== class SystemConfig(Base): @@ -335,6 +365,18 @@ def init_default_data(): ) db.add(matrix_mapping) + # 5. 创建默认搜索工具配置 + search_config = SearchToolConfig( + name="Tavily Search", + provider="tavily", + api_key="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", + max_results=5, + search_depth="basic", + is_active=True, + is_default=True + ) + db.add(search_config) + db.commit() print("默认数据初始化完成") diff --git a/services/agent_service.py b/services/agent_service.py index 9dfc476..e07b346 100644 --- a/services/agent_service.py +++ b/services/agent_service.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from typing import List, Optional, Dict import logging -from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, init_default_data +from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, SearchToolConfig, init_default_data logger = logging.getLogger(__name__) @@ -52,6 +52,7 @@ class AgentService: thinking_prompt=data.get('thinking_prompt'), thinking_prefix=data.get('thinking_prefix', ''), thinking_suffix=data.get('thinking_suffix', ''), + tools=data.get('tools', []), # 工具列表 max_history=data.get('max_history', 20), temperature_override=data.get('temperature_override'), is_active=data.get('is_active', True), @@ -405,4 +406,87 @@ class ChannelService: 'is_active': channel.is_active, 'is_primary': channel.is_primary, 'agent_mappings': self.get_channel_agents(channel_id) - } \ No newline at end of file + } + + +class SearchToolService: + """搜索工具管理服务""" + + def __init__(self, db: Session): + self.db = db + + def get_all_configs(self) -> List[SearchToolConfig]: + """获取所有搜索工具配置""" + return self.db.query(SearchToolConfig).order_by(SearchToolConfig.is_default.desc(), SearchToolConfig.name).all() + + def get_active_configs(self) -> List[SearchToolConfig]: + """获取活跃的搜索工具配置""" + return self.db.query(SearchToolConfig).filter(SearchToolConfig.is_active == True).all() + + def get_config(self, config_id: int) -> Optional[SearchToolConfig]: + """获取单个配置""" + return self.db.query(SearchToolConfig).filter(SearchToolConfig.id == config_id).first() + + def get_default_config(self) -> Optional[SearchToolConfig]: + """获取默认配置""" + config = self.db.query(SearchToolConfig).filter( + SearchToolConfig.is_default == True, + SearchToolConfig.is_active == True + ).first() + if not config: + config = self.db.query(SearchToolConfig).filter(SearchToolConfig.is_active == True).first() + return config + + def create_config(self, data: Dict) -> SearchToolConfig: + """创建搜索工具配置""" + config = SearchToolConfig( + name=data.get('name'), + provider=data.get('provider', 'tavily'), + api_key=data.get('api_key'), + api_base=data.get('api_base'), + max_results=data.get('max_results', 5), + include_raw_content=data.get('include_raw_content', False), + search_depth=data.get('search_depth', 'basic'), + is_active=data.get('is_active', True), + is_default=data.get('is_default', False) + ) + self.db.add(config) + self.db.commit() + self.db.refresh(config) + return config + + def update_config(self, config_id: int, data: Dict) -> Optional[SearchToolConfig]: + """更新配置""" + config = self.get_config(config_id) + if not config: + return None + + for key, value in data.items(): + if hasattr(config, key) and value is not None: + setattr(config, key, value) + + self.db.commit() + self.db.refresh(config) + return config + + def delete_config(self, config_id: int) -> bool: + """删除配置""" + config = self.get_config(config_id) + if not config: + return False + + self.db.delete(config) + self.db.commit() + return True + + def set_default_config(self, config_id: int) -> bool: + """设置默认配置""" + self.db.query(SearchToolConfig).update({SearchToolConfig.is_default: False}) + + config = self.get_config(config_id) + if not config: + return False + + config.is_default = True + self.db.commit() + return True \ No newline at end of file diff --git a/templates/admin_v2/index.html b/templates/admin_v2/index.html index 04bc535..e92b362 100644 --- a/templates/admin_v2/index.html +++ b/templates/admin_v2/index.html @@ -42,6 +42,7 @@ + @@ -109,6 +110,23 @@ + +
+ +
+
+ 搜索工具列表(Tavily、Google等) + +
+
+ + + +
名称提供商API Key最大结果默认状态操作
加载中...
+
+
+
+
@@ -152,6 +170,8 @@

思考功能
+
工具配置
+
启用后 Agent 可以使用搜索功能获取实时信息
@@ -180,6 +200,22 @@ + + +