Files
llm-proxy/app.py
hubian 247f9e2165 fix: 修复高优先级提供商失败时自动切换到备用提供商
- 任何非200响应(400、429、500等)都会触发切换
- 增加 exclude_providers 参数排除已尝试的提供商
- 避免重复尝试失败的提供商
- 添加切换日志便于调试
2026-04-10 01:51:36 +08:00

625 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
大模型API中转系统
兼容OpenAI API格式支持多上游提供商优先级调度
"""
from flask import Flask, request, jsonify, Response, stream_with_context
from flask_cors import CORS
import requests
import json
import time
import logging
from datetime import datetime, date
from pathlib import Path
import sys
import threading
# 添加配置路径
sys.path.insert(0, str(Path(__file__).parent))
from config.settings import (
get_providers, get_model_aliases, get_auto_profiles, SERVER_CONFIG,
LOG_CONFIG, RETRY_CONFIG
)
# 数据目录和统计文件
DATA_DIR = Path(__file__).parent / 'data'
DATA_DIR.mkdir(exist_ok=True)
STATS_FILE = DATA_DIR / 'stats.json'
# 统计锁(避免并发写入冲突)
stats_lock = threading.Lock()
def load_stats():
"""加载统计数据"""
if STATS_FILE.exists():
try:
return json.loads(STATS_FILE.read_text(encoding='utf-8'))
except:
pass
return {
'total_requests': 0,
'total_success': 0,
'total_errors': 0,
'total_tokens': 0,
'requests_today': 0,
'requests_by_model': {},
'providers': {},
'last_updated': None,
'date': None # 用于判断是否需要重置每日计数
}
def save_stats(stats):
"""保存统计数据"""
stats['last_updated'] = datetime.now().isoformat()
STATS_FILE.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding='utf-8')
def increment_stats(model, provider_name, success=False, tokens=0, error=None):
"""增加统计计数"""
with stats_lock:
stats = load_stats()
today = date.today().isoformat()
# 如果是新的一天,重置每日计数
if stats.get('date') != today:
stats['date'] = today
stats['requests_today'] = 0
stats['total_requests'] += 1
stats['requests_today'] += 1
# 模型统计
if model not in stats['requests_by_model']:
stats['requests_by_model'][model] = {'count': 0, 'success': 0, 'tokens': 0}
stats['requests_by_model'][model]['count'] += 1
if success:
stats['requests_by_model'][model]['success'] += 1
stats['requests_by_model'][model]['tokens'] += tokens
# 提供商统计
if provider_name not in stats['providers']:
stats['providers'][provider_name] = {'requests': 0, 'success': 0, 'errors': 0, 'tokens': 0}
stats['providers'][provider_name]['requests'] += 1
if success:
stats['providers'][provider_name]['success'] += 1
stats['providers'][provider_name]['tokens'] += tokens
stats['total_success'] += 1
stats['total_tokens'] += tokens
else:
stats['providers'][provider_name]['errors'] += 1
stats['total_errors'] += 1
save_stats(stats)
app = Flask(__name__)
CORS(app)
# 配置缓存时间(秒)
CONFIG_CACHE_TTL = 5
_last_config_load = 0
_cached_providers = []
_cached_aliases = {}
_cached_auto_profiles = {}
def refresh_config():
"""动态刷新配置(支持后台管理修改)"""
global _last_config_load, _cached_providers, _cached_aliases, _cached_auto_profiles, provider_status
current_time = time.time()
if current_time - _last_config_load > CONFIG_CACHE_TTL:
_cached_providers = get_providers()
_cached_aliases = get_model_aliases()
_cached_auto_profiles = get_auto_profiles()
_last_config_load = current_time
# 更新提供商状态缓存(新增的提供商)
for provider in _cached_providers:
if provider['name'] not in provider_status:
provider_status[provider['name']] = {
'available': True,
'last_check': None,
'error_count': 0,
'last_error': None,
}
# 配置日志
log_dir = Path(__file__).parent / LOG_CONFIG['log_dir']
log_dir.mkdir(exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_dir / 'proxy.log', encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# 提供商状态缓存
provider_status = {}
# 初始化
refresh_config()
def get_provider_for_model(model_name):
"""根据模型名获取提供商"""
refresh_config()
# 解析别名
resolved_model = _cached_aliases.get(model_name, model_name)
# auto模式按优先级选择可用提供商
if resolved_model == 'auto' or resolved_model.startswith('auto-'):
return get_available_provider_for_auto(resolved_model)
# 查找支持该模型的提供商
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
for provider in sorted_providers:
if not provider['enabled']:
continue
if not provider_status.get(provider['name'], {}).get('available', True):
continue
if resolved_model in provider['models']:
return provider, resolved_model
# 如果没找到精确匹配,尝试模糊匹配
for provider in sorted_providers:
if not provider['enabled']:
continue
if not provider_status.get(provider['name'], {}).get('available', True):
continue
for m in provider['models']:
if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower():
return provider, m
return None, None
def get_available_provider_for_auto(auto_name='auto', exclude_providers=None):
"""获取auto模式下的可用提供商支持自定义auto配置
Args:
auto_name: auto配置名称
exclude_providers: 要排除的提供商名称列表(已尝试过的)
"""
refresh_config()
exclude_providers = exclude_providers or []
# 获取auto配置
profile = _cached_auto_profiles.get(auto_name, _cached_auto_profiles.get('auto', {}))
# 获取允许的提供商
allowed_providers = profile.get('providers', ['*'])
# 筛选提供商
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
candidates = []
for provider in sorted_providers:
if not provider['enabled']:
continue
if provider['name'] in exclude_providers:
continue
if not provider_status.get(provider['name'], {}).get('available', True):
continue
# 检查是否在允许列表中
if '*' in allowed_providers:
candidates.append(provider)
elif provider.get('id') in allowed_providers or provider['name'] in allowed_providers:
candidates.append(provider)
# 根据策略选择
strategy = profile.get('strategy', 'priority')
if not candidates:
# 如果没有候选,返回第一个尝试(让错误信息传递)
if sorted_providers:
return sorted_providers[0], sorted_providers[0]['default_model']
return None, None
if strategy == 'priority':
# 按优先级选择第一个
return candidates[0], candidates[0]['default_model']
elif strategy == 'random':
# 随机选择
import random
selected = random.choice(candidates)
return selected, selected['default_model']
else:
# 默认按优先级
return candidates[0], candidates[0]['default_model']
def mark_provider_error(provider_name, error):
"""标记提供商错误"""
if provider_name in provider_status:
provider_status[provider_name]['error_count'] += 1
provider_status[provider_name]['last_error'] = str(error)
provider_status[provider_name]['last_check'] = datetime.now()
# 连续错误超过阈值,暂时标记不可用
if provider_status[provider_name]['error_count'] >= 3:
provider_status[provider_name]['available'] = False
logger.warning(f"Provider {provider_name} marked as unavailable due to errors")
def mark_provider_success(provider_name):
"""标记提供商成功"""
if provider_name in provider_status:
provider_status[provider_name]['error_count'] = 0
provider_status[provider_name]['available'] = True
provider_status[provider_name]['last_check'] = datetime.now()
def proxy_request(provider, model, request_data, stream=False):
"""转发请求到上游提供商"""
url = f"{provider['base_url'].rstrip('/')}/chat/completions"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {provider['api_key']}"
}
# 构建请求数据
data = request_data.copy()
data['model'] = model
try:
if stream:
# 流式请求
response = requests.post(
url,
headers=headers,
json=data,
stream=True,
timeout=provider.get('timeout', 120)
)
return response
else:
# 非流式请求
response = requests.post(
url,
headers=headers,
json=data,
timeout=provider.get('timeout', 120)
)
return response
except requests.exceptions.Timeout:
mark_provider_error(provider['name'], "Timeout")
raise Exception(f"Provider {provider['name']} timeout")
except requests.exceptions.ConnectionError:
mark_provider_error(provider['name'], "Connection error")
raise Exception(f"Provider {provider['name']} connection error")
except Exception as e:
mark_provider_error(provider['name'], str(e))
raise
def stream_response(response):
"""流式响应生成器"""
try:
for line in response.iter_lines():
if line:
yield line + b'\n'
except Exception as e:
logger.error(f"Stream error: {e}")
yield b'data: {"error": "' + str(e).encode() + b'"}\n\n'
# ============ API 路由 ============
@app.route('/')
def index():
"""首页"""
return jsonify({
"name": "LLM Proxy",
"version": "1.0.0",
"description": "OpenAI-compatible LLM API Proxy",
"endpoints": {
"chat": "/v1/chat/completions",
"models": "/v1/models",
"embeddings": "/v1/embeddings",
"health": "/health",
"status": "/status"
},
"examples": {
"curl": {
"description": "curl 命令行",
"code": "curl -X POST http://localhost:19007/v1/chat/completions -H 'Content-Type: application/json' -d '{\"model\": \"auto\", \"messages\": [{\"role\": \"user\", \"content\": \"你好\"}]}'"
},
"python_openai": {
"description": "Python OpenAI SDK",
"code": "from openai import OpenAI\nclient = OpenAI(base_url='http://localhost:19007/v1', api_key='any')\nresponse = client.chat.completions.create(model='auto', messages=[{'role': 'user', 'content': '你好'}])\nprint(response.choices[0].message.content)"
},
"python_requests": {
"description": "Python requests",
"code": "import requests\nresponse = requests.post('http://localhost:19007/v1/chat/completions', json={'model': 'auto', 'messages': [{'role': 'user', 'content': '你好'}]})\nprint(response.json()['choices'][0]['message']['content'])"
},
"javascript": {
"description": "JavaScript fetch",
"code": "fetch('http://localhost:19007/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({model: 'auto', messages: [{role: 'user', content: '你好'}]})}).then(r => r.json()).then(d => console.log(d.choices[0].message.content))"
}
}
})
@app.route('/v1/models', methods=['GET'])
def list_models():
"""列出可用模型"""
refresh_config()
models_list = []
added_models = set()
# 添加所有auto配置
for profile_name, profile in _cached_auto_profiles.items():
if profile_name not in added_models:
models_list.append({
"id": profile_name,
"object": "model",
"created": int(time.time()),
"owned_by": "proxy",
"description": profile.get('description', 'Auto-select available model')
})
added_models.add(profile_name)
for provider in _cached_providers:
if not provider['enabled']:
continue
for model in provider['models']:
if model not in added_models:
models_list.append({
"id": model,
"object": "model",
"created": int(time.time()),
"owned_by": provider['name'],
})
added_models.add(model)
return jsonify({
"object": "list",
"data": models_list
})
@app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
"""聊天完成API"""
# 用于统计的变量
request_model = None
request_provider = None
request_success = False
request_tokens = 0
try:
data = request.get_json()
if not data:
increment_stats('unknown', 'unknown', success=False, error='Invalid request body')
return jsonify({"error": "Invalid request body"}), 400
model = data.get('model', 'auto')
stream = data.get('stream', False)
request_model = model
# 获取提供商
provider, resolved_model = get_provider_for_model(model)
if not provider:
increment_stats(model, 'unknown', success=False, error=f'No provider for model: {model}')
return jsonify({
"error": {
"message": f"No available provider for model: {model}",
"type": "invalid_request_error"
}
}), 400
request_provider = provider['name']
logger.info(f"Request: model={model} -> provider={provider['name']}, resolved_model={resolved_model}, stream={stream}")
# 重试逻辑
last_error = None
tried_providers = []
for attempt in range(RETRY_CONFIG['max_retries']):
try:
response = proxy_request(provider, resolved_model, data, stream)
if response.status_code == 200:
mark_provider_success(provider['name'])
request_success = True
if stream:
# 流式响应 - 统计流式请求
increment_stats(model, provider['name'], success=True, tokens=0)
return Response(
stream_with_context(stream_response(response)),
content_type='text/event-stream',
headers={
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
}
)
else:
# 非流式响应 - 提取token统计
result = response.json()
# 尝试提取usage信息
usage = result.get('usage', {})
request_tokens = usage.get('total_tokens', 0)
increment_stats(model, provider['name'], success=True, tokens=request_tokens)
return jsonify(result)
else:
# 任何非200响应都尝试下一个提供商
error_info = response.json() if response.headers.get('content-type', '').startswith('application/json') else {"error": response.text}
last_error = error_info
logger.warning(f"Provider {provider['name']} returned {response.status_code}: {error_info}")
mark_provider_error(provider['name'], f"HTTP {response.status_code}")
tried_providers.append(provider['name'])
# 尝试下一个提供商
next_provider, next_model = get_available_provider_for_auto('auto', exclude_providers=tried_providers)
if next_provider and next_provider['name'] not in tried_providers:
logger.info(f"Switching to next provider: {next_provider['name']}")
provider = next_provider
resolved_model = next_model
request_provider = provider['name']
time.sleep(RETRY_CONFIG['retry_delay'])
continue
# 所有提供商都尝试过了,返回最后一个错误
increment_stats(model, provider['name'], success=False, error=str(last_error))
return jsonify(error_info), response.status_code
except Exception as e:
last_error = str(e)
logger.error(f"Attempt {attempt + 1} failed: {e}")
tried_providers.append(provider['name'])
# 尝试下一个提供商
next_provider, next_model = get_available_provider_for_auto('auto', exclude_providers=tried_providers)
if next_provider and next_provider['name'] not in tried_providers:
provider = next_provider
resolved_model = next_model
request_provider = provider['name']
time.sleep(RETRY_CONFIG['retry_delay'])
continue
# 所有重试都失败
increment_stats(model, request_provider or 'unknown', success=False, error=str(last_error))
return jsonify({
"error": {
"message": f"All providers failed. Last error: {last_error}",
"type": "api_error"
}
}), 503
except Exception as e:
logger.error(f"Unexpected error: {e}")
increment_stats(request_model or 'unknown', request_provider or 'unknown', success=False, error=str(e))
return jsonify({
"error": {
"message": str(e),
"type": "internal_error"
}
}), 500
@app.route('/v1/embeddings', methods=['POST'])
def embeddings():
"""嵌入API简单转发"""
refresh_config()
try:
data = request.get_json()
model = data.get('model', 'text-embedding-ada-002')
# 使用第一个可用提供商
if _cached_providers:
provider = _cached_providers[0]
url = f"{provider['base_url'].rstrip('/')}/embeddings"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {provider['api_key']}"
}
response = requests.post(url, headers=headers, json=data, timeout=60)
return jsonify(response.json()), response.status_code
else:
return jsonify({"error": "No providers available"}), 503
except Exception as e:
return jsonify({"error": str(e)}), 500
@app.route('/health', methods=['GET'])
def health():
"""健康检查"""
available_count = sum(1 for s in provider_status.values() if s['available'])
total_count = len(provider_status)
return jsonify({
"status": "healthy" if available_count > 0 else "degraded",
"providers": {
"available": available_count,
"total": total_count,
},
"timestamp": datetime.now().isoformat()
})
@app.route('/status', methods=['GET'])
def status():
"""详细状态"""
refresh_config()
providers_detail = []
for provider in _cached_providers:
status_info = provider_status.get(provider['name'], {})
providers_detail.append({
"name": provider['name'],
"priority": provider['priority'],
"enabled": provider['enabled'],
"available": status_info.get('available', True),
"error_count": status_info.get('error_count', 0),
"last_error": status_info.get('last_error'),
"models": provider['models'],
})
return jsonify({
"version": "1.0.0",
"uptime": time.time(),
"providers": providers_detail,
"model_aliases": _cached_aliases,
})
# 兼容 OpenAI SDK 的其他端点
@app.route('/v1/engines', methods=['GET'])
def list_engines():
"""兼容旧版 engines 端点"""
return list_models()
@app.route('/v1/engines/<model>/completions', methods=['POST'])
def engine_completions(model):
"""兼容旧版 completions 端点"""
data = request.get_json()
data['model'] = model
return chat_completions()
if __name__ == '__main__':
refresh_config()
print("=" * 60)
print("大模型API中转系统")
print("=" * 60)
print(f"访问地址: http://localhost:{SERVER_CONFIG['port']}")
print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions")
print("=" * 60)
print("上游提供商:")
for p in sorted(_cached_providers, key=lambda x: x['priority']):
print(f" [{p['priority']}] {p['name']}: {p['base_url']}")
print(f" 模型: {', '.join(p['models'])}")
print("=" * 60)
print("支持的模型别名:")
for alias, target in _cached_aliases.items():
print(f" {alias} -> {target}")
print("=" * 60)
app.run(
host=SERVER_CONFIG['host'],
port=SERVER_CONFIG['port'],
debug=SERVER_CONFIG['debug']
)