Files
chattts-server/server.py

367 lines
10 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
ChatTTS 语音合成服务
专为对话场景设计的 TTS 模型
"""
import os
import io
import uuid
import logging
import tempfile
import wave
import numpy as np
from typing import Optional
from datetime import datetime
import torch
import torchaudio
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from pydantic import BaseModel
# 配置
PORT = int(os.getenv("PORT", "12002"))
SAMPLE_RATE = 24000 # ChatTTS 默认采样率
AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_output")
# ChatTTS 模型路径配置
# 如果已下载模型,设置 MODEL_PATH 环境变量
# 例如: MODEL_PATH=/path/to/ChatTTS
MODEL_PATH = os.getenv("MODEL_PATH", None)
MODEL_SOURCE = os.getenv("MODEL_SOURCE", "auto") # auto / local / download
os.makedirs(AUDIO_DIR, exist_ok=True)
# 日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="ChatTTS Service",
description="对话式语音合成服务",
version="1.0.0"
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ChatTTS 模型(延迟加载)
chat_model = None
class SynthesizeRequest(BaseModel):
"""合成请求"""
text: str
temperature: float = 0.3
top_p: float = 0.7
top_k: int = 20
class SynthesizeResponse(BaseModel):
"""合成响应"""
audio_url: str
duration: float
text: str
timestamp: str
def load_model():
"""加载 ChatTTS 模型"""
global chat_model
if chat_model is None:
logger.info("Loading ChatTTS model...")
try:
import ChatTTS
chat_model = ChatTTS.Chat()
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# 加载模型方式
if MODEL_PATH and os.path.exists(MODEL_PATH):
# 从本地路径加载
logger.info(f"Loading model from: {MODEL_PATH}")
chat_model.load(
source=MODEL_PATH,
compile=True,
device=device
)
elif MODEL_SOURCE == "local":
# 从默认本地缓存加载
logger.info("Loading from local cache...")
chat_model.load(
compile=True,
device=device,
source="local"
)
else:
# 自动下载
logger.info("Auto downloading model...")
chat_model.load(
compile=True,
device=device
)
logger.info("ChatTTS model loaded successfully")
except Exception as e:
logger.error(f"Failed to load ChatTTS: {e}")
raise
return chat_model
def save_audio(audio_tensor: torch.Tensor, filename: str) -> str:
"""
保存音频文件
audio_tensor: shape [1, samples] 或 [samples]
"""
filepath = os.path.join(AUDIO_DIR, filename)
# 确保 tensor 正确形状
if audio_tensor.dim() == 2:
audio_tensor = audio_tensor.squeeze(0)
# 转换为 numpy
audio_np = audio_tensor.cpu().numpy() if audio_tensor.is_cuda else audio_tensor.numpy()
# 使用 soundfile 保存
import soundfile as sf
sf.write(filepath, audio_np, SAMPLE_RATE)
return filepath
@app.on_event("startup")
async def startup():
"""启动时预加载模型"""
try:
load_model()
logger.info(f"ChatTTS service ready on port {PORT}")
except Exception as e:
logger.warning(f"Model load delayed: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"service": "ChatTTS",
"model_loaded": chat_model is not None,
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
@app.get("/health")
async def health():
"""健康检查"""
model_loaded = chat_model is not None
return {
"status": "ok" if model_loaded else "loading",
"gpu": torch.cuda.is_available()
}
@app.post("/synthesize", response_model=SynthesizeResponse)
async def synthesize(
text: str = Form(..., description="要合成的文本"),
temperature: float = Form(0.3, description="温度参数"),
top_p: float = Form(0.7, description="Top-P参数"),
top_k: int = Form(20, description="Top-K参数")
):
"""
合成语音
ChatTTS 特点:
- 支持对话场景,有语气变化
- 支持笑声、叹气等情感
- 中文效果好
"""
try:
model = load_model()
# 生成唯一文件名
filename = f"{uuid.uuid4().hex}.wav"
# 合成语音
logger.info(f"Synthesizing: {text[:50]}...")
# ChatTTS 基本调用(简化版)
# 返回: list of audio tensors
result = model.infer(text)
# 处理返回结果
if isinstance(result, list):
audio_tensor = result[0]
elif isinstance(result, tuple):
audio_tensor = result[0]
else:
audio_tensor = result
# 转换为 torch tensor如果是 numpy
import numpy as np
if isinstance(audio_tensor, np.ndarray):
audio_tensor = torch.from_numpy(audio_tensor).float()
# 确保 tensor 正确形状
if audio_tensor.dim() == 1:
audio_tensor = audio_tensor.unsqueeze(0)
# 保存音频
filepath = save_audio(audio_tensor, filename)
# 计算时长
duration = audio_tensor.shape[-1] / SAMPLE_RATE
audio_url = f"/audio/{filename}"
logger.info(f"Generated audio: {duration:.2f}s")
return SynthesizeResponse(
audio_url=audio_url,
duration=round(duration, 2),
text=text,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Synthesis error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/synthesize/batch")
async def synthesize_batch(requests: list[SynthesizeRequest]):
"""
批量合成语音
"""
try:
model = load_model()
texts = [r.text for r in requests]
# 统一参数
infer_params = {}
# 批量生成
audio_tensors = model.infer(
texts,
temperature=requests[0].temperature,
top_P=requests[0].top_p,
top_K=requests[0].top_k,
)
results = []
for i, audio_tensor in enumerate(audio_tensors):
filename = f"{uuid.uuid4().hex}.wav"
filepath = save_audio(audio_tensor, filename)
duration = audio_tensor.shape[-1] / SAMPLE_RATE
results.append({
"audio_url": f"/audio/{filename}",
"duration": round(duration, 2),
"text": texts[i]
})
return {"results": results}
except Exception as e:
logger.error(f"Batch synthesis error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/audio/{filename}")
async def get_audio(filename: str):
"""获取音频文件"""
filepath = os.path.join(AUDIO_DIR, filename)
if os.path.exists(filepath):
return FileResponse(filepath, media_type="audio/wav")
raise HTTPException(status_code=404, detail="Audio not found")
@app.delete("/audio/{filename}")
async def delete_audio(filename: str):
"""删除音频文件"""
filepath = os.path.join(AUDIO_DIR, filename)
if os.path.exists(filepath):
os.remove(filepath)
return {"status": "deleted"}
return {"status": "not_found"}
# 情感控制接口ChatTTS 特色)
@app.post("/synthesize/emotion")
async def synthesize_with_emotion(
text: str = Form(...),
emotion: str = Form("neutral", description="情感neutral/happy/sad/laugh")
):
"""
带情感的语音合成
ChatTTS 支持情感控制:
- neutral: 平静
- happy: 开心
- sad: 悲伤
- laugh: 笑声
"""
try:
model = load_model()
# 根据情感调整参数
emotion_params = {
'neutral': {'temperature': 0.3, 'top_p': 0.7},
'happy': {'temperature': 0.5, 'top_p': 0.8},
'sad': {'temperature': 0.2, 'top_p': 0.6},
'laugh': {'temperature': 0.6, 'top_p': 0.9},
}
params = emotion_params.get(emotion, emotion_params['neutral'])
# 添加情感标记ChatTTS 特有)
if emotion == 'laugh':
# 添加笑声标记
text = f"[laugh]{text}"
elif emotion == 'happy':
text = f"[happy]{text}"
elif emotion == 'sad':
text = f"[sad]{text}"
filename = f"{uuid.uuid4().hex}.wav"
audio_tensor = model.infer([text], params=params)[0]
filepath = save_audio(audio_tensor, filename)
duration = audio_tensor.shape[-1] / SAMPLE_RATE
return SynthesizeResponse(
audio_url=f"/audio/{filename}",
duration=round(duration, 2),
text=text,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Emotion synthesis error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT) audio_url=f"/audio/{filename}",
duration=round(duration, 2),
text=text,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Emotion synthesis error: {e}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)