Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6b73e2287c | |||
| 3ca2921820 | |||
| 24b590d07f | |||
| 264c0413c6 |
@@ -4,4 +4,5 @@ python-multipart==0.0.9
|
|||||||
torch==2.2.0
|
torch==2.2.0
|
||||||
torchaudio==2.2.0
|
torchaudio==2.2.0
|
||||||
transformers==4.38.0
|
transformers==4.38.0
|
||||||
ChatTTS==0.1.1
|
ChatTTS
|
||||||
|
soundfile==0.12.1
|
||||||
71
server.py
71
server.py
@@ -124,11 +124,15 @@ def save_audio(audio_tensor: torch.Tensor, filename: str) -> str:
|
|||||||
filepath = os.path.join(AUDIO_DIR, filename)
|
filepath = os.path.join(AUDIO_DIR, filename)
|
||||||
|
|
||||||
# 确保 tensor 正确形状
|
# 确保 tensor 正确形状
|
||||||
if audio_tensor.dim() == 1:
|
if audio_tensor.dim() == 2:
|
||||||
audio_tensor = audio_tensor.unsqueeze(0)
|
audio_tensor = audio_tensor.squeeze(0)
|
||||||
|
|
||||||
# 保存为 WAV
|
# 转换为 numpy
|
||||||
torchaudio.save(filepath, audio_tensor, SAMPLE_RATE, format="wav")
|
audio_np = audio_tensor.cpu().numpy() if audio_tensor.is_cuda else audio_tensor.numpy()
|
||||||
|
|
||||||
|
# 使用 soundfile 保存
|
||||||
|
import soundfile as sf
|
||||||
|
sf.write(filepath, audio_np, SAMPLE_RATE)
|
||||||
|
|
||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
@@ -185,22 +189,29 @@ async def synthesize(
|
|||||||
# 生成唯一文件名
|
# 生成唯一文件名
|
||||||
filename = f"{uuid.uuid4().hex}.wav"
|
filename = f"{uuid.uuid4().hex}.wav"
|
||||||
|
|
||||||
# 合成参数
|
|
||||||
params = {
|
|
||||||
'temperature': temperature,
|
|
||||||
'top_p': top_p,
|
|
||||||
'top_k': top_k,
|
|
||||||
'spk_emb': None, # 可选:说话人嵌入
|
|
||||||
}
|
|
||||||
|
|
||||||
# 合成语音
|
# 合成语音
|
||||||
logger.info(f"Synthesizing: {text[:50]}...")
|
logger.info(f"Synthesizing: {text[:50]}...")
|
||||||
|
|
||||||
# ChatTTS 生成
|
# ChatTTS 基本调用(简化版)
|
||||||
audio_tensor = model.infer(
|
# 返回: list of audio tensors
|
||||||
[text],
|
result = model.infer(text)
|
||||||
params=params
|
|
||||||
)[0] # 返回是列表,取第一个
|
# 处理返回结果
|
||||||
|
if isinstance(result, list):
|
||||||
|
audio_tensor = result[0]
|
||||||
|
elif isinstance(result, tuple):
|
||||||
|
audio_tensor = result[0]
|
||||||
|
else:
|
||||||
|
audio_tensor = result
|
||||||
|
|
||||||
|
# 转换为 torch tensor(如果是 numpy)
|
||||||
|
import numpy as np
|
||||||
|
if isinstance(audio_tensor, np.ndarray):
|
||||||
|
audio_tensor = torch.from_numpy(audio_tensor).float()
|
||||||
|
|
||||||
|
# 确保 tensor 正确形状
|
||||||
|
if audio_tensor.dim() == 1:
|
||||||
|
audio_tensor = audio_tensor.unsqueeze(0)
|
||||||
|
|
||||||
# 保存音频
|
# 保存音频
|
||||||
filepath = save_audio(audio_tensor, filename)
|
filepath = save_audio(audio_tensor, filename)
|
||||||
@@ -235,14 +246,15 @@ async def synthesize_batch(requests: list[SynthesizeRequest]):
|
|||||||
texts = [r.text for r in requests]
|
texts = [r.text for r in requests]
|
||||||
|
|
||||||
# 统一参数
|
# 统一参数
|
||||||
params = {
|
infer_params = {}
|
||||||
'temperature': requests[0].temperature,
|
|
||||||
'top_p': requests[0].top_p,
|
|
||||||
'top_k': requests[0].top_k,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 批量生成
|
# 批量生成
|
||||||
audio_tensors = model.infer(texts, params=params)
|
audio_tensors = model.infer(
|
||||||
|
texts,
|
||||||
|
temperature=requests[0].temperature,
|
||||||
|
top_P=requests[0].top_p,
|
||||||
|
top_K=requests[0].top_k,
|
||||||
|
)
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i, audio_tensor in enumerate(audio_tensors):
|
for i, audio_tensor in enumerate(audio_tensors):
|
||||||
@@ -337,6 +349,19 @@ async def synthesize_with_emotion(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=PORT) audio_url=f"/audio/{filename}",
|
||||||
|
duration=round(duration, 2),
|
||||||
|
text=text,
|
||||||
|
timestamp=datetime.now().isoformat()
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Emotion synthesis error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
uvicorn.run(app, host="0.0.0.0", port=PORT)
|
uvicorn.run(app, host="0.0.0.0", port=PORT)
|
||||||
Reference in New Issue
Block a user