581 lines
21 KiB
Python
581 lines
21 KiB
Python
"""
|
||
Agent管理服务
|
||
"""
|
||
from sqlalchemy.orm import Session
|
||
from typing import List, Optional, Dict
|
||
import logging
|
||
|
||
from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, ToolConfig, ToolUsageLog, init_default_data
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentService:
|
||
"""Agent管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_agents(self) -> List[Agent]:
|
||
"""获取所有Agent"""
|
||
return self.db.query(Agent).order_by(Agent.is_default.desc(), Agent.name).all()
|
||
|
||
def get_active_agents(self) -> List[Agent]:
|
||
"""获取活跃的Agent"""
|
||
return self.db.query(Agent).filter(Agent.is_active == True).all()
|
||
|
||
def get_agent(self, agent_id: int) -> Optional[Agent]:
|
||
"""获取单个Agent"""
|
||
return self.db.query(Agent).filter(Agent.id == agent_id).first()
|
||
|
||
def get_agent_by_name(self, name: str) -> Optional[Agent]:
|
||
"""通过名称获取Agent"""
|
||
return self.db.query(Agent).filter(Agent.name == name).first()
|
||
|
||
def get_default_agent(self) -> Optional[Agent]:
|
||
"""获取默认Agent"""
|
||
agent = self.db.query(Agent).filter(Agent.is_default == True, Agent.is_active == True).first()
|
||
if not agent:
|
||
# 如果没有默认,获取第一个活跃的
|
||
agent = self.db.query(Agent).filter(Agent.is_active == True).first()
|
||
return agent
|
||
|
||
def create_agent(self, data: Dict) -> Agent:
|
||
"""创建Agent"""
|
||
agent = Agent(
|
||
name=data.get('name'),
|
||
display_name=data.get('display_name', data.get('name')),
|
||
llm_provider_id=data.get('llm_provider_id'),
|
||
model_override=data.get('model_override'),
|
||
system_prompt=data.get('system_prompt', '你是一个有用的AI助手。'),
|
||
enable_thinking=data.get('enable_thinking', True),
|
||
thinking_prompt=data.get('thinking_prompt'),
|
||
thinking_prefix=data.get('thinking_prefix', ''),
|
||
thinking_suffix=data.get('thinking_suffix', ''),
|
||
tools=data.get('tools', []), # 工具列表
|
||
max_history=data.get('max_history', 20),
|
||
temperature_override=data.get('temperature_override'),
|
||
is_active=data.get('is_active', True),
|
||
is_default=data.get('is_default', False),
|
||
description=data.get('description')
|
||
)
|
||
self.db.add(agent)
|
||
self.db.commit()
|
||
self.db.refresh(agent)
|
||
return agent
|
||
|
||
def update_agent(self, agent_id: int, data: Dict) -> Optional[Agent]:
|
||
"""更新Agent"""
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(agent, key) and value is not None:
|
||
setattr(agent, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(agent)
|
||
return agent
|
||
|
||
def delete_agent(self, agent_id: int) -> bool:
|
||
"""删除Agent"""
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return False
|
||
|
||
# 如果是默认Agent,不允许删除
|
||
if agent.is_default:
|
||
return False
|
||
|
||
self.db.delete(agent)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def set_default_agent(self, agent_id: int) -> bool:
|
||
"""设置默认Agent"""
|
||
# 先清除所有默认
|
||
self.db.query(Agent).update({Agent.is_default: False})
|
||
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return False
|
||
|
||
agent.is_default = True
|
||
self.db.commit()
|
||
return True
|
||
|
||
def get_agent_config(self, agent_id: int) -> Dict:
|
||
"""获取Agent完整配置(包含LLM Provider信息)"""
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return {}
|
||
|
||
provider = self.db.query(LLMProvider).filter(LLMProvider.id == agent.llm_provider_id).first()
|
||
|
||
return {
|
||
'agent': {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'display_name': agent.display_name,
|
||
'system_prompt': agent.system_prompt,
|
||
'enable_thinking': agent.enable_thinking,
|
||
'thinking_prompt': agent.thinking_prompt,
|
||
'thinking_prefix': agent.thinking_prefix,
|
||
'thinking_suffix': agent.thinking_suffix,
|
||
'model_override': agent.model_override,
|
||
'max_history': agent.max_history,
|
||
'temperature_override': agent.temperature_override,
|
||
'tools': agent.tools or [], # 工具列表
|
||
'is_default': agent.is_default,
|
||
'is_active': agent.is_active
|
||
},
|
||
'provider': {
|
||
'id': provider.id if provider else None,
|
||
'name': provider.name if provider else None,
|
||
'api_base': provider.api_base if provider else None,
|
||
'api_key': provider.api_key if provider else None,
|
||
'supports_thinking': provider.supports_thinking if provider else False,
|
||
'thinking_model': provider.thinking_model if provider else None,
|
||
'supports_vision': provider.supports_vision if provider else False,
|
||
'vision_model': provider.vision_model if provider else None,
|
||
'supports_function_calling': provider.supports_function_calling if provider else False,
|
||
'default_model': provider.default_model if provider else 'auto',
|
||
'max_tokens': provider.max_tokens if provider else 4096,
|
||
'temperature': provider.temperature if provider else 0.7,
|
||
'models': provider.models if provider else []
|
||
} if provider else {}
|
||
}
|
||
|
||
def get_agent_for_channel(self, channel_id: int, conditions: Dict = None) -> Optional[Agent]:
|
||
"""根据渠道获取Agent(支持优先级和条件)"""
|
||
mappings = self.db.query(ChannelAgentMapping).filter(
|
||
ChannelAgentMapping.channel_id == channel_id,
|
||
ChannelAgentMapping.is_active == True
|
||
).order_by(ChannelAgentMapping.priority).all()
|
||
|
||
if not mappings:
|
||
return self.get_default_agent()
|
||
|
||
# 检查条件匹配
|
||
for mapping in mappings:
|
||
if conditions and mapping.conditions:
|
||
# 检查条件是否匹配
|
||
match = True
|
||
for key, value in conditions.items():
|
||
if mapping.conditions.get(key) != value:
|
||
match = False
|
||
break
|
||
if match:
|
||
return self.get_agent(mapping.agent_id)
|
||
else:
|
||
# 无条件或无条件映射,返回第一个
|
||
return self.get_agent(mappings.agent_id)
|
||
|
||
# 没有匹配的,返回默认
|
||
return self.get_default_agent()
|
||
|
||
|
||
class LLMProviderService:
|
||
"""大模型池管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_providers(self) -> List[LLMProvider]:
|
||
"""获取所有Provider"""
|
||
return self.db.query(LLMProvider).order_by(LLMProvider.priority, LLMProvider.name).all()
|
||
|
||
def get_active_providers(self) -> List[LLMProvider]:
|
||
"""获取活跃的Provider"""
|
||
return self.db.query(LLMProvider).filter(LLMProvider.is_active == True).all()
|
||
|
||
def get_provider(self, provider_id: int) -> Optional[LLMProvider]:
|
||
"""获取单个Provider"""
|
||
return self.db.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
|
||
|
||
def get_provider_by_name(self, name: str) -> Optional[LLMProvider]:
|
||
"""通过名称获取Provider"""
|
||
return self.db.query(LLMProvider).filter(LLMProvider.name == name).first()
|
||
|
||
def create_provider(self, data: Dict) -> LLMProvider:
|
||
"""创建Provider"""
|
||
provider = LLMProvider(
|
||
name=data.get('name'),
|
||
api_base=data.get('api_base'),
|
||
api_key=data.get('api_key', ''),
|
||
models=data.get('models', []),
|
||
default_model=data.get('default_model', 'auto'),
|
||
supports_thinking=data.get('supports_thinking', False),
|
||
thinking_model=data.get('thinking_model'),
|
||
max_tokens=data.get('max_tokens', 4096),
|
||
temperature=data.get('temperature', 0.7),
|
||
is_active=data.get('is_active', True),
|
||
priority=data.get('priority', 0),
|
||
description=data.get('description')
|
||
)
|
||
self.db.add(provider)
|
||
self.db.commit()
|
||
self.db.refresh(provider)
|
||
return provider
|
||
|
||
def update_provider(self, provider_id: int, data: Dict) -> Optional[LLMProvider]:
|
||
"""更新Provider"""
|
||
provider = self.get_provider(provider_id)
|
||
if not provider:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(provider, key) and value is not None:
|
||
setattr(provider, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(provider)
|
||
return provider
|
||
|
||
def delete_provider(self, provider_id: int) -> bool:
|
||
"""删除Provider"""
|
||
provider = self.get_provider(provider_id)
|
||
if not provider:
|
||
return False
|
||
|
||
# 检查是否有Agent在使用
|
||
from models_v2 import Agent
|
||
agents_count = self.db.query(Agent).filter(Agent.llm_provider_id == provider_id).count()
|
||
if agents_count > 0:
|
||
return False
|
||
|
||
self.db.delete(provider)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def get_provider_config(self, provider_id: int) -> Dict:
|
||
"""获取Provider配置"""
|
||
provider = self.get_provider(provider_id)
|
||
if not provider:
|
||
return {}
|
||
|
||
return {
|
||
'id': provider.id,
|
||
'name': provider.name,
|
||
'api_base': provider.api_base,
|
||
'api_key': provider.api_key,
|
||
'models': provider.models,
|
||
'default_model': provider.default_model,
|
||
'supports_thinking': provider.supports_thinking,
|
||
'thinking_model': provider.thinking_model,
|
||
'max_tokens': provider.max_tokens,
|
||
'temperature': provider.temperature,
|
||
'is_active': provider.is_active,
|
||
'priority': provider.priority,
|
||
'description': provider.description
|
||
}
|
||
|
||
|
||
class ChannelService:
|
||
"""渠道管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_channels(self) -> List[Channel]:
|
||
"""获取所有渠道"""
|
||
return self.db.query(Channel).order_by(Channel.is_primary.desc(), Channel.channel_type).all()
|
||
|
||
def get_active_channels(self) -> List[Channel]:
|
||
"""获取活跃渠道"""
|
||
return self.db.query(Channel).filter(Channel.is_active == True).all()
|
||
|
||
def get_channel(self, channel_id: int) -> Optional[Channel]:
|
||
"""获取单个渠道"""
|
||
return self.db.query(Channel).filter(Channel.id == channel_id).first()
|
||
|
||
def get_channel_by_type(self, channel_type: str) -> Optional[Channel]:
|
||
"""获取指定类型的渠道"""
|
||
return self.db.query(Channel).filter(
|
||
Channel.channel_type == channel_type,
|
||
Channel.is_active == True
|
||
).first()
|
||
|
||
def get_web_channel(self) -> Optional[Channel]:
|
||
"""获取网页渠道"""
|
||
return self.get_channel_by_type('web')
|
||
|
||
def get_matrix_channel(self) -> Optional[Channel]:
|
||
"""获取Matrix渠道"""
|
||
return self.get_channel_by_type('matrix')
|
||
|
||
def create_channel(self, data: Dict) -> Channel:
|
||
"""创建渠道"""
|
||
channel = Channel(
|
||
channel_type=data.get('channel_type'),
|
||
name=data.get('name'),
|
||
config=data.get('config', {}),
|
||
is_active=data.get('is_active', True),
|
||
is_primary=data.get('is_primary', False)
|
||
)
|
||
self.db.add(channel)
|
||
self.db.commit()
|
||
self.db.refresh(channel)
|
||
return channel
|
||
|
||
def update_channel(self, channel_id: int, data: Dict) -> Optional[Channel]:
|
||
"""更新渠道"""
|
||
channel = self.get_channel(channel_id)
|
||
if not channel:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(channel, key) and value is not None:
|
||
setattr(channel, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(channel)
|
||
return channel
|
||
|
||
def delete_channel(self, channel_id: int) -> bool:
|
||
"""删除渠道"""
|
||
channel = self.get_channel(channel_id)
|
||
if not channel:
|
||
return False
|
||
|
||
# 检查是否是主渠道
|
||
if channel.is_primary:
|
||
return False
|
||
|
||
self.db.delete(channel)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def bind_agent(self, channel_id: int, agent_id: int, priority: int = 0, mode: str = 'single', conditions: Dict = None) -> ChannelAgentMapping:
|
||
"""绑定Agent到渠道"""
|
||
mapping = ChannelAgentMapping(
|
||
channel_id=channel_id,
|
||
agent_id=agent_id,
|
||
priority=priority,
|
||
mode=mode,
|
||
conditions=conditions,
|
||
is_active=True
|
||
)
|
||
self.db.add(mapping)
|
||
self.db.commit()
|
||
self.db.refresh(mapping)
|
||
return mapping
|
||
|
||
def unbind_agent(self, mapping_id: int) -> bool:
|
||
"""解绑Agent"""
|
||
mapping = self.db.query(ChannelAgentMapping).filter(ChannelAgentMapping.id == mapping_id).first()
|
||
if not mapping:
|
||
return False
|
||
|
||
self.db.delete(mapping)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def get_channel_agents(self, channel_id: int) -> List[Dict]:
|
||
"""获取渠道绑定的所有Agent"""
|
||
mappings = self.db.query(ChannelAgentMapping).filter(
|
||
ChannelAgentMapping.channel_id == channel_id
|
||
).order_by(ChannelAgentMapping.priority).all()
|
||
|
||
result = []
|
||
for mapping in mappings:
|
||
agent = self.db.query(Agent).filter(Agent.id == mapping.agent_id).first()
|
||
if agent:
|
||
result.append({
|
||
'mapping_id': mapping.id,
|
||
'priority': mapping.priority,
|
||
'mode': mapping.mode,
|
||
'conditions': mapping.conditions,
|
||
'agent': {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'display_name': agent.display_name,
|
||
'is_active': agent.is_active
|
||
}
|
||
})
|
||
return result
|
||
|
||
def get_channel_config(self, channel_id: int) -> Dict:
|
||
"""获取渠道完整配置"""
|
||
channel = self.get_channel(channel_id)
|
||
if not channel:
|
||
return {}
|
||
|
||
return {
|
||
'id': channel.id,
|
||
'channel_type': channel.channel_type,
|
||
'name': channel.name,
|
||
'config': channel.config,
|
||
'is_active': channel.is_active,
|
||
'is_primary': channel.is_primary,
|
||
'agent_mappings': self.get_channel_agents(channel_id)
|
||
}
|
||
|
||
|
||
class ToolService:
|
||
"""工具管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_tools(self) -> List[ToolConfig]:
|
||
"""获取所有工具配置"""
|
||
return self.db.query(ToolConfig).order_by(ToolConfig.tool_type, ToolConfig.is_default.desc()).all()
|
||
|
||
def get_tools_by_type(self, tool_type: str) -> List[ToolConfig]:
|
||
"""获取指定类型的工具"""
|
||
return self.db.query(ToolConfig).filter(ToolConfig.tool_type == tool_type).all()
|
||
|
||
def get_active_tools(self) -> List[ToolConfig]:
|
||
"""获取活跃的工具"""
|
||
return self.db.query(ToolConfig).filter(ToolConfig.is_active == True).all()
|
||
|
||
def get_tool(self, tool_id: int) -> Optional[ToolConfig]:
|
||
"""获取单个工具"""
|
||
return self.db.query(ToolConfig).filter(ToolConfig.id == tool_id).first()
|
||
|
||
def get_default_tool(self, tool_type: str) -> Optional[ToolConfig]:
|
||
"""获取指定类型的默认工具"""
|
||
tool = self.db.query(ToolConfig).filter(
|
||
ToolConfig.tool_type == tool_type,
|
||
ToolConfig.is_default == True,
|
||
ToolConfig.is_active == True
|
||
).first()
|
||
if not tool:
|
||
tool = self.db.query(ToolConfig).filter(
|
||
ToolConfig.tool_type == tool_type,
|
||
ToolConfig.is_active == True
|
||
).first()
|
||
return tool
|
||
|
||
def create_tool(self, data: Dict) -> ToolConfig:
|
||
"""创建工具配置"""
|
||
tool = ToolConfig(
|
||
name=data.get('name'),
|
||
tool_type=data.get('tool_type', 'search'),
|
||
provider=data.get('provider'),
|
||
config=data.get('config', {}),
|
||
is_active=data.get('is_active', True),
|
||
is_default=data.get('is_default', False)
|
||
)
|
||
self.db.add(tool)
|
||
self.db.commit()
|
||
self.db.refresh(tool)
|
||
return tool
|
||
|
||
def update_tool(self, tool_id: int, data: Dict) -> Optional[ToolConfig]:
|
||
"""更新工具配置"""
|
||
tool = self.get_tool(tool_id)
|
||
if not tool:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(tool, key) and value is not None:
|
||
setattr(tool, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(tool)
|
||
return tool
|
||
|
||
def delete_tool(self, tool_id: int) -> bool:
|
||
"""删除工具配置"""
|
||
tool = self.get_tool(tool_id)
|
||
if not tool:
|
||
return False
|
||
|
||
self.db.delete(tool)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def set_default_tool(self, tool_id: int) -> bool:
|
||
"""设置默认工具"""
|
||
tool = self.get_tool(tool_id)
|
||
if not tool:
|
||
return False
|
||
|
||
# 清除同类型的其他默认
|
||
self.db.query(ToolConfig).filter(
|
||
ToolConfig.tool_type == tool.tool_type
|
||
).update({ToolConfig.is_default: False})
|
||
|
||
tool.is_default = True
|
||
self.db.commit()
|
||
return True
|
||
|
||
def increment_stats(self, tool_id: int, success: bool):
|
||
"""更新工具调用统计"""
|
||
tool = self.get_tool(tool_id)
|
||
if tool:
|
||
tool.total_calls += 1
|
||
if success:
|
||
tool.success_calls += 1
|
||
else:
|
||
tool.failed_calls += 1
|
||
self.db.commit()
|
||
|
||
def log_usage(self, data: Dict) -> ToolUsageLog:
|
||
"""记录工具使用日志"""
|
||
log = ToolUsageLog(
|
||
tool_id=data.get('tool_id'),
|
||
tool_type=data.get('tool_type'),
|
||
query=data.get('query'),
|
||
success=data.get('success', True),
|
||
error_message=data.get('error_message'),
|
||
result_summary=data.get('result_summary'),
|
||
conversation_id=data.get('conversation_id'),
|
||
agent_id=data.get('agent_id'),
|
||
user_id=data.get('user_id'),
|
||
duration_ms=data.get('duration_ms')
|
||
)
|
||
self.db.add(log)
|
||
self.db.commit()
|
||
self.db.refresh(log)
|
||
return log
|
||
|
||
def get_usage_stats(self, days: int = 7) -> Dict:
|
||
"""获取工具使用统计"""
|
||
from datetime import datetime, timedelta
|
||
start_date = datetime.utcnow() - timedelta(days=days)
|
||
|
||
# 按工具类型统计
|
||
logs = self.db.query(ToolUsageLog).filter(
|
||
ToolUsageLog.called_at >= start_date
|
||
).all()
|
||
|
||
stats = {
|
||
'total_calls': len(logs),
|
||
'success_rate': sum(1 for l in logs if l.success) / len(logs) * 100 if logs else 0,
|
||
'by_type': {},
|
||
'by_tool': {},
|
||
'recent_errors': []
|
||
}
|
||
|
||
for log in logs:
|
||
# 按类型
|
||
if log.tool_type not in stats['by_type']:
|
||
stats['by_type'][log.tool_type] = {'total': 0, 'success': 0, 'failed': 0}
|
||
stats['by_type'][log.tool_type]['total'] += 1
|
||
if log.success:
|
||
stats['by_type'][log.tool_type]['success'] += 1
|
||
else:
|
||
stats['by_type'][log.tool_type]['failed'] += 1
|
||
|
||
# 按工具
|
||
tool = self.get_tool(log.tool_id) if log.tool_id else None
|
||
tool_name = tool.name if tool else f'Tool#{log.tool_id}'
|
||
if tool_name not in stats['by_tool']:
|
||
stats['by_tool'][tool_name] = {'total': 0, 'success': 0}
|
||
stats['by_tool'][tool_name]['total'] += 1
|
||
if log.success:
|
||
stats['by_tool'][tool_name]['success'] += 1
|
||
|
||
# 最近错误
|
||
if not log.success and log.error_message:
|
||
stats['recent_errors'].append({
|
||
'tool': tool_name,
|
||
'error': log.error_message[:100],
|
||
'time': log.called_at.isoformat()
|
||
})
|
||
|
||
return stats |