feat: 后台管理增加提供商动态管理功能
- 新增:添加新的大模型接口提供商 - 新增:编辑已有提供商的参数(API地址、Key、模型列表等) - 新增:删除提供商 - 新增:拖拽排序调整auto模式的优先级顺序 - 新增:启用/禁用提供商开关 - 优化:主服务动态读取配置,后台修改实时生效
This commit is contained in:
95
app.py
95
app.py
@@ -16,13 +16,39 @@ import sys
|
||||
# 添加配置路径
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
from config.settings import (
|
||||
UPSTREAM_PROVIDERS, MODEL_ALIASES, SERVER_CONFIG,
|
||||
get_providers, get_model_aliases, SERVER_CONFIG,
|
||||
LOG_CONFIG, RETRY_CONFIG
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# 配置缓存时间(秒)
|
||||
CONFIG_CACHE_TTL = 5
|
||||
_last_config_load = 0
|
||||
_cached_providers = []
|
||||
_cached_aliases = {}
|
||||
|
||||
def refresh_config():
|
||||
"""动态刷新配置(支持后台管理修改)"""
|
||||
global _last_config_load, _cached_providers, _cached_aliases, provider_status
|
||||
|
||||
current_time = time.time()
|
||||
if current_time - _last_config_load > CONFIG_CACHE_TTL:
|
||||
_cached_providers = get_providers()
|
||||
_cached_aliases = get_model_aliases()
|
||||
_last_config_load = current_time
|
||||
|
||||
# 更新提供商状态缓存(新增的提供商)
|
||||
for provider in _cached_providers:
|
||||
if provider['name'] not in provider_status:
|
||||
provider_status[provider['name']] = {
|
||||
'available': True,
|
||||
'last_check': None,
|
||||
'error_count': 0,
|
||||
'last_error': None,
|
||||
}
|
||||
|
||||
# 配置日志
|
||||
log_dir = Path(__file__).parent / LOG_CONFIG['log_dir']
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
@@ -39,31 +65,29 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# 提供商状态缓存
|
||||
provider_status = {}
|
||||
for provider in UPSTREAM_PROVIDERS:
|
||||
provider_status[provider['name']] = {
|
||||
'available': True,
|
||||
'last_check': None,
|
||||
'error_count': 0,
|
||||
'last_error': None,
|
||||
}
|
||||
|
||||
# 初始化
|
||||
refresh_config()
|
||||
|
||||
|
||||
def get_provider_for_model(model_name):
|
||||
"""根据模型名获取提供商"""
|
||||
refresh_config()
|
||||
|
||||
# 解析别名
|
||||
resolved_model = MODEL_ALIASES.get(model_name, model_name)
|
||||
resolved_model = _cached_aliases.get(model_name, model_name)
|
||||
|
||||
# auto模式:按优先级选择可用提供商
|
||||
if resolved_model == 'auto':
|
||||
return get_available_provider()
|
||||
|
||||
# 查找支持该模型的提供商
|
||||
sorted_providers = sorted(UPSTREAM_PROVIDERS, key=lambda x: x['priority'])
|
||||
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
|
||||
|
||||
for provider in sorted_providers:
|
||||
if not provider['enabled']:
|
||||
continue
|
||||
if not provider_status[provider['name']]['available']:
|
||||
if not provider_status.get(provider['name'], {}).get('available', True):
|
||||
continue
|
||||
if resolved_model in provider['models']:
|
||||
return provider, resolved_model
|
||||
@@ -72,7 +96,7 @@ def get_provider_for_model(model_name):
|
||||
for provider in sorted_providers:
|
||||
if not provider['enabled']:
|
||||
continue
|
||||
if not provider_status[provider['name']]['available']:
|
||||
if not provider_status.get(provider['name'], {}).get('available', True):
|
||||
continue
|
||||
for m in provider['models']:
|
||||
if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower():
|
||||
@@ -83,10 +107,12 @@ def get_provider_for_model(model_name):
|
||||
|
||||
def get_available_provider():
|
||||
"""获取可用的提供商(按优先级)"""
|
||||
sorted_providers = sorted(UPSTREAM_PROVIDERS, key=lambda x: x['priority'])
|
||||
refresh_config()
|
||||
|
||||
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
|
||||
|
||||
for provider in sorted_providers:
|
||||
if provider['enabled'] and provider_status[provider['name']]['available']:
|
||||
if provider['enabled'] and provider_status.get(provider['name'], {}).get('available', True):
|
||||
return provider, provider['default_model']
|
||||
|
||||
# 如果都不可用,返回第一个尝试(让错误信息传递)
|
||||
@@ -194,10 +220,12 @@ def index():
|
||||
@app.route('/v1/models', methods=['GET'])
|
||||
def list_models():
|
||||
"""列出可用模型"""
|
||||
refresh_config()
|
||||
|
||||
models_list = []
|
||||
added_models = set()
|
||||
|
||||
for provider in UPSTREAM_PROVIDERS:
|
||||
for provider in _cached_providers:
|
||||
if not provider['enabled']:
|
||||
continue
|
||||
for model in provider['models']:
|
||||
@@ -328,21 +356,26 @@ def chat_completions():
|
||||
@app.route('/v1/embeddings', methods=['POST'])
|
||||
def embeddings():
|
||||
"""嵌入API(简单转发)"""
|
||||
refresh_config()
|
||||
|
||||
try:
|
||||
data = request.get_json()
|
||||
model = data.get('model', 'text-embedding-ada-002')
|
||||
|
||||
# 使用第一个可用提供商
|
||||
provider = UPSTREAM_PROVIDERS[0]
|
||||
|
||||
url = f"{provider['base_url'].rstrip('/')}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {provider['api_key']}"
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=60)
|
||||
return jsonify(response.json()), response.status_code
|
||||
if _cached_providers:
|
||||
provider = _cached_providers[0]
|
||||
|
||||
url = f"{provider['base_url'].rstrip('/')}/embeddings"
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {provider['api_key']}"
|
||||
}
|
||||
|
||||
response = requests.post(url, headers=headers, json=data, timeout=60)
|
||||
return jsonify(response.json()), response.status_code
|
||||
else:
|
||||
return jsonify({"error": "No providers available"}), 503
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"error": str(e)}), 500
|
||||
@@ -367,9 +400,11 @@ def health():
|
||||
@app.route('/status', methods=['GET'])
|
||||
def status():
|
||||
"""详细状态"""
|
||||
refresh_config()
|
||||
|
||||
providers_detail = []
|
||||
|
||||
for provider in UPSTREAM_PROVIDERS:
|
||||
for provider in _cached_providers:
|
||||
status_info = provider_status.get(provider['name'], {})
|
||||
providers_detail.append({
|
||||
"name": provider['name'],
|
||||
@@ -385,7 +420,7 @@ def status():
|
||||
"version": "1.0.0",
|
||||
"uptime": time.time(),
|
||||
"providers": providers_detail,
|
||||
"model_aliases": MODEL_ALIASES,
|
||||
"model_aliases": _cached_aliases,
|
||||
})
|
||||
|
||||
|
||||
@@ -405,6 +440,8 @@ def engine_completions(model):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
refresh_config()
|
||||
|
||||
print("=" * 60)
|
||||
print("大模型API中转系统")
|
||||
print("=" * 60)
|
||||
@@ -412,12 +449,12 @@ if __name__ == '__main__':
|
||||
print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions")
|
||||
print("=" * 60)
|
||||
print("上游提供商:")
|
||||
for p in sorted(UPSTREAM_PROVIDERS, key=lambda x: x['priority']):
|
||||
for p in sorted(_cached_providers, key=lambda x: x['priority']):
|
||||
print(f" [{p['priority']}] {p['name']}: {p['base_url']}")
|
||||
print(f" 模型: {', '.join(p['models'])}")
|
||||
print("=" * 60)
|
||||
print("支持的模型别名:")
|
||||
for alias, target in MODEL_ALIASES.items():
|
||||
for alias, target in _cached_aliases.items():
|
||||
print(f" {alias} -> {target}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user