- 任何非200响应(400、429、500等)都会触发切换 - 增加 exclude_providers 参数排除已尝试的提供商 - 避免重复尝试失败的提供商 - 添加切换日志便于调试
625 lines
22 KiB
Python
625 lines
22 KiB
Python
"""
|
||
大模型API中转系统
|
||
兼容OpenAI API格式,支持多上游提供商优先级调度
|
||
"""
|
||
|
||
from flask import Flask, request, jsonify, Response, stream_with_context
|
||
from flask_cors import CORS
|
||
import requests
|
||
import json
|
||
import time
|
||
import logging
|
||
from datetime import datetime, date
|
||
from pathlib import Path
|
||
import sys
|
||
import threading
|
||
|
||
# 添加配置路径
|
||
sys.path.insert(0, str(Path(__file__).parent))
|
||
from config.settings import (
|
||
get_providers, get_model_aliases, get_auto_profiles, SERVER_CONFIG,
|
||
LOG_CONFIG, RETRY_CONFIG
|
||
)
|
||
|
||
# 数据目录和统计文件
|
||
DATA_DIR = Path(__file__).parent / 'data'
|
||
DATA_DIR.mkdir(exist_ok=True)
|
||
STATS_FILE = DATA_DIR / 'stats.json'
|
||
|
||
# 统计锁(避免并发写入冲突)
|
||
stats_lock = threading.Lock()
|
||
|
||
def load_stats():
|
||
"""加载统计数据"""
|
||
if STATS_FILE.exists():
|
||
try:
|
||
return json.loads(STATS_FILE.read_text(encoding='utf-8'))
|
||
except:
|
||
pass
|
||
return {
|
||
'total_requests': 0,
|
||
'total_success': 0,
|
||
'total_errors': 0,
|
||
'total_tokens': 0,
|
||
'requests_today': 0,
|
||
'requests_by_model': {},
|
||
'providers': {},
|
||
'last_updated': None,
|
||
'date': None # 用于判断是否需要重置每日计数
|
||
}
|
||
|
||
def save_stats(stats):
|
||
"""保存统计数据"""
|
||
stats['last_updated'] = datetime.now().isoformat()
|
||
STATS_FILE.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding='utf-8')
|
||
|
||
def increment_stats(model, provider_name, success=False, tokens=0, error=None):
|
||
"""增加统计计数"""
|
||
with stats_lock:
|
||
stats = load_stats()
|
||
today = date.today().isoformat()
|
||
|
||
# 如果是新的一天,重置每日计数
|
||
if stats.get('date') != today:
|
||
stats['date'] = today
|
||
stats['requests_today'] = 0
|
||
|
||
stats['total_requests'] += 1
|
||
stats['requests_today'] += 1
|
||
|
||
# 模型统计
|
||
if model not in stats['requests_by_model']:
|
||
stats['requests_by_model'][model] = {'count': 0, 'success': 0, 'tokens': 0}
|
||
stats['requests_by_model'][model]['count'] += 1
|
||
if success:
|
||
stats['requests_by_model'][model]['success'] += 1
|
||
stats['requests_by_model'][model]['tokens'] += tokens
|
||
|
||
# 提供商统计
|
||
if provider_name not in stats['providers']:
|
||
stats['providers'][provider_name] = {'requests': 0, 'success': 0, 'errors': 0, 'tokens': 0}
|
||
stats['providers'][provider_name]['requests'] += 1
|
||
if success:
|
||
stats['providers'][provider_name]['success'] += 1
|
||
stats['providers'][provider_name]['tokens'] += tokens
|
||
stats['total_success'] += 1
|
||
stats['total_tokens'] += tokens
|
||
else:
|
||
stats['providers'][provider_name]['errors'] += 1
|
||
stats['total_errors'] += 1
|
||
|
||
save_stats(stats)
|
||
|
||
app = Flask(__name__)
|
||
CORS(app)
|
||
|
||
# 配置缓存时间(秒)
|
||
CONFIG_CACHE_TTL = 5
|
||
_last_config_load = 0
|
||
_cached_providers = []
|
||
_cached_aliases = {}
|
||
_cached_auto_profiles = {}
|
||
|
||
def refresh_config():
|
||
"""动态刷新配置(支持后台管理修改)"""
|
||
global _last_config_load, _cached_providers, _cached_aliases, _cached_auto_profiles, provider_status
|
||
|
||
current_time = time.time()
|
||
if current_time - _last_config_load > CONFIG_CACHE_TTL:
|
||
_cached_providers = get_providers()
|
||
_cached_aliases = get_model_aliases()
|
||
_cached_auto_profiles = get_auto_profiles()
|
||
_last_config_load = current_time
|
||
|
||
# 更新提供商状态缓存(新增的提供商)
|
||
for provider in _cached_providers:
|
||
if provider['name'] not in provider_status:
|
||
provider_status[provider['name']] = {
|
||
'available': True,
|
||
'last_check': None,
|
||
'error_count': 0,
|
||
'last_error': None,
|
||
}
|
||
|
||
# 配置日志
|
||
log_dir = Path(__file__).parent / LOG_CONFIG['log_dir']
|
||
log_dir.mkdir(exist_ok=True)
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(log_dir / 'proxy.log', encoding='utf-8'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 提供商状态缓存
|
||
provider_status = {}
|
||
|
||
# 初始化
|
||
refresh_config()
|
||
|
||
|
||
def get_provider_for_model(model_name):
|
||
"""根据模型名获取提供商"""
|
||
refresh_config()
|
||
|
||
# 解析别名
|
||
resolved_model = _cached_aliases.get(model_name, model_name)
|
||
|
||
# auto模式:按优先级选择可用提供商
|
||
if resolved_model == 'auto' or resolved_model.startswith('auto-'):
|
||
return get_available_provider_for_auto(resolved_model)
|
||
|
||
# 查找支持该模型的提供商
|
||
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
|
||
|
||
for provider in sorted_providers:
|
||
if not provider['enabled']:
|
||
continue
|
||
if not provider_status.get(provider['name'], {}).get('available', True):
|
||
continue
|
||
if resolved_model in provider['models']:
|
||
return provider, resolved_model
|
||
|
||
# 如果没找到精确匹配,尝试模糊匹配
|
||
for provider in sorted_providers:
|
||
if not provider['enabled']:
|
||
continue
|
||
if not provider_status.get(provider['name'], {}).get('available', True):
|
||
continue
|
||
for m in provider['models']:
|
||
if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower():
|
||
return provider, m
|
||
|
||
return None, None
|
||
|
||
|
||
def get_available_provider_for_auto(auto_name='auto', exclude_providers=None):
|
||
"""获取auto模式下的可用提供商(支持自定义auto配置)
|
||
|
||
Args:
|
||
auto_name: auto配置名称
|
||
exclude_providers: 要排除的提供商名称列表(已尝试过的)
|
||
"""
|
||
refresh_config()
|
||
|
||
exclude_providers = exclude_providers or []
|
||
|
||
# 获取auto配置
|
||
profile = _cached_auto_profiles.get(auto_name, _cached_auto_profiles.get('auto', {}))
|
||
|
||
# 获取允许的提供商
|
||
allowed_providers = profile.get('providers', ['*'])
|
||
|
||
# 筛选提供商
|
||
sorted_providers = sorted(_cached_providers, key=lambda x: x['priority'])
|
||
candidates = []
|
||
|
||
for provider in sorted_providers:
|
||
if not provider['enabled']:
|
||
continue
|
||
if provider['name'] in exclude_providers:
|
||
continue
|
||
if not provider_status.get(provider['name'], {}).get('available', True):
|
||
continue
|
||
|
||
# 检查是否在允许列表中
|
||
if '*' in allowed_providers:
|
||
candidates.append(provider)
|
||
elif provider.get('id') in allowed_providers or provider['name'] in allowed_providers:
|
||
candidates.append(provider)
|
||
|
||
# 根据策略选择
|
||
strategy = profile.get('strategy', 'priority')
|
||
|
||
if not candidates:
|
||
# 如果没有候选,返回第一个尝试(让错误信息传递)
|
||
if sorted_providers:
|
||
return sorted_providers[0], sorted_providers[0]['default_model']
|
||
return None, None
|
||
|
||
if strategy == 'priority':
|
||
# 按优先级选择第一个
|
||
return candidates[0], candidates[0]['default_model']
|
||
elif strategy == 'random':
|
||
# 随机选择
|
||
import random
|
||
selected = random.choice(candidates)
|
||
return selected, selected['default_model']
|
||
else:
|
||
# 默认按优先级
|
||
return candidates[0], candidates[0]['default_model']
|
||
|
||
|
||
def mark_provider_error(provider_name, error):
|
||
"""标记提供商错误"""
|
||
if provider_name in provider_status:
|
||
provider_status[provider_name]['error_count'] += 1
|
||
provider_status[provider_name]['last_error'] = str(error)
|
||
provider_status[provider_name]['last_check'] = datetime.now()
|
||
|
||
# 连续错误超过阈值,暂时标记不可用
|
||
if provider_status[provider_name]['error_count'] >= 3:
|
||
provider_status[provider_name]['available'] = False
|
||
logger.warning(f"Provider {provider_name} marked as unavailable due to errors")
|
||
|
||
|
||
def mark_provider_success(provider_name):
|
||
"""标记提供商成功"""
|
||
if provider_name in provider_status:
|
||
provider_status[provider_name]['error_count'] = 0
|
||
provider_status[provider_name]['available'] = True
|
||
provider_status[provider_name]['last_check'] = datetime.now()
|
||
|
||
|
||
def proxy_request(provider, model, request_data, stream=False):
|
||
"""转发请求到上游提供商"""
|
||
url = f"{provider['base_url'].rstrip('/')}/chat/completions"
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {provider['api_key']}"
|
||
}
|
||
|
||
# 构建请求数据
|
||
data = request_data.copy()
|
||
data['model'] = model
|
||
|
||
try:
|
||
if stream:
|
||
# 流式请求
|
||
response = requests.post(
|
||
url,
|
||
headers=headers,
|
||
json=data,
|
||
stream=True,
|
||
timeout=provider.get('timeout', 120)
|
||
)
|
||
return response
|
||
else:
|
||
# 非流式请求
|
||
response = requests.post(
|
||
url,
|
||
headers=headers,
|
||
json=data,
|
||
timeout=provider.get('timeout', 120)
|
||
)
|
||
return response
|
||
|
||
except requests.exceptions.Timeout:
|
||
mark_provider_error(provider['name'], "Timeout")
|
||
raise Exception(f"Provider {provider['name']} timeout")
|
||
except requests.exceptions.ConnectionError:
|
||
mark_provider_error(provider['name'], "Connection error")
|
||
raise Exception(f"Provider {provider['name']} connection error")
|
||
except Exception as e:
|
||
mark_provider_error(provider['name'], str(e))
|
||
raise
|
||
|
||
|
||
def stream_response(response):
|
||
"""流式响应生成器"""
|
||
try:
|
||
for line in response.iter_lines():
|
||
if line:
|
||
yield line + b'\n'
|
||
except Exception as e:
|
||
logger.error(f"Stream error: {e}")
|
||
yield b'data: {"error": "' + str(e).encode() + b'"}\n\n'
|
||
|
||
|
||
# ============ API 路由 ============
|
||
|
||
@app.route('/')
|
||
def index():
|
||
"""首页"""
|
||
return jsonify({
|
||
"name": "LLM Proxy",
|
||
"version": "1.0.0",
|
||
"description": "OpenAI-compatible LLM API Proxy",
|
||
"endpoints": {
|
||
"chat": "/v1/chat/completions",
|
||
"models": "/v1/models",
|
||
"embeddings": "/v1/embeddings",
|
||
"health": "/health",
|
||
"status": "/status"
|
||
},
|
||
"examples": {
|
||
"curl": {
|
||
"description": "curl 命令行",
|
||
"code": "curl -X POST http://localhost:19007/v1/chat/completions -H 'Content-Type: application/json' -d '{\"model\": \"auto\", \"messages\": [{\"role\": \"user\", \"content\": \"你好\"}]}'"
|
||
},
|
||
"python_openai": {
|
||
"description": "Python OpenAI SDK",
|
||
"code": "from openai import OpenAI\nclient = OpenAI(base_url='http://localhost:19007/v1', api_key='any')\nresponse = client.chat.completions.create(model='auto', messages=[{'role': 'user', 'content': '你好'}])\nprint(response.choices[0].message.content)"
|
||
},
|
||
"python_requests": {
|
||
"description": "Python requests",
|
||
"code": "import requests\nresponse = requests.post('http://localhost:19007/v1/chat/completions', json={'model': 'auto', 'messages': [{'role': 'user', 'content': '你好'}]})\nprint(response.json()['choices'][0]['message']['content'])"
|
||
},
|
||
"javascript": {
|
||
"description": "JavaScript fetch",
|
||
"code": "fetch('http://localhost:19007/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({model: 'auto', messages: [{role: 'user', content: '你好'}]})}).then(r => r.json()).then(d => console.log(d.choices[0].message.content))"
|
||
}
|
||
}
|
||
})
|
||
|
||
|
||
@app.route('/v1/models', methods=['GET'])
|
||
def list_models():
|
||
"""列出可用模型"""
|
||
refresh_config()
|
||
|
||
models_list = []
|
||
added_models = set()
|
||
|
||
# 添加所有auto配置
|
||
for profile_name, profile in _cached_auto_profiles.items():
|
||
if profile_name not in added_models:
|
||
models_list.append({
|
||
"id": profile_name,
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": "proxy",
|
||
"description": profile.get('description', 'Auto-select available model')
|
||
})
|
||
added_models.add(profile_name)
|
||
|
||
for provider in _cached_providers:
|
||
if not provider['enabled']:
|
||
continue
|
||
for model in provider['models']:
|
||
if model not in added_models:
|
||
models_list.append({
|
||
"id": model,
|
||
"object": "model",
|
||
"created": int(time.time()),
|
||
"owned_by": provider['name'],
|
||
})
|
||
added_models.add(model)
|
||
|
||
return jsonify({
|
||
"object": "list",
|
||
"data": models_list
|
||
})
|
||
|
||
|
||
@app.route('/v1/chat/completions', methods=['POST'])
|
||
def chat_completions():
|
||
"""聊天完成API"""
|
||
# 用于统计的变量
|
||
request_model = None
|
||
request_provider = None
|
||
request_success = False
|
||
request_tokens = 0
|
||
|
||
try:
|
||
data = request.get_json()
|
||
|
||
if not data:
|
||
increment_stats('unknown', 'unknown', success=False, error='Invalid request body')
|
||
return jsonify({"error": "Invalid request body"}), 400
|
||
|
||
model = data.get('model', 'auto')
|
||
stream = data.get('stream', False)
|
||
|
||
request_model = model
|
||
|
||
# 获取提供商
|
||
provider, resolved_model = get_provider_for_model(model)
|
||
|
||
if not provider:
|
||
increment_stats(model, 'unknown', success=False, error=f'No provider for model: {model}')
|
||
return jsonify({
|
||
"error": {
|
||
"message": f"No available provider for model: {model}",
|
||
"type": "invalid_request_error"
|
||
}
|
||
}), 400
|
||
|
||
request_provider = provider['name']
|
||
|
||
logger.info(f"Request: model={model} -> provider={provider['name']}, resolved_model={resolved_model}, stream={stream}")
|
||
|
||
# 重试逻辑
|
||
last_error = None
|
||
tried_providers = []
|
||
|
||
for attempt in range(RETRY_CONFIG['max_retries']):
|
||
try:
|
||
response = proxy_request(provider, resolved_model, data, stream)
|
||
|
||
if response.status_code == 200:
|
||
mark_provider_success(provider['name'])
|
||
request_success = True
|
||
|
||
if stream:
|
||
# 流式响应 - 统计流式请求
|
||
increment_stats(model, provider['name'], success=True, tokens=0)
|
||
return Response(
|
||
stream_with_context(stream_response(response)),
|
||
content_type='text/event-stream',
|
||
headers={
|
||
'Cache-Control': 'no-cache',
|
||
'Connection': 'keep-alive',
|
||
}
|
||
)
|
||
else:
|
||
# 非流式响应 - 提取token统计
|
||
result = response.json()
|
||
# 尝试提取usage信息
|
||
usage = result.get('usage', {})
|
||
request_tokens = usage.get('total_tokens', 0)
|
||
|
||
increment_stats(model, provider['name'], success=True, tokens=request_tokens)
|
||
return jsonify(result)
|
||
|
||
else:
|
||
# 任何非200响应都尝试下一个提供商
|
||
error_info = response.json() if response.headers.get('content-type', '').startswith('application/json') else {"error": response.text}
|
||
last_error = error_info
|
||
logger.warning(f"Provider {provider['name']} returned {response.status_code}: {error_info}")
|
||
mark_provider_error(provider['name'], f"HTTP {response.status_code}")
|
||
tried_providers.append(provider['name'])
|
||
|
||
# 尝试下一个提供商
|
||
next_provider, next_model = get_available_provider_for_auto('auto', exclude_providers=tried_providers)
|
||
if next_provider and next_provider['name'] not in tried_providers:
|
||
logger.info(f"Switching to next provider: {next_provider['name']}")
|
||
provider = next_provider
|
||
resolved_model = next_model
|
||
request_provider = provider['name']
|
||
time.sleep(RETRY_CONFIG['retry_delay'])
|
||
continue
|
||
|
||
# 所有提供商都尝试过了,返回最后一个错误
|
||
increment_stats(model, provider['name'], success=False, error=str(last_error))
|
||
return jsonify(error_info), response.status_code
|
||
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
logger.error(f"Attempt {attempt + 1} failed: {e}")
|
||
tried_providers.append(provider['name'])
|
||
|
||
# 尝试下一个提供商
|
||
next_provider, next_model = get_available_provider_for_auto('auto', exclude_providers=tried_providers)
|
||
if next_provider and next_provider['name'] not in tried_providers:
|
||
provider = next_provider
|
||
resolved_model = next_model
|
||
request_provider = provider['name']
|
||
time.sleep(RETRY_CONFIG['retry_delay'])
|
||
continue
|
||
|
||
# 所有重试都失败
|
||
increment_stats(model, request_provider or 'unknown', success=False, error=str(last_error))
|
||
return jsonify({
|
||
"error": {
|
||
"message": f"All providers failed. Last error: {last_error}",
|
||
"type": "api_error"
|
||
}
|
||
}), 503
|
||
|
||
except Exception as e:
|
||
logger.error(f"Unexpected error: {e}")
|
||
increment_stats(request_model or 'unknown', request_provider or 'unknown', success=False, error=str(e))
|
||
return jsonify({
|
||
"error": {
|
||
"message": str(e),
|
||
"type": "internal_error"
|
||
}
|
||
}), 500
|
||
|
||
|
||
@app.route('/v1/embeddings', methods=['POST'])
|
||
def embeddings():
|
||
"""嵌入API(简单转发)"""
|
||
refresh_config()
|
||
|
||
try:
|
||
data = request.get_json()
|
||
model = data.get('model', 'text-embedding-ada-002')
|
||
|
||
# 使用第一个可用提供商
|
||
if _cached_providers:
|
||
provider = _cached_providers[0]
|
||
|
||
url = f"{provider['base_url'].rstrip('/')}/embeddings"
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {provider['api_key']}"
|
||
}
|
||
|
||
response = requests.post(url, headers=headers, json=data, timeout=60)
|
||
return jsonify(response.json()), response.status_code
|
||
else:
|
||
return jsonify({"error": "No providers available"}), 503
|
||
|
||
except Exception as e:
|
||
return jsonify({"error": str(e)}), 500
|
||
|
||
|
||
@app.route('/health', methods=['GET'])
|
||
def health():
|
||
"""健康检查"""
|
||
available_count = sum(1 for s in provider_status.values() if s['available'])
|
||
total_count = len(provider_status)
|
||
|
||
return jsonify({
|
||
"status": "healthy" if available_count > 0 else "degraded",
|
||
"providers": {
|
||
"available": available_count,
|
||
"total": total_count,
|
||
},
|
||
"timestamp": datetime.now().isoformat()
|
||
})
|
||
|
||
|
||
@app.route('/status', methods=['GET'])
|
||
def status():
|
||
"""详细状态"""
|
||
refresh_config()
|
||
|
||
providers_detail = []
|
||
|
||
for provider in _cached_providers:
|
||
status_info = provider_status.get(provider['name'], {})
|
||
providers_detail.append({
|
||
"name": provider['name'],
|
||
"priority": provider['priority'],
|
||
"enabled": provider['enabled'],
|
||
"available": status_info.get('available', True),
|
||
"error_count": status_info.get('error_count', 0),
|
||
"last_error": status_info.get('last_error'),
|
||
"models": provider['models'],
|
||
})
|
||
|
||
return jsonify({
|
||
"version": "1.0.0",
|
||
"uptime": time.time(),
|
||
"providers": providers_detail,
|
||
"model_aliases": _cached_aliases,
|
||
})
|
||
|
||
|
||
# 兼容 OpenAI SDK 的其他端点
|
||
@app.route('/v1/engines', methods=['GET'])
|
||
def list_engines():
|
||
"""兼容旧版 engines 端点"""
|
||
return list_models()
|
||
|
||
|
||
@app.route('/v1/engines/<model>/completions', methods=['POST'])
|
||
def engine_completions(model):
|
||
"""兼容旧版 completions 端点"""
|
||
data = request.get_json()
|
||
data['model'] = model
|
||
return chat_completions()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
refresh_config()
|
||
|
||
print("=" * 60)
|
||
print("大模型API中转系统")
|
||
print("=" * 60)
|
||
print(f"访问地址: http://localhost:{SERVER_CONFIG['port']}")
|
||
print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions")
|
||
print("=" * 60)
|
||
print("上游提供商:")
|
||
for p in sorted(_cached_providers, key=lambda x: x['priority']):
|
||
print(f" [{p['priority']}] {p['name']}: {p['base_url']}")
|
||
print(f" 模型: {', '.join(p['models'])}")
|
||
print("=" * 60)
|
||
print("支持的模型别名:")
|
||
for alias, target in _cached_aliases.items():
|
||
print(f" {alias} -> {target}")
|
||
print("=" * 60)
|
||
|
||
app.run(
|
||
host=SERVER_CONFIG['host'],
|
||
port=SERVER_CONFIG['port'],
|
||
debug=SERVER_CONFIG['debug']
|
||
) |