Files
voice-chat-web/server.py

260 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
语音交互网页后端
代理转发到 Qwen2-Audio 模型服务
"""
import os
import logging
from typing import Optional
from datetime import datetime
import aiohttp
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from pydantic import BaseModel
# 导入 TTS 服务
from tts_service import tts_manager, AUDIO_DIR
# 配置
MODEL_SERVICE_URL = os.getenv("MODEL_SERVICE_URL", "http://localhost:19018")
PORT = int(os.getenv("PORT", "19019"))
# 日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Voice Chat Web",
description="语音交互网页后端",
version="1.0.0"
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class VoiceResponse(BaseModel):
"""语音响应"""
reply: str
conversation_id: str
timestamp: str
class StatusResponse(BaseModel):
"""状态响应"""
status: str
model_service: str
model_online: bool
@app.get("/")
async def root():
"""根路径返回状态"""
return {"status": "ok", "service": "voice-chat-web"}
@app.get("/status", response_model=StatusResponse)
async def get_status():
"""检查服务状态"""
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{MODEL_SERVICE_URL}/", timeout=aiohttp.ClientTimeout(total=5)) as resp:
if resp.status == 200:
data = await resp.json()
return StatusResponse(
status="ok",
model_service=MODEL_SERVICE_URL,
model_online=True
)
except Exception as e:
logger.warning(f"Model service check failed: {e}")
return StatusResponse(
status="partial",
model_service=MODEL_SERVICE_URL,
model_online=False
)
@app.post("/voice/chat", response_model=VoiceResponse)
async def voice_chat(
audio: UploadFile = File(..., description="音频文件"),
conversation_id: Optional[str] = Form(None, description="对话ID")
):
"""
语音聊天接口
转发到模型服务
"""
try:
# 读取音频数据
audio_bytes = await audio.read()
# 转发到模型服务
async with aiohttp.ClientSession() as session:
form = aiohttp.FormData()
form.add_field(
'audio',
audio_bytes,
filename=audio.filename or 'audio.wav',
content_type=audio.content_type or 'audio/wav'
)
if conversation_id:
form.add_field('conversation_id', conversation_id)
async with session.post(
f"{MODEL_SERVICE_URL}/api/voice/inference",
data=form,
timeout=aiohttp.ClientTimeout(total=120) # 模型推理可能较慢
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Model service error: {error_text}")
raise HTTPException(status_code=resp.status, detail=error_text)
data = await resp.json()
return VoiceResponse(
reply=data["reply"],
conversation_id=data["conversation_id"],
timestamp=data.get("timestamp", datetime.now().isoformat())
)
except aiohttp.ClientError as e:
logger.error(f"Connection error: {e}")
raise HTTPException(status_code=503, detail="模型服务连接失败")
except Exception as e:
logger.error(f"Voice chat error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/voice/text", response_model=VoiceResponse)
async def text_chat(
text: str = Form(..., description="文本消息"),
conversation_id: Optional[str] = Form(None, description="对话ID")
):
"""
文字聊天接口
转发到模型服务
"""
try:
async with aiohttp.ClientSession() as session:
form = aiohttp.FormData()
form.add_field('text', text)
if conversation_id:
form.add_field('conversation_id', conversation_id)
async with session.post(
f"{MODEL_SERVICE_URL}/api/voice/text",
data=form,
timeout=aiohttp.ClientTimeout(total=120)
) as resp:
if resp.status != 200:
error_text = await resp.text()
logger.error(f"Model service error: {error_text}")
raise HTTPException(status_code=resp.status, detail=error_text)
data = await resp.json()
return VoiceResponse(
reply=data["reply"],
conversation_id=data["conversation_id"],
timestamp=data.get("timestamp", datetime.now().isoformat())
)
except aiohttp.ClientError as e:
logger.error(f"Connection error: {e}")
raise HTTPException(status_code=503, detail="模型服务连接失败")
except Exception as e:
logger.error(f"Text chat error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/conversation/{conversation_id}")
async def delete_conversation(conversation_id: str):
"""删除对话"""
try:
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{MODEL_SERVICE_URL}/api/voice/conversation/{conversation_id}",
timeout=aiohttp.ClientTimeout(total=10)
) as resp:
return await resp.json()
except Exception as e:
logger.error(f"Delete conversation error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ========== TTS 相关接口 ==========
class TTSSettings(BaseModel):
"""TTS 设置"""
provider: str = "none"
voice: Optional[str] = None
class TTSResponse(BaseModel):
"""TTS 响应"""
audio_url: Optional[str]
provider: str
@app.get("/tts/providers")
async def get_tts_providers():
"""获取可用的 TTS 方案列表"""
providers = tts_manager.list_providers()
voices = tts_manager.get_edge_voices()
return {
"providers": providers,
"voices": voices,
"current": tts_manager.current_provider
}
@app.post("/tts/settings")
async def set_tts_settings(settings: TTSSettings):
"""设置 TTS 方案"""
tts_manager.set_provider(settings.provider)
# 设置音色(仅 Edge TTS
if settings.provider == "edge" and settings.voice:
provider = tts_manager.get_provider("edge")
if hasattr(provider, 'set_voice'):
provider.set_voice(settings.voice)
return {
"provider": settings.provider,
"voice": settings.voice
}
@app.post("/tts/synthesize")
async def synthesize_tts(text: str = Form(...), provider: Optional[str] = Form(None)):
"""
合成语音
返回音频文件 URL
"""
try:
audio_url = await tts_manager.synthesize(text, provider)
return TTSResponse(
audio_url=audio_url,
provider=provider or tts_manager.current_provider
)
except Exception as e:
logger.error(f"TTS synthesis error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# 挂载音频文件目录
app.mount("/audio", StaticFiles(directory=AUDIO_DIR), name="audio")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=PORT)