Files
qwen-audio-server/server.py

334 lines
9.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Qwen2-Audio 模型服务
语音转文字 + 多轮对话
"""
import os
import io
import uuid
import tempfile
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
import librosa
import soundfile as sf
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
# 模型加载(延迟加载)
model = None
processor = None
# 对话历史存储(内存,可换成 Redis
conversations: Dict[str, List[Dict]] = {}
# 配置
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2-Audio-7B-Instruct")
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "10")) # 最多保留10轮对话
SAMPLE_RATE = 16000 # Qwen2-Audio 采样率
# 日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Qwen2-Audio Voice Service",
description="语音交互模型服务",
version="1.0.0"
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class VoiceResponse(BaseModel):
"""语音响应"""
reply: str
conversation_id: str
timestamp: str
class ConversationHistory(BaseModel):
"""对话历史"""
conversation_id: str
history: List[Dict[str, Any]]
def load_model():
"""加载模型(延迟加载)"""
global model, processor
if model is None:
logger.info(f"Loading model: {MODEL_NAME}")
from modelscope import Qwen2AudioForConditionalGeneration, AutoProcessor
processor = AutoProcessor.from_pretrained(MODEL_NAME)
model = Qwen2AudioForConditionalGeneration.from_pretrained(
MODEL_NAME,
device_map="auto"
)
logger.info("Model loaded successfully")
return model, processor
def process_audio(audio_bytes: bytes) -> tuple:
"""
处理音频文件
返回: (audio_array, sample_rate)
"""
audio_io = io.BytesIO(audio_bytes)
# 使用 librosa 读取音频支持更多格式WAV, WebM, MP3, FLAC 等)
audio, sr = librosa.load(audio_io, sr=SAMPLE_RATE, mono=True)
return audio, SAMPLE_RATE
def build_conversation(
history: List[Dict],
audio_array: Optional[Any] = None
) -> List[Dict]:
"""
构建对话格式
"""
conversation = []
for turn in history:
if turn["role"] == "user" and turn.get("audio"):
# 音频消息
conversation.append({
"role": "user",
"content": [{"type": "audio", "audio_url": turn["audio"]}]
})
elif turn["role"] == "user":
# 文本消息
conversation.append({
"role": "user",
"content": [{"type": "text", "text": turn["content"]}]
})
else:
conversation.append({
"role": "assistant",
"content": turn["content"]
})
# 添加当前音频
if audio_array is not None:
conversation.append({
"role": "user",
"content": [{"type": "audio", "audio": audio_array}]
})
return conversation
@app.on_event("startup")
async def startup():
"""启动时预加载模型"""
logger.info("Preloading model...")
load_model()
logger.info("Server ready!")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"model": MODEL_NAME,
"conversations": len(conversations)
}
@app.post("/api/voice/inference", response_model=VoiceResponse)
async def inference(
audio: UploadFile = File(..., description="音频文件"),
conversation_id: Optional[str] = Form(None, description="对话ID不传则创建新对话"),
max_length: int = Form(256, description="最大生成长度")
):
"""
语音推理接口
- 接收音频文件
- 返回模型回复文本
- 支持多轮对话
"""
try:
# 加载模型
model, processor = load_model()
# 读取音频
audio_bytes = await audio.read()
audio_array, sr = process_audio(audio_bytes)
# 获取或创建对话
if conversation_id is None:
conversation_id = str(uuid.uuid4())
conversations[conversation_id] = []
history = conversations.get(conversation_id, [])
# 构建对话
conversation = []
for turn in history:
conversation.append(turn)
# 添加当前音频
conversation.append({
"role": "user",
"content": [{"type": "audio"}]
})
# 处理对话格式
text = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False
)
# 提取历史音频(如果有)
audios = []
for turn in history:
if turn.get("audio_array") is not None:
audios.append(turn["audio_array"])
audios.append(audio_array)
# 推理
inputs = processor(
text=text,
audios=audios if audios else None,
return_tensors="pt",
padding=True
)
inputs.input_ids = inputs.input_ids.to("cuda")
generate_ids = model.generate(**inputs, max_length=max_length)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
reply = processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
# 保存对话历史(不保存音频数组,太大)
history.append({
"role": "user",
"content": "[audio]",
"audio_array": None # 不保存,太占内存
})
history.append({
"role": "assistant",
"content": reply
})
# 限制历史长度
if len(history) > MAX_HISTORY_TURNS * 2:
history = history[-MAX_HISTORY_TURNS * 2:]
conversations[conversation_id] = history
logger.info(f"Conversation {conversation_id}: {len(history)//2} turns")
return VoiceResponse(
reply=reply,
conversation_id=conversation_id,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Inference error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/voice/text", response_model=VoiceResponse)
async def text_inference(
text: str = Form(..., description="文本消息"),
conversation_id: Optional[str] = Form(None, description="对话ID")
):
"""
文本推理接口(可选,用于测试或纯文本对话)
"""
try:
model, processor = load_model()
if conversation_id is None:
conversation_id = str(uuid.uuid4())
conversations[conversation_id] = []
history = conversations.get(conversation_id, [])
# 构建对话
conversation = []
for turn in history:
conversation.append(turn)
conversation.append({
"role": "user",
"content": [{"type": "text", "text": text}]
})
# 处理
prompt = processor.apply_chat_template(
conversation,
add_generation_prompt=True,
tokenize=False
)
inputs = processor(text=prompt, return_tensors="pt", padding=True)
inputs.input_ids = inputs.input_ids.to("cuda")
generate_ids = model.generate(**inputs, max_length=256)
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
reply = processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_tokens=False
)[0]
# 保存历史
history.append({"role": "user", "content": text})
history.append({"role": "assistant", "content": reply})
conversations[conversation_id] = history
return VoiceResponse(
reply=reply,
conversation_id=conversation_id,
timestamp=datetime.now().isoformat()
)
except Exception as e:
logger.error(f"Text inference error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/voice/conversation/{conversation_id}", response_model=ConversationHistory)
async def get_conversation(conversation_id: str):
"""获取对话历史"""
if conversation_id not in conversations:
raise HTTPException(status_code=404, detail="Conversation not found")
return ConversationHistory(
conversation_id=conversation_id,
history=conversations[conversation_id]
)
@app.delete("/api/voice/conversation/{conversation_id}")
async def delete_conversation(conversation_id: str):
"""删除对话"""
if conversation_id in conversations:
del conversations[conversation_id]
return {"status": "deleted"}
return {"status": "not_found"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19018)