254 lines
7.6 KiB
Python
254 lines
7.6 KiB
Python
"""
|
||
TTS 语音合成模块
|
||
支持多种 TTS 方案
|
||
"""
|
||
|
||
import os
|
||
import uuid
|
||
import logging
|
||
import asyncio
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional, Tuple
|
||
from datetime import datetime
|
||
|
||
# 配置
|
||
AUDIO_DIR = os.getenv("AUDIO_DIR", "audio_cache")
|
||
os.makedirs(AUDIO_DIR, exist_ok=True)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TTSProvider(ABC):
|
||
"""TTS 提供者抽象类"""
|
||
|
||
@abstractmethod
|
||
async def synthesize(self, text: str) -> Tuple[str, str]:
|
||
"""
|
||
合成语音
|
||
返回: (音频文件路径, 音频URL路径)
|
||
"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def get_name(self) -> str:
|
||
"""获取提供者名称"""
|
||
pass
|
||
|
||
@abstractmethod
|
||
def is_available(self) -> bool:
|
||
"""检查是否可用"""
|
||
pass
|
||
|
||
|
||
class EdgeTTSProvider(TTSProvider):
|
||
"""Edge TTS 提供者(微软免费TTS)"""
|
||
|
||
# 可用音色
|
||
VOICES = {
|
||
"zh-CN-XiaoxiaoNeural": "晓晓(女)",
|
||
"zh-CN-YunxiNeural": "云希(男)",
|
||
"zh-CN-YunyangNeural": "云扬(男)",
|
||
"zh-CN-XiaochenNeural": "晓晨(女)",
|
||
"zh-CN-XiaohanNeural": "晓涵(女)",
|
||
"zh-CN-XiaomengNeural": "晓梦(女)",
|
||
"zh-CN-XiaomoNeural": "晓墨(女)",
|
||
"zh-CN-XiaoruiNeural": "晓睿(女)",
|
||
"zh-CN-XiaoshuangNeural": "晓双(女)",
|
||
"zh-CN-XiaoxuanNeural": "晓萱(女)",
|
||
"zh-CN-XiaoyanNeural": "晓颜(女)",
|
||
"zh-CN-XiaoyouNeural": "晓悠(女)",
|
||
}
|
||
|
||
DEFAULT_VOICE = "zh-CN-XiaoxiaoNeural"
|
||
|
||
def __init__(self, voice: Optional[str] = None):
|
||
self.voice = voice or self.DEFAULT_VOICE
|
||
self._available = None
|
||
|
||
async def synthesize(self, text: str) -> Tuple[str, str]:
|
||
"""使用 Edge TTS 合成语音"""
|
||
import edge_tts
|
||
|
||
# 生成唯一文件名
|
||
filename = f"{uuid.uuid4().hex}.mp3"
|
||
filepath = os.path.join(AUDIO_DIR, filename)
|
||
|
||
# 合成语音
|
||
communicate = edge_tts.Communicate(text, self.voice)
|
||
await communicate.save(filepath)
|
||
|
||
# 返回路径
|
||
audio_url = f"/audio/{filename}"
|
||
return filepath, audio_url
|
||
|
||
def get_name(self) -> str:
|
||
return "Edge TTS"
|
||
|
||
def get_voice_name(self) -> str:
|
||
"""获取当前音色名称"""
|
||
return self.VOICES.get(self.voice, self.voice)
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查 Edge TTS 是否可用"""
|
||
if self._available is None:
|
||
try:
|
||
import edge_tts
|
||
self._available = True
|
||
except ImportError:
|
||
logger.warning("edge-tts not installed")
|
||
self._available = False
|
||
return self._available
|
||
|
||
def set_voice(self, voice: str):
|
||
"""设置音色"""
|
||
if voice in self.VOICES:
|
||
self.voice = voice
|
||
else:
|
||
logger.warning(f"Unknown voice: {voice}, using default")
|
||
|
||
|
||
class ChatTTSProvider(TTSProvider):
|
||
"""ChatTTS 提供者(本地部署)"""
|
||
|
||
# ChatTTS 服务地址
|
||
CHATTTS_URL = os.getenv("CHATTTS_URL", "http://192.168.2.5:12002")
|
||
|
||
def __init__(self):
|
||
self._available = None
|
||
|
||
async def synthesize(self, text: str) -> Tuple[str, str]:
|
||
"""使用 ChatTTS 合成语音"""
|
||
import aiohttp
|
||
|
||
async with aiohttp.ClientSession() as session:
|
||
form = aiohttp.FormData()
|
||
form.add_field('text', text)
|
||
|
||
async with session.post(
|
||
f"{self.CHATTTS_URL}/synthesize",
|
||
data=form,
|
||
timeout=aiohttp.ClientTimeout(total=60)
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error = await resp.text()
|
||
raise Exception(f"ChatTTS error: {error}")
|
||
|
||
data = await resp.json()
|
||
# ChatTTS 返回的 URL 是 /audio/xxx.wav
|
||
# 改用本地代理路径(解决 HTTPS 页面访问 HTTP 问题)
|
||
original_url = data['audio_url']
|
||
# /audio/xxx.wav -> /chattts/audio/xxx.wav (通过本地代理)
|
||
filename = original_url.split('/')[-1]
|
||
audio_url = f"/chattts/audio/{filename}"
|
||
return None, audio_url
|
||
|
||
def get_name(self) -> str:
|
||
return "ChatTTS"
|
||
|
||
def is_available(self) -> bool:
|
||
"""检查 ChatTTS 是否可用"""
|
||
if self._available is None:
|
||
try:
|
||
import requests
|
||
resp = requests.get(f"{self.CHATTTS_URL}/health", timeout=5)
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
self._available = data.get("status") == "ok"
|
||
else:
|
||
self._available = False
|
||
except Exception as e:
|
||
logger.warning(f"ChatTTS check failed: {e}")
|
||
self._available = False
|
||
return self._available
|
||
|
||
def set_url(self, url: str):
|
||
"""设置服务地址"""
|
||
self.CHATTTS_URL = url
|
||
self._available = None # 重新检测
|
||
|
||
|
||
class NoTTSProvider(TTSProvider):
|
||
"""不使用 TTS"""
|
||
|
||
async def synthesize(self, text: str) -> Tuple[str, str]:
|
||
return None, None
|
||
|
||
def get_name(self) -> str:
|
||
return "无 TTS"
|
||
|
||
def is_available(self) -> bool:
|
||
return True
|
||
|
||
|
||
# TTS 管理器
|
||
class TTSManager:
|
||
"""TTS 方案管理"""
|
||
|
||
PROVIDERS = {
|
||
"edge": EdgeTTSProvider,
|
||
"chattts": ChatTTSProvider,
|
||
"none": NoTTSProvider,
|
||
}
|
||
|
||
def __init__(self, default_provider: str = "none"):
|
||
self.current_provider = default_provider
|
||
self._providers = {}
|
||
|
||
# 初始化 Edge TTS(如果可用)
|
||
edge_provider = EdgeTTSProvider()
|
||
if edge_provider.is_available():
|
||
self._providers["edge"] = edge_provider
|
||
|
||
# 初始化 ChatTTS(预留)
|
||
self._providers["chattts"] = ChatTTSProvider()
|
||
|
||
# 无 TTS
|
||
self._providers["none"] = NoTTSProvider()
|
||
|
||
def get_provider(self, provider_name: Optional[str] = None) -> TTSProvider:
|
||
"""获取 TTS 提供者"""
|
||
name = provider_name or self.current_provider
|
||
return self._providers.get(name, self._providers["none"])
|
||
|
||
def set_provider(self, provider_name: str):
|
||
"""设置当前 TTS 方案"""
|
||
if provider_name in self._providers:
|
||
self.current_provider = provider_name
|
||
else:
|
||
logger.warning(f"Unknown provider: {provider_name}")
|
||
|
||
def list_providers(self) -> list:
|
||
"""列出所有可用方案"""
|
||
return [
|
||
{
|
||
"name": name,
|
||
"display_name": provider.get_name(),
|
||
"available": provider.is_available()
|
||
}
|
||
for name, provider in self._providers.items()
|
||
]
|
||
|
||
def get_edge_voices(self) -> dict:
|
||
"""获取 Edge TTS 可用音色"""
|
||
return EdgeTTSProvider.VOICES
|
||
|
||
async def synthesize(self, text: str, provider_name: Optional[str] = None) -> Optional[str]:
|
||
"""
|
||
合成语音
|
||
返回音频URL
|
||
"""
|
||
provider = self.get_provider(provider_name)
|
||
if not provider.is_available():
|
||
logger.warning(f"Provider {provider.get_name()} not available")
|
||
return None
|
||
|
||
try:
|
||
_, audio_url = await provider.synthesize(text)
|
||
return audio_url
|
||
except Exception as e:
|
||
logger.error(f"TTS synthesis failed: {e}")
|
||
return None
|
||
|
||
|
||
# 全局 TTS 管理器
|
||
tts_manager = TTSManager() |