334 lines
9.3 KiB
Python
334 lines
9.3 KiB
Python
"""
|
||
Qwen2-Audio 模型服务
|
||
语音转文字 + 多轮对话
|
||
"""
|
||
|
||
import os
|
||
import io
|
||
import uuid
|
||
import tempfile
|
||
import logging
|
||
from typing import Optional, List, Dict, Any
|
||
from datetime import datetime
|
||
|
||
import librosa
|
||
import soundfile as sf
|
||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
|
||
# 模型加载(延迟加载)
|
||
model = None
|
||
processor = None
|
||
|
||
# 对话历史存储(内存,可换成 Redis)
|
||
conversations: Dict[str, List[Dict]] = {}
|
||
|
||
# 配置
|
||
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2-Audio-7B-Instruct")
|
||
MAX_HISTORY_TURNS = int(os.getenv("MAX_HISTORY_TURNS", "10")) # 最多保留10轮对话
|
||
SAMPLE_RATE = 16000 # Qwen2-Audio 采样率
|
||
|
||
# 日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
app = FastAPI(
|
||
title="Qwen2-Audio Voice Service",
|
||
description="语音交互模型服务",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
class VoiceResponse(BaseModel):
|
||
"""语音响应"""
|
||
reply: str
|
||
conversation_id: str
|
||
timestamp: str
|
||
|
||
|
||
class ConversationHistory(BaseModel):
|
||
"""对话历史"""
|
||
conversation_id: str
|
||
history: List[Dict[str, Any]]
|
||
|
||
|
||
def load_model():
|
||
"""加载模型(延迟加载)"""
|
||
global model, processor
|
||
if model is None:
|
||
logger.info(f"Loading model: {MODEL_NAME}")
|
||
from modelscope import Qwen2AudioForConditionalGeneration, AutoProcessor
|
||
processor = AutoProcessor.from_pretrained(MODEL_NAME)
|
||
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
||
MODEL_NAME,
|
||
device_map="auto"
|
||
)
|
||
logger.info("Model loaded successfully")
|
||
return model, processor
|
||
|
||
|
||
def process_audio(audio_bytes: bytes) -> tuple:
|
||
"""
|
||
处理音频文件
|
||
返回: (audio_array, sample_rate)
|
||
"""
|
||
audio_io = io.BytesIO(audio_bytes)
|
||
|
||
# 使用 librosa 读取音频(支持更多格式:WAV, WebM, MP3, FLAC 等)
|
||
audio, sr = librosa.load(audio_io, sr=SAMPLE_RATE, mono=True)
|
||
|
||
return audio, SAMPLE_RATE
|
||
|
||
|
||
def build_conversation(
|
||
history: List[Dict],
|
||
audio_array: Optional[Any] = None
|
||
) -> List[Dict]:
|
||
"""
|
||
构建对话格式
|
||
"""
|
||
conversation = []
|
||
|
||
for turn in history:
|
||
if turn["role"] == "user" and turn.get("audio"):
|
||
# 音频消息
|
||
conversation.append({
|
||
"role": "user",
|
||
"content": [{"type": "audio", "audio_url": turn["audio"]}]
|
||
})
|
||
elif turn["role"] == "user":
|
||
# 文本消息
|
||
conversation.append({
|
||
"role": "user",
|
||
"content": [{"type": "text", "text": turn["content"]}]
|
||
})
|
||
else:
|
||
conversation.append({
|
||
"role": "assistant",
|
||
"content": turn["content"]
|
||
})
|
||
|
||
# 添加当前音频
|
||
if audio_array is not None:
|
||
conversation.append({
|
||
"role": "user",
|
||
"content": [{"type": "audio", "audio": audio_array}]
|
||
})
|
||
|
||
return conversation
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
"""启动时预加载模型"""
|
||
logger.info("Preloading model...")
|
||
load_model()
|
||
logger.info("Server ready!")
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""健康检查"""
|
||
return {
|
||
"status": "ok",
|
||
"model": MODEL_NAME,
|
||
"conversations": len(conversations)
|
||
}
|
||
|
||
|
||
@app.post("/api/voice/inference", response_model=VoiceResponse)
|
||
async def inference(
|
||
audio: UploadFile = File(..., description="音频文件"),
|
||
conversation_id: Optional[str] = Form(None, description="对话ID,不传则创建新对话"),
|
||
max_length: int = Form(256, description="最大生成长度")
|
||
):
|
||
"""
|
||
语音推理接口
|
||
|
||
- 接收音频文件
|
||
- 返回模型回复文本
|
||
- 支持多轮对话
|
||
"""
|
||
try:
|
||
# 加载模型
|
||
model, processor = load_model()
|
||
|
||
# 读取音频
|
||
audio_bytes = await audio.read()
|
||
audio_array, sr = process_audio(audio_bytes)
|
||
|
||
# 获取或创建对话
|
||
if conversation_id is None:
|
||
conversation_id = str(uuid.uuid4())
|
||
conversations[conversation_id] = []
|
||
|
||
history = conversations.get(conversation_id, [])
|
||
|
||
# 构建对话
|
||
conversation = []
|
||
for turn in history:
|
||
conversation.append(turn)
|
||
|
||
# 添加当前音频
|
||
conversation.append({
|
||
"role": "user",
|
||
"content": [{"type": "audio"}]
|
||
})
|
||
|
||
# 处理对话格式
|
||
text = processor.apply_chat_template(
|
||
conversation,
|
||
add_generation_prompt=True,
|
||
tokenize=False
|
||
)
|
||
|
||
# 提取历史音频(如果有)
|
||
audios = []
|
||
for turn in history:
|
||
if turn.get("audio_array") is not None:
|
||
audios.append(turn["audio_array"])
|
||
audios.append(audio_array)
|
||
|
||
# 推理
|
||
inputs = processor(
|
||
text=text,
|
||
audios=audios if audios else None,
|
||
return_tensors="pt",
|
||
padding=True
|
||
)
|
||
inputs.input_ids = inputs.input_ids.to("cuda")
|
||
|
||
generate_ids = model.generate(**inputs, max_length=max_length)
|
||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||
|
||
reply = processor.batch_decode(
|
||
generate_ids,
|
||
skip_special_tokens=True,
|
||
clean_up_tokenization_spaces=False
|
||
)[0]
|
||
|
||
# 保存对话历史(不保存音频数组,太大)
|
||
history.append({
|
||
"role": "user",
|
||
"content": "[audio]",
|
||
"audio_array": None # 不保存,太占内存
|
||
})
|
||
history.append({
|
||
"role": "assistant",
|
||
"content": reply
|
||
})
|
||
|
||
# 限制历史长度
|
||
if len(history) > MAX_HISTORY_TURNS * 2:
|
||
history = history[-MAX_HISTORY_TURNS * 2:]
|
||
|
||
conversations[conversation_id] = history
|
||
|
||
logger.info(f"Conversation {conversation_id}: {len(history)//2} turns")
|
||
|
||
return VoiceResponse(
|
||
reply=reply,
|
||
conversation_id=conversation_id,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Inference error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/api/voice/text", response_model=VoiceResponse)
|
||
async def text_inference(
|
||
text: str = Form(..., description="文本消息"),
|
||
conversation_id: Optional[str] = Form(None, description="对话ID")
|
||
):
|
||
"""
|
||
文本推理接口(可选,用于测试或纯文本对话)
|
||
"""
|
||
try:
|
||
model, processor = load_model()
|
||
|
||
if conversation_id is None:
|
||
conversation_id = str(uuid.uuid4())
|
||
conversations[conversation_id] = []
|
||
|
||
history = conversations.get(conversation_id, [])
|
||
|
||
# 构建对话
|
||
conversation = []
|
||
for turn in history:
|
||
conversation.append(turn)
|
||
conversation.append({
|
||
"role": "user",
|
||
"content": [{"type": "text", "text": text}]
|
||
})
|
||
|
||
# 处理
|
||
prompt = processor.apply_chat_template(
|
||
conversation,
|
||
add_generation_prompt=True,
|
||
tokenize=False
|
||
)
|
||
|
||
inputs = processor(text=prompt, return_tensors="pt", padding=True)
|
||
inputs.input_ids = inputs.input_ids.to("cuda")
|
||
|
||
generate_ids = model.generate(**inputs, max_length=256)
|
||
generate_ids = generate_ids[:, inputs.input_ids.size(1):]
|
||
|
||
reply = processor.batch_decode(
|
||
generate_ids,
|
||
skip_special_tokens=True,
|
||
clean_up_tokenization_tokens=False
|
||
)[0]
|
||
|
||
# 保存历史
|
||
history.append({"role": "user", "content": text})
|
||
history.append({"role": "assistant", "content": reply})
|
||
conversations[conversation_id] = history
|
||
|
||
return VoiceResponse(
|
||
reply=reply,
|
||
conversation_id=conversation_id,
|
||
timestamp=datetime.now().isoformat()
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Text inference error: {e}", exc_info=True)
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.get("/api/voice/conversation/{conversation_id}", response_model=ConversationHistory)
|
||
async def get_conversation(conversation_id: str):
|
||
"""获取对话历史"""
|
||
if conversation_id not in conversations:
|
||
raise HTTPException(status_code=404, detail="Conversation not found")
|
||
|
||
return ConversationHistory(
|
||
conversation_id=conversation_id,
|
||
history=conversations[conversation_id]
|
||
)
|
||
|
||
|
||
@app.delete("/api/voice/conversation/{conversation_id}")
|
||
async def delete_conversation(conversation_id: str):
|
||
"""删除对话"""
|
||
if conversation_id in conversations:
|
||
del conversations[conversation_id]
|
||
return {"status": "deleted"}
|
||
return {"status": "not_found"}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=19018) |