Files
ai-chat-system/services/agent_service.py

408 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent管理服务
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict
import logging
from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, init_default_data
logger = logging.getLogger(__name__)
class AgentService:
"""Agent管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_agents(self) -> List[Agent]:
"""获取所有Agent"""
return self.db.query(Agent).order_by(Agent.is_default.desc(), Agent.name).all()
def get_active_agents(self) -> List[Agent]:
"""获取活跃的Agent"""
return self.db.query(Agent).filter(Agent.is_active == True).all()
def get_agent(self, agent_id: int) -> Optional[Agent]:
"""获取单个Agent"""
return self.db.query(Agent).filter(Agent.id == agent_id).first()
def get_agent_by_name(self, name: str) -> Optional[Agent]:
"""通过名称获取Agent"""
return self.db.query(Agent).filter(Agent.name == name).first()
def get_default_agent(self) -> Optional[Agent]:
"""获取默认Agent"""
agent = self.db.query(Agent).filter(Agent.is_default == True, Agent.is_active == True).first()
if not agent:
# 如果没有默认,获取第一个活跃的
agent = self.db.query(Agent).filter(Agent.is_active == True).first()
return agent
def create_agent(self, data: Dict) -> Agent:
"""创建Agent"""
agent = Agent(
name=data.get('name'),
display_name=data.get('display_name', data.get('name')),
llm_provider_id=data.get('llm_provider_id'),
model_override=data.get('model_override'),
system_prompt=data.get('system_prompt', '你是一个有用的AI助手。'),
enable_thinking=data.get('enable_thinking', True),
thinking_prompt=data.get('thinking_prompt'),
thinking_prefix=data.get('thinking_prefix', ''),
thinking_suffix=data.get('thinking_suffix', ''),
max_history=data.get('max_history', 20),
temperature_override=data.get('temperature_override'),
is_active=data.get('is_active', True),
is_default=data.get('is_default', False),
description=data.get('description')
)
self.db.add(agent)
self.db.commit()
self.db.refresh(agent)
return agent
def update_agent(self, agent_id: int, data: Dict) -> Optional[Agent]:
"""更新Agent"""
agent = self.get_agent(agent_id)
if not agent:
return None
for key, value in data.items():
if hasattr(agent, key) and value is not None:
setattr(agent, key, value)
self.db.commit()
self.db.refresh(agent)
return agent
def delete_agent(self, agent_id: int) -> bool:
"""删除Agent"""
agent = self.get_agent(agent_id)
if not agent:
return False
# 如果是默认Agent不允许删除
if agent.is_default:
return False
self.db.delete(agent)
self.db.commit()
return True
def set_default_agent(self, agent_id: int) -> bool:
"""设置默认Agent"""
# 先清除所有默认
self.db.query(Agent).update({Agent.is_default: False})
agent = self.get_agent(agent_id)
if not agent:
return False
agent.is_default = True
self.db.commit()
return True
def get_agent_config(self, agent_id: int) -> Dict:
"""获取Agent完整配置包含LLM Provider信息"""
agent = self.get_agent(agent_id)
if not agent:
return {}
provider = self.db.query(LLMProvider).filter(LLMProvider.id == agent.llm_provider_id).first()
return {
'agent': {
'id': agent.id,
'name': agent.name,
'display_name': agent.display_name,
'system_prompt': agent.system_prompt,
'enable_thinking': agent.enable_thinking,
'thinking_prompt': agent.thinking_prompt,
'thinking_prefix': agent.thinking_prefix,
'thinking_suffix': agent.thinking_suffix,
'model_override': agent.model_override,
'max_history': agent.max_history,
'temperature_override': agent.temperature_override,
'is_default': agent.is_default,
'is_active': agent.is_active
},
'provider': {
'id': provider.id if provider else None,
'name': provider.name if provider else None,
'api_base': provider.api_base if provider else None,
'api_key': provider.api_key if provider else None,
'supports_thinking': provider.supports_thinking if provider else False,
'thinking_model': provider.thinking_model if provider else None,
'default_model': provider.default_model if provider else 'auto',
'max_tokens': provider.max_tokens if provider else 4096,
'temperature': provider.temperature if provider else 0.7,
'models': provider.models if provider else []
} if provider else {}
}
def get_agent_for_channel(self, channel_id: int, conditions: Dict = None) -> Optional[Agent]:
"""根据渠道获取Agent支持优先级和条件"""
mappings = self.db.query(ChannelAgentMapping).filter(
ChannelAgentMapping.channel_id == channel_id,
ChannelAgentMapping.is_active == True
).order_by(ChannelAgentMapping.priority).all()
if not mappings:
return self.get_default_agent()
# 检查条件匹配
for mapping in mappings:
if conditions and mapping.conditions:
# 检查条件是否匹配
match = True
for key, value in conditions.items():
if mapping.conditions.get(key) != value:
match = False
break
if match:
return self.get_agent(mapping.agent_id)
else:
# 无条件或无条件映射,返回第一个
return self.get_agent(mappings.agent_id)
# 没有匹配的,返回默认
return self.get_default_agent()
class LLMProviderService:
"""大模型池管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_providers(self) -> List[LLMProvider]:
"""获取所有Provider"""
return self.db.query(LLMProvider).order_by(LLMProvider.priority, LLMProvider.name).all()
def get_active_providers(self) -> List[LLMProvider]:
"""获取活跃的Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.is_active == True).all()
def get_provider(self, provider_id: int) -> Optional[LLMProvider]:
"""获取单个Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
def get_provider_by_name(self, name: str) -> Optional[LLMProvider]:
"""通过名称获取Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.name == name).first()
def create_provider(self, data: Dict) -> LLMProvider:
"""创建Provider"""
provider = LLMProvider(
name=data.get('name'),
api_base=data.get('api_base'),
api_key=data.get('api_key', ''),
models=data.get('models', []),
default_model=data.get('default_model', 'auto'),
supports_thinking=data.get('supports_thinking', False),
thinking_model=data.get('thinking_model'),
max_tokens=data.get('max_tokens', 4096),
temperature=data.get('temperature', 0.7),
is_active=data.get('is_active', True),
priority=data.get('priority', 0),
description=data.get('description')
)
self.db.add(provider)
self.db.commit()
self.db.refresh(provider)
return provider
def update_provider(self, provider_id: int, data: Dict) -> Optional[LLMProvider]:
"""更新Provider"""
provider = self.get_provider(provider_id)
if not provider:
return None
for key, value in data.items():
if hasattr(provider, key) and value is not None:
setattr(provider, key, value)
self.db.commit()
self.db.refresh(provider)
return provider
def delete_provider(self, provider_id: int) -> bool:
"""删除Provider"""
provider = self.get_provider(provider_id)
if not provider:
return False
# 检查是否有Agent在使用
from models_v2 import Agent
agents_count = self.db.query(Agent).filter(Agent.llm_provider_id == provider_id).count()
if agents_count > 0:
return False
self.db.delete(provider)
self.db.commit()
return True
def get_provider_config(self, provider_id: int) -> Dict:
"""获取Provider配置"""
provider = self.get_provider(provider_id)
if not provider:
return {}
return {
'id': provider.id,
'name': provider.name,
'api_base': provider.api_base,
'api_key': provider.api_key,
'models': provider.models,
'default_model': provider.default_model,
'supports_thinking': provider.supports_thinking,
'thinking_model': provider.thinking_model,
'max_tokens': provider.max_tokens,
'temperature': provider.temperature,
'is_active': provider.is_active,
'priority': provider.priority,
'description': provider.description
}
class ChannelService:
"""渠道管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_channels(self) -> List[Channel]:
"""获取所有渠道"""
return self.db.query(Channel).order_by(Channel.is_primary.desc(), Channel.channel_type).all()
def get_active_channels(self) -> List[Channel]:
"""获取活跃渠道"""
return self.db.query(Channel).filter(Channel.is_active == True).all()
def get_channel(self, channel_id: int) -> Optional[Channel]:
"""获取单个渠道"""
return self.db.query(Channel).filter(Channel.id == channel_id).first()
def get_channel_by_type(self, channel_type: str) -> Optional[Channel]:
"""获取指定类型的渠道"""
return self.db.query(Channel).filter(
Channel.channel_type == channel_type,
Channel.is_active == True
).first()
def get_web_channel(self) -> Optional[Channel]:
"""获取网页渠道"""
return self.get_channel_by_type('web')
def get_matrix_channel(self) -> Optional[Channel]:
"""获取Matrix渠道"""
return self.get_channel_by_type('matrix')
def create_channel(self, data: Dict) -> Channel:
"""创建渠道"""
channel = Channel(
channel_type=data.get('channel_type'),
name=data.get('name'),
config=data.get('config', {}),
is_active=data.get('is_active', True),
is_primary=data.get('is_primary', False)
)
self.db.add(channel)
self.db.commit()
self.db.refresh(channel)
return channel
def update_channel(self, channel_id: int, data: Dict) -> Optional[Channel]:
"""更新渠道"""
channel = self.get_channel(channel_id)
if not channel:
return None
for key, value in data.items():
if hasattr(channel, key) and value is not None:
setattr(channel, key, value)
self.db.commit()
self.db.refresh(channel)
return channel
def delete_channel(self, channel_id: int) -> bool:
"""删除渠道"""
channel = self.get_channel(channel_id)
if not channel:
return False
# 检查是否是主渠道
if channel.is_primary:
return False
self.db.delete(channel)
self.db.commit()
return True
def bind_agent(self, channel_id: int, agent_id: int, priority: int = 0, mode: str = 'single', conditions: Dict = None) -> ChannelAgentMapping:
"""绑定Agent到渠道"""
mapping = ChannelAgentMapping(
channel_id=channel_id,
agent_id=agent_id,
priority=priority,
mode=mode,
conditions=conditions,
is_active=True
)
self.db.add(mapping)
self.db.commit()
self.db.refresh(mapping)
return mapping
def unbind_agent(self, mapping_id: int) -> bool:
"""解绑Agent"""
mapping = self.db.query(ChannelAgentMapping).filter(ChannelAgentMapping.id == mapping_id).first()
if not mapping:
return False
self.db.delete(mapping)
self.db.commit()
return True
def get_channel_agents(self, channel_id: int) -> List[Dict]:
"""获取渠道绑定的所有Agent"""
mappings = self.db.query(ChannelAgentMapping).filter(
ChannelAgentMapping.channel_id == channel_id
).order_by(ChannelAgentMapping.priority).all()
result = []
for mapping in mappings:
agent = self.db.query(Agent).filter(Agent.id == mapping.agent_id).first()
if agent:
result.append({
'mapping_id': mapping.id,
'priority': mapping.priority,
'mode': mapping.mode,
'conditions': mapping.conditions,
'agent': {
'id': agent.id,
'name': agent.name,
'display_name': agent.display_name,
'is_active': agent.is_active
}
})
return result
def get_channel_config(self, channel_id: int) -> Dict:
"""获取渠道完整配置"""
channel = self.get_channel(channel_id)
if not channel:
return {}
return {
'id': channel.id,
'channel_type': channel.channel_type,
'name': channel.name,
'config': channel.config,
'is_active': channel.is_active,
'is_primary': channel.is_primary,
'agent_mappings': self.get_channel_agents(channel_id)
}