288 lines
8.8 KiB
Python
288 lines
8.8 KiB
Python
"""
|
||
语音交互网页后端
|
||
代理转发到 Qwen2-Audio 模型服务
|
||
"""
|
||
|
||
import os
|
||
import logging
|
||
from typing import Optional
|
||
from datetime import datetime
|
||
|
||
import aiohttp
|
||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.responses import FileResponse, Response
|
||
from pydantic import BaseModel
|
||
|
||
# 导入 TTS 服务
|
||
from tts_service import tts_manager, AUDIO_DIR
|
||
|
||
# 配置
|
||
MODEL_SERVICE_URL = os.getenv("MODEL_SERVICE_URL", "http://localhost:19018")
|
||
PORT = int(os.getenv("PORT", "19019"))
|
||
|
||
# 日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
app = FastAPI(
|
||
title="Voice Chat Web",
|
||
description="语音交互网页后端",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
class VoiceResponse(BaseModel):
|
||
"""语音响应"""
|
||
reply: str
|
||
conversation_id: str
|
||
timestamp: str
|
||
|
||
|
||
class StatusResponse(BaseModel):
|
||
"""状态响应"""
|
||
status: str
|
||
model_service: str
|
||
model_online: bool
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""根路径返回状态"""
|
||
return {"status": "ok", "service": "voice-chat-web"}
|
||
|
||
|
||
@app.get("/status", response_model=StatusResponse)
|
||
async def get_status():
|
||
"""检查服务状态"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(f"{MODEL_SERVICE_URL}/", timeout=aiohttp.ClientTimeout(total=5)) as resp:
|
||
if resp.status == 200:
|
||
data = await resp.json()
|
||
return StatusResponse(
|
||
status="ok",
|
||
model_service=MODEL_SERVICE_URL,
|
||
model_online=True
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Model service check failed: {e}")
|
||
|
||
return StatusResponse(
|
||
status="partial",
|
||
model_service=MODEL_SERVICE_URL,
|
||
model_online=False
|
||
)
|
||
|
||
|
||
@app.post("/voice/chat", response_model=VoiceResponse)
|
||
async def voice_chat(
|
||
audio: UploadFile = File(..., description="音频文件"),
|
||
conversation_id: Optional[str] = Form(None, description="对话ID")
|
||
):
|
||
"""
|
||
语音聊天接口
|
||
转发到模型服务
|
||
"""
|
||
try:
|
||
# 读取音频数据
|
||
audio_bytes = await audio.read()
|
||
|
||
# 转发到模型服务
|
||
async with aiohttp.ClientSession() as session:
|
||
form = aiohttp.FormData()
|
||
form.add_field(
|
||
'audio',
|
||
audio_bytes,
|
||
filename=audio.filename or 'audio.wav',
|
||
content_type=audio.content_type or 'audio/wav'
|
||
)
|
||
if conversation_id:
|
||
form.add_field('conversation_id', conversation_id)
|
||
|
||
async with session.post(
|
||
f"{MODEL_SERVICE_URL}/api/voice/inference",
|
||
data=form,
|
||
timeout=aiohttp.ClientTimeout(total=120) # 模型推理可能较慢
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"Model service error: {error_text}")
|
||
raise HTTPException(status_code=resp.status, detail=error_text)
|
||
|
||
data = await resp.json()
|
||
return VoiceResponse(
|
||
reply=data["reply"],
|
||
conversation_id=data["conversation_id"],
|
||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||
)
|
||
|
||
except aiohttp.ClientError as e:
|
||
logger.error(f"Connection error: {e}")
|
||
raise HTTPException(status_code=503, detail="模型服务连接失败")
|
||
except Exception as e:
|
||
logger.error(f"Voice chat error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/voice/text", response_model=VoiceResponse)
|
||
async def text_chat(
|
||
text: str = Form(..., description="文本消息"),
|
||
conversation_id: Optional[str] = Form(None, description="对话ID")
|
||
):
|
||
"""
|
||
文字聊天接口
|
||
转发到模型服务
|
||
"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
form = aiohttp.FormData()
|
||
form.add_field('text', text)
|
||
if conversation_id:
|
||
form.add_field('conversation_id', conversation_id)
|
||
|
||
async with session.post(
|
||
f"{MODEL_SERVICE_URL}/api/voice/text",
|
||
data=form,
|
||
timeout=aiohttp.ClientTimeout(total=120)
|
||
) as resp:
|
||
if resp.status != 200:
|
||
error_text = await resp.text()
|
||
logger.error(f"Model service error: {error_text}")
|
||
raise HTTPException(status_code=resp.status, detail=error_text)
|
||
|
||
data = await resp.json()
|
||
return VoiceResponse(
|
||
reply=data["reply"],
|
||
conversation_id=data["conversation_id"],
|
||
timestamp=data.get("timestamp", datetime.now().isoformat())
|
||
)
|
||
|
||
except aiohttp.ClientError as e:
|
||
logger.error(f"Connection error: {e}")
|
||
raise HTTPException(status_code=503, detail="模型服务连接失败")
|
||
except Exception as e:
|
||
logger.error(f"Text chat error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.delete("/conversation/{conversation_id}")
|
||
async def delete_conversation(conversation_id: str):
|
||
"""删除对话"""
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.delete(
|
||
f"{MODEL_SERVICE_URL}/api/voice/conversation/{conversation_id}",
|
||
timeout=aiohttp.ClientTimeout(total=10)
|
||
) as resp:
|
||
return await resp.json()
|
||
except Exception as e:
|
||
logger.error(f"Delete conversation error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# ========== TTS 相关接口 ==========
|
||
|
||
class TTSSettings(BaseModel):
|
||
"""TTS 设置"""
|
||
provider: str = "none"
|
||
voice: Optional[str] = None
|
||
|
||
|
||
class TTSResponse(BaseModel):
|
||
"""TTS 响应"""
|
||
audio_url: Optional[str]
|
||
provider: str
|
||
|
||
|
||
@app.get("/tts/providers")
|
||
async def get_tts_providers():
|
||
"""获取可用的 TTS 方案列表"""
|
||
providers = tts_manager.list_providers()
|
||
voices = tts_manager.get_edge_voices()
|
||
return {
|
||
"providers": providers,
|
||
"voices": voices,
|
||
"current": tts_manager.current_provider
|
||
}
|
||
|
||
|
||
@app.post("/tts/settings")
|
||
async def set_tts_settings(settings: TTSSettings):
|
||
"""设置 TTS 方案"""
|
||
tts_manager.set_provider(settings.provider)
|
||
|
||
# 设置音色(仅 Edge TTS)
|
||
if settings.provider == "edge" and settings.voice:
|
||
provider = tts_manager.get_provider("edge")
|
||
if hasattr(provider, 'set_voice'):
|
||
provider.set_voice(settings.voice)
|
||
|
||
return {
|
||
"provider": settings.provider,
|
||
"voice": settings.voice
|
||
}
|
||
|
||
|
||
@app.post("/tts/synthesize")
|
||
async def synthesize_tts(text: str = Form(...), provider: Optional[str] = Form(None)):
|
||
"""
|
||
合成语音
|
||
返回音频文件 URL
|
||
"""
|
||
try:
|
||
audio_url = await tts_manager.synthesize(text, provider)
|
||
return TTSResponse(
|
||
audio_url=audio_url,
|
||
provider=provider or tts_manager.current_provider
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"TTS synthesis error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
# 挂载音频文件目录
|
||
app.mount("/audio", StaticFiles(directory=AUDIO_DIR), name="audio")
|
||
|
||
|
||
# ChatTTS 音频代理(解决 HTTPS 页面访问 HTTP 资源问题)
|
||
@app.get("/chattts/audio/{filename}")
|
||
async def proxy_chattts_audio(filename: str):
|
||
"""代理 ChatTTS 音频文件"""
|
||
import aiohttp
|
||
|
||
chattts_url = os.getenv("CHATTTS_URL", "http://192.168.2.5:12002")
|
||
|
||
try:
|
||
async with aiohttp.ClientSession() as session:
|
||
async with session.get(
|
||
f"{chattts_url}/audio/{filename}",
|
||
timeout=aiohttp.ClientTimeout(total=30)
|
||
) as resp:
|
||
if resp.status != 200:
|
||
raise HTTPException(status_code=404, detail="Audio not found")
|
||
|
||
audio_data = await resp.read()
|
||
return Response(
|
||
content=audio_data,
|
||
media_type="audio/wav",
|
||
headers={"Cache-Control": "public, max-age=3600"}
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Proxy audio error: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=PORT) |