feat: AI对话系统 v1.0.0 - 网页端和Matrix端实时同步
This commit is contained in:
459
main.py
Normal file
459
main.py
Normal file
@@ -0,0 +1,459 @@
|
||||
"""
|
||||
AI对话系统 - 主应用
|
||||
支持网页端和Matrix端实时同步对话
|
||||
"""
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Set
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
from models import init_db, get_db, User, Conversation, Message, SystemConfig
|
||||
from services import ai_service, ConversationService, matrix_bot
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建应用
|
||||
app = FastAPI(title="AI对话系统", version="1.0.0")
|
||||
|
||||
# 静态文件和模板
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
|
||||
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
|
||||
|
||||
# WebSocket连接管理
|
||||
class ConnectionManager:
|
||||
def __init__(self):
|
||||
# user_id -> Set[WebSocket]
|
||||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket, user_id: str):
|
||||
await websocket.accept()
|
||||
if user_id not in self.active_connections:
|
||||
self.active_connections[user_id] = set()
|
||||
self.active_connections[user_id].add(websocket)
|
||||
logger.info(f"WebSocket连接: {user_id}, 当前连接数: {len(self.active_connections[user_id])}")
|
||||
|
||||
def disconnect(self, websocket: WebSocket, user_id: str):
|
||||
if user_id in self.active_connections:
|
||||
self.active_connections[user_id].discard(websocket)
|
||||
if not self.active_connections[user_id]:
|
||||
del self.active_connections[user_id]
|
||||
logger.info(f"WebSocket断开: {user_id}")
|
||||
|
||||
async def send_to_user(self, user_id: str, message: dict):
|
||||
"""发送消息给用户的所有连接"""
|
||||
if user_id in self.active_connections:
|
||||
dead_connections = set()
|
||||
for connection in self.active_connections[user_id]:
|
||||
try:
|
||||
await connection.send_json(message)
|
||||
except:
|
||||
dead_connections.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
for conn in dead_connections:
|
||||
self.active_connections[user_id].discard(conn)
|
||||
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
# ==================== 页面路由 ====================
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def index(request: Request):
|
||||
"""主页 - 聊天界面"""
|
||||
return templates.TemplateResponse("index.html", {"request": request})
|
||||
|
||||
|
||||
@app.get("/admin", response_class=HTMLResponse)
|
||||
async def admin(request: Request):
|
||||
"""后台管理首页"""
|
||||
return templates.TemplateResponse("admin/index.html", {"request": request})
|
||||
|
||||
|
||||
# ==================== API路由 ====================
|
||||
|
||||
@app.get("/api/conversations")
|
||||
async def get_conversations(user_id: str = None, db: Session = Depends(get_db)):
|
||||
"""获取会话列表"""
|
||||
if not user_id:
|
||||
# 临时用户ID(实际应用中应该从session获取)
|
||||
user_id = "web_anonymous"
|
||||
|
||||
conv_service = ConversationService(db)
|
||||
user = conv_service.get_or_create_user(user_id, user_type='web')
|
||||
conversations = conv_service.get_user_conversations(user.id)
|
||||
|
||||
return {
|
||||
"conversations": [
|
||||
{
|
||||
"id": c.conversation_id,
|
||||
"title": c.title or "新对话",
|
||||
"created_at": c.created_at.isoformat(),
|
||||
"updated_at": c.updated_at.isoformat()
|
||||
}
|
||||
for c in conversations
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/conversations")
|
||||
async def create_conversation(user_id: str = None, db: Session = Depends(get_db)):
|
||||
"""创建新会话"""
|
||||
if not user_id:
|
||||
user_id = "web_anonymous"
|
||||
|
||||
conv_service = ConversationService(db)
|
||||
user = conv_service.get_or_create_user(user_id, user_type='web')
|
||||
conversation = conv_service.create_conversation(user.id)
|
||||
|
||||
return {
|
||||
"id": conversation.conversation_id,
|
||||
"title": conversation.title or "新对话",
|
||||
"created_at": conversation.created_at.isoformat()
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/conversations/{conversation_id}/messages")
|
||||
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
|
||||
"""获取会话消息"""
|
||||
conv_service = ConversationService(db)
|
||||
conversation = conv_service.get_conversation(conversation_id)
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
messages = conv_service.get_messages(conversation.id)
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
{
|
||||
"id": m.id,
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"source": m.source,
|
||||
"created_at": m.created_at.isoformat()
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.delete("/api/conversations/{conversation_id}")
|
||||
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
|
||||
"""删除会话"""
|
||||
conv_service = ConversationService(db)
|
||||
success = conv_service.delete_conversation(conversation_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# ==================== WebSocket路由 ====================
|
||||
|
||||
@app.websocket("/ws/{user_id}")
|
||||
async def websocket_endpoint(websocket: WebSocket, user_id: str, db: Session = Depends(get_db)):
|
||||
"""WebSocket连接 - 实时对话"""
|
||||
await manager.connect(websocket, user_id)
|
||||
conv_service = ConversationService(db)
|
||||
user = conv_service.get_or_create_user(user_id, user_type='web')
|
||||
|
||||
current_conversation_id = None
|
||||
|
||||
try:
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
action = data.get("action")
|
||||
|
||||
if action == "select_conversation":
|
||||
# 选择会话
|
||||
current_conversation_id = data.get("conversation_id")
|
||||
conversation = conv_service.get_conversation(current_conversation_id)
|
||||
|
||||
if conversation:
|
||||
# 发送历史消息
|
||||
messages = conv_service.get_messages(conversation.id)
|
||||
await websocket.send_json({
|
||||
"type": "history",
|
||||
"conversation_id": current_conversation_id,
|
||||
"messages": [
|
||||
{
|
||||
"role": m.role,
|
||||
"content": m.content,
|
||||
"source": m.source,
|
||||
"created_at": m.created_at.isoformat()
|
||||
}
|
||||
for m in messages
|
||||
]
|
||||
})
|
||||
|
||||
elif action == "chat":
|
||||
# 对话消息
|
||||
message = data.get("message", "")
|
||||
conversation_id = data.get("conversation_id")
|
||||
|
||||
if not message.strip():
|
||||
continue
|
||||
|
||||
# 获取或创建会话
|
||||
if conversation_id:
|
||||
conversation = conv_service.get_conversation(conversation_id)
|
||||
else:
|
||||
conversation = conv_service.create_conversation(user.id)
|
||||
conversation_id = conversation.conversation_id
|
||||
await websocket.send_json({
|
||||
"type": "conversation_created",
|
||||
"conversation_id": conversation_id
|
||||
})
|
||||
|
||||
# 保存用户消息
|
||||
user_msg = conv_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role='user',
|
||||
content=message,
|
||||
source='web'
|
||||
)
|
||||
|
||||
# 通知用户消息已收到
|
||||
await manager.send_to_user(user_id, {
|
||||
"type": "user_message",
|
||||
"conversation_id": conversation_id,
|
||||
"message": {
|
||||
"id": user_msg.id,
|
||||
"role": "user",
|
||||
"content": message,
|
||||
"source": "web",
|
||||
"created_at": user_msg.created_at.isoformat()
|
||||
}
|
||||
})
|
||||
|
||||
# 获取对话历史
|
||||
history = conv_service.get_conversation_history(conversation_id, limit=20)
|
||||
|
||||
# 调用AI
|
||||
try:
|
||||
ai_response = await ai_service.chat(history)
|
||||
|
||||
# 保存AI回复
|
||||
assistant_msg = conv_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role='assistant',
|
||||
content=ai_response,
|
||||
source='web'
|
||||
)
|
||||
|
||||
# 发送AI回复
|
||||
await manager.send_to_user(user_id, {
|
||||
"type": "assistant_message",
|
||||
"conversation_id": conversation_id,
|
||||
"message": {
|
||||
"id": assistant_msg.id,
|
||||
"role": "assistant",
|
||||
"content": ai_response,
|
||||
"created_at": assistant_msg.created_at.isoformat()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"AI调用失败: {e}")
|
||||
await websocket.send_json({
|
||||
"type": "error",
|
||||
"message": "AI服务暂时不可用,请稍后重试"
|
||||
})
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect(websocket, user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket错误: {e}")
|
||||
manager.disconnect(websocket, user_id)
|
||||
|
||||
|
||||
# ==================== 后台管理API ====================
|
||||
|
||||
@app.get("/api/admin/stats")
|
||||
async def get_stats(db: Session = Depends(get_db)):
|
||||
"""获取统计数据"""
|
||||
total_users = db.query(User).count()
|
||||
total_conversations = db.query(Conversation).count()
|
||||
total_messages = db.query(Message).count()
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"total_conversations": total_conversations,
|
||||
"total_messages": total_messages
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/admin/users")
|
||||
async def get_users(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
||||
"""获取用户列表"""
|
||||
users = db.query(User).order_by(User.created_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
return {
|
||||
"users": [
|
||||
{
|
||||
"id": u.id,
|
||||
"user_id": u.user_id,
|
||||
"display_name": u.display_name,
|
||||
"user_type": u.user_type,
|
||||
"created_at": u.created_at.isoformat(),
|
||||
"last_active_at": u.last_active_at.isoformat(),
|
||||
"is_active": u.is_active
|
||||
}
|
||||
for u in users
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/admin/conversations")
|
||||
async def admin_get_conversations(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
|
||||
"""获取所有会话"""
|
||||
conversations = db.query(Conversation).order_by(Conversation.updated_at.desc()).offset(skip).limit(limit).all()
|
||||
|
||||
return {
|
||||
"conversations": [
|
||||
{
|
||||
"id": c.id,
|
||||
"conversation_id": c.conversation_id,
|
||||
"user_id": c.user.user_id if c.user else None,
|
||||
"title": c.title,
|
||||
"message_count": len(c.messages),
|
||||
"created_at": c.created_at.isoformat(),
|
||||
"updated_at": c.updated_at.isoformat(),
|
||||
"is_active": c.is_active
|
||||
}
|
||||
for c in conversations
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/admin/config")
|
||||
async def get_config(db: Session = Depends(get_db)):
|
||||
"""获取系统配置"""
|
||||
configs = db.query(SystemConfig).all()
|
||||
|
||||
return {
|
||||
"configs": [
|
||||
{
|
||||
"key": c.key,
|
||||
"value": c.value,
|
||||
"description": c.description
|
||||
}
|
||||
for c in configs
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/admin/config")
|
||||
async def update_config(data: dict, db: Session = Depends(get_db)):
|
||||
"""更新系统配置"""
|
||||
key = data.get("key")
|
||||
value = data.get("value")
|
||||
description = data.get("description", "")
|
||||
|
||||
if not key:
|
||||
raise HTTPException(status_code=400, detail="key不能为空")
|
||||
|
||||
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
|
||||
|
||||
if config:
|
||||
config.value = value
|
||||
if description:
|
||||
config.description = description
|
||||
else:
|
||||
config = SystemConfig(key=key, value=value, description=description)
|
||||
db.add(config)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 如果更新了Matrix配置,重新连接
|
||||
if key.startswith('matrix_') and matrix_bot.is_running:
|
||||
await matrix_bot.disconnect()
|
||||
await matrix_bot.init_from_config()
|
||||
|
||||
return {"success": True, "config": {"key": key, "value": value}}
|
||||
|
||||
|
||||
# ==================== Matrix消息处理回调 ====================
|
||||
|
||||
async def handle_matrix_message(conversation_id: str, user_message: str, user_id: str, room_id: str):
|
||||
"""处理从Matrix收到的消息"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
conv_service = ConversationService(db)
|
||||
|
||||
# 获取会话历史
|
||||
history = conv_service.get_conversation_history(conversation_id, limit=20)
|
||||
|
||||
# 调用AI
|
||||
ai_response = await ai_service.chat(history)
|
||||
|
||||
# 保存AI回复
|
||||
conversation = conv_service.get_conversation(conversation_id)
|
||||
if conversation:
|
||||
conv_service.add_message(
|
||||
conversation_id=conversation.id,
|
||||
role='assistant',
|
||||
content=ai_response,
|
||||
source='matrix'
|
||||
)
|
||||
|
||||
# 发送到Matrix
|
||||
await matrix_bot.send_message(room_id, ai_response)
|
||||
|
||||
# 同时推送到网页端
|
||||
await manager.send_to_user(user_id, {
|
||||
"type": "assistant_message",
|
||||
"conversation_id": conversation_id,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ai_response,
|
||||
"source": "matrix",
|
||||
"created_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理Matrix消息失败: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
# ==================== 启动和关闭 ====================
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""应用启动时初始化"""
|
||||
init_db()
|
||||
logger.info("数据库初始化完成")
|
||||
|
||||
# 初始化Matrix Bot
|
||||
await matrix_bot.init_from_config()
|
||||
if matrix_bot.is_running:
|
||||
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
|
||||
logger.info("Matrix Bot已启动")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
"""应用关闭时清理"""
|
||||
await matrix_bot.disconnect()
|
||||
logger.info("应用已关闭")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=19020)
|
||||
Reference in New Issue
Block a user