359 lines
9.7 KiB
Python
359 lines
9.7 KiB
Python
"""
|
||
ChatTTS 语音合成服务
|
||
专为对话场景设计的 TTS 模型
|
||
"""
|
||
|
||
import os
|
||
import io
|
||
import uuid
|
||
import logging
|
||
import tempfile
|
||
import wave
|
||
import numpy as np
|
||
from typing import Optional
|
||
from datetime import datetime
|
||
|
||
import torch
|
||
import torchaudio
|
||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.responses import FileResponse
|
||
from pydantic import BaseModel
|
||
|
||
# 配置
|
||
PORT = int(os.getenv("PORT", "12002"))
|
||
SAMPLE_RATE = 24000 # ChatTTS 默认采样率
|
||
AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_output")
|
||
|
||
# ChatTTS 模型路径配置
|
||
# 如果已下载模型,设置 MODEL_PATH 环境变量
|
||
# 例如: MODEL_PATH=/path/to/ChatTTS
|
||
MODEL_PATH = os.getenv("MODEL_PATH", None)
|
||
MODEL_SOURCE = os.getenv("MODEL_SOURCE", "auto") # auto / local / download
|
||
|
||
os.makedirs(AUDIO_DIR, exist_ok=True)
|
||
|
||
# 日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
app = FastAPI(
|
||
title="ChatTTS Service",
|
||
description="对话式语音合成服务",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
# ChatTTS 模型(延迟加载)
|
||
chat_model = None
|
||
|
||
|
||
class SynthesizeRequest(BaseModel):
|
||
"""合成请求"""
|
||
text: str
|
||
temperature: float = 0.3
|
||
top_p: float = 0.7
|
||
top_k: int = 20
|
||
|
||
|
||
class SynthesizeResponse(BaseModel):
|
||
"""合成响应"""
|
||
audio_url: str
|
||
duration: float
|
||
text: str
|
||
timestamp: str
|
||
|
||
|
||
def load_model():
|
||
"""加载 ChatTTS 模型"""
|
||
global chat_model
|
||
if chat_model is None:
|
||
logger.info("Loading ChatTTS model...")
|
||
try:
|
||
import ChatTTS
|
||
chat_model = ChatTTS.Chat()
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
logger.info(f"Using device: {device}")
|
||
|
||
# 加载模型方式
|
||
if MODEL_PATH and os.path.exists(MODEL_PATH):
|
||
# 从本地路径加载
|
||
logger.info(f"Loading model from: {MODEL_PATH}")
|
||
chat_model.load(
|
||
source=MODEL_PATH,
|
||
compile=True,
|
||
device=device
|
||
)
|
||
elif MODEL_SOURCE == "local":
|
||
# 从默认本地缓存加载
|
||
logger.info("Loading from local cache...")
|
||
chat_model.load(
|
||
compile=True,
|
||
device=device,
|
||
source="local"
|
||
)
|
||
else:
|
||
# 自动下载
|
||
logger.info("Auto downloading model...")
|
||
chat_model.load(
|
||
compile=True,
|
||
device=device
|
||
)
|
||
|
||
logger.info("ChatTTS model loaded successfully")
|
||
except Exception as e:
|
||
logger.error(f"Failed to load ChatTTS: {e}")
|
||
raise
|
||
return chat_model
|
||
|
||
|
||
def save_audio(audio_tensor: torch.Tensor, filename: str) -> str:
|
||
"""
|
||
保存音频文件
|
||
audio_tensor: shape [1, samples] 或 [samples]
|
||
"""
|
||
filepath = os.path.join(AUDIO_DIR, filename)
|
||
|
||
# 确保 tensor 正确形状
|
||
if audio_tensor.dim() == 1:
|
||
audio_tensor = audio_tensor.unsqueeze(0)
|
||
|
||
# 保存为 WAV
|
||
torchaudio.save(filepath, audio_tensor, SAMPLE_RATE, format="wav")
|
||
|
||
return filepath
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
"""启动时预加载模型"""
|
||
try:
|
||
load_model()
|
||
logger.info(f"ChatTTS service ready on port {PORT}")
|
||
except Exception as e:
|
||
logger.warning(f"Model load delayed: {e}")
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""健康检查"""
|
||
return {
|
||
"status": "ok",
|
||
"service": "ChatTTS",
|
||
"model_loaded": chat_model is not None,
|
||
"device": "cuda" if torch.cuda.is_available() else "cpu"
|
||
}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health():
|
||
"""健康检查"""
|
||
model_loaded = chat_model is not None
|
||
return {
|
||
"status": "ok" if model_loaded else "loading",
|
||
"gpu": torch.cuda.is_available()
|
||
}
|
||
|
||
|
||
@app.post("/synthesize", response_model=SynthesizeResponse)
|
||
async def synthesize(
|
||
text: str = Form(..., description="要合成的文本"),
|
||
temperature: float = Form(0.3, description="温度参数"),
|
||
top_p: float = Form(0.7, description="Top-P参数"),
|
||
top_k: int = Form(20, description="Top-K参数")
|
||
):
|
||
"""
|
||
合成语音
|
||
|
||
ChatTTS 特点:
|
||
- 支持对话场景,有语气变化
|
||
- 支持笑声、叹气等情感
|
||
- 中文效果好
|
||
"""
|
||
try:
|
||
model = load_model()
|
||
|
||
# 生成唯一文件名
|
||
filename = f"{uuid.uuid4().hex}.wav"
|
||
|
||
# 合成参数(ChatTTS 新版本参数)
|
||
infer_params = {}
|
||
|
||
# 合成语音
|
||
logger.info(f"Synthesizing: {text[:50]}...")
|
||
|
||
# ChatTTS 生成(新版本 API)
|
||
audio_tensor = model.infer(
|
||
text,
|
||
temperature=temperature,
|
||
top_P=top_p,
|
||
top_K=top_k,
|
||
)
|
||
|
||
# 返回可能是列表或tensor
|
||
if isinstance(audio_tensor, list):
|
||
audio_tensor = audio_tensor[0]
|
||
if isinstance(audio_tensor, tuple):
|
||
audio_tensor = audio_tensor[0]
|
||
|
||
# 保存音频
|
||
filepath = save_audio(audio_tensor, filename)
|
||
|
||
# 计算时长
|
||
duration = audio_tensor.shape[-1] / SAMPLE_RATE
|
||
|
||
audio_url = f"/audio/{filename}"
|
||
|
||
logger.info(f"Generated audio: {duration:.2f}s")
|
||
|
||
return SynthesizeResponse(
|
||
audio_url=audio_url,
|
||
duration=round(duration, 2),
|
||
text=text,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Synthesis error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/synthesize/batch")
|
||
async def synthesize_batch(requests: list[SynthesizeRequest]):
|
||
"""
|
||
批量合成语音
|
||
"""
|
||
try:
|
||
model = load_model()
|
||
|
||
texts = [r.text for r in requests]
|
||
|
||
# 统一参数
|
||
infer_params = {}
|
||
|
||
# 批量生成
|
||
audio_tensors = model.infer(
|
||
texts,
|
||
temperature=requests[0].temperature,
|
||
top_P=requests[0].top_p,
|
||
top_K=requests[0].top_k,
|
||
)
|
||
|
||
results = []
|
||
for i, audio_tensor in enumerate(audio_tensors):
|
||
filename = f"{uuid.uuid4().hex}.wav"
|
||
filepath = save_audio(audio_tensor, filename)
|
||
duration = audio_tensor.shape[-1] / SAMPLE_RATE
|
||
|
||
results.append({
|
||
"audio_url": f"/audio/{filename}",
|
||
"duration": round(duration, 2),
|
||
"text": texts[i]
|
||
})
|
||
|
||
return {"results": results}
|
||
|
||
except Exception as e:
|
||
logger.error(f"Batch synthesis error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/audio/{filename}")
|
||
async def get_audio(filename: str):
|
||
"""获取音频文件"""
|
||
filepath = os.path.join(AUDIO_DIR, filename)
|
||
if os.path.exists(filepath):
|
||
return FileResponse(filepath, media_type="audio/wav")
|
||
raise HTTPException(status_code=404, detail="Audio not found")
|
||
|
||
|
||
@app.delete("/audio/{filename}")
|
||
async def delete_audio(filename: str):
|
||
"""删除音频文件"""
|
||
filepath = os.path.join(AUDIO_DIR, filename)
|
||
if os.path.exists(filepath):
|
||
os.remove(filepath)
|
||
return {"status": "deleted"}
|
||
return {"status": "not_found"}
|
||
|
||
|
||
# 情感控制接口(ChatTTS 特色)
|
||
@app.post("/synthesize/emotion")
|
||
async def synthesize_with_emotion(
|
||
text: str = Form(...),
|
||
emotion: str = Form("neutral", description="情感:neutral/happy/sad/laugh")
|
||
):
|
||
"""
|
||
带情感的语音合成
|
||
|
||
ChatTTS 支持情感控制:
|
||
- neutral: 平静
|
||
- happy: 开心
|
||
- sad: 悲伤
|
||
- laugh: 笑声
|
||
"""
|
||
try:
|
||
model = load_model()
|
||
|
||
# 根据情感调整参数
|
||
emotion_params = {
|
||
'neutral': {'temperature': 0.3, 'top_p': 0.7},
|
||
'happy': {'temperature': 0.5, 'top_p': 0.8},
|
||
'sad': {'temperature': 0.2, 'top_p': 0.6},
|
||
'laugh': {'temperature': 0.6, 'top_p': 0.9},
|
||
}
|
||
|
||
params = emotion_params.get(emotion, emotion_params['neutral'])
|
||
|
||
# 添加情感标记(ChatTTS 特有)
|
||
if emotion == 'laugh':
|
||
# 添加笑声标记
|
||
text = f"[laugh]{text}"
|
||
elif emotion == 'happy':
|
||
text = f"[happy]{text}"
|
||
elif emotion == 'sad':
|
||
text = f"[sad]{text}"
|
||
|
||
filename = f"{uuid.uuid4().hex}.wav"
|
||
audio_tensor = model.infer([text], params=params)[0]
|
||
|
||
filepath = save_audio(audio_tensor, filename)
|
||
duration = audio_tensor.shape[-1] / SAMPLE_RATE
|
||
|
||
return SynthesizeResponse(
|
||
audio_url=f"/audio/{filename}",
|
||
duration=round(duration, 2),
|
||
text=text,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Emotion synthesis error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=PORT) audio_url=f"/audio/{filename}",
|
||
duration=round(duration, 2),
|
||
text=text,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Emotion synthesis error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=PORT) |