Files
ai-chat-system/main.py
hubian a07de626ad fix: 修复AI重复调用问题,统一消息处理流程
- _process_message 只保存用户消息并通知
- 新增 generate_ai_response action 统一处理AI回复
- 避免重复调用AI和重复发送消息
2026-04-12 00:34:30 +08:00

597 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI对话系统 - 主应用
支持网页端和Matrix端实时同步对话
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Dict, Set
import asyncio
import json
import logging
from datetime import datetime
import os
from models import init_db, get_db, SessionLocal, User, Conversation, Message, SystemConfig
from services import ai_service, ConversationService, matrix_bot
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建应用
app = FastAPI(title="AI对话系统", version="1.0.0")
# 静态文件和模板
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
# WebSocket连接管理
class ConnectionManager:
def __init__(self):
# user_id -> Set[WebSocket]
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.info(f"WebSocket连接: {user_id}, 当前连接数: {len(self.active_connections[user_id])}")
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
logger.info(f"WebSocket断开: {user_id}")
async def send_to_user(self, user_id: str, message: dict):
"""发送消息给用户的所有连接"""
if user_id in self.active_connections:
dead_connections = set()
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
except:
dead_connections.add(connection)
# 清理断开的连接
for conn in dead_connections:
self.active_connections[user_id].discard(conn)
manager = ConnectionManager()
# ==================== 页面路由 ====================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页 - 聊天界面"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/admin", response_class=HTMLResponse)
async def admin(request: Request):
"""后台管理首页"""
return templates.TemplateResponse("admin/index.html", {"request": request})
# ==================== API路由 ====================
# 固定主用户ID与Matrix AI用户关联
MAIN_USER_ID = "main_user"
@app.get("/api/conversations")
async def get_conversations(db: Session = Depends(get_db)):
"""获取会话列表(使用固定主用户)"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
return {
"conversations": [
{
"id": c.conversation_id,
"title": c.title or "新对话",
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat()
}
for c in conversations
]
}
@app.get("/api/conversations/latest")
async def get_latest_conversation(db: Session = Depends(get_db)):
"""获取最新会话用于Matrix同步"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
if conversations:
latest = conversations[0] # 已按更新时间倒序
return {
"conversation": {
"id": latest.conversation_id,
"title": latest.title or "新对话",
"updated_at": latest.updated_at.isoformat()
}
}
return {"conversation": None}
@app.post("/api/conversations")
async def create_conversation(db: Session = Depends(get_db)):
"""创建新会话(使用固定主用户)"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversation = conv_service.create_conversation(user.id)
return {
"id": conversation.conversation_id,
"title": conversation.title or "新对话",
"created_at": conversation.created_at.isoformat()
}
@app.get("/api/conversations/{conversation_id}/messages")
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
"""获取会话消息"""
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
messages = conv_service.get_messages(conversation.id)
return {
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
}
@app.delete("/api/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
"""删除会话"""
conv_service = ConversationService(db)
success = conv_service.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"success": True}
# ==================== WebSocket路由 ====================
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str, db: Session = Depends(get_db)):
"""WebSocket连接 - 实时对话(所有连接使用主用户)"""
# 统一使用主用户ID
actual_user_id = MAIN_USER_ID
await manager.connect(websocket, actual_user_id)
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
current_conversation_id = None
try:
while True:
data = await websocket.receive_json()
action = data.get("action")
if action == "select_conversation":
# 选择会话
current_conversation_id = data.get("conversation_id")
conversation = conv_service.get_conversation(current_conversation_id)
if conversation:
# 发送历史消息
messages = conv_service.get_messages(conversation.id)
await websocket.send_json({
"type": "history",
"conversation_id": current_conversation_id,
"messages": [
{
"role": m.role,
"content": m.content,
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
})
elif action == "chat":
# 对话消息
message = data.get("message", "")
conversation_id = data.get("conversation_id")
if not message.strip():
continue
# 获取或创建会话
if conversation_id:
conversation = conv_service.get_conversation(conversation_id)
else:
conversation = conv_service.create_conversation(user.id)
conversation_id = conversation.conversation_id
await websocket.send_json({
"type": "conversation_created",
"conversation_id": conversation_id
})
# 保存用户消息
user_msg = conv_service.add_message(
conversation_id=conversation.id,
role='user',
content=message,
source='web'
)
# 广播用户消息(同步到所有客户端)
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": user_msg.id,
"role": "user",
"content": message,
"source": "web",
"created_at": user_msg.created_at.isoformat()
}
})
# 同步到Matrix如果有房间
if matrix_bot.is_running and matrix_bot.last_room_id:
await matrix_bot.send_message(matrix_bot.last_room_id, message)
# 获取对话历史
history = conv_service.get_conversation_history(conversation_id, limit=20)
# 调用AI
try:
ai_response = await ai_service.chat(history)
# 保存AI回复
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=ai_response,
source='web'
)
# 广播AI回复同步到所有客户端
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": ai_response,
"source": "web",
"created_at": assistant_msg.created_at.isoformat()
}
})
# 同步AI回复到Matrix
if matrix_bot.is_running and matrix_bot.last_room_id:
await matrix_bot.send_message(matrix_bot.last_room_id, ai_response)
except Exception as e:
logger.error(f"AI调用失败: {e}")
await websocket.send_json({
"type": "error",
"message": "AI服务暂时不可用请稍后重试"
})
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
manager.disconnect(websocket, user_id)
# ==================== 后台管理API ====================
@app.get("/api/admin/stats")
async def get_stats(db: Session = Depends(get_db)):
"""获取统计数据"""
total_users = db.query(User).count()
total_conversations = db.query(Conversation).count()
total_messages = db.query(Message).count()
return {
"total_users": total_users,
"total_conversations": total_conversations,
"total_messages": total_messages
}
@app.get("/api/admin/users")
async def get_users(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
"""获取用户列表"""
users = db.query(User).order_by(User.created_at.desc()).offset(skip).limit(limit).all()
return {
"users": [
{
"id": u.id,
"user_id": u.user_id,
"display_name": u.display_name,
"user_type": u.user_type,
"created_at": u.created_at.isoformat(),
"last_active_at": u.last_active_at.isoformat(),
"is_active": u.is_active
}
for u in users
]
}
@app.get("/api/admin/conversations")
async def admin_get_conversations(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
"""获取所有会话"""
conversations = db.query(Conversation).order_by(Conversation.updated_at.desc()).offset(skip).limit(limit).all()
return {
"conversations": [
{
"id": c.id,
"conversation_id": c.conversation_id,
"user_id": c.user.user_id if c.user else None,
"title": c.title,
"message_count": len(c.messages),
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat(),
"is_active": c.is_active
}
for c in conversations
]
}
@app.get("/api/admin/config")
async def get_config(db: Session = Depends(get_db)):
"""获取系统配置"""
configs = db.query(SystemConfig).all()
return {
"configs": [
{
"key": c.key,
"value": c.value,
"description": c.description
}
for c in configs
]
}
@app.post("/api/admin/config")
async def update_config(data: dict, db: Session = Depends(get_db)):
"""更新系统配置"""
key = data.get("key")
value = data.get("value")
description = data.get("description", "")
if not key:
raise HTTPException(status_code=400, detail="key不能为空")
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
config.value = value
if description:
config.description = description
else:
config = SystemConfig(key=key, value=value, description=description)
db.add(config)
db.commit()
# 根据配置类型执行相应操作
if key.startswith('ai_'):
# 更新AI服务配置
configs = {c.key: c.value for c in db.query(SystemConfig).all()}
api_base = configs.get('ai_api_base', 'http://192.168.2.17:19007/v1')
api_key = configs.get('ai_api_key', 'xxxx')
model = configs.get('ai_model', 'auto')
ai_service.update_config(api_base, api_key, model)
if key.startswith('matrix_') and matrix_bot.is_running:
await matrix_bot.disconnect()
await matrix_bot.init_from_config()
return {"success": True, "config": {"key": key, "value": value}}
# ==================== Matrix消息处理回调 ====================
async def handle_matrix_message(action: str, conversation_id: str = None, user_message: str = None, room_id: str = None, message_id: int = None, sender: str = None):
"""处理从Matrix收到的消息"""
if action == "new_conversation":
# 创建新会话 - 通知WebSocket客户端
await manager.send_to_user(MAIN_USER_ID, {
"type": "new_conversation",
"conversation_id": conversation_id
})
return
if action == "user_message":
# 用户消息 - 同步到网页端
db = SessionLocal()
try:
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if conversation:
# 发送用户消息通知到网页端
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "user",
"content": user_message,
"source": "matrix",
"sender": sender,
"created_at": datetime.now().isoformat()
}
})
finally:
db.close()
return
if action == "generate_ai_response":
# 生成AI回复并同步
db = SessionLocal()
try:
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if conversation:
# 发送"正在输入"状态
try:
await matrix_bot.client.room_typing(room_id, typing_state=True)
except:
pass
# 获取会话历史
history = conv_service.get_conversation_history(conversation_id, limit=20)
# 调用AI
ai_response = await ai_service.chat(history)
# 保存AI回复
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=ai_response,
source='matrix'
)
# 发送到Matrix
await matrix_bot.send_message(room_id, ai_response)
# 关闭"正在输入"状态
try:
await matrix_bot.client.room_typing(room_id, typing_state=False)
except:
pass
# 同步到网页端WebSocket
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": ai_response,
"source": "matrix",
"created_at": assistant_msg.created_at.isoformat()
}
})
logger.info(f"Matrix AI回复已发送: {ai_response[:50]}")
except Exception as e:
logger.error(f"生成AI回复失败: {e}")
await matrix_bot.send_message(room_id, f"处理消息时出错: {str(e)}")
finally:
db.close()
return
if action == "chat":
db = SessionLocal()
try:
conv_service = ConversationService(db)
# 获取会话历史
history = conv_service.get_conversation_history(conversation_id, limit=20)
# 调用AI
ai_response = await ai_service.chat(history)
# 保存AI回复
conversation = conv_service.get_conversation(conversation_id)
if conversation:
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=ai_response,
source='matrix'
)
# 发送到Matrix
await matrix_bot.send_message(room_id, ai_response)
# 同步到网页端WebSocket
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": ai_response,
"source": "matrix",
"created_at": assistant_msg.created_at.isoformat()
}
})
except Exception as e:
logger.error(f"处理Matrix消息失败: {e}")
await matrix_bot.send_message(room_id, f"处理消息时出错: {str(e)}")
finally:
db.close()
# ==================== 启动和关闭 ====================
@app.on_event("startup")
async def startup():
"""应用启动时初始化"""
init_db()
logger.info("数据库初始化完成")
# 从数据库加载AI配置
db = SessionLocal()
try:
configs = {c.key: c.value for c in db.query(SystemConfig).all()}
api_base = configs.get('ai_api_base', 'http://192.168.2.17:19007/v1')
api_key = configs.get('ai_api_key', 'xxxx')
model = configs.get('ai_model', 'auto')
ai_service.update_config(api_base, api_key, model)
logger.info(f"AI配置已加载: {api_base}, model={model}")
finally:
db.close()
# 初始化Matrix Bot
await matrix_bot.init_from_config()
if matrix_bot.is_running:
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
logger.info("Matrix Bot已启动")
@app.on_event("shutdown")
async def shutdown():
"""应用关闭时清理"""
await matrix_bot.disconnect()
logger.info("应用已关闭")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19020)