From e6ae11b7c4a52418a147d3255561999f43f4295a Mon Sep 17 00:00:00 2001 From: hubian <908234780@qq.com> Date: Wed, 22 Apr 2026 16:34:03 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=9C=AC=E5=9C=B0=E6=A8=A1=E5=9E=8B=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 20 +++++++++++++++++--- server.py | 41 ++++++++++++++++++++++++++++++++++------- start.sh | 9 +++++++++ 3 files changed, 60 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index b878d46..c0ce1b4 100644 --- a/README.md +++ b/README.md @@ -30,15 +30,29 @@ pip install -r requirements.txt ### 2. 启动服务 ```bash -# 默认端口 12002 +# 方式一:自动下载模型 python3 server.py -# 或使用脚本 -PORT=12002 ./start.sh +# 方式二:使用已下载的模型(指定路径) +MODEL_PATH=/path/to/ChatTTS python3 server.py + +# 方式三:使用本地缓存模型 +MODEL_SOURCE=local python3 server.py ``` 首次启动会自动下载 ChatTTS 模型(约 2GB)。 +**模型路径配置**: + +| 环境变量 | 说明 | 示例 | +|---------|------|------| +| `MODEL_PATH` | 模型目录路径 | `/data/models/ChatTTS` | +| `MODEL_SOURCE` | 加载方式 | `auto` / `local` / `download` | + +ChatTTS 模型默认下载位置: +- Linux: `~/.cache/modelscope/hub/ChatTTS/` +- 或手动下载后解压到任意目录 + ### 3. 验证服务 ```bash diff --git a/server.py b/server.py index 6d71e87..03337b6 100644 --- a/server.py +++ b/server.py @@ -24,6 +24,13 @@ from pydantic import BaseModel PORT = int(os.getenv("PORT", "12002")) SAMPLE_RATE = 24000 # ChatTTS 默认采样率 AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_output") + +# ChatTTS 模型路径配置 +# 如果已下载模型,设置 MODEL_PATH 环境变量 +# 例如: MODEL_PATH=/path/to/ChatTTS +MODEL_PATH = os.getenv("MODEL_PATH", None) +MODEL_SOURCE = os.getenv("MODEL_SOURCE", "auto") # auto / local / download + os.makedirs(AUDIO_DIR, exist_ok=True) # 日志 @@ -74,13 +81,33 @@ def load_model(): import ChatTTS chat_model = ChatTTS.Chat() - # 加载模型 - # 方式一:从本地加载(如果有) - # 方式二:自动下载 - chat_model.load( - compile=True, # 编译优化 - device="cuda" if torch.cuda.is_available() else "cpu" - ) + device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Using device: {device}") + + # 加载模型方式 + if MODEL_PATH and os.path.exists(MODEL_PATH): + # 从本地路径加载 + logger.info(f"Loading model from: {MODEL_PATH}") + chat_model.load( + source=MODEL_PATH, + compile=True, + device=device + ) + elif MODEL_SOURCE == "local": + # 从默认本地缓存加载 + logger.info("Loading from local cache...") + chat_model.load( + compile=True, + device=device, + source="local" + ) + else: + # 自动下载 + logger.info("Auto downloading model...") + chat_model.load( + compile=True, + device=device + ) logger.info("ChatTTS model loaded successfully") except Exception as e: diff --git a/start.sh b/start.sh index 8d2850f..2840f9c 100755 --- a/start.sh +++ b/start.sh @@ -4,8 +4,17 @@ cd "$(dirname "$0")" PORT=${PORT:-12002} +MODEL_PATH=${MODEL_PATH:-""} +MODEL_SOURCE=${MODEL_SOURCE:-"auto"} echo "Starting ChatTTS service on port $PORT..." echo "Device: $(python3 -c 'import torch; print("cuda" if torch.cuda.is_available() else "cpu")')" +if [ -n "$MODEL_PATH" ]; then + echo "Model path: $MODEL_PATH" +fi + +export MODEL_PATH +export MODEL_SOURCE + python3 server.py \ No newline at end of file