""" Agent管理服务 """ from sqlalchemy.orm import Session from typing import List, Optional, Dict import logging from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, ToolConfig, ToolUsageLog, init_default_data logger = logging.getLogger(__name__) class AgentService: """Agent管理服务""" def __init__(self, db: Session): self.db = db def get_all_agents(self) -> List[Agent]: """获取所有Agent""" return self.db.query(Agent).order_by(Agent.is_default.desc(), Agent.name).all() def get_active_agents(self) -> List[Agent]: """获取活跃的Agent""" return self.db.query(Agent).filter(Agent.is_active == True).all() def get_agent(self, agent_id: int) -> Optional[Agent]: """获取单个Agent""" return self.db.query(Agent).filter(Agent.id == agent_id).first() def get_agent_by_name(self, name: str) -> Optional[Agent]: """通过名称获取Agent""" return self.db.query(Agent).filter(Agent.name == name).first() def get_default_agent(self) -> Optional[Agent]: """获取默认Agent""" agent = self.db.query(Agent).filter(Agent.is_default == True, Agent.is_active == True).first() if not agent: # 如果没有默认,获取第一个活跃的 agent = self.db.query(Agent).filter(Agent.is_active == True).first() return agent def create_agent(self, data: Dict) -> Agent: """创建Agent""" agent = Agent( name=data.get('name'), display_name=data.get('display_name', data.get('name')), llm_provider_id=data.get('llm_provider_id'), model_override=data.get('model_override'), system_prompt=data.get('system_prompt', '你是一个有用的AI助手。'), enable_thinking=data.get('enable_thinking', True), thinking_prompt=data.get('thinking_prompt'), thinking_prefix=data.get('thinking_prefix', ''), thinking_suffix=data.get('thinking_suffix', ''), tools=data.get('tools', []), # 工具列表 max_history=data.get('max_history', 20), temperature_override=data.get('temperature_override'), is_active=data.get('is_active', True), is_default=data.get('is_default', False), description=data.get('description') ) self.db.add(agent) self.db.commit() self.db.refresh(agent) return agent def update_agent(self, agent_id: int, data: Dict) -> Optional[Agent]: """更新Agent""" agent = self.get_agent(agent_id) if not agent: return None for key, value in data.items(): if hasattr(agent, key) and value is not None: setattr(agent, key, value) self.db.commit() self.db.refresh(agent) return agent def delete_agent(self, agent_id: int) -> bool: """删除Agent""" agent = self.get_agent(agent_id) if not agent: return False # 如果是默认Agent,不允许删除 if agent.is_default: return False self.db.delete(agent) self.db.commit() return True def set_default_agent(self, agent_id: int) -> bool: """设置默认Agent""" # 先清除所有默认 self.db.query(Agent).update({Agent.is_default: False}) agent = self.get_agent(agent_id) if not agent: return False agent.is_default = True self.db.commit() return True def get_agent_config(self, agent_id: int) -> Dict: """获取Agent完整配置(包含LLM Provider信息)""" agent = self.get_agent(agent_id) if not agent: return {} provider = self.db.query(LLMProvider).filter(LLMProvider.id == agent.llm_provider_id).first() return { 'agent': { 'id': agent.id, 'name': agent.name, 'display_name': agent.display_name, 'system_prompt': agent.system_prompt, 'enable_thinking': agent.enable_thinking, 'thinking_prompt': agent.thinking_prompt, 'thinking_prefix': agent.thinking_prefix, 'thinking_suffix': agent.thinking_suffix, 'model_override': agent.model_override, 'max_history': agent.max_history, 'temperature_override': agent.temperature_override, 'tools': agent.tools or [], # 工具列表 'is_default': agent.is_default, 'is_active': agent.is_active }, 'provider': { 'id': provider.id if provider else None, 'name': provider.name if provider else None, 'api_base': provider.api_base if provider else None, 'api_key': provider.api_key if provider else None, 'supports_thinking': provider.supports_thinking if provider else False, 'thinking_model': provider.thinking_model if provider else None, 'default_model': provider.default_model if provider else 'auto', 'max_tokens': provider.max_tokens if provider else 4096, 'temperature': provider.temperature if provider else 0.7, 'models': provider.models if provider else [] } if provider else {} } def get_agent_for_channel(self, channel_id: int, conditions: Dict = None) -> Optional[Agent]: """根据渠道获取Agent(支持优先级和条件)""" mappings = self.db.query(ChannelAgentMapping).filter( ChannelAgentMapping.channel_id == channel_id, ChannelAgentMapping.is_active == True ).order_by(ChannelAgentMapping.priority).all() if not mappings: return self.get_default_agent() # 检查条件匹配 for mapping in mappings: if conditions and mapping.conditions: # 检查条件是否匹配 match = True for key, value in conditions.items(): if mapping.conditions.get(key) != value: match = False break if match: return self.get_agent(mapping.agent_id) else: # 无条件或无条件映射,返回第一个 return self.get_agent(mappings.agent_id) # 没有匹配的,返回默认 return self.get_default_agent() class LLMProviderService: """大模型池管理服务""" def __init__(self, db: Session): self.db = db def get_all_providers(self) -> List[LLMProvider]: """获取所有Provider""" return self.db.query(LLMProvider).order_by(LLMProvider.priority, LLMProvider.name).all() def get_active_providers(self) -> List[LLMProvider]: """获取活跃的Provider""" return self.db.query(LLMProvider).filter(LLMProvider.is_active == True).all() def get_provider(self, provider_id: int) -> Optional[LLMProvider]: """获取单个Provider""" return self.db.query(LLMProvider).filter(LLMProvider.id == provider_id).first() def get_provider_by_name(self, name: str) -> Optional[LLMProvider]: """通过名称获取Provider""" return self.db.query(LLMProvider).filter(LLMProvider.name == name).first() def create_provider(self, data: Dict) -> LLMProvider: """创建Provider""" provider = LLMProvider( name=data.get('name'), api_base=data.get('api_base'), api_key=data.get('api_key', ''), models=data.get('models', []), default_model=data.get('default_model', 'auto'), supports_thinking=data.get('supports_thinking', False), thinking_model=data.get('thinking_model'), max_tokens=data.get('max_tokens', 4096), temperature=data.get('temperature', 0.7), is_active=data.get('is_active', True), priority=data.get('priority', 0), description=data.get('description') ) self.db.add(provider) self.db.commit() self.db.refresh(provider) return provider def update_provider(self, provider_id: int, data: Dict) -> Optional[LLMProvider]: """更新Provider""" provider = self.get_provider(provider_id) if not provider: return None for key, value in data.items(): if hasattr(provider, key) and value is not None: setattr(provider, key, value) self.db.commit() self.db.refresh(provider) return provider def delete_provider(self, provider_id: int) -> bool: """删除Provider""" provider = self.get_provider(provider_id) if not provider: return False # 检查是否有Agent在使用 from models_v2 import Agent agents_count = self.db.query(Agent).filter(Agent.llm_provider_id == provider_id).count() if agents_count > 0: return False self.db.delete(provider) self.db.commit() return True def get_provider_config(self, provider_id: int) -> Dict: """获取Provider配置""" provider = self.get_provider(provider_id) if not provider: return {} return { 'id': provider.id, 'name': provider.name, 'api_base': provider.api_base, 'api_key': provider.api_key, 'models': provider.models, 'default_model': provider.default_model, 'supports_thinking': provider.supports_thinking, 'thinking_model': provider.thinking_model, 'max_tokens': provider.max_tokens, 'temperature': provider.temperature, 'is_active': provider.is_active, 'priority': provider.priority, 'description': provider.description } class ChannelService: """渠道管理服务""" def __init__(self, db: Session): self.db = db def get_all_channels(self) -> List[Channel]: """获取所有渠道""" return self.db.query(Channel).order_by(Channel.is_primary.desc(), Channel.channel_type).all() def get_active_channels(self) -> List[Channel]: """获取活跃渠道""" return self.db.query(Channel).filter(Channel.is_active == True).all() def get_channel(self, channel_id: int) -> Optional[Channel]: """获取单个渠道""" return self.db.query(Channel).filter(Channel.id == channel_id).first() def get_channel_by_type(self, channel_type: str) -> Optional[Channel]: """获取指定类型的渠道""" return self.db.query(Channel).filter( Channel.channel_type == channel_type, Channel.is_active == True ).first() def get_web_channel(self) -> Optional[Channel]: """获取网页渠道""" return self.get_channel_by_type('web') def get_matrix_channel(self) -> Optional[Channel]: """获取Matrix渠道""" return self.get_channel_by_type('matrix') def create_channel(self, data: Dict) -> Channel: """创建渠道""" channel = Channel( channel_type=data.get('channel_type'), name=data.get('name'), config=data.get('config', {}), is_active=data.get('is_active', True), is_primary=data.get('is_primary', False) ) self.db.add(channel) self.db.commit() self.db.refresh(channel) return channel def update_channel(self, channel_id: int, data: Dict) -> Optional[Channel]: """更新渠道""" channel = self.get_channel(channel_id) if not channel: return None for key, value in data.items(): if hasattr(channel, key) and value is not None: setattr(channel, key, value) self.db.commit() self.db.refresh(channel) return channel def delete_channel(self, channel_id: int) -> bool: """删除渠道""" channel = self.get_channel(channel_id) if not channel: return False # 检查是否是主渠道 if channel.is_primary: return False self.db.delete(channel) self.db.commit() return True def bind_agent(self, channel_id: int, agent_id: int, priority: int = 0, mode: str = 'single', conditions: Dict = None) -> ChannelAgentMapping: """绑定Agent到渠道""" mapping = ChannelAgentMapping( channel_id=channel_id, agent_id=agent_id, priority=priority, mode=mode, conditions=conditions, is_active=True ) self.db.add(mapping) self.db.commit() self.db.refresh(mapping) return mapping def unbind_agent(self, mapping_id: int) -> bool: """解绑Agent""" mapping = self.db.query(ChannelAgentMapping).filter(ChannelAgentMapping.id == mapping_id).first() if not mapping: return False self.db.delete(mapping) self.db.commit() return True def get_channel_agents(self, channel_id: int) -> List[Dict]: """获取渠道绑定的所有Agent""" mappings = self.db.query(ChannelAgentMapping).filter( ChannelAgentMapping.channel_id == channel_id ).order_by(ChannelAgentMapping.priority).all() result = [] for mapping in mappings: agent = self.db.query(Agent).filter(Agent.id == mapping.agent_id).first() if agent: result.append({ 'mapping_id': mapping.id, 'priority': mapping.priority, 'mode': mapping.mode, 'conditions': mapping.conditions, 'agent': { 'id': agent.id, 'name': agent.name, 'display_name': agent.display_name, 'is_active': agent.is_active } }) return result def get_channel_config(self, channel_id: int) -> Dict: """获取渠道完整配置""" channel = self.get_channel(channel_id) if not channel: return {} return { 'id': channel.id, 'channel_type': channel.channel_type, 'name': channel.name, 'config': channel.config, 'is_active': channel.is_active, 'is_primary': channel.is_primary, 'agent_mappings': self.get_channel_agents(channel_id) } class ToolService: """工具管理服务""" def __init__(self, db: Session): self.db = db def get_all_tools(self) -> List[ToolConfig]: """获取所有工具配置""" return self.db.query(ToolConfig).order_by(ToolConfig.tool_type, ToolConfig.is_default.desc()).all() def get_tools_by_type(self, tool_type: str) -> List[ToolConfig]: """获取指定类型的工具""" return self.db.query(ToolConfig).filter(ToolConfig.tool_type == tool_type).all() def get_active_tools(self) -> List[ToolConfig]: """获取活跃的工具""" return self.db.query(ToolConfig).filter(ToolConfig.is_active == True).all() def get_tool(self, tool_id: int) -> Optional[ToolConfig]: """获取单个工具""" return self.db.query(ToolConfig).filter(ToolConfig.id == tool_id).first() def get_default_tool(self, tool_type: str) -> Optional[ToolConfig]: """获取指定类型的默认工具""" tool = self.db.query(ToolConfig).filter( ToolConfig.tool_type == tool_type, ToolConfig.is_default == True, ToolConfig.is_active == True ).first() if not tool: tool = self.db.query(ToolConfig).filter( ToolConfig.tool_type == tool_type, ToolConfig.is_active == True ).first() return tool def create_tool(self, data: Dict) -> ToolConfig: """创建工具配置""" tool = ToolConfig( name=data.get('name'), tool_type=data.get('tool_type', 'search'), provider=data.get('provider'), config=data.get('config', {}), is_active=data.get('is_active', True), is_default=data.get('is_default', False) ) self.db.add(tool) self.db.commit() self.db.refresh(tool) return tool def update_tool(self, tool_id: int, data: Dict) -> Optional[ToolConfig]: """更新工具配置""" tool = self.get_tool(tool_id) if not tool: return None for key, value in data.items(): if hasattr(tool, key) and value is not None: setattr(tool, key, value) self.db.commit() self.db.refresh(tool) return tool def delete_tool(self, tool_id: int) -> bool: """删除工具配置""" tool = self.get_tool(tool_id) if not tool: return False self.db.delete(tool) self.db.commit() return True def set_default_tool(self, tool_id: int) -> bool: """设置默认工具""" tool = self.get_tool(tool_id) if not tool: return False # 清除同类型的其他默认 self.db.query(ToolConfig).filter( ToolConfig.tool_type == tool.tool_type ).update({ToolConfig.is_default: False}) tool.is_default = True self.db.commit() return True def increment_stats(self, tool_id: int, success: bool): """更新工具调用统计""" tool = self.get_tool(tool_id) if tool: tool.total_calls += 1 if success: tool.success_calls += 1 else: tool.failed_calls += 1 self.db.commit() def log_usage(self, data: Dict) -> ToolUsageLog: """记录工具使用日志""" log = ToolUsageLog( tool_id=data.get('tool_id'), tool_type=data.get('tool_type'), query=data.get('query'), success=data.get('success', True), error_message=data.get('error_message'), result_summary=data.get('result_summary'), conversation_id=data.get('conversation_id'), agent_id=data.get('agent_id'), user_id=data.get('user_id'), duration_ms=data.get('duration_ms') ) self.db.add(log) self.db.commit() self.db.refresh(log) return log def get_usage_stats(self, days: int = 7) -> Dict: """获取工具使用统计""" from datetime import datetime, timedelta start_date = datetime.utcnow() - timedelta(days=days) # 按工具类型统计 logs = self.db.query(ToolUsageLog).filter( ToolUsageLog.called_at >= start_date ).all() stats = { 'total_calls': len(logs), 'success_rate': sum(1 for l in logs if l.success) / len(logs) * 100 if logs else 0, 'by_type': {}, 'by_tool': {}, 'recent_errors': [] } for log in logs: # 按类型 if log.tool_type not in stats['by_type']: stats['by_type'][log.tool_type] = {'total': 0, 'success': 0, 'failed': 0} stats['by_type'][log.tool_type]['total'] += 1 if log.success: stats['by_type'][log.tool_type]['success'] += 1 else: stats['by_type'][log.tool_type]['failed'] += 1 # 按工具 tool = self.get_tool(log.tool_id) if log.tool_id else None tool_name = tool.name if tool else f'Tool#{log.tool_id}' if tool_name not in stats['by_tool']: stats['by_tool'][tool_name] = {'total': 0, 'success': 0} stats['by_tool'][tool_name]['total'] += 1 if log.success: stats['by_tool'][tool_name]['success'] += 1 # 最近错误 if not log.success and log.error_message: stats['recent_errors'].append({ 'tool': tool_name, 'error': log.error_message[:100], 'time': log.called_at.isoformat() }) return stats