Files
ai-chat-system/services/agent_service.py

581 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent管理服务
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict
import logging
from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, ToolConfig, ToolUsageLog, init_default_data
logger = logging.getLogger(__name__)
class AgentService:
"""Agent管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_agents(self) -> List[Agent]:
"""获取所有Agent"""
return self.db.query(Agent).order_by(Agent.is_default.desc(), Agent.name).all()
def get_active_agents(self) -> List[Agent]:
"""获取活跃的Agent"""
return self.db.query(Agent).filter(Agent.is_active == True).all()
def get_agent(self, agent_id: int) -> Optional[Agent]:
"""获取单个Agent"""
return self.db.query(Agent).filter(Agent.id == agent_id).first()
def get_agent_by_name(self, name: str) -> Optional[Agent]:
"""通过名称获取Agent"""
return self.db.query(Agent).filter(Agent.name == name).first()
def get_default_agent(self) -> Optional[Agent]:
"""获取默认Agent"""
agent = self.db.query(Agent).filter(Agent.is_default == True, Agent.is_active == True).first()
if not agent:
# 如果没有默认,获取第一个活跃的
agent = self.db.query(Agent).filter(Agent.is_active == True).first()
return agent
def create_agent(self, data: Dict) -> Agent:
"""创建Agent"""
agent = Agent(
name=data.get('name'),
display_name=data.get('display_name', data.get('name')),
llm_provider_id=data.get('llm_provider_id'),
model_override=data.get('model_override'),
system_prompt=data.get('system_prompt', '你是一个有用的AI助手。'),
enable_thinking=data.get('enable_thinking', True),
thinking_prompt=data.get('thinking_prompt'),
thinking_prefix=data.get('thinking_prefix', ''),
thinking_suffix=data.get('thinking_suffix', ''),
tools=data.get('tools', []), # 工具列表
max_history=data.get('max_history', 20),
temperature_override=data.get('temperature_override'),
is_active=data.get('is_active', True),
is_default=data.get('is_default', False),
description=data.get('description')
)
self.db.add(agent)
self.db.commit()
self.db.refresh(agent)
return agent
def update_agent(self, agent_id: int, data: Dict) -> Optional[Agent]:
"""更新Agent"""
agent = self.get_agent(agent_id)
if not agent:
return None
for key, value in data.items():
if hasattr(agent, key) and value is not None:
setattr(agent, key, value)
self.db.commit()
self.db.refresh(agent)
return agent
def delete_agent(self, agent_id: int) -> bool:
"""删除Agent"""
agent = self.get_agent(agent_id)
if not agent:
return False
# 如果是默认Agent不允许删除
if agent.is_default:
return False
self.db.delete(agent)
self.db.commit()
return True
def set_default_agent(self, agent_id: int) -> bool:
"""设置默认Agent"""
# 先清除所有默认
self.db.query(Agent).update({Agent.is_default: False})
agent = self.get_agent(agent_id)
if not agent:
return False
agent.is_default = True
self.db.commit()
return True
def get_agent_config(self, agent_id: int) -> Dict:
"""获取Agent完整配置包含LLM Provider信息"""
agent = self.get_agent(agent_id)
if not agent:
return {}
provider = self.db.query(LLMProvider).filter(LLMProvider.id == agent.llm_provider_id).first()
return {
'agent': {
'id': agent.id,
'name': agent.name,
'display_name': agent.display_name,
'system_prompt': agent.system_prompt,
'enable_thinking': agent.enable_thinking,
'thinking_prompt': agent.thinking_prompt,
'thinking_prefix': agent.thinking_prefix,
'thinking_suffix': agent.thinking_suffix,
'model_override': agent.model_override,
'max_history': agent.max_history,
'temperature_override': agent.temperature_override,
'tools': agent.tools or [], # 工具列表
'is_default': agent.is_default,
'is_active': agent.is_active
},
'provider': {
'id': provider.id if provider else None,
'name': provider.name if provider else None,
'api_base': provider.api_base if provider else None,
'api_key': provider.api_key if provider else None,
'supports_thinking': provider.supports_thinking if provider else False,
'thinking_model': provider.thinking_model if provider else None,
'supports_vision': provider.supports_vision if provider else False,
'vision_model': provider.vision_model if provider else None,
'supports_function_calling': provider.supports_function_calling if provider else False,
'default_model': provider.default_model if provider else 'auto',
'max_tokens': provider.max_tokens if provider else 4096,
'temperature': provider.temperature if provider else 0.7,
'models': provider.models if provider else []
} if provider else {}
}
def get_agent_for_channel(self, channel_id: int, conditions: Dict = None) -> Optional[Agent]:
"""根据渠道获取Agent支持优先级和条件"""
mappings = self.db.query(ChannelAgentMapping).filter(
ChannelAgentMapping.channel_id == channel_id,
ChannelAgentMapping.is_active == True
).order_by(ChannelAgentMapping.priority).all()
if not mappings:
return self.get_default_agent()
# 检查条件匹配
for mapping in mappings:
if conditions and mapping.conditions:
# 检查条件是否匹配
match = True
for key, value in conditions.items():
if mapping.conditions.get(key) != value:
match = False
break
if match:
return self.get_agent(mapping.agent_id)
else:
# 无条件或无条件映射,返回第一个
return self.get_agent(mappings.agent_id)
# 没有匹配的,返回默认
return self.get_default_agent()
class LLMProviderService:
"""大模型池管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_providers(self) -> List[LLMProvider]:
"""获取所有Provider"""
return self.db.query(LLMProvider).order_by(LLMProvider.priority, LLMProvider.name).all()
def get_active_providers(self) -> List[LLMProvider]:
"""获取活跃的Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.is_active == True).all()
def get_provider(self, provider_id: int) -> Optional[LLMProvider]:
"""获取单个Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
def get_provider_by_name(self, name: str) -> Optional[LLMProvider]:
"""通过名称获取Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.name == name).first()
def create_provider(self, data: Dict) -> LLMProvider:
"""创建Provider"""
provider = LLMProvider(
name=data.get('name'),
api_base=data.get('api_base'),
api_key=data.get('api_key', ''),
models=data.get('models', []),
default_model=data.get('default_model', 'auto'),
supports_thinking=data.get('supports_thinking', False),
thinking_model=data.get('thinking_model'),
max_tokens=data.get('max_tokens', 4096),
temperature=data.get('temperature', 0.7),
is_active=data.get('is_active', True),
priority=data.get('priority', 0),
description=data.get('description')
)
self.db.add(provider)
self.db.commit()
self.db.refresh(provider)
return provider
def update_provider(self, provider_id: int, data: Dict) -> Optional[LLMProvider]:
"""更新Provider"""
provider = self.get_provider(provider_id)
if not provider:
return None
for key, value in data.items():
if hasattr(provider, key) and value is not None:
setattr(provider, key, value)
self.db.commit()
self.db.refresh(provider)
return provider
def delete_provider(self, provider_id: int) -> bool:
"""删除Provider"""
provider = self.get_provider(provider_id)
if not provider:
return False
# 检查是否有Agent在使用
from models_v2 import Agent
agents_count = self.db.query(Agent).filter(Agent.llm_provider_id == provider_id).count()
if agents_count > 0:
return False
self.db.delete(provider)
self.db.commit()
return True
def get_provider_config(self, provider_id: int) -> Dict:
"""获取Provider配置"""
provider = self.get_provider(provider_id)
if not provider:
return {}
return {
'id': provider.id,
'name': provider.name,
'api_base': provider.api_base,
'api_key': provider.api_key,
'models': provider.models,
'default_model': provider.default_model,
'supports_thinking': provider.supports_thinking,
'thinking_model': provider.thinking_model,
'max_tokens': provider.max_tokens,
'temperature': provider.temperature,
'is_active': provider.is_active,
'priority': provider.priority,
'description': provider.description
}
class ChannelService:
"""渠道管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_channels(self) -> List[Channel]:
"""获取所有渠道"""
return self.db.query(Channel).order_by(Channel.is_primary.desc(), Channel.channel_type).all()
def get_active_channels(self) -> List[Channel]:
"""获取活跃渠道"""
return self.db.query(Channel).filter(Channel.is_active == True).all()
def get_channel(self, channel_id: int) -> Optional[Channel]:
"""获取单个渠道"""
return self.db.query(Channel).filter(Channel.id == channel_id).first()
def get_channel_by_type(self, channel_type: str) -> Optional[Channel]:
"""获取指定类型的渠道"""
return self.db.query(Channel).filter(
Channel.channel_type == channel_type,
Channel.is_active == True
).first()
def get_web_channel(self) -> Optional[Channel]:
"""获取网页渠道"""
return self.get_channel_by_type('web')
def get_matrix_channel(self) -> Optional[Channel]:
"""获取Matrix渠道"""
return self.get_channel_by_type('matrix')
def create_channel(self, data: Dict) -> Channel:
"""创建渠道"""
channel = Channel(
channel_type=data.get('channel_type'),
name=data.get('name'),
config=data.get('config', {}),
is_active=data.get('is_active', True),
is_primary=data.get('is_primary', False)
)
self.db.add(channel)
self.db.commit()
self.db.refresh(channel)
return channel
def update_channel(self, channel_id: int, data: Dict) -> Optional[Channel]:
"""更新渠道"""
channel = self.get_channel(channel_id)
if not channel:
return None
for key, value in data.items():
if hasattr(channel, key) and value is not None:
setattr(channel, key, value)
self.db.commit()
self.db.refresh(channel)
return channel
def delete_channel(self, channel_id: int) -> bool:
"""删除渠道"""
channel = self.get_channel(channel_id)
if not channel:
return False
# 检查是否是主渠道
if channel.is_primary:
return False
self.db.delete(channel)
self.db.commit()
return True
def bind_agent(self, channel_id: int, agent_id: int, priority: int = 0, mode: str = 'single', conditions: Dict = None) -> ChannelAgentMapping:
"""绑定Agent到渠道"""
mapping = ChannelAgentMapping(
channel_id=channel_id,
agent_id=agent_id,
priority=priority,
mode=mode,
conditions=conditions,
is_active=True
)
self.db.add(mapping)
self.db.commit()
self.db.refresh(mapping)
return mapping
def unbind_agent(self, mapping_id: int) -> bool:
"""解绑Agent"""
mapping = self.db.query(ChannelAgentMapping).filter(ChannelAgentMapping.id == mapping_id).first()
if not mapping:
return False
self.db.delete(mapping)
self.db.commit()
return True
def get_channel_agents(self, channel_id: int) -> List[Dict]:
"""获取渠道绑定的所有Agent"""
mappings = self.db.query(ChannelAgentMapping).filter(
ChannelAgentMapping.channel_id == channel_id
).order_by(ChannelAgentMapping.priority).all()
result = []
for mapping in mappings:
agent = self.db.query(Agent).filter(Agent.id == mapping.agent_id).first()
if agent:
result.append({
'mapping_id': mapping.id,
'priority': mapping.priority,
'mode': mapping.mode,
'conditions': mapping.conditions,
'agent': {
'id': agent.id,
'name': agent.name,
'display_name': agent.display_name,
'is_active': agent.is_active
}
})
return result
def get_channel_config(self, channel_id: int) -> Dict:
"""获取渠道完整配置"""
channel = self.get_channel(channel_id)
if not channel:
return {}
return {
'id': channel.id,
'channel_type': channel.channel_type,
'name': channel.name,
'config': channel.config,
'is_active': channel.is_active,
'is_primary': channel.is_primary,
'agent_mappings': self.get_channel_agents(channel_id)
}
class ToolService:
"""工具管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_tools(self) -> List[ToolConfig]:
"""获取所有工具配置"""
return self.db.query(ToolConfig).order_by(ToolConfig.tool_type, ToolConfig.is_default.desc()).all()
def get_tools_by_type(self, tool_type: str) -> List[ToolConfig]:
"""获取指定类型的工具"""
return self.db.query(ToolConfig).filter(ToolConfig.tool_type == tool_type).all()
def get_active_tools(self) -> List[ToolConfig]:
"""获取活跃的工具"""
return self.db.query(ToolConfig).filter(ToolConfig.is_active == True).all()
def get_tool(self, tool_id: int) -> Optional[ToolConfig]:
"""获取单个工具"""
return self.db.query(ToolConfig).filter(ToolConfig.id == tool_id).first()
def get_default_tool(self, tool_type: str) -> Optional[ToolConfig]:
"""获取指定类型的默认工具"""
tool = self.db.query(ToolConfig).filter(
ToolConfig.tool_type == tool_type,
ToolConfig.is_default == True,
ToolConfig.is_active == True
).first()
if not tool:
tool = self.db.query(ToolConfig).filter(
ToolConfig.tool_type == tool_type,
ToolConfig.is_active == True
).first()
return tool
def create_tool(self, data: Dict) -> ToolConfig:
"""创建工具配置"""
tool = ToolConfig(
name=data.get('name'),
tool_type=data.get('tool_type', 'search'),
provider=data.get('provider'),
config=data.get('config', {}),
is_active=data.get('is_active', True),
is_default=data.get('is_default', False)
)
self.db.add(tool)
self.db.commit()
self.db.refresh(tool)
return tool
def update_tool(self, tool_id: int, data: Dict) -> Optional[ToolConfig]:
"""更新工具配置"""
tool = self.get_tool(tool_id)
if not tool:
return None
for key, value in data.items():
if hasattr(tool, key) and value is not None:
setattr(tool, key, value)
self.db.commit()
self.db.refresh(tool)
return tool
def delete_tool(self, tool_id: int) -> bool:
"""删除工具配置"""
tool = self.get_tool(tool_id)
if not tool:
return False
self.db.delete(tool)
self.db.commit()
return True
def set_default_tool(self, tool_id: int) -> bool:
"""设置默认工具"""
tool = self.get_tool(tool_id)
if not tool:
return False
# 清除同类型的其他默认
self.db.query(ToolConfig).filter(
ToolConfig.tool_type == tool.tool_type
).update({ToolConfig.is_default: False})
tool.is_default = True
self.db.commit()
return True
def increment_stats(self, tool_id: int, success: bool):
"""更新工具调用统计"""
tool = self.get_tool(tool_id)
if tool:
tool.total_calls += 1
if success:
tool.success_calls += 1
else:
tool.failed_calls += 1
self.db.commit()
def log_usage(self, data: Dict) -> ToolUsageLog:
"""记录工具使用日志"""
log = ToolUsageLog(
tool_id=data.get('tool_id'),
tool_type=data.get('tool_type'),
query=data.get('query'),
success=data.get('success', True),
error_message=data.get('error_message'),
result_summary=data.get('result_summary'),
conversation_id=data.get('conversation_id'),
agent_id=data.get('agent_id'),
user_id=data.get('user_id'),
duration_ms=data.get('duration_ms')
)
self.db.add(log)
self.db.commit()
self.db.refresh(log)
return log
def get_usage_stats(self, days: int = 7) -> Dict:
"""获取工具使用统计"""
from datetime import datetime, timedelta
start_date = datetime.utcnow() - timedelta(days=days)
# 按工具类型统计
logs = self.db.query(ToolUsageLog).filter(
ToolUsageLog.called_at >= start_date
).all()
stats = {
'total_calls': len(logs),
'success_rate': sum(1 for l in logs if l.success) / len(logs) * 100 if logs else 0,
'by_type': {},
'by_tool': {},
'recent_errors': []
}
for log in logs:
# 按类型
if log.tool_type not in stats['by_type']:
stats['by_type'][log.tool_type] = {'total': 0, 'success': 0, 'failed': 0}
stats['by_type'][log.tool_type]['total'] += 1
if log.success:
stats['by_type'][log.tool_type]['success'] += 1
else:
stats['by_type'][log.tool_type]['failed'] += 1
# 按工具
tool = self.get_tool(log.tool_id) if log.tool_id else None
tool_name = tool.name if tool else f'Tool#{log.tool_id}'
if tool_name not in stats['by_tool']:
stats['by_tool'][tool_name] = {'total': 0, 'success': 0}
stats['by_tool'][tool_name]['total'] += 1
if log.success:
stats['by_tool'][tool_name]['success'] += 1
# 最近错误
if not log.success and log.error_message:
stats['recent_errors'].append({
'tool': tool_name,
'error': log.error_message[:100],
'time': log.called_at.isoformat()
})
return stats