""" 大模型API中转系统 兼容OpenAI API格式,支持多上游提供商优先级调度 """ from flask import Flask, request, jsonify, Response, stream_with_context from flask_cors import CORS import requests import json import time import logging from datetime import datetime, date from pathlib import Path import sys import threading # 添加配置路径 sys.path.insert(0, str(Path(__file__).parent)) from config.settings import ( get_providers, get_model_aliases, get_auto_profiles, SERVER_CONFIG, LOG_CONFIG, RETRY_CONFIG ) # 数据目录和统计文件 DATA_DIR = Path(__file__).parent / 'data' DATA_DIR.mkdir(exist_ok=True) STATS_FILE = DATA_DIR / 'stats.json' # 统计锁(避免并发写入冲突) stats_lock = threading.Lock() def load_stats(): """加载统计数据""" if STATS_FILE.exists(): try: return json.loads(STATS_FILE.read_text(encoding='utf-8')) except: pass return { 'total_requests': 0, 'total_success': 0, 'total_errors': 0, 'total_tokens': 0, 'requests_today': 0, 'requests_by_model': {}, 'providers': {}, 'last_updated': None, 'date': None # 用于判断是否需要重置每日计数 } def save_stats(stats): """保存统计数据""" stats['last_updated'] = datetime.now().isoformat() STATS_FILE.write_text(json.dumps(stats, ensure_ascii=False, indent=2), encoding='utf-8') def increment_stats(model, provider_name, success=False, tokens=0, error=None): """增加统计计数""" with stats_lock: stats = load_stats() today = date.today().isoformat() # 如果是新的一天,重置每日计数 if stats.get('date') != today: stats['date'] = today stats['requests_today'] = 0 stats['total_requests'] += 1 stats['requests_today'] += 1 # 模型统计 if model not in stats['requests_by_model']: stats['requests_by_model'][model] = {'count': 0, 'success': 0, 'tokens': 0} stats['requests_by_model'][model]['count'] += 1 if success: stats['requests_by_model'][model]['success'] += 1 stats['requests_by_model'][model]['tokens'] += tokens # 提供商统计 if provider_name not in stats['providers']: stats['providers'][provider_name] = {'requests': 0, 'success': 0, 'errors': 0, 'tokens': 0} stats['providers'][provider_name]['requests'] += 1 if success: stats['providers'][provider_name]['success'] += 1 stats['providers'][provider_name]['tokens'] += tokens stats['total_success'] += 1 stats['total_tokens'] += tokens else: stats['providers'][provider_name]['errors'] += 1 stats['total_errors'] += 1 save_stats(stats) app = Flask(__name__) CORS(app) # 配置缓存时间(秒) CONFIG_CACHE_TTL = 5 _last_config_load = 0 _cached_providers = [] _cached_aliases = {} _cached_auto_profiles = {} def refresh_config(): """动态刷新配置(支持后台管理修改)""" global _last_config_load, _cached_providers, _cached_aliases, _cached_auto_profiles, provider_status current_time = time.time() if current_time - _last_config_load > CONFIG_CACHE_TTL: _cached_providers = get_providers() _cached_aliases = get_model_aliases() _cached_auto_profiles = get_auto_profiles() _last_config_load = current_time # 更新提供商状态缓存(新增的提供商) for provider in _cached_providers: if provider['name'] not in provider_status: provider_status[provider['name']] = { 'available': True, 'last_check': None, 'error_count': 0, 'last_error': None, } # 配置日志 log_dir = Path(__file__).parent / LOG_CONFIG['log_dir'] log_dir.mkdir(exist_ok=True) logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(log_dir / 'proxy.log', encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) # 提供商状态缓存 provider_status = {} # 初始化 refresh_config() def get_provider_for_model(model_name): """根据模型名获取提供商""" refresh_config() # 解析别名 resolved_model = _cached_aliases.get(model_name, model_name) # auto模式:按优先级选择可用提供商 if resolved_model == 'auto' or resolved_model.startswith('auto-'): return get_available_provider_for_auto(resolved_model) # 查找支持该模型的提供商 sorted_providers = sorted(_cached_providers, key=lambda x: x['priority']) for provider in sorted_providers: if not provider['enabled']: continue if not provider_status.get(provider['name'], {}).get('available', True): continue if resolved_model in provider['models']: return provider, resolved_model # 如果没找到精确匹配,尝试模糊匹配 for provider in sorted_providers: if not provider['enabled']: continue if not provider_status.get(provider['name'], {}).get('available', True): continue for m in provider['models']: if resolved_model.lower() in m.lower() or m.lower() in resolved_model.lower(): return provider, m return None, None def get_available_provider_for_auto(auto_name='auto'): """获取auto模式下的可用提供商(支持自定义auto配置)""" refresh_config() # 获取auto配置 profile = _cached_auto_profiles.get(auto_name, _cached_auto_profiles.get('auto', {})) # 获取允许的提供商 allowed_providers = profile.get('providers', ['*']) # 筛选提供商 sorted_providers = sorted(_cached_providers, key=lambda x: x['priority']) candidates = [] for provider in sorted_providers: if not provider['enabled']: continue if not provider_status.get(provider['name'], {}).get('available', True): continue # 检查是否在允许列表中 if '*' in allowed_providers: candidates.append(provider) elif provider.get('id') in allowed_providers or provider['name'] in allowed_providers: candidates.append(provider) # 根据策略选择 strategy = profile.get('strategy', 'priority') if not candidates: # 如果没有候选,返回第一个尝试(让错误信息传递) if sorted_providers: return sorted_providers[0], sorted_providers[0]['default_model'] return None, None if strategy == 'priority': # 按优先级选择第一个 return candidates[0], candidates[0]['default_model'] elif strategy == 'random': # 随机选择 import random selected = random.choice(candidates) return selected, selected['default_model'] else: # 默认按优先级 return candidates[0], candidates[0]['default_model'] def mark_provider_error(provider_name, error): """标记提供商错误""" if provider_name in provider_status: provider_status[provider_name]['error_count'] += 1 provider_status[provider_name]['last_error'] = str(error) provider_status[provider_name]['last_check'] = datetime.now() # 连续错误超过阈值,暂时标记不可用 if provider_status[provider_name]['error_count'] >= 3: provider_status[provider_name]['available'] = False logger.warning(f"Provider {provider_name} marked as unavailable due to errors") def mark_provider_success(provider_name): """标记提供商成功""" if provider_name in provider_status: provider_status[provider_name]['error_count'] = 0 provider_status[provider_name]['available'] = True provider_status[provider_name]['last_check'] = datetime.now() def proxy_request(provider, model, request_data, stream=False): """转发请求到上游提供商""" url = f"{provider['base_url'].rstrip('/')}/chat/completions" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {provider['api_key']}" } # 构建请求数据 data = request_data.copy() data['model'] = model try: if stream: # 流式请求 response = requests.post( url, headers=headers, json=data, stream=True, timeout=provider.get('timeout', 120) ) return response else: # 非流式请求 response = requests.post( url, headers=headers, json=data, timeout=provider.get('timeout', 120) ) return response except requests.exceptions.Timeout: mark_provider_error(provider['name'], "Timeout") raise Exception(f"Provider {provider['name']} timeout") except requests.exceptions.ConnectionError: mark_provider_error(provider['name'], "Connection error") raise Exception(f"Provider {provider['name']} connection error") except Exception as e: mark_provider_error(provider['name'], str(e)) raise def stream_response(response): """流式响应生成器""" try: for line in response.iter_lines(): if line: yield line + b'\n' except Exception as e: logger.error(f"Stream error: {e}") yield b'data: {"error": "' + str(e).encode() + b'"}\n\n' # ============ API 路由 ============ @app.route('/') def index(): """首页""" return jsonify({ "name": "LLM Proxy", "version": "1.0.0", "description": "OpenAI-compatible LLM API Proxy", "endpoints": { "chat": "/v1/chat/completions", "models": "/v1/models", "embeddings": "/v1/embeddings", "health": "/health", "status": "/status" }, "examples": { "curl": { "description": "curl 命令行", "code": "curl -X POST http://localhost:19007/v1/chat/completions -H 'Content-Type: application/json' -d '{\"model\": \"auto\", \"messages\": [{\"role\": \"user\", \"content\": \"你好\"}]}'" }, "python_openai": { "description": "Python OpenAI SDK", "code": "from openai import OpenAI\nclient = OpenAI(base_url='http://localhost:19007/v1', api_key='any')\nresponse = client.chat.completions.create(model='auto', messages=[{'role': 'user', 'content': '你好'}])\nprint(response.choices[0].message.content)" }, "python_requests": { "description": "Python requests", "code": "import requests\nresponse = requests.post('http://localhost:19007/v1/chat/completions', json={'model': 'auto', 'messages': [{'role': 'user', 'content': '你好'}]})\nprint(response.json()['choices'][0]['message']['content'])" }, "javascript": { "description": "JavaScript fetch", "code": "fetch('http://localhost:19007/v1/chat/completions', {method: 'POST', headers: {'Content-Type': 'application/json'}, body: JSON.stringify({model: 'auto', messages: [{role: 'user', content: '你好'}]})}).then(r => r.json()).then(d => console.log(d.choices[0].message.content))" } } }) @app.route('/v1/models', methods=['GET']) def list_models(): """列出可用模型""" refresh_config() models_list = [] added_models = set() # 添加所有auto配置 for profile_name, profile in _cached_auto_profiles.items(): if profile_name not in added_models: models_list.append({ "id": profile_name, "object": "model", "created": int(time.time()), "owned_by": "proxy", "description": profile.get('description', 'Auto-select available model') }) added_models.add(profile_name) for provider in _cached_providers: if not provider['enabled']: continue for model in provider['models']: if model not in added_models: models_list.append({ "id": model, "object": "model", "created": int(time.time()), "owned_by": provider['name'], }) added_models.add(model) return jsonify({ "object": "list", "data": models_list }) @app.route('/v1/chat/completions', methods=['POST']) def chat_completions(): """聊天完成API""" # 用于统计的变量 request_model = None request_provider = None request_success = False request_tokens = 0 try: data = request.get_json() if not data: increment_stats('unknown', 'unknown', success=False, error='Invalid request body') return jsonify({"error": "Invalid request body"}), 400 model = data.get('model', 'auto') stream = data.get('stream', False) request_model = model # 获取提供商 provider, resolved_model = get_provider_for_model(model) if not provider: increment_stats(model, 'unknown', success=False, error=f'No provider for model: {model}') return jsonify({ "error": { "message": f"No available provider for model: {model}", "type": "invalid_request_error" } }), 400 request_provider = provider['name'] logger.info(f"Request: model={model} -> provider={provider['name']}, resolved_model={resolved_model}, stream={stream}") # 重试逻辑 last_error = None tried_providers = [] for attempt in range(RETRY_CONFIG['max_retries']): try: response = proxy_request(provider, resolved_model, data, stream) if response.status_code == 200: mark_provider_success(provider['name']) request_success = True if stream: # 流式响应 - 统计流式请求 increment_stats(model, provider['name'], success=True, tokens=0) return Response( stream_with_context(stream_response(response)), content_type='text/event-stream', headers={ 'Cache-Control': 'no-cache', 'Connection': 'keep-alive', } ) else: # 非流式响应 - 提取token统计 result = response.json() # 尝试提取usage信息 usage = result.get('usage', {}) request_tokens = usage.get('total_tokens', 0) increment_stats(model, provider['name'], success=True, tokens=request_tokens) return jsonify(result) elif response.status_code == 429: # 速率限制,尝试下一个提供商 mark_provider_error(provider['name'], "Rate limit") tried_providers.append(provider['name']) # 尝试下一个提供商 next_provider, next_model = get_available_provider_for_auto('auto') if next_provider and next_provider['name'] not in tried_providers: provider = next_provider resolved_model = next_model request_provider = provider['name'] continue increment_stats(model, provider['name'], success=False, error='Rate limit') return jsonify(response.json()), response.status_code else: last_error = response.json() if response.headers.get('content-type', '').startswith('application/json') else {"error": response.text} increment_stats(model, provider['name'], success=False, error=str(last_error)) return jsonify(last_error), response.status_code except Exception as e: last_error = str(e) logger.error(f"Attempt {attempt + 1} failed: {e}") tried_providers.append(provider['name']) # 尝试下一个提供商 next_provider, next_model = get_available_provider_for_auto('auto') if next_provider and next_provider['name'] not in tried_providers: provider = next_provider resolved_model = next_model request_provider = provider['name'] time.sleep(RETRY_CONFIG['retry_delay']) continue # 所有重试都失败 increment_stats(model, request_provider or 'unknown', success=False, error=str(last_error)) return jsonify({ "error": { "message": f"All providers failed. Last error: {last_error}", "type": "api_error" } }), 503 except Exception as e: logger.error(f"Unexpected error: {e}") increment_stats(request_model or 'unknown', request_provider or 'unknown', success=False, error=str(e)) return jsonify({ "error": { "message": str(e), "type": "internal_error" } }), 500 @app.route('/v1/embeddings', methods=['POST']) def embeddings(): """嵌入API(简单转发)""" refresh_config() try: data = request.get_json() model = data.get('model', 'text-embedding-ada-002') # 使用第一个可用提供商 if _cached_providers: provider = _cached_providers[0] url = f"{provider['base_url'].rstrip('/')}/embeddings" headers = { "Content-Type": "application/json", "Authorization": f"Bearer {provider['api_key']}" } response = requests.post(url, headers=headers, json=data, timeout=60) return jsonify(response.json()), response.status_code else: return jsonify({"error": "No providers available"}), 503 except Exception as e: return jsonify({"error": str(e)}), 500 @app.route('/health', methods=['GET']) def health(): """健康检查""" available_count = sum(1 for s in provider_status.values() if s['available']) total_count = len(provider_status) return jsonify({ "status": "healthy" if available_count > 0 else "degraded", "providers": { "available": available_count, "total": total_count, }, "timestamp": datetime.now().isoformat() }) @app.route('/status', methods=['GET']) def status(): """详细状态""" refresh_config() providers_detail = [] for provider in _cached_providers: status_info = provider_status.get(provider['name'], {}) providers_detail.append({ "name": provider['name'], "priority": provider['priority'], "enabled": provider['enabled'], "available": status_info.get('available', True), "error_count": status_info.get('error_count', 0), "last_error": status_info.get('last_error'), "models": provider['models'], }) return jsonify({ "version": "1.0.0", "uptime": time.time(), "providers": providers_detail, "model_aliases": _cached_aliases, }) # 兼容 OpenAI SDK 的其他端点 @app.route('/v1/engines', methods=['GET']) def list_engines(): """兼容旧版 engines 端点""" return list_models() @app.route('/v1/engines//completions', methods=['POST']) def engine_completions(model): """兼容旧版 completions 端点""" data = request.get_json() data['model'] = model return chat_completions() if __name__ == '__main__': refresh_config() print("=" * 60) print("大模型API中转系统") print("=" * 60) print(f"访问地址: http://localhost:{SERVER_CONFIG['port']}") print(f"API端点: http://localhost:{SERVER_CONFIG['port']}/v1/chat/completions") print("=" * 60) print("上游提供商:") for p in sorted(_cached_providers, key=lambda x: x['priority']): print(f" [{p['priority']}] {p['name']}: {p['base_url']}") print(f" 模型: {', '.join(p['models'])}") print("=" * 60) print("支持的模型别名:") for alias, target in _cached_aliases.items(): print(f" {alias} -> {target}") print("=" * 60) app.run( host=SERVER_CONFIG['host'], port=SERVER_CONFIG['port'], debug=SERVER_CONFIG['debug'] )