Files
ai-chat-system/services/agent_service.py
hubian 0c9bfca346 feat: 添加搜索工具功能(Tavily Search)
- 新增 SearchToolConfig 模型:支持搜索工具配置
- Agent 增加 tools 字段:可配置可用工具列表
- 后台管理增加搜索工具配置页面
- Agent 管理增加工具启用开关
- 网页端增加搜索工具禁用复选框
- WebSocket chat 处理增加搜索调用逻辑
- 默认配置 Tavily Search API
2026-04-13 13:26:43 +08:00

492 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Agent管理服务
"""
from sqlalchemy.orm import Session
from typing import List, Optional, Dict
import logging
from models_v2 import Agent, LLMProvider, ChannelAgentMapping, Channel, SearchToolConfig, init_default_data
logger = logging.getLogger(__name__)
class AgentService:
"""Agent管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_agents(self) -> List[Agent]:
"""获取所有Agent"""
return self.db.query(Agent).order_by(Agent.is_default.desc(), Agent.name).all()
def get_active_agents(self) -> List[Agent]:
"""获取活跃的Agent"""
return self.db.query(Agent).filter(Agent.is_active == True).all()
def get_agent(self, agent_id: int) -> Optional[Agent]:
"""获取单个Agent"""
return self.db.query(Agent).filter(Agent.id == agent_id).first()
def get_agent_by_name(self, name: str) -> Optional[Agent]:
"""通过名称获取Agent"""
return self.db.query(Agent).filter(Agent.name == name).first()
def get_default_agent(self) -> Optional[Agent]:
"""获取默认Agent"""
agent = self.db.query(Agent).filter(Agent.is_default == True, Agent.is_active == True).first()
if not agent:
# 如果没有默认,获取第一个活跃的
agent = self.db.query(Agent).filter(Agent.is_active == True).first()
return agent
def create_agent(self, data: Dict) -> Agent:
"""创建Agent"""
agent = Agent(
name=data.get('name'),
display_name=data.get('display_name', data.get('name')),
llm_provider_id=data.get('llm_provider_id'),
model_override=data.get('model_override'),
system_prompt=data.get('system_prompt', '你是一个有用的AI助手。'),
enable_thinking=data.get('enable_thinking', True),
thinking_prompt=data.get('thinking_prompt'),
thinking_prefix=data.get('thinking_prefix', ''),
thinking_suffix=data.get('thinking_suffix', ''),
tools=data.get('tools', []), # 工具列表
max_history=data.get('max_history', 20),
temperature_override=data.get('temperature_override'),
is_active=data.get('is_active', True),
is_default=data.get('is_default', False),
description=data.get('description')
)
self.db.add(agent)
self.db.commit()
self.db.refresh(agent)
return agent
def update_agent(self, agent_id: int, data: Dict) -> Optional[Agent]:
"""更新Agent"""
agent = self.get_agent(agent_id)
if not agent:
return None
for key, value in data.items():
if hasattr(agent, key) and value is not None:
setattr(agent, key, value)
self.db.commit()
self.db.refresh(agent)
return agent
def delete_agent(self, agent_id: int) -> bool:
"""删除Agent"""
agent = self.get_agent(agent_id)
if not agent:
return False
# 如果是默认Agent不允许删除
if agent.is_default:
return False
self.db.delete(agent)
self.db.commit()
return True
def set_default_agent(self, agent_id: int) -> bool:
"""设置默认Agent"""
# 先清除所有默认
self.db.query(Agent).update({Agent.is_default: False})
agent = self.get_agent(agent_id)
if not agent:
return False
agent.is_default = True
self.db.commit()
return True
def get_agent_config(self, agent_id: int) -> Dict:
"""获取Agent完整配置包含LLM Provider信息"""
agent = self.get_agent(agent_id)
if not agent:
return {}
provider = self.db.query(LLMProvider).filter(LLMProvider.id == agent.llm_provider_id).first()
return {
'agent': {
'id': agent.id,
'name': agent.name,
'display_name': agent.display_name,
'system_prompt': agent.system_prompt,
'enable_thinking': agent.enable_thinking,
'thinking_prompt': agent.thinking_prompt,
'thinking_prefix': agent.thinking_prefix,
'thinking_suffix': agent.thinking_suffix,
'model_override': agent.model_override,
'max_history': agent.max_history,
'temperature_override': agent.temperature_override,
'is_default': agent.is_default,
'is_active': agent.is_active
},
'provider': {
'id': provider.id if provider else None,
'name': provider.name if provider else None,
'api_base': provider.api_base if provider else None,
'api_key': provider.api_key if provider else None,
'supports_thinking': provider.supports_thinking if provider else False,
'thinking_model': provider.thinking_model if provider else None,
'default_model': provider.default_model if provider else 'auto',
'max_tokens': provider.max_tokens if provider else 4096,
'temperature': provider.temperature if provider else 0.7,
'models': provider.models if provider else []
} if provider else {}
}
def get_agent_for_channel(self, channel_id: int, conditions: Dict = None) -> Optional[Agent]:
"""根据渠道获取Agent支持优先级和条件"""
mappings = self.db.query(ChannelAgentMapping).filter(
ChannelAgentMapping.channel_id == channel_id,
ChannelAgentMapping.is_active == True
).order_by(ChannelAgentMapping.priority).all()
if not mappings:
return self.get_default_agent()
# 检查条件匹配
for mapping in mappings:
if conditions and mapping.conditions:
# 检查条件是否匹配
match = True
for key, value in conditions.items():
if mapping.conditions.get(key) != value:
match = False
break
if match:
return self.get_agent(mapping.agent_id)
else:
# 无条件或无条件映射,返回第一个
return self.get_agent(mappings.agent_id)
# 没有匹配的,返回默认
return self.get_default_agent()
class LLMProviderService:
"""大模型池管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_providers(self) -> List[LLMProvider]:
"""获取所有Provider"""
return self.db.query(LLMProvider).order_by(LLMProvider.priority, LLMProvider.name).all()
def get_active_providers(self) -> List[LLMProvider]:
"""获取活跃的Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.is_active == True).all()
def get_provider(self, provider_id: int) -> Optional[LLMProvider]:
"""获取单个Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.id == provider_id).first()
def get_provider_by_name(self, name: str) -> Optional[LLMProvider]:
"""通过名称获取Provider"""
return self.db.query(LLMProvider).filter(LLMProvider.name == name).first()
def create_provider(self, data: Dict) -> LLMProvider:
"""创建Provider"""
provider = LLMProvider(
name=data.get('name'),
api_base=data.get('api_base'),
api_key=data.get('api_key', ''),
models=data.get('models', []),
default_model=data.get('default_model', 'auto'),
supports_thinking=data.get('supports_thinking', False),
thinking_model=data.get('thinking_model'),
max_tokens=data.get('max_tokens', 4096),
temperature=data.get('temperature', 0.7),
is_active=data.get('is_active', True),
priority=data.get('priority', 0),
description=data.get('description')
)
self.db.add(provider)
self.db.commit()
self.db.refresh(provider)
return provider
def update_provider(self, provider_id: int, data: Dict) -> Optional[LLMProvider]:
"""更新Provider"""
provider = self.get_provider(provider_id)
if not provider:
return None
for key, value in data.items():
if hasattr(provider, key) and value is not None:
setattr(provider, key, value)
self.db.commit()
self.db.refresh(provider)
return provider
def delete_provider(self, provider_id: int) -> bool:
"""删除Provider"""
provider = self.get_provider(provider_id)
if not provider:
return False
# 检查是否有Agent在使用
from models_v2 import Agent
agents_count = self.db.query(Agent).filter(Agent.llm_provider_id == provider_id).count()
if agents_count > 0:
return False
self.db.delete(provider)
self.db.commit()
return True
def get_provider_config(self, provider_id: int) -> Dict:
"""获取Provider配置"""
provider = self.get_provider(provider_id)
if not provider:
return {}
return {
'id': provider.id,
'name': provider.name,
'api_base': provider.api_base,
'api_key': provider.api_key,
'models': provider.models,
'default_model': provider.default_model,
'supports_thinking': provider.supports_thinking,
'thinking_model': provider.thinking_model,
'max_tokens': provider.max_tokens,
'temperature': provider.temperature,
'is_active': provider.is_active,
'priority': provider.priority,
'description': provider.description
}
class ChannelService:
"""渠道管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_channels(self) -> List[Channel]:
"""获取所有渠道"""
return self.db.query(Channel).order_by(Channel.is_primary.desc(), Channel.channel_type).all()
def get_active_channels(self) -> List[Channel]:
"""获取活跃渠道"""
return self.db.query(Channel).filter(Channel.is_active == True).all()
def get_channel(self, channel_id: int) -> Optional[Channel]:
"""获取单个渠道"""
return self.db.query(Channel).filter(Channel.id == channel_id).first()
def get_channel_by_type(self, channel_type: str) -> Optional[Channel]:
"""获取指定类型的渠道"""
return self.db.query(Channel).filter(
Channel.channel_type == channel_type,
Channel.is_active == True
).first()
def get_web_channel(self) -> Optional[Channel]:
"""获取网页渠道"""
return self.get_channel_by_type('web')
def get_matrix_channel(self) -> Optional[Channel]:
"""获取Matrix渠道"""
return self.get_channel_by_type('matrix')
def create_channel(self, data: Dict) -> Channel:
"""创建渠道"""
channel = Channel(
channel_type=data.get('channel_type'),
name=data.get('name'),
config=data.get('config', {}),
is_active=data.get('is_active', True),
is_primary=data.get('is_primary', False)
)
self.db.add(channel)
self.db.commit()
self.db.refresh(channel)
return channel
def update_channel(self, channel_id: int, data: Dict) -> Optional[Channel]:
"""更新渠道"""
channel = self.get_channel(channel_id)
if not channel:
return None
for key, value in data.items():
if hasattr(channel, key) and value is not None:
setattr(channel, key, value)
self.db.commit()
self.db.refresh(channel)
return channel
def delete_channel(self, channel_id: int) -> bool:
"""删除渠道"""
channel = self.get_channel(channel_id)
if not channel:
return False
# 检查是否是主渠道
if channel.is_primary:
return False
self.db.delete(channel)
self.db.commit()
return True
def bind_agent(self, channel_id: int, agent_id: int, priority: int = 0, mode: str = 'single', conditions: Dict = None) -> ChannelAgentMapping:
"""绑定Agent到渠道"""
mapping = ChannelAgentMapping(
channel_id=channel_id,
agent_id=agent_id,
priority=priority,
mode=mode,
conditions=conditions,
is_active=True
)
self.db.add(mapping)
self.db.commit()
self.db.refresh(mapping)
return mapping
def unbind_agent(self, mapping_id: int) -> bool:
"""解绑Agent"""
mapping = self.db.query(ChannelAgentMapping).filter(ChannelAgentMapping.id == mapping_id).first()
if not mapping:
return False
self.db.delete(mapping)
self.db.commit()
return True
def get_channel_agents(self, channel_id: int) -> List[Dict]:
"""获取渠道绑定的所有Agent"""
mappings = self.db.query(ChannelAgentMapping).filter(
ChannelAgentMapping.channel_id == channel_id
).order_by(ChannelAgentMapping.priority).all()
result = []
for mapping in mappings:
agent = self.db.query(Agent).filter(Agent.id == mapping.agent_id).first()
if agent:
result.append({
'mapping_id': mapping.id,
'priority': mapping.priority,
'mode': mapping.mode,
'conditions': mapping.conditions,
'agent': {
'id': agent.id,
'name': agent.name,
'display_name': agent.display_name,
'is_active': agent.is_active
}
})
return result
def get_channel_config(self, channel_id: int) -> Dict:
"""获取渠道完整配置"""
channel = self.get_channel(channel_id)
if not channel:
return {}
return {
'id': channel.id,
'channel_type': channel.channel_type,
'name': channel.name,
'config': channel.config,
'is_active': channel.is_active,
'is_primary': channel.is_primary,
'agent_mappings': self.get_channel_agents(channel_id)
}
class SearchToolService:
"""搜索工具管理服务"""
def __init__(self, db: Session):
self.db = db
def get_all_configs(self) -> List[SearchToolConfig]:
"""获取所有搜索工具配置"""
return self.db.query(SearchToolConfig).order_by(SearchToolConfig.is_default.desc(), SearchToolConfig.name).all()
def get_active_configs(self) -> List[SearchToolConfig]:
"""获取活跃的搜索工具配置"""
return self.db.query(SearchToolConfig).filter(SearchToolConfig.is_active == True).all()
def get_config(self, config_id: int) -> Optional[SearchToolConfig]:
"""获取单个配置"""
return self.db.query(SearchToolConfig).filter(SearchToolConfig.id == config_id).first()
def get_default_config(self) -> Optional[SearchToolConfig]:
"""获取默认配置"""
config = self.db.query(SearchToolConfig).filter(
SearchToolConfig.is_default == True,
SearchToolConfig.is_active == True
).first()
if not config:
config = self.db.query(SearchToolConfig).filter(SearchToolConfig.is_active == True).first()
return config
def create_config(self, data: Dict) -> SearchToolConfig:
"""创建搜索工具配置"""
config = SearchToolConfig(
name=data.get('name'),
provider=data.get('provider', 'tavily'),
api_key=data.get('api_key'),
api_base=data.get('api_base'),
max_results=data.get('max_results', 5),
include_raw_content=data.get('include_raw_content', False),
search_depth=data.get('search_depth', 'basic'),
is_active=data.get('is_active', True),
is_default=data.get('is_default', False)
)
self.db.add(config)
self.db.commit()
self.db.refresh(config)
return config
def update_config(self, config_id: int, data: Dict) -> Optional[SearchToolConfig]:
"""更新配置"""
config = self.get_config(config_id)
if not config:
return None
for key, value in data.items():
if hasattr(config, key) and value is not None:
setattr(config, key, value)
self.db.commit()
self.db.refresh(config)
return config
def delete_config(self, config_id: int) -> bool:
"""删除配置"""
config = self.get_config(config_id)
if not config:
return False
self.db.delete(config)
self.db.commit()
return True
def set_default_config(self, config_id: int) -> bool:
"""设置默认配置"""
self.db.query(SearchToolConfig).update({SearchToolConfig.is_default: False})
config = self.get_config(config_id)
if not config:
return False
config.is_default = True
self.db.commit()
return True