Files
llm-proxy/app.py
hubian 82100cdf00 feat: 后台管理增加提供商动态管理功能
- 新增:添加新的大模型接口提供商
- 新增:编辑已有提供商的参数(API地址、Key、模型列表等)
- 新增:删除提供商
- 新增:拖拽排序调整auto模式的优先级顺序
- 新增:启用/禁用提供商开关
- 优化:主服务动态读取配置,后台修改实时生效
2026-04-08 18:14:09 +08:00

465 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
大模型API中转系统
兼容OpenAI API格式支持多上游提供商优先级调度
"""
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
import requests
import json
import time
import logging
from datetime import datetime
from pathlib import Path
import sys
# 添加配置路径
sys.path.insert(0, str(Path(__file__).parent))
from config.settings import (
get_providers, get_model_aliases, SERVER_CONFIG,
LOG_CONFIG, RETRY_CONFIG
)
app = Flask(__name__)
CORS(app)
# 配置缓存时间(秒)
CONFIG_CACHE_TTL = 5
_last_config_load = 0
_cached_providers = []
_cached_aliases = {}
def refresh_config():
"""动态刷新配置(支持后台管理修改)"""
global _last_config_load, _cached_providers, _cached_aliases, provider_status
current_time = time.time()
if current_time - _last_config_load > CONFIG_CACHE_TTL:
_cached_providers = get_providers()
_cached_aliases = get_model_aliases()
_last_config_load = current_time
# 更新提供商状态缓存(新增的提供商)
for provider in _cached_providers:
if provider['name'] not in provider_status:
provider_status[provider['name']] = {
'available': True,
'last_check': None,
'error_count': 0,
'last_error': None,
}
# 配置日志
log_dir = Path(__file__).parent / LOG_CONFIG['log_dir']
log_dir.mkdir(exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / 'proxy.log', encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 提供商状态缓存
provider_status = {}
# 初始化
refresh_config()
def get_provider_for_model(model_name):
"""根据模型名获取提供商"""
refresh_config()
# 解析别名
resolved_model = _cached_aliases.get(model_name, model_name)
# auto模式按优先级选择可用提供商
if resolved_model == 'auto':
return get_available_provider()
# 查找支持该模型的提供商
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
for provider in sorted_providers:
if not provider['enabled']:
continue
if not provider_status.get(provider['name'], {}).get('available', True):
continue
if resolved_model in provider['models']:
return provider, resolved_model
# 如果没找到精确匹配,尝试模糊匹配
for provider in sorted_providers:
if not provider['enabled']:
continue
if not provider_status.get(provider['name'], {}).get('available', True):
continue
for m in provider['models']:
if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower():
return provider, m
return None, None
def get_available_provider():
"""获取可用的提供商(按优先级)"""
refresh_config()
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
for provider in sorted_providers:
if provider['enabled'] and provider_status.get(provider['name'], {}).get('available', True):
return provider, provider['default_model']
# 如果都不可用,返回第一个尝试(让错误信息传递)
if sorted_providers:
return sorted_providers[0], sorted_providers[0]['default_model']
return None, None
def mark_provider_error(provider_name, error):
"""标记提供商错误"""
if provider_name in provider_status:
provider_status[provider_name]['error_count'] += 1
provider_status[provider_name]['last_error'] = str(error)
provider_status[provider_name]['last_check'] = datetime.now()
# 连续错误超过阈值,暂时标记不可用
if provider_status[provider_name]['error_count'] >= 3:
provider_status[provider_name]['available'] = False
logger.warning(f"Provider {provider_name} marked as unavailable due to errors")
def mark_provider_success(provider_name):
"""标记提供商成功"""
if provider_name in provider_status:
provider_status[provider_name]['error_count'] = 0
provider_status[provider_name]['available'] = True
provider_status[provider_name]['last_check'] = datetime.now()
def proxy_request(provider, model, request_data, stream=False):
"""转发请求到上游提供商"""
url = f"{provider['base_url'].rstrip('/')}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {provider['api_key']}"
}
# 构建请求数据
data = request_data.copy()
data['model'] = model
try:
if stream:
# 流式请求
response = requests.post(
url,
headers=headers,
json=data,
stream=True,
timeout=provider.get('timeout', 120)
)
return response
else:
# 非流式请求
response = requests.post(
url,
headers=headers,
json=data,
timeout=provider.get('timeout', 120)
)
return response
except requests.exceptions.Timeout:
mark_provider_error(provider['name'], "Timeout")
raise Exception(f"Provider {provider['name']} timeout")
except requests.exceptions.ConnectionError:
mark_provider_error(provider['name'], "Connection error")
raise Exception(f"Provider {provider['name']} connection error")
except Exception as e:
mark_provider_error(provider['name'], str(e))
raise
def stream_response(response):
"""流式响应生成器"""
try:
for line in response.iter_lines():
if line:
yield line + b'\n'
except Exception as e:
logger.error(f"Stream error: {e}")
yield b'data: {"error": "' + str(e).encode() + b'"}\n\n'
# ============ API 路由 ============
@app.route('/')
def index():
"""首页"""
return jsonify({
"name": "LLM Proxy",
"version": "1.0.0",
"description": "OpenAI-compatible LLM API Proxy",
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"health": "/health",
"status": "/status"
}
})
@app.route('/v1/models', methods=['GET'])
def list_models():
"""列出可用模型"""
refresh_config()
models_list = []
added_models = set()
for provider in _cached_providers:
if not provider['enabled']:
continue
for model in provider['models']:
if model not in added_models:
models_list.append({
"id": model,
"object": "model",
"created": int(time.time()),
"owned_by": provider['name'],
})
added_models.add(model)
# 添加auto模型
if "auto" not in added_models:
models_list.insert(0, {
"id": "auto",
"object": "model",
"created": int(time.time()),
"owned_by": "proxy",
"description": "Auto-select available model by priority"
})
return jsonify({
"object": "list",
"data": models_list
})
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
"""聊天完成API"""
try:
data = request.get_json()
if not data:
return jsonify({"error": "Invalid request body"}), 400
model = data.get('model', 'auto')
stream = data.get('stream', False)
# 获取提供商
provider, resolved_model = get_provider_for_model(model)
if not provider:
return jsonify({
"error": {
"message": f"No available provider for model: {model}",
"type": "invalid_request_error"
}
}), 400
logger.info(f"Request: model={model} -> provider={provider['name']}, resolved_model={resolved_model}, stream={stream}")
# 重试逻辑
last_error = None
tried_providers = []
for attempt in range(RETRY_CONFIG['max_retries']):
try:
response = proxy_request(provider, resolved_model, data, stream)
if response.status_code == 200:
mark_provider_success(provider['name'])
if stream:
# 流式响应
return Response(
stream_with_context(stream_response(response)),
content_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
else:
# 非流式响应
return jsonify(response.json())
elif response.status_code == 429:
# 速率限制,尝试下一个提供商
mark_provider_error(provider['name'], "Rate limit")
tried_providers.append(provider['name'])
# 尝试下一个提供商
next_provider, next_model = get_available_provider()
if next_provider and next_provider['name'] not in tried_providers:
provider = next_provider
resolved_model = next_model
continue
return jsonify(response.json()), response.status_code
else:
last_error = response.json() if response.headers.get('content-type', '').startswith('application/json') else {"error": response.text}
return jsonify(last_error), response.status_code
except Exception as e:
last_error = str(e)
logger.error(f"Attempt {attempt + 1} failed: {e}")
tried_providers.append(provider['name'])
# 尝试下一个提供商
next_provider, next_model = get_available_provider()
if next_provider and next_provider['name'] not in tried_providers:
provider = next_provider
resolved_model = next_model
time.sleep(RETRY_CONFIG['retry_delay'])
continue
# 所有重试都失败
return jsonify({
"error": {
"message": f"All providers failed. Last error: {last_error}",
"type": "api_error"
}
}), 503
except Exception as e:
logger.error(f"Unexpected error: {e}")
return jsonify({
"error": {
"message": str(e),
"type": "internal_error"
}
}), 500
@app.route('/v1/embeddings', methods=['POST'])
def embeddings():
"""嵌入API简单转发"""
refresh_config()
try:
data = request.get_json()
model = data.get('model', 'text-embedding-ada-002')
# 使用第一个可用提供商
if _cached_providers:
provider = _cached_providers[0]
url = f"{provider['base_url'].rstrip('/')}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {provider['api_key']}"
}
response = requests.post(url, headers=headers, json=data, timeout=60)
return jsonify(response.json()), response.status_code
else:
return jsonify({"error": "No providers available"}), 503
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
"""健康检查"""
available_count = sum(1 for s in provider_status.values() if s['available'])
total_count = len(provider_status)
return jsonify({
"status": "healthy" if available_count > 0 else "degraded",
"providers": {
"available": available_count,
"total": total_count,
},
"timestamp": datetime.now().isoformat()
})
@app.route('/status', methods=['GET'])
def status():
"""详细状态"""
refresh_config()
providers_detail = []
for provider in _cached_providers:
status_info = provider_status.get(provider['name'], {})
providers_detail.append({
"name": provider['name'],
"priority": provider['priority'],
"enabled": provider['enabled'],
"available": status_info.get('available', True),
"error_count": status_info.get('error_count', 0),
"last_error": status_info.get('last_error'),
"models": provider['models'],
})
return jsonify({
"version": "1.0.0",
"uptime": time.time(),
"providers": providers_detail,
"model_aliases": _cached_aliases,
})
# 兼容 OpenAI SDK 的其他端点
@app.route('/v1/engines', methods=['GET'])
def list_engines():
"""兼容旧版 engines 端点"""
return list_models()
@app.route('/v1/engines/<model>/completions', methods=['POST'])
def engine_completions(model):
"""兼容旧版 completions 端点"""
data = request.get_json()
data['model'] = model
return chat_completions()
if __name__ == '__main__':
refresh_config()
print("=" * 60)
print("大模型API中转系统")
print("=" * 60)
print(f"访问地址: http://localhost:{SERVER_CONFIG['port']}")
print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions")
print("=" * 60)
print("上游提供商:")
for p in sorted(_cached_providers, key=lambda x: x['priority']):
print(f" [{p['priority']}] {p['name']}: {p['base_url']}")
print(f" 模型: {', '.join(p['models'])}")
print("=" * 60)
print("支持的模型别名:")
for alias, target in _cached_aliases.items():
print(f" {alias} -> {target}")
print("=" * 60)
app.run(
host=SERVER_CONFIG['host'],
port=SERVER_CONFIG['port'],
debug=SERVER_CONFIG['debug']
)