feat: 后台管理增加提供商动态管理功能

- 新增:添加新的大模型接口提供商
- 新增:编辑已有提供商的参数(API地址、Key、模型列表等)
- 新增:删除提供商
- 新增:拖拽排序调整auto模式的优先级顺序
- 新增:启用/禁用提供商开关
- 优化:主服务动态读取配置,后台修改实时生效
This commit is contained in:
2026-04-08 18:14:09 +08:00
parent 292ff7b03e
commit 82100cdf00
5 changed files with 819 additions and 124 deletions

95
app.py
View File

@@ -16,13 +16,39 @@ import sys
# 添加配置路径
sys.path.insert(0, str(Path(__file__).parent))
from config.settings import (
UPSTREAM_PROVIDERS, MODEL_ALIASES, SERVER_CONFIG,
get_providers, get_model_aliases, SERVER_CONFIG,
LOG_CONFIG, RETRY_CONFIG
)
app = Flask(__name__)
CORS(app)
# 配置缓存时间(秒)
CONFIG_CACHE_TTL = 5
_last_config_load = 0
_cached_providers = []
_cached_aliases = {}
def refresh_config():
"""动态刷新配置(支持后台管理修改)"""
global _last_config_load, _cached_providers, _cached_aliases, provider_status
current_time = time.time()
if current_time - _last_config_load > CONFIG_CACHE_TTL:
_cached_providers = get_providers()
_cached_aliases = get_model_aliases()
_last_config_load = current_time
# 更新提供商状态缓存(新增的提供商)
for provider in _cached_providers:
if provider['name'] not in provider_status:
provider_status[provider['name']] = {
'available': True,
'last_check': None,
'error_count': 0,
'last_error': None,
}
# 配置日志
log_dir = Path(__file__).parent / LOG_CONFIG['log_dir']
log_dir.mkdir(exist_ok=True)
@@ -39,31 +65,29 @@ logger = logging.getLogger(__name__)
# 提供商状态缓存
provider_status = {}
for provider in UPSTREAM_PROVIDERS:
provider_status[provider['name']] = {
'available': True,
'last_check': None,
'error_count': 0,
'last_error': None,
}
# 初始化
refresh_config()
def get_provider_for_model(model_name):
"""根据模型名获取提供商"""
refresh_config()
# 解析别名
resolved_model = MODEL_ALIASES.get(model_name, model_name)
resolved_model = _cached_aliases.get(model_name, model_name)
# auto模式按优先级选择可用提供商
if resolved_model == 'auto':
return get_available_provider()
# 查找支持该模型的提供商
sorted_providers = sorted(UPSTREAM_PROVIDERS, key=lambda x: x['priority'])
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
for provider in sorted_providers:
if not provider['enabled']:
continue
if not provider_status[provider['name']]['available']:
if not provider_status.get(provider['name'], {}).get('available', True):
continue
if resolved_model in provider['models']:
return provider, resolved_model
@@ -72,7 +96,7 @@ def get_provider_for_model(model_name):
for provider in sorted_providers:
if not provider['enabled']:
continue
if not provider_status[provider['name']]['available']:
if not provider_status.get(provider['name'], {}).get('available', True):
continue
for m in provider['models']:
if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower():
@@ -83,10 +107,12 @@ def get_provider_for_model(model_name):
def get_available_provider():
"""获取可用的提供商(按优先级)"""
sorted_providers = sorted(UPSTREAM_PROVIDERS, key=lambda x: x['priority'])
refresh_config()
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
for provider in sorted_providers:
if provider['enabled'] and provider_status[provider['name']]['available']:
if provider['enabled'] and provider_status.get(provider['name'], {}).get('available', True):
return provider, provider['default_model']
# 如果都不可用,返回第一个尝试(让错误信息传递)
@@ -194,10 +220,12 @@ def index():
@app.route('/v1/models', methods=['GET'])
def list_models():
"""列出可用模型"""
refresh_config()
models_list = []
added_models = set()
for provider in UPSTREAM_PROVIDERS:
for provider in _cached_providers:
if not provider['enabled']:
continue
for model in provider['models']:
@@ -328,21 +356,26 @@ def chat_completions():
@app.route('/v1/embeddings', methods=['POST'])
def embeddings():
"""嵌入API简单转发"""
refresh_config()
try:
data = request.get_json()
model = data.get('model', 'text-embedding-ada-002')
# 使用第一个可用提供商
provider = UPSTREAM_PROVIDERS[0]
url = f"{provider['base_url'].rstrip('/')}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {provider['api_key']}"
}
response = requests.post(url, headers=headers, json=data, timeout=60)
return jsonify(response.json()), response.status_code
if _cached_providers:
provider = _cached_providers[0]
url = f"{provider['base_url'].rstrip('/')}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {provider['api_key']}"
}
response = requests.post(url, headers=headers, json=data, timeout=60)
return jsonify(response.json()), response.status_code
else:
return jsonify({"error": "No providers available"}), 503
except Exception as e:
return jsonify({"error": str(e)}), 500
@@ -367,9 +400,11 @@ def health():
@app.route('/status', methods=['GET'])
def status():
"""详细状态"""
refresh_config()
providers_detail = []
for provider in UPSTREAM_PROVIDERS:
for provider in _cached_providers:
status_info = provider_status.get(provider['name'], {})
providers_detail.append({
"name": provider['name'],
@@ -385,7 +420,7 @@ def status():
"version": "1.0.0",
"uptime": time.time(),
"providers": providers_detail,
"model_aliases": MODEL_ALIASES,
"model_aliases": _cached_aliases,
})
@@ -405,6 +440,8 @@ def engine_completions(model):
if __name__ == '__main__':
refresh_config()
print("=" * 60)
print("大模型API中转系统")
print("=" * 60)
@@ -412,12 +449,12 @@ if __name__ == '__main__':
print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions")
print("=" * 60)
print("上游提供商:")
for p in sorted(UPSTREAM_PROVIDERS, key=lambda x: x['priority']):
for p in sorted(_cached_providers, key=lambda x: x['priority']):
print(f" [{p['priority']}] {p['name']}: {p['base_url']}")
print(f" 模型: {', '.join(p['models'])}")
print("=" * 60)
print("支持的模型别名:")
for alias, target in MODEL_ALIASES.items():
for alias, target in _cached_aliases.items():
print(f" {alias} -> {target}")
print("=" * 60)