Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b73e2287c | |||
| 3ca2921820 | |||
| 24b590d07f | |||
| 264c0413c6 | |||
| e6ae11b7c4 |
20
README.md
20
README.md
@@ -30,15 +30,29 @@ pip install -r requirements.txt
|
||||
### 2. 启动服务
|
||||
|
||||
```bash
|
||||
# 默认端口 12002
|
||||
# 方式一:自动下载模型
|
||||
python3 server.py
|
||||
|
||||
# 或使用脚本
|
||||
PORT=12002 ./start.sh
|
||||
# 方式二:使用已下载的模型(指定路径)
|
||||
MODEL_PATH=/path/to/ChatTTS python3 server.py
|
||||
|
||||
# 方式三:使用本地缓存模型
|
||||
MODEL_SOURCE=local python3 server.py
|
||||
```
|
||||
|
||||
首次启动会自动下载 ChatTTS 模型(约 2GB)。
|
||||
|
||||
**模型路径配置**:
|
||||
|
||||
| 环境变量 | 说明 | 示例 |
|
||||
|---------|------|------|
|
||||
| `MODEL_PATH` | 模型目录路径 | `/data/models/ChatTTS` |
|
||||
| `MODEL_SOURCE` | 加载方式 | `auto` / `local` / `download` |
|
||||
|
||||
ChatTTS 模型默认下载位置:
|
||||
- Linux: `~/.cache/modelscope/hub/ChatTTS/`
|
||||
- 或手动下载后解压到任意目录
|
||||
|
||||
### 3. 验证服务
|
||||
|
||||
```bash
|
||||
|
||||
@@ -4,4 +4,5 @@ python-multipart==0.0.9
|
||||
torch==2.2.0
|
||||
torchaudio==2.2.0
|
||||
transformers==4.38.0
|
||||
ChatTTS==0.1.1
|
||||
ChatTTS
|
||||
soundfile==0.12.1
|
||||
112
server.py
112
server.py
@@ -24,6 +24,13 @@ from pydantic import BaseModel
|
||||
PORT = int(os.getenv("PORT", "12002"))
|
||||
SAMPLE_RATE = 24000 # ChatTTS 默认采样率
|
||||
AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_output")
|
||||
|
||||
# ChatTTS 模型路径配置
|
||||
# 如果已下载模型,设置 MODEL_PATH 环境变量
|
||||
# 例如: MODEL_PATH=/path/to/ChatTTS
|
||||
MODEL_PATH = os.getenv("MODEL_PATH", None)
|
||||
MODEL_SOURCE = os.getenv("MODEL_SOURCE", "auto") # auto / local / download
|
||||
|
||||
os.makedirs(AUDIO_DIR, exist_ok=True)
|
||||
|
||||
# 日志
|
||||
@@ -74,13 +81,33 @@ def load_model():
|
||||
import ChatTTS
|
||||
chat_model = ChatTTS.Chat()
|
||||
|
||||
# 加载模型
|
||||
# 方式一:从本地加载(如果有)
|
||||
# 方式二:自动下载
|
||||
chat_model.load(
|
||||
compile=True, # 编译优化
|
||||
device="cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# 加载模型方式
|
||||
if MODEL_PATH and os.path.exists(MODEL_PATH):
|
||||
# 从本地路径加载
|
||||
logger.info(f"Loading model from: {MODEL_PATH}")
|
||||
chat_model.load(
|
||||
source=MODEL_PATH,
|
||||
compile=True,
|
||||
device=device
|
||||
)
|
||||
elif MODEL_SOURCE == "local":
|
||||
# 从默认本地缓存加载
|
||||
logger.info("Loading from local cache...")
|
||||
chat_model.load(
|
||||
compile=True,
|
||||
device=device,
|
||||
source="local"
|
||||
)
|
||||
else:
|
||||
# 自动下载
|
||||
logger.info("Auto downloading model...")
|
||||
chat_model.load(
|
||||
compile=True,
|
||||
device=device
|
||||
)
|
||||
|
||||
logger.info("ChatTTS model loaded successfully")
|
||||
except Exception as e:
|
||||
@@ -97,11 +124,15 @@ def save_audio(audio_tensor: torch.Tensor, filename: str) -> str:
|
||||
filepath = os.path.join(AUDIO_DIR, filename)
|
||||
|
||||
# 确保 tensor 正确形状
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
if audio_tensor.dim() == 2:
|
||||
audio_tensor = audio_tensor.squeeze(0)
|
||||
|
||||
# 保存为 WAV
|
||||
torchaudio.save(filepath, audio_tensor, SAMPLE_RATE, format="wav")
|
||||
# 转换为 numpy
|
||||
audio_np = audio_tensor.cpu().numpy() if audio_tensor.is_cuda else audio_tensor.numpy()
|
||||
|
||||
# 使用 soundfile 保存
|
||||
import soundfile as sf
|
||||
sf.write(filepath, audio_np, SAMPLE_RATE)
|
||||
|
||||
return filepath
|
||||
|
||||
@@ -158,22 +189,29 @@ async def synthesize(
|
||||
# 生成唯一文件名
|
||||
filename = f"{uuid.uuid4().hex}.wav"
|
||||
|
||||
# 合成参数
|
||||
params = {
|
||||
'temperature': temperature,
|
||||
'top_p': top_p,
|
||||
'top_k': top_k,
|
||||
'spk_emb': None, # 可选:说话人嵌入
|
||||
}
|
||||
|
||||
# 合成语音
|
||||
logger.info(f"Synthesizing: {text[:50]}...")
|
||||
|
||||
# ChatTTS 生成
|
||||
audio_tensor = model.infer(
|
||||
[text],
|
||||
params=params
|
||||
)[0] # 返回是列表,取第一个
|
||||
# ChatTTS 基本调用(简化版)
|
||||
# 返回: list of audio tensors
|
||||
result = model.infer(text)
|
||||
|
||||
# 处理返回结果
|
||||
if isinstance(result, list):
|
||||
audio_tensor = result[0]
|
||||
elif isinstance(result, tuple):
|
||||
audio_tensor = result[0]
|
||||
else:
|
||||
audio_tensor = result
|
||||
|
||||
# 转换为 torch tensor(如果是 numpy)
|
||||
import numpy as np
|
||||
if isinstance(audio_tensor, np.ndarray):
|
||||
audio_tensor = torch.from_numpy(audio_tensor).float()
|
||||
|
||||
# 确保 tensor 正确形状
|
||||
if audio_tensor.dim() == 1:
|
||||
audio_tensor = audio_tensor.unsqueeze(0)
|
||||
|
||||
# 保存音频
|
||||
filepath = save_audio(audio_tensor, filename)
|
||||
@@ -208,14 +246,15 @@ async def synthesize_batch(requests: list[SynthesizeRequest]):
|
||||
texts = [r.text for r in requests]
|
||||
|
||||
# 统一参数
|
||||
params = {
|
||||
'temperature': requests[0].temperature,
|
||||
'top_p': requests[0].top_p,
|
||||
'top_k': requests[0].top_k,
|
||||
}
|
||||
infer_params = {}
|
||||
|
||||
# 批量生成
|
||||
audio_tensors = model.infer(texts, params=params)
|
||||
audio_tensors = model.infer(
|
||||
texts,
|
||||
temperature=requests[0].temperature,
|
||||
top_P=requests[0].top_p,
|
||||
top_K=requests[0].top_k,
|
||||
)
|
||||
|
||||
results = []
|
||||
for i, audio_tensor in enumerate(audio_tensors):
|
||||
@@ -310,6 +349,19 @@ async def synthesize_with_emotion(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=PORT) audio_url=f"/audio/{filename}",
|
||||
duration=round(duration, 2),
|
||||
text=text,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Emotion synthesis error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=PORT)
|
||||
9
start.sh
9
start.sh
@@ -4,8 +4,17 @@
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
PORT=${PORT:-12002}
|
||||
MODEL_PATH=${MODEL_PATH:-""}
|
||||
MODEL_SOURCE=${MODEL_SOURCE:-"auto"}
|
||||
|
||||
echo "Starting ChatTTS service on port $PORT..."
|
||||
echo "Device: $(python3 -c 'import torch; print("cuda" if torch.cuda.is_available() else "cpu")')"
|
||||
|
||||
if [ -n "$MODEL_PATH" ]; then
|
||||
echo "Model path: $MODEL_PATH"
|
||||
fi
|
||||
|
||||
export MODEL_PATH
|
||||
export MODEL_SOURCE
|
||||
|
||||
python3 server.py
|
||||
Reference in New Issue
Block a user