539 lines
19 KiB
Python
539 lines
19 KiB
Python
"""
|
||
AI对话系统 - 主应用
|
||
支持网页端和Matrix端实时同步对话
|
||
"""
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
|
||
from fastapi.responses import HTMLResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.templating import Jinja2Templates
|
||
from sqlalchemy.orm import Session
|
||
from typing import Dict, Set
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from datetime import datetime
|
||
import os
|
||
|
||
from models import init_db, get_db, SessionLocal, User, Conversation, Message, SystemConfig
|
||
from services import ai_service, ConversationService, matrix_bot
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建应用
|
||
app = FastAPI(title="AI对话系统", version="1.0.0")
|
||
|
||
# 静态文件和模板
|
||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
|
||
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
|
||
|
||
# WebSocket连接管理
|
||
class ConnectionManager:
|
||
def __init__(self):
|
||
# user_id -> Set[WebSocket]
|
||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||
|
||
async def connect(self, websocket: WebSocket, user_id: str):
|
||
await websocket.accept()
|
||
if user_id not in self.active_connections:
|
||
self.active_connections[user_id] = set()
|
||
self.active_connections[user_id].add(websocket)
|
||
logger.info(f"WebSocket连接: {user_id}, 当前连接数: {len(self.active_connections[user_id])}")
|
||
|
||
def disconnect(self, websocket: WebSocket, user_id: str):
|
||
if user_id in self.active_connections:
|
||
self.active_connections[user_id].discard(websocket)
|
||
if not self.active_connections[user_id]:
|
||
del self.active_connections[user_id]
|
||
logger.info(f"WebSocket断开: {user_id}")
|
||
|
||
async def send_to_user(self, user_id: str, message: dict):
|
||
"""发送消息给用户的所有连接"""
|
||
if user_id in self.active_connections:
|
||
dead_connections = set()
|
||
for connection in self.active_connections[user_id]:
|
||
try:
|
||
await connection.send_json(message)
|
||
except:
|
||
dead_connections.add(connection)
|
||
|
||
# 清理断开的连接
|
||
for conn in dead_connections:
|
||
self.active_connections[user_id].discard(conn)
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
|
||
# ==================== 页面路由 ====================
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def index(request: Request):
|
||
"""主页 - 聊天界面"""
|
||
return templates.TemplateResponse("index.html", {"request": request})
|
||
|
||
|
||
@app.get("/admin", response_class=HTMLResponse)
|
||
async def admin(request: Request):
|
||
"""后台管理首页"""
|
||
return templates.TemplateResponse("admin/index.html", {"request": request})
|
||
|
||
|
||
# ==================== API路由 ====================
|
||
|
||
# 固定主用户ID(与Matrix AI用户关联)
|
||
MAIN_USER_ID = "main_user"
|
||
|
||
@app.get("/api/conversations")
|
||
async def get_conversations(db: Session = Depends(get_db)):
|
||
"""获取会话列表(使用固定主用户)"""
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
conversations = conv_service.get_user_conversations(user.id)
|
||
|
||
return {
|
||
"conversations": [
|
||
{
|
||
"id": c.conversation_id,
|
||
"title": c.title or "新对话",
|
||
"created_at": c.created_at.isoformat(),
|
||
"updated_at": c.updated_at.isoformat()
|
||
}
|
||
for c in conversations
|
||
]
|
||
}
|
||
|
||
@app.get("/api/conversations/latest")
|
||
async def get_latest_conversation(db: Session = Depends(get_db)):
|
||
"""获取最新会话(用于Matrix同步)"""
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
conversations = conv_service.get_user_conversations(user.id)
|
||
|
||
if conversations:
|
||
latest = conversations[0] # 已按更新时间倒序
|
||
return {
|
||
"conversation": {
|
||
"id": latest.conversation_id,
|
||
"title": latest.title or "新对话",
|
||
"updated_at": latest.updated_at.isoformat()
|
||
}
|
||
}
|
||
return {"conversation": None}
|
||
|
||
|
||
@app.post("/api/conversations")
|
||
async def create_conversation(db: Session = Depends(get_db)):
|
||
"""创建新会话(使用固定主用户)"""
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
conversation = conv_service.create_conversation(user.id)
|
||
|
||
return {
|
||
"id": conversation.conversation_id,
|
||
"title": conversation.title or "新对话",
|
||
"created_at": conversation.created_at.isoformat()
|
||
}
|
||
|
||
|
||
@app.get("/api/conversations/{conversation_id}/messages")
|
||
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
|
||
"""获取会话消息"""
|
||
conv_service = ConversationService(db)
|
||
conversation = conv_service.get_conversation(conversation_id)
|
||
|
||
if not conversation:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
messages = conv_service.get_messages(conversation.id)
|
||
|
||
return {
|
||
"messages": [
|
||
{
|
||
"id": m.id,
|
||
"role": m.role,
|
||
"content": m.content,
|
||
"source": m.source,
|
||
"created_at": m.created_at.isoformat()
|
||
}
|
||
for m in messages
|
||
]
|
||
}
|
||
|
||
|
||
@app.delete("/api/conversations/{conversation_id}")
|
||
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
|
||
"""删除会话"""
|
||
conv_service = ConversationService(db)
|
||
success = conv_service.delete_conversation(conversation_id)
|
||
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
return {"success": True}
|
||
|
||
|
||
# ==================== WebSocket路由 ====================
|
||
|
||
@app.websocket("/ws/{user_id}")
|
||
async def websocket_endpoint(websocket: WebSocket, user_id: str, db: Session = Depends(get_db)):
|
||
"""WebSocket连接 - 实时对话(所有连接使用主用户)"""
|
||
# 统一使用主用户ID
|
||
actual_user_id = MAIN_USER_ID
|
||
await manager.connect(websocket, actual_user_id)
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
|
||
current_conversation_id = None
|
||
|
||
try:
|
||
while True:
|
||
data = await websocket.receive_json()
|
||
|
||
action = data.get("action")
|
||
|
||
if action == "select_conversation":
|
||
# 选择会话
|
||
current_conversation_id = data.get("conversation_id")
|
||
conversation = conv_service.get_conversation(current_conversation_id)
|
||
|
||
if conversation:
|
||
# 发送历史消息
|
||
messages = conv_service.get_messages(conversation.id)
|
||
await websocket.send_json({
|
||
"type": "history",
|
||
"conversation_id": current_conversation_id,
|
||
"messages": [
|
||
{
|
||
"role": m.role,
|
||
"content": m.content,
|
||
"source": m.source,
|
||
"created_at": m.created_at.isoformat()
|
||
}
|
||
for m in messages
|
||
]
|
||
})
|
||
|
||
elif action == "chat":
|
||
# 对话消息
|
||
message = data.get("message", "")
|
||
conversation_id = data.get("conversation_id")
|
||
|
||
if not message.strip():
|
||
continue
|
||
|
||
# 获取或创建会话
|
||
if conversation_id:
|
||
conversation = conv_service.get_conversation(conversation_id)
|
||
else:
|
||
conversation = conv_service.create_conversation(user.id)
|
||
conversation_id = conversation.conversation_id
|
||
await websocket.send_json({
|
||
"type": "conversation_created",
|
||
"conversation_id": conversation_id
|
||
})
|
||
|
||
# 保存用户消息
|
||
user_msg = conv_service.add_message(
|
||
conversation_id=conversation.id,
|
||
role='user',
|
||
content=message,
|
||
source='web'
|
||
)
|
||
|
||
# 广播用户消息(同步到所有客户端)
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "user_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": user_msg.id,
|
||
"role": "user",
|
||
"content": message,
|
||
"source": "web",
|
||
"created_at": user_msg.created_at.isoformat()
|
||
}
|
||
})
|
||
|
||
# 同步到Matrix(如果有房间)
|
||
if matrix_bot.is_running and matrix_bot.last_room_id:
|
||
await matrix_bot.send_message(matrix_bot.last_room_id, message)
|
||
|
||
# 获取对话历史
|
||
history = conv_service.get_conversation_history(conversation_id, limit=20)
|
||
|
||
# 调用AI
|
||
try:
|
||
ai_response = await ai_service.chat(history)
|
||
|
||
# 保存AI回复
|
||
assistant_msg = conv_service.add_message(
|
||
conversation_id=conversation.id,
|
||
role='assistant',
|
||
content=ai_response,
|
||
source='web'
|
||
)
|
||
|
||
# 广播AI回复(同步到所有客户端)
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "assistant_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": assistant_msg.id,
|
||
"role": "assistant",
|
||
"content": ai_response,
|
||
"source": "web",
|
||
"created_at": assistant_msg.created_at.isoformat()
|
||
}
|
||
})
|
||
|
||
# 同步AI回复到Matrix
|
||
if matrix_bot.is_running and matrix_bot.last_room_id:
|
||
await matrix_bot.send_message(matrix_bot.last_room_id, ai_response)
|
||
|
||
except Exception as e:
|
||
logger.error(f"AI调用失败: {e}")
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"message": "AI服务暂时不可用,请稍后重试"
|
||
})
|
||
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket, user_id)
|
||
except Exception as e:
|
||
logger.error(f"WebSocket错误: {e}")
|
||
manager.disconnect(websocket, user_id)
|
||
|
||
|
||
# ==================== 后台管理API ====================
|
||
|
||
@app.get("/api/admin/stats")
|
||
async def get_stats(db: Session = Depends(get_db)):
|
||
"""获取统计数据"""
|
||
total_users = db.query(User).count()
|
||
total_conversations = db.query(Conversation).count()
|
||
total_messages = db.query(Message).count()
|
||
|
||
return {
|
||
"total_users": total_users,
|
||
"total_conversations": total_conversations,
|
||
"total_messages": total_messages
|
||
}
|
||
|
||
|
||
@app.get("/api/admin/users")
|
||
async def get_users(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
||
"""获取用户列表"""
|
||
users = db.query(User).order_by(User.created_at.desc()).offset(skip).limit(limit).all()
|
||
|
||
return {
|
||
"users": [
|
||
{
|
||
"id": u.id,
|
||
"user_id": u.user_id,
|
||
"display_name": u.display_name,
|
||
"user_type": u.user_type,
|
||
"created_at": u.created_at.isoformat(),
|
||
"last_active_at": u.last_active_at.isoformat(),
|
||
"is_active": u.is_active
|
||
}
|
||
for u in users
|
||
]
|
||
}
|
||
|
||
|
||
@app.get("/api/admin/conversations")
|
||
async def admin_get_conversations(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
||
"""获取所有会话"""
|
||
conversations = db.query(Conversation).order_by(Conversation.updated_at.desc()).offset(skip).limit(limit).all()
|
||
|
||
return {
|
||
"conversations": [
|
||
{
|
||
"id": c.id,
|
||
"conversation_id": c.conversation_id,
|
||
"user_id": c.user.user_id if c.user else None,
|
||
"title": c.title,
|
||
"message_count": len(c.messages),
|
||
"created_at": c.created_at.isoformat(),
|
||
"updated_at": c.updated_at.isoformat(),
|
||
"is_active": c.is_active
|
||
}
|
||
for c in conversations
|
||
]
|
||
}
|
||
|
||
|
||
@app.get("/api/admin/config")
|
||
async def get_config(db: Session = Depends(get_db)):
|
||
"""获取系统配置"""
|
||
configs = db.query(SystemConfig).all()
|
||
|
||
return {
|
||
"configs": [
|
||
{
|
||
"key": c.key,
|
||
"value": c.value,
|
||
"description": c.description
|
||
}
|
||
for c in configs
|
||
]
|
||
}
|
||
|
||
|
||
@app.post("/api/admin/config")
|
||
async def update_config(data: dict, db: Session = Depends(get_db)):
|
||
"""更新系统配置"""
|
||
key = data.get("key")
|
||
value = data.get("value")
|
||
description = data.get("description", "")
|
||
|
||
if not key:
|
||
raise HTTPException(status_code=400, detail="key不能为空")
|
||
|
||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||
|
||
if config:
|
||
config.value = value
|
||
if description:
|
||
config.description = description
|
||
else:
|
||
config = SystemConfig(key=key, value=value, description=description)
|
||
db.add(config)
|
||
|
||
db.commit()
|
||
|
||
# 根据配置类型执行相应操作
|
||
if key.startswith('ai_'):
|
||
# 更新AI服务配置
|
||
configs = {c.key: c.value for c in db.query(SystemConfig).all()}
|
||
api_base = configs.get('ai_api_base', 'http://192.168.2.17:19007/v1')
|
||
api_key = configs.get('ai_api_key', 'xxxx')
|
||
model = configs.get('ai_model', 'auto')
|
||
ai_service.update_config(api_base, api_key, model)
|
||
|
||
if key.startswith('matrix_') and matrix_bot.is_running:
|
||
await matrix_bot.disconnect()
|
||
await matrix_bot.init_from_config()
|
||
|
||
return {"success": True, "config": {"key": key, "value": value}}
|
||
|
||
|
||
# ==================== Matrix消息处理回调 ====================
|
||
|
||
async def handle_matrix_message(action: str, conversation_id: str = None, user_message: str = None, room_id: str = None, message_id: int = None, sender: str = None):
|
||
"""处理从Matrix收到的消息"""
|
||
|
||
if action == "new_conversation":
|
||
# 创建新会话 - 通知WebSocket客户端
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "new_conversation",
|
||
"conversation_id": conversation_id
|
||
})
|
||
return
|
||
|
||
if action == "user_message":
|
||
# 用户消息 - 同步到网页端
|
||
db = SessionLocal()
|
||
try:
|
||
conv_service = ConversationService(db)
|
||
conversation = conv_service.get_conversation(conversation_id)
|
||
if conversation:
|
||
# 发送用户消息通知到网页端
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "user_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": message_id,
|
||
"role": "user",
|
||
"content": user_message,
|
||
"source": "matrix",
|
||
"sender": sender,
|
||
"created_at": datetime.now().isoformat()
|
||
}
|
||
})
|
||
finally:
|
||
db.close()
|
||
return
|
||
|
||
if action == "chat":
|
||
db = SessionLocal()
|
||
try:
|
||
conv_service = ConversationService(db)
|
||
|
||
# 获取会话历史
|
||
history = conv_service.get_conversation_history(conversation_id, limit=20)
|
||
|
||
# 调用AI
|
||
ai_response = await ai_service.chat(history)
|
||
|
||
# 保存AI回复
|
||
conversation = conv_service.get_conversation(conversation_id)
|
||
if conversation:
|
||
assistant_msg = conv_service.add_message(
|
||
conversation_id=conversation.id,
|
||
role='assistant',
|
||
content=ai_response,
|
||
source='matrix'
|
||
)
|
||
|
||
# 发送到Matrix
|
||
await matrix_bot.send_message(room_id, ai_response)
|
||
|
||
# 同步到网页端WebSocket
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "assistant_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": assistant_msg.id,
|
||
"role": "assistant",
|
||
"content": ai_response,
|
||
"source": "matrix",
|
||
"created_at": assistant_msg.created_at.isoformat()
|
||
}
|
||
})
|
||
except Exception as e:
|
||
logger.error(f"处理Matrix消息失败: {e}")
|
||
await matrix_bot.send_message(room_id, f"处理消息时出错: {str(e)}")
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
# ==================== 启动和关闭 ====================
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
"""应用启动时初始化"""
|
||
init_db()
|
||
logger.info("数据库初始化完成")
|
||
|
||
# 从数据库加载AI配置
|
||
db = SessionLocal()
|
||
try:
|
||
configs = {c.key: c.value for c in db.query(SystemConfig).all()}
|
||
api_base = configs.get('ai_api_base', 'http://192.168.2.17:19007/v1')
|
||
api_key = configs.get('ai_api_key', 'xxxx')
|
||
model = configs.get('ai_model', 'auto')
|
||
ai_service.update_config(api_base, api_key, model)
|
||
logger.info(f"AI配置已加载: {api_base}, model={model}")
|
||
finally:
|
||
db.close()
|
||
|
||
# 初始化Matrix Bot
|
||
await matrix_bot.init_from_config()
|
||
if matrix_bot.is_running:
|
||
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
|
||
logger.info("Matrix Bot已启动")
|
||
|
||
|
||
@app.on_event("shutdown")
|
||
async def shutdown():
|
||
"""应用关闭时清理"""
|
||
await matrix_bot.disconnect()
|
||
logger.info("应用已关闭")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=19020) |