diff --git a/requirements.txt b/requirements.txt index eaecb69..a645629 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ python-multipart==0.0.9 torch==2.2.0 torchaudio==2.2.0 transformers==4.38.0 -ChatTTS==0.1.1 \ No newline at end of file +ChatTTS +soundfile==0.12.1 \ No newline at end of file diff --git a/server.py b/server.py index 4188910..a6a6925 100644 --- a/server.py +++ b/server.py @@ -124,11 +124,15 @@ def save_audio(audio_tensor: torch.Tensor, filename: str) -> str: filepath = os.path.join(AUDIO_DIR, filename) # 确保 tensor 正确形状 - if audio_tensor.dim() == 1: - audio_tensor = audio_tensor.unsqueeze(0) + if audio_tensor.dim() == 2: + audio_tensor = audio_tensor.squeeze(0) - # 保存为 WAV - torchaudio.save(filepath, audio_tensor, SAMPLE_RATE, format="wav") + # 转换为 numpy + audio_np = audio_tensor.cpu().numpy() if audio_tensor.is_cuda else audio_tensor.numpy() + + # 使用 soundfile 保存 + import soundfile as sf + sf.write(filepath, audio_np, SAMPLE_RATE) return filepath