""" ChatTTS 语音合成服务 专为对话场景设计的 TTS 模型 """ import os import io import uuid import logging import tempfile import wave import numpy as np from typing import Optional from datetime import datetime import torch import torchaudio from fastapi import FastAPI, UploadFile, File, HTTPException, Form from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from pydantic import BaseModel # 配置 PORT = int(os.getenv("PORT", "12002")) SAMPLE_RATE = 24000 # ChatTTS 默认采样率 AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_output") # ChatTTS 模型路径配置 # 如果已下载模型,设置 MODEL_PATH 环境变量 # 例如: MODEL_PATH=/path/to/ChatTTS MODEL_PATH = os.getenv("MODEL_PATH", None) MODEL_SOURCE = os.getenv("MODEL_SOURCE", "auto") # auto / local / download os.makedirs(AUDIO_DIR, exist_ok=True) # 日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="ChatTTS Service", description="对话式语音合成服务", version="1.0.0" ) # CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ChatTTS 模型(延迟加载) chat_model = None class SynthesizeRequest(BaseModel): """合成请求""" text: str temperature: float = 0.3 top_p: float = 0.7 top_k: int = 20 class SynthesizeResponse(BaseModel): """合成响应""" audio_url: str duration: float text: str timestamp: str def load_model(): """加载 ChatTTS 模型""" global chat_model if chat_model is None: logger.info("Loading ChatTTS model...") try: import ChatTTS chat_model = ChatTTS.Chat() device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # 加载模型方式 if MODEL_PATH and os.path.exists(MODEL_PATH): # 从本地路径加载 logger.info(f"Loading model from: {MODEL_PATH}") chat_model.load( source=MODEL_PATH, compile=True, device=device ) elif MODEL_SOURCE == "local": # 从默认本地缓存加载 logger.info("Loading from local cache...") chat_model.load( compile=True, device=device, source="local" ) else: # 自动下载 logger.info("Auto downloading model...") chat_model.load( compile=True, device=device ) logger.info("ChatTTS model loaded successfully") except Exception as e: logger.error(f"Failed to load ChatTTS: {e}") raise return chat_model def save_audio(audio_tensor: torch.Tensor, filename: str) -> str: """ 保存音频文件 audio_tensor: shape [1, samples] 或 [samples] """ filepath = os.path.join(AUDIO_DIR, filename) # 确保 tensor 正确形状 if audio_tensor.dim() == 2: audio_tensor = audio_tensor.squeeze(0) # 转换为 numpy audio_np = audio_tensor.cpu().numpy() if audio_tensor.is_cuda else audio_tensor.numpy() # 使用 soundfile 保存 import soundfile as sf sf.write(filepath, audio_np, SAMPLE_RATE) return filepath @app.on_event("startup") async def startup(): """启动时预加载模型""" try: load_model() logger.info(f"ChatTTS service ready on port {PORT}") except Exception as e: logger.warning(f"Model load delayed: {e}") @app.get("/") async def root(): """健康检查""" return { "status": "ok", "service": "ChatTTS", "model_loaded": chat_model is not None, "device": "cuda" if torch.cuda.is_available() else "cpu" } @app.get("/health") async def health(): """健康检查""" model_loaded = chat_model is not None return { "status": "ok" if model_loaded else "loading", "gpu": torch.cuda.is_available() } @app.post("/synthesize", response_model=SynthesizeResponse) async def synthesize( text: str = Form(..., description="要合成的文本"), temperature: float = Form(0.3, description="温度参数"), top_p: float = Form(0.7, description="Top-P参数"), top_k: int = Form(20, description="Top-K参数") ): """ 合成语音 ChatTTS 特点: - 支持对话场景,有语气变化 - 支持笑声、叹气等情感 - 中文效果好 """ try: model = load_model() # 生成唯一文件名 filename = f"{uuid.uuid4().hex}.wav" # 合成语音 logger.info(f"Synthesizing: {text[:50]}...") # ChatTTS 基本调用(简化版) # 返回: list of audio tensors result = model.infer(text) # 处理返回结果 if isinstance(result, list): audio_tensor = result[0] elif isinstance(result, tuple): audio_tensor = result[0] else: audio_tensor = result # 转换为 torch tensor(如果是 numpy) import numpy as np if isinstance(audio_tensor, np.ndarray): audio_tensor = torch.from_numpy(audio_tensor).float() # 确保 tensor 正确形状 if audio_tensor.dim() == 1: audio_tensor = audio_tensor.unsqueeze(0) # 保存音频 filepath = save_audio(audio_tensor, filename) # 计算时长 duration = audio_tensor.shape[-1] / SAMPLE_RATE audio_url = f"/audio/{filename}" logger.info(f"Generated audio: {duration:.2f}s") return SynthesizeResponse( audio_url=audio_url, duration=round(duration, 2), text=text, timestamp=datetime.now().isoformat() ) except Exception as e: logger.error(f"Synthesis error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/synthesize/batch") async def synthesize_batch(requests: list[SynthesizeRequest]): """ 批量合成语音 """ try: model = load_model() texts = [r.text for r in requests] # 统一参数 infer_params = {} # 批量生成 audio_tensors = model.infer( texts, temperature=requests[0].temperature, top_P=requests[0].top_p, top_K=requests[0].top_k, ) results = [] for i, audio_tensor in enumerate(audio_tensors): filename = f"{uuid.uuid4().hex}.wav" filepath = save_audio(audio_tensor, filename) duration = audio_tensor.shape[-1] / SAMPLE_RATE results.append({ "audio_url": f"/audio/{filename}", "duration": round(duration, 2), "text": texts[i] }) return {"results": results} except Exception as e: logger.error(f"Batch synthesis error: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/audio/{filename}") async def get_audio(filename: str): """获取音频文件""" filepath = os.path.join(AUDIO_DIR, filename) if os.path.exists(filepath): return FileResponse(filepath, media_type="audio/wav") raise HTTPException(status_code=404, detail="Audio not found") @app.delete("/audio/{filename}") async def delete_audio(filename: str): """删除音频文件""" filepath = os.path.join(AUDIO_DIR, filename) if os.path.exists(filepath): os.remove(filepath) return {"status": "deleted"} return {"status": "not_found"} # 情感控制接口(ChatTTS 特色) @app.post("/synthesize/emotion") async def synthesize_with_emotion( text: str = Form(...), emotion: str = Form("neutral", description="情感:neutral/happy/sad/laugh") ): """ 带情感的语音合成 ChatTTS 支持情感控制: - neutral: 平静 - happy: 开心 - sad: 悲伤 - laugh: 笑声 """ try: model = load_model() # 根据情感调整参数 emotion_params = { 'neutral': {'temperature': 0.3, 'top_p': 0.7}, 'happy': {'temperature': 0.5, 'top_p': 0.8}, 'sad': {'temperature': 0.2, 'top_p': 0.6}, 'laugh': {'temperature': 0.6, 'top_p': 0.9}, } params = emotion_params.get(emotion, emotion_params['neutral']) # 添加情感标记(ChatTTS 特有) if emotion == 'laugh': # 添加笑声标记 text = f"[laugh]{text}" elif emotion == 'happy': text = f"[happy]{text}" elif emotion == 'sad': text = f"[sad]{text}" filename = f"{uuid.uuid4().hex}.wav" audio_tensor = model.infer([text], params=params)[0] filepath = save_audio(audio_tensor, filename) duration = audio_tensor.shape[-1] / SAMPLE_RATE return SynthesizeResponse( audio_url=f"/audio/{filename}", duration=round(duration, 2), text=text, timestamp=datetime.now().isoformat() ) except Exception as e: logger.error(f"Emotion synthesis error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=PORT) audio_url=f"/audio/{filename}", duration=round(duration, 2), text=text, timestamp=datetime.now().isoformat() ) except Exception as e: logger.error(f"Emotion synthesis error: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=PORT)