""" 大模型API中转系统 兼容OpenAI API格式,支持多上游提供商优先级调度 """ from flask import Flask, request, jsonify, Response, stream_with_context from flask_cors import CORS import requests import json import time import logging from datetime import datetime from pathlib import Path import sys # 添加配置路径 sys.path.insert(0, str(Path(__file__).parent)) from config.settings import ( get_providers, get_model_aliases, get_auto_profiles, SERVER_CONFIG, LOG_CONFIG, RETRY_CONFIG ) app = Flask(__name__) CORS(app) # 配置缓存时间(秒) CONFIG_CACHE_TTL = 5 _last_config_load = 0 _cached_providers = [] _cached_aliases = {} _cached_auto_profiles = {} def refresh_config(): """动态刷新配置(支持后台管理修改)""" global _last_config_load, _cached_providers, _cached_aliases, _cached_auto_profiles, provider_status current_time = time.time() if current_time - _last_config_load > CONFIG_CACHE_TTL: _cached_providers = get_providers() _cached_aliases = get_model_aliases() _cached_auto_profiles = get_auto_profiles() _last_config_load = current_time # 更新提供商状态缓存(新增的提供商) for provider in _cached_providers: if provider['name'] not in provider_status: provider_status[provider['name']] = { 'available': True, 'last_check': None, 'error_count': 0, 'last_error': None, } # 配置日志 log_dir = Path(__file__).parent / LOG_CONFIG['log_dir'] log_dir.mkdir(exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / 'proxy.log', encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 提供商状态缓存 provider_status = {} # 初始化 refresh_config() def get_provider_for_model(model_name): """根据模型名获取提供商""" refresh_config() # 解析别名 resolved_model = _cached_aliases.get(model_name, model_name) # auto模式:按优先级选择可用提供商 if resolved_model == 'auto' or resolved_model.startswith('auto-'): return get_available_provider_for_auto(resolved_model) # 查找支持该模型的提供商 sorted_providers = sorted(_cached_providers, key=lambda x: x['priority']) for provider in sorted_providers: if not provider['enabled']: continue if not provider_status.get(provider['name'], {}).get('available', True): continue if resolved_model in provider['models']: return provider, resolved_model # 如果没找到精确匹配,尝试模糊匹配 for provider in sorted_providers: if not provider['enabled']: continue if not provider_status.get(provider['name'], {}).get('available', True): continue for m in provider['models']: if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower(): return provider, m return None, None def get_available_provider_for_auto(auto_name='auto'): """获取auto模式下的可用提供商(支持自定义auto配置)""" refresh_config() # 获取auto配置 profile = _cached_auto_profiles.get(auto_name, _cached_auto_profiles.get('auto', {})) # 获取允许的提供商 allowed_providers = profile.get('providers', ['*']) # 筛选提供商 sorted_providers = sorted(_cached_providers, key=lambda x: x['priority']) candidates = [] for provider in sorted_providers: if not provider['enabled']: continue if not provider_status.get(provider['name'], {}).get('available', True): continue # 检查是否在允许列表中 if '*' in allowed_providers: candidates.append(provider) elif provider.get('id') in allowed_providers or provider['name'] in allowed_providers: candidates.append(provider) # 根据策略选择 strategy = profile.get('strategy', 'priority') if not candidates: # 如果没有候选,返回第一个尝试(让错误信息传递) if sorted_providers: return sorted_providers[0], sorted_providers[0]['default_model'] return None, None if strategy == 'priority': # 按优先级选择第一个 return candidates[0], candidates[0]['default_model'] elif strategy == 'random': # 随机选择 import random selected = random.choice(candidates) return selected, selected['default_model'] else: # 默认按优先级 return candidates[0], candidates[0]['default_model'] def mark_provider_error(provider_name, error): """标记提供商错误""" if provider_name in provider_status: provider_status[provider_name]['error_count'] += 1 provider_status[provider_name]['last_error'] = str(error) provider_status[provider_name]['last_check'] = datetime.now() # 连续错误超过阈值,暂时标记不可用 if provider_status[provider_name]['error_count'] >= 3: provider_status[provider_name]['available'] = False logger.warning(f"Provider {provider_name} marked as unavailable due to errors") def mark_provider_success(provider_name): """标记提供商成功""" if provider_name in provider_status: provider_status[provider_name]['error_count'] = 0 provider_status[provider_name]['available'] = True provider_status[provider_name]['last_check'] = datetime.now() def proxy_request(provider, model, request_data, stream=False): """转发请求到上游提供商""" url = f"{provider['base_url'].rstrip('/')}/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {provider['api_key']}" } # 构建请求数据 data = request_data.copy() data['model'] = model try: if stream: # 流式请求 response = requests.post( url, headers=headers, json=data, stream=True, timeout=provider.get('timeout', 120) ) return response else: # 非流式请求 response = requests.post( url, headers=headers, json=data, timeout=provider.get('timeout', 120) ) return response except requests.exceptions.Timeout: mark_provider_error(provider['name'], "Timeout") raise Exception(f"Provider {provider['name']} timeout") except requests.exceptions.ConnectionError: mark_provider_error(provider['name'], "Connection error") raise Exception(f"Provider {provider['name']} connection error") except Exception as e: mark_provider_error(provider['name'], str(e)) raise def stream_response(response): """流式响应生成器""" try: for line in response.iter_lines(): if line: yield line + b'\n' except Exception as e: logger.error(f"Stream error: {e}") yield b'data: {"error": "' + str(e).encode() + b'"}\n\n' # ============ API 路由 ============ @app.route('/') def index(): """首页""" return jsonify({ "name": "LLM Proxy", "version": "1.0.0", "description": "OpenAI-compatible LLM API Proxy", "endpoints": { "chat": "/v1/chat/completions", "models": "/v1/models", "health": "/health", "status": "/status" } }) @app.route('/v1/models', methods=['GET']) def list_models(): """列出可用模型""" refresh_config() models_list = [] added_models = set() # 添加所有auto配置 for profile_name, profile in _cached_auto_profiles.items(): if profile_name not in added_models: models_list.append({ "id": profile_name, "object": "model", "created": int(time.time()), "owned_by": "proxy", "description": profile.get('description', 'Auto-select available model') }) added_models.add(profile_name) for provider in _cached_providers: if not provider['enabled']: continue for model in provider['models']: if model not in added_models: models_list.append({ "id": model, "object": "model", "created": int(time.time()), "owned_by": provider['name'], }) added_models.add(model) return jsonify({ "object": "list", "data": models_list }) @app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): """聊天完成API""" try: data = request.get_json() if not data: return jsonify({"error": "Invalid request body"}), 400 model = data.get('model', 'auto') stream = data.get('stream', False) # 获取提供商 provider, resolved_model = get_provider_for_model(model) if not provider: return jsonify({ "error": { "message": f"No available provider for model: {model}", "type": "invalid_request_error" } }), 400 logger.info(f"Request: model={model} -> provider={provider['name']}, resolved_model={resolved_model}, stream={stream}") # 重试逻辑 last_error = None tried_providers = [] for attempt in range(RETRY_CONFIG['max_retries']): try: response = proxy_request(provider, resolved_model, data, stream) if response.status_code == 200: mark_provider_success(provider['name']) if stream: # 流式响应 return Response( stream_with_context(stream_response(response)), content_type='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', } ) else: # 非流式响应 return jsonify(response.json()) elif response.status_code == 429: # 速率限制,尝试下一个提供商 mark_provider_error(provider['name'], "Rate limit") tried_providers.append(provider['name']) # 尝试下一个提供商 next_provider, next_model = get_available_provider_for_auto('auto') if next_provider and next_provider['name'] not in tried_providers: provider = next_provider resolved_model = next_model continue return jsonify(response.json()), response.status_code else: last_error = response.json() if response.headers.get('content-type', '').startswith('application/json') else {"error": response.text} return jsonify(last_error), response.status_code except Exception as e: last_error = str(e) logger.error(f"Attempt {attempt + 1} failed: {e}") tried_providers.append(provider['name']) # 尝试下一个提供商 next_provider, next_model = get_available_provider_for_auto('auto') if next_provider and next_provider['name'] not in tried_providers: provider = next_provider resolved_model = next_model time.sleep(RETRY_CONFIG['retry_delay']) continue # 所有重试都失败 return jsonify({ "error": { "message": f"All providers failed. Last error: {last_error}", "type": "api_error" } }), 503 except Exception as e: logger.error(f"Unexpected error: {e}") return jsonify({ "error": { "message": str(e), "type": "internal_error" } }), 500 @app.route('/v1/embeddings', methods=['POST']) def embeddings(): """嵌入API(简单转发)""" refresh_config() try: data = request.get_json() model = data.get('model', 'text-embedding-ada-002') # 使用第一个可用提供商 if _cached_providers: provider = _cached_providers[0] url = f"{provider['base_url'].rstrip('/')}/embeddings" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {provider['api_key']}" } response = requests.post(url, headers=headers, json=data, timeout=60) return jsonify(response.json()), response.status_code else: return jsonify({"error": "No providers available"}), 503 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/health', methods=['GET']) def health(): """健康检查""" available_count = sum(1 for s in provider_status.values() if s['available']) total_count = len(provider_status) return jsonify({ "status": "healthy" if available_count > 0 else "degraded", "providers": { "available": available_count, "total": total_count, }, "timestamp": datetime.now().isoformat() }) @app.route('/status', methods=['GET']) def status(): """详细状态""" refresh_config() providers_detail = [] for provider in _cached_providers: status_info = provider_status.get(provider['name'], {}) providers_detail.append({ "name": provider['name'], "priority": provider['priority'], "enabled": provider['enabled'], "available": status_info.get('available', True), "error_count": status_info.get('error_count', 0), "last_error": status_info.get('last_error'), "models": provider['models'], }) return jsonify({ "version": "1.0.0", "uptime": time.time(), "providers": providers_detail, "model_aliases": _cached_aliases, }) # 兼容 OpenAI SDK 的其他端点 @app.route('/v1/engines', methods=['GET']) def list_engines(): """兼容旧版 engines 端点""" return list_models() @app.route('/v1/engines//completions', methods=['POST']) def engine_completions(model): """兼容旧版 completions 端点""" data = request.get_json() data['model'] = model return chat_completions() if __name__ == '__main__': refresh_config() print("=" * 60) print("大模型API中转系统") print("=" * 60) print(f"访问地址: http://localhost:{SERVER_CONFIG['port']}") print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions") print("=" * 60) print("上游提供商:") for p in sorted(_cached_providers, key=lambda x: x['priority']): print(f" [{p['priority']}] {p['name']}: {p['base_url']}") print(f" 模型: {', '.join(p['models'])}") print("=" * 60) print("支持的模型别名:") for alias, target in _cached_aliases.items(): print(f" {alias} -> {target}") print("=" * 60) app.run( host=SERVER_CONFIG['host'], port=SERVER_CONFIG['port'], debug=SERVER_CONFIG['debug'] )