- 新增 SearchToolConfig 模型:支持搜索工具配置 - Agent 增加 tools 字段:可配置可用工具列表 - 后台管理增加搜索工具配置页面 - Agent 管理增加工具启用开关 - 网页端增加搜索工具禁用复选框 - WebSocket chat 处理增加搜索调用逻辑 - 默认配置 Tavily Search API
492 lines
18 KiB
Python
492 lines
18 KiB
Python
"""
|
||
Agent管理服务
|
||
"""
|
||
from sqlalchemy.orm import Session
|
||
from typing import List, Optional, Dict
|
||
import logging
|
||
|
||
from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, SearchToolConfig, init_default_data
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentService:
|
||
"""Agent管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_agents(self) -> List[Agent]:
|
||
"""获取所有Agent"""
|
||
return self.db.query(Agent).order_by(Agent.is_default.desc(), Agent.name).all()
|
||
|
||
def get_active_agents(self) -> List[Agent]:
|
||
"""获取活跃的Agent"""
|
||
return self.db.query(Agent).filter(Agent.is_active == True).all()
|
||
|
||
def get_agent(self, agent_id: int) -> Optional[Agent]:
|
||
"""获取单个Agent"""
|
||
return self.db.query(Agent).filter(Agent.id == agent_id).first()
|
||
|
||
def get_agent_by_name(self, name: str) -> Optional[Agent]:
|
||
"""通过名称获取Agent"""
|
||
return self.db.query(Agent).filter(Agent.name == name).first()
|
||
|
||
def get_default_agent(self) -> Optional[Agent]:
|
||
"""获取默认Agent"""
|
||
agent = self.db.query(Agent).filter(Agent.is_default == True, Agent.is_active == True).first()
|
||
if not agent:
|
||
# 如果没有默认,获取第一个活跃的
|
||
agent = self.db.query(Agent).filter(Agent.is_active == True).first()
|
||
return agent
|
||
|
||
def create_agent(self, data: Dict) -> Agent:
|
||
"""创建Agent"""
|
||
agent = Agent(
|
||
name=data.get('name'),
|
||
display_name=data.get('display_name', data.get('name')),
|
||
llm_provider_id=data.get('llm_provider_id'),
|
||
model_override=data.get('model_override'),
|
||
system_prompt=data.get('system_prompt', '你是一个有用的AI助手。'),
|
||
enable_thinking=data.get('enable_thinking', True),
|
||
thinking_prompt=data.get('thinking_prompt'),
|
||
thinking_prefix=data.get('thinking_prefix', ''),
|
||
thinking_suffix=data.get('thinking_suffix', ''),
|
||
tools=data.get('tools', []), # 工具列表
|
||
max_history=data.get('max_history', 20),
|
||
temperature_override=data.get('temperature_override'),
|
||
is_active=data.get('is_active', True),
|
||
is_default=data.get('is_default', False),
|
||
description=data.get('description')
|
||
)
|
||
self.db.add(agent)
|
||
self.db.commit()
|
||
self.db.refresh(agent)
|
||
return agent
|
||
|
||
def update_agent(self, agent_id: int, data: Dict) -> Optional[Agent]:
|
||
"""更新Agent"""
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(agent, key) and value is not None:
|
||
setattr(agent, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(agent)
|
||
return agent
|
||
|
||
def delete_agent(self, agent_id: int) -> bool:
|
||
"""删除Agent"""
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return False
|
||
|
||
# 如果是默认Agent,不允许删除
|
||
if agent.is_default:
|
||
return False
|
||
|
||
self.db.delete(agent)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def set_default_agent(self, agent_id: int) -> bool:
|
||
"""设置默认Agent"""
|
||
# 先清除所有默认
|
||
self.db.query(Agent).update({Agent.is_default: False})
|
||
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return False
|
||
|
||
agent.is_default = True
|
||
self.db.commit()
|
||
return True
|
||
|
||
def get_agent_config(self, agent_id: int) -> Dict:
|
||
"""获取Agent完整配置(包含LLM Provider信息)"""
|
||
agent = self.get_agent(agent_id)
|
||
if not agent:
|
||
return {}
|
||
|
||
provider = self.db.query(LLMProvider).filter(LLMProvider.id == agent.llm_provider_id).first()
|
||
|
||
return {
|
||
'agent': {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'display_name': agent.display_name,
|
||
'system_prompt': agent.system_prompt,
|
||
'enable_thinking': agent.enable_thinking,
|
||
'thinking_prompt': agent.thinking_prompt,
|
||
'thinking_prefix': agent.thinking_prefix,
|
||
'thinking_suffix': agent.thinking_suffix,
|
||
'model_override': agent.model_override,
|
||
'max_history': agent.max_history,
|
||
'temperature_override': agent.temperature_override,
|
||
'is_default': agent.is_default,
|
||
'is_active': agent.is_active
|
||
},
|
||
'provider': {
|
||
'id': provider.id if provider else None,
|
||
'name': provider.name if provider else None,
|
||
'api_base': provider.api_base if provider else None,
|
||
'api_key': provider.api_key if provider else None,
|
||
'supports_thinking': provider.supports_thinking if provider else False,
|
||
'thinking_model': provider.thinking_model if provider else None,
|
||
'default_model': provider.default_model if provider else 'auto',
|
||
'max_tokens': provider.max_tokens if provider else 4096,
|
||
'temperature': provider.temperature if provider else 0.7,
|
||
'models': provider.models if provider else []
|
||
} if provider else {}
|
||
}
|
||
|
||
def get_agent_for_channel(self, channel_id: int, conditions: Dict = None) -> Optional[Agent]:
|
||
"""根据渠道获取Agent(支持优先级和条件)"""
|
||
mappings = self.db.query(ChannelAgentMapping).filter(
|
||
ChannelAgentMapping.channel_id == channel_id,
|
||
ChannelAgentMapping.is_active == True
|
||
).order_by(ChannelAgentMapping.priority).all()
|
||
|
||
if not mappings:
|
||
return self.get_default_agent()
|
||
|
||
# 检查条件匹配
|
||
for mapping in mappings:
|
||
if conditions and mapping.conditions:
|
||
# 检查条件是否匹配
|
||
match = True
|
||
for key, value in conditions.items():
|
||
if mapping.conditions.get(key) != value:
|
||
match = False
|
||
break
|
||
if match:
|
||
return self.get_agent(mapping.agent_id)
|
||
else:
|
||
# 无条件或无条件映射,返回第一个
|
||
return self.get_agent(mappings.agent_id)
|
||
|
||
# 没有匹配的,返回默认
|
||
return self.get_default_agent()
|
||
|
||
|
||
class LLMProviderService:
|
||
"""大模型池管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_providers(self) -> List[LLMProvider]:
|
||
"""获取所有Provider"""
|
||
return self.db.query(LLMProvider).order_by(LLMProvider.priority, LLMProvider.name).all()
|
||
|
||
def get_active_providers(self) -> List[LLMProvider]:
|
||
"""获取活跃的Provider"""
|
||
return self.db.query(LLMProvider).filter(LLMProvider.is_active == True).all()
|
||
|
||
def get_provider(self, provider_id: int) -> Optional[LLMProvider]:
|
||
"""获取单个Provider"""
|
||
return self.db.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
|
||
|
||
def get_provider_by_name(self, name: str) -> Optional[LLMProvider]:
|
||
"""通过名称获取Provider"""
|
||
return self.db.query(LLMProvider).filter(LLMProvider.name == name).first()
|
||
|
||
def create_provider(self, data: Dict) -> LLMProvider:
|
||
"""创建Provider"""
|
||
provider = LLMProvider(
|
||
name=data.get('name'),
|
||
api_base=data.get('api_base'),
|
||
api_key=data.get('api_key', ''),
|
||
models=data.get('models', []),
|
||
default_model=data.get('default_model', 'auto'),
|
||
supports_thinking=data.get('supports_thinking', False),
|
||
thinking_model=data.get('thinking_model'),
|
||
max_tokens=data.get('max_tokens', 4096),
|
||
temperature=data.get('temperature', 0.7),
|
||
is_active=data.get('is_active', True),
|
||
priority=data.get('priority', 0),
|
||
description=data.get('description')
|
||
)
|
||
self.db.add(provider)
|
||
self.db.commit()
|
||
self.db.refresh(provider)
|
||
return provider
|
||
|
||
def update_provider(self, provider_id: int, data: Dict) -> Optional[LLMProvider]:
|
||
"""更新Provider"""
|
||
provider = self.get_provider(provider_id)
|
||
if not provider:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(provider, key) and value is not None:
|
||
setattr(provider, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(provider)
|
||
return provider
|
||
|
||
def delete_provider(self, provider_id: int) -> bool:
|
||
"""删除Provider"""
|
||
provider = self.get_provider(provider_id)
|
||
if not provider:
|
||
return False
|
||
|
||
# 检查是否有Agent在使用
|
||
from models_v2 import Agent
|
||
agents_count = self.db.query(Agent).filter(Agent.llm_provider_id == provider_id).count()
|
||
if agents_count > 0:
|
||
return False
|
||
|
||
self.db.delete(provider)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def get_provider_config(self, provider_id: int) -> Dict:
|
||
"""获取Provider配置"""
|
||
provider = self.get_provider(provider_id)
|
||
if not provider:
|
||
return {}
|
||
|
||
return {
|
||
'id': provider.id,
|
||
'name': provider.name,
|
||
'api_base': provider.api_base,
|
||
'api_key': provider.api_key,
|
||
'models': provider.models,
|
||
'default_model': provider.default_model,
|
||
'supports_thinking': provider.supports_thinking,
|
||
'thinking_model': provider.thinking_model,
|
||
'max_tokens': provider.max_tokens,
|
||
'temperature': provider.temperature,
|
||
'is_active': provider.is_active,
|
||
'priority': provider.priority,
|
||
'description': provider.description
|
||
}
|
||
|
||
|
||
class ChannelService:
|
||
"""渠道管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_channels(self) -> List[Channel]:
|
||
"""获取所有渠道"""
|
||
return self.db.query(Channel).order_by(Channel.is_primary.desc(), Channel.channel_type).all()
|
||
|
||
def get_active_channels(self) -> List[Channel]:
|
||
"""获取活跃渠道"""
|
||
return self.db.query(Channel).filter(Channel.is_active == True).all()
|
||
|
||
def get_channel(self, channel_id: int) -> Optional[Channel]:
|
||
"""获取单个渠道"""
|
||
return self.db.query(Channel).filter(Channel.id == channel_id).first()
|
||
|
||
def get_channel_by_type(self, channel_type: str) -> Optional[Channel]:
|
||
"""获取指定类型的渠道"""
|
||
return self.db.query(Channel).filter(
|
||
Channel.channel_type == channel_type,
|
||
Channel.is_active == True
|
||
).first()
|
||
|
||
def get_web_channel(self) -> Optional[Channel]:
|
||
"""获取网页渠道"""
|
||
return self.get_channel_by_type('web')
|
||
|
||
def get_matrix_channel(self) -> Optional[Channel]:
|
||
"""获取Matrix渠道"""
|
||
return self.get_channel_by_type('matrix')
|
||
|
||
def create_channel(self, data: Dict) -> Channel:
|
||
"""创建渠道"""
|
||
channel = Channel(
|
||
channel_type=data.get('channel_type'),
|
||
name=data.get('name'),
|
||
config=data.get('config', {}),
|
||
is_active=data.get('is_active', True),
|
||
is_primary=data.get('is_primary', False)
|
||
)
|
||
self.db.add(channel)
|
||
self.db.commit()
|
||
self.db.refresh(channel)
|
||
return channel
|
||
|
||
def update_channel(self, channel_id: int, data: Dict) -> Optional[Channel]:
|
||
"""更新渠道"""
|
||
channel = self.get_channel(channel_id)
|
||
if not channel:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(channel, key) and value is not None:
|
||
setattr(channel, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(channel)
|
||
return channel
|
||
|
||
def delete_channel(self, channel_id: int) -> bool:
|
||
"""删除渠道"""
|
||
channel = self.get_channel(channel_id)
|
||
if not channel:
|
||
return False
|
||
|
||
# 检查是否是主渠道
|
||
if channel.is_primary:
|
||
return False
|
||
|
||
self.db.delete(channel)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def bind_agent(self, channel_id: int, agent_id: int, priority: int = 0, mode: str = 'single', conditions: Dict = None) -> ChannelAgentMapping:
|
||
"""绑定Agent到渠道"""
|
||
mapping = ChannelAgentMapping(
|
||
channel_id=channel_id,
|
||
agent_id=agent_id,
|
||
priority=priority,
|
||
mode=mode,
|
||
conditions=conditions,
|
||
is_active=True
|
||
)
|
||
self.db.add(mapping)
|
||
self.db.commit()
|
||
self.db.refresh(mapping)
|
||
return mapping
|
||
|
||
def unbind_agent(self, mapping_id: int) -> bool:
|
||
"""解绑Agent"""
|
||
mapping = self.db.query(ChannelAgentMapping).filter(ChannelAgentMapping.id == mapping_id).first()
|
||
if not mapping:
|
||
return False
|
||
|
||
self.db.delete(mapping)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def get_channel_agents(self, channel_id: int) -> List[Dict]:
|
||
"""获取渠道绑定的所有Agent"""
|
||
mappings = self.db.query(ChannelAgentMapping).filter(
|
||
ChannelAgentMapping.channel_id == channel_id
|
||
).order_by(ChannelAgentMapping.priority).all()
|
||
|
||
result = []
|
||
for mapping in mappings:
|
||
agent = self.db.query(Agent).filter(Agent.id == mapping.agent_id).first()
|
||
if agent:
|
||
result.append({
|
||
'mapping_id': mapping.id,
|
||
'priority': mapping.priority,
|
||
'mode': mapping.mode,
|
||
'conditions': mapping.conditions,
|
||
'agent': {
|
||
'id': agent.id,
|
||
'name': agent.name,
|
||
'display_name': agent.display_name,
|
||
'is_active': agent.is_active
|
||
}
|
||
})
|
||
return result
|
||
|
||
def get_channel_config(self, channel_id: int) -> Dict:
|
||
"""获取渠道完整配置"""
|
||
channel = self.get_channel(channel_id)
|
||
if not channel:
|
||
return {}
|
||
|
||
return {
|
||
'id': channel.id,
|
||
'channel_type': channel.channel_type,
|
||
'name': channel.name,
|
||
'config': channel.config,
|
||
'is_active': channel.is_active,
|
||
'is_primary': channel.is_primary,
|
||
'agent_mappings': self.get_channel_agents(channel_id)
|
||
}
|
||
|
||
|
||
class SearchToolService:
|
||
"""搜索工具管理服务"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
|
||
def get_all_configs(self) -> List[SearchToolConfig]:
|
||
"""获取所有搜索工具配置"""
|
||
return self.db.query(SearchToolConfig).order_by(SearchToolConfig.is_default.desc(), SearchToolConfig.name).all()
|
||
|
||
def get_active_configs(self) -> List[SearchToolConfig]:
|
||
"""获取活跃的搜索工具配置"""
|
||
return self.db.query(SearchToolConfig).filter(SearchToolConfig.is_active == True).all()
|
||
|
||
def get_config(self, config_id: int) -> Optional[SearchToolConfig]:
|
||
"""获取单个配置"""
|
||
return self.db.query(SearchToolConfig).filter(SearchToolConfig.id == config_id).first()
|
||
|
||
def get_default_config(self) -> Optional[SearchToolConfig]:
|
||
"""获取默认配置"""
|
||
config = self.db.query(SearchToolConfig).filter(
|
||
SearchToolConfig.is_default == True,
|
||
SearchToolConfig.is_active == True
|
||
).first()
|
||
if not config:
|
||
config = self.db.query(SearchToolConfig).filter(SearchToolConfig.is_active == True).first()
|
||
return config
|
||
|
||
def create_config(self, data: Dict) -> SearchToolConfig:
|
||
"""创建搜索工具配置"""
|
||
config = SearchToolConfig(
|
||
name=data.get('name'),
|
||
provider=data.get('provider', 'tavily'),
|
||
api_key=data.get('api_key'),
|
||
api_base=data.get('api_base'),
|
||
max_results=data.get('max_results', 5),
|
||
include_raw_content=data.get('include_raw_content', False),
|
||
search_depth=data.get('search_depth', 'basic'),
|
||
is_active=data.get('is_active', True),
|
||
is_default=data.get('is_default', False)
|
||
)
|
||
self.db.add(config)
|
||
self.db.commit()
|
||
self.db.refresh(config)
|
||
return config
|
||
|
||
def update_config(self, config_id: int, data: Dict) -> Optional[SearchToolConfig]:
|
||
"""更新配置"""
|
||
config = self.get_config(config_id)
|
||
if not config:
|
||
return None
|
||
|
||
for key, value in data.items():
|
||
if hasattr(config, key) and value is not None:
|
||
setattr(config, key, value)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(config)
|
||
return config
|
||
|
||
def delete_config(self, config_id: int) -> bool:
|
||
"""删除配置"""
|
||
config = self.get_config(config_id)
|
||
if not config:
|
||
return False
|
||
|
||
self.db.delete(config)
|
||
self.db.commit()
|
||
return True
|
||
|
||
def set_default_config(self, config_id: int) -> bool:
|
||
"""设置默认配置"""
|
||
self.db.query(SearchToolConfig).update({SearchToolConfig.is_default: False})
|
||
|
||
config = self.get_config(config_id)
|
||
if not config:
|
||
return False
|
||
|
||
config.is_default = True
|
||
self.db.commit()
|
||
return True |