Files
ai-chat-system/main.py

796 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI对话系统 - 主应用
支持网页端和Matrix端实时同步对话
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Dict, Set
import asyncio
import json
import logging
from datetime import datetime
import os
from models import init_db, get_db, SessionLocal, User, Conversation, Message, SystemConfig
from services import ai_service, ConversationService, matrix_bot
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建应用
app = FastAPI(title="AI对话系统", version="1.0.0")
# 静态文件和模板
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
# WebSocket连接管理
class ConnectionManager:
def __init__(self):
# user_id -> Set[WebSocket]
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.info(f"WebSocket连接: {user_id}, 当前连接数: {len(self.active_connections[user_id])}")
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
logger.info(f"WebSocket断开: {user_id}")
async def send_to_user(self, user_id: str, message: dict):
"""发送消息给用户的所有连接"""
if user_id in self.active_connections:
dead_connections = set()
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
except:
dead_connections.add(connection)
# 清理断开的连接
for conn in dead_connections:
self.active_connections[user_id].discard(conn)
manager = ConnectionManager()
# ==================== 页面路由 ====================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页 - 聊天界面"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/admin", response_class=HTMLResponse)
async def admin(request: Request):
"""后台管理首页"""
return templates.TemplateResponse("admin/index.html", {"request": request})
# ==================== API路由 ====================
# 固定主用户ID与Matrix AI用户关联
MAIN_USER_ID = "main_user"
@app.get("/api/conversations")
async def get_conversations(db: Session = Depends(get_db)):
"""获取会话列表(使用固定主用户)"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
return {
"conversations": [
{
"id": c.conversation_id,
"title": c.title or "新对话",
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat()
}
for c in conversations
]
}
@app.get("/api/conversations/latest")
async def get_latest_conversation(db: Session = Depends(get_db)):
"""获取最新会话用于Matrix同步"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
if conversations:
latest = conversations[0] # 已按更新时间倒序
return {
"conversation": {
"id": latest.conversation_id,
"title": latest.title or "新对话",
"updated_at": latest.updated_at.isoformat()
}
}
return {"conversation": None}
@app.post("/api/conversations")
async def create_conversation(db: Session = Depends(get_db)):
"""创建新会话(使用固定主用户)"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversation = conv_service.create_conversation(user.id)
return {
"id": conversation.conversation_id,
"title": conversation.title or "新对话",
"created_at": conversation.created_at.isoformat()
}
@app.get("/api/conversations/{conversation_id}/messages")
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
"""获取会话消息"""
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
messages = conv_service.get_messages(conversation.id)
return {
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
}
@app.delete("/api/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
"""删除会话"""
conv_service = ConversationService(db)
success = conv_service.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"success": True}
# ==================== WebSocket路由 ====================
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str, db: Session = Depends(get_db)):
"""WebSocket连接 - 实时对话(所有连接使用主用户)"""
# 统一使用主用户ID
actual_user_id = MAIN_USER_ID
await manager.connect(websocket, actual_user_id)
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
current_conversation_id = None
try:
while True:
data = await websocket.receive_json()
action = data.get("action")
if action == "select_conversation":
# 选择会话
current_conversation_id = data.get("conversation_id")
conversation = conv_service.get_conversation(current_conversation_id)
if conversation:
# 发送历史消息
messages = conv_service.get_messages(conversation.id)
await websocket.send_json({
"type": "history",
"conversation_id": current_conversation_id,
"messages": [
{
"role": m.role,
"content": m.content,
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
})
elif action == "chat":
# 对话消息
message = data.get("message", "")
conversation_id = data.get("conversation_id")
if not message.strip():
continue
# 获取或创建会话
if conversation_id:
conversation = conv_service.get_conversation(conversation_id)
else:
conversation = conv_service.create_conversation(user.id)
conversation_id = conversation.conversation_id
await websocket.send_json({
"type": "conversation_created",
"conversation_id": conversation_id
})
# 保存用户消息
user_msg = conv_service.add_message(
conversation_id=conversation.id,
role='user',
content=message,
source='web'
)
# 广播用户消息(同步到所有客户端)
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": user_msg.id,
"role": "user",
"content": message,
"source": "web",
"created_at": user_msg.created_at.isoformat()
}
})
# 同步到Matrix如果有房间
if matrix_bot.is_running and matrix_bot.last_room_id:
await matrix_bot.send_message(matrix_bot.last_room_id, message)
# 获取对话历史
history = conv_service.get_conversation_history(conversation_id, limit=20)
# 调用AI
try:
ai_response = await ai_service.chat(history)
# 保存AI回复
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=ai_response,
source='web'
)
# 广播AI回复同步到所有客户端
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": ai_response,
"source": "web",
"created_at": assistant_msg.created_at.isoformat()
}
})
# 同步AI回复到Matrix
if matrix_bot.is_running and matrix_bot.last_room_id:
await matrix_bot.send_message(matrix_bot.last_room_id, ai_response)
except Exception as e:
logger.error(f"AI调用失败: {e}")
await websocket.send_json({
"type": "error",
"message": "AI服务暂时不可用请稍后重试"
})
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
manager.disconnect(websocket, user_id)
# ==================== 后台管理API ====================
@app.get("/api/admin/stats")
async def get_stats(db: Session = Depends(get_db)):
"""获取统计数据"""
total_users = db.query(User).count()
total_conversations = db.query(Conversation).count()
total_messages = db.query(Message).count()
return {
"total_users": total_users,
"total_conversations": total_conversations,
"total_messages": total_messages
}
@app.get("/api/admin/users")
async def get_users(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
"""获取用户列表"""
users = db.query(User).order_by(User.created_at.desc()).offset(skip).limit(limit).all()
return {
"users": [
{
"id": u.id,
"user_id": u.user_id,
"display_name": u.display_name,
"user_type": u.user_type,
"created_at": u.created_at.isoformat(),
"last_active_at": u.last_active_at.isoformat(),
"is_active": u.is_active
}
for u in users
]
}
@app.get("/api/admin/conversations")
async def admin_get_conversations(skip: int = 0, limit: int = 50, db: Session = Depends(get_db)):
"""获取所有会话"""
conversations = db.query(Conversation).order_by(Conversation.updated_at.desc()).offset(skip).limit(limit).all()
return {
"conversations": [
{
"id": c.id,
"conversation_id": c.conversation_id,
"user_id": c.user.user_id if c.user else None,
"title": c.title,
"message_count": len(c.messages),
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat(),
"is_active": c.is_active
}
for c in conversations
]
}
@app.get("/api/admin/config")
async def get_config(db: Session = Depends(get_db)):
"""获取系统配置"""
configs = db.query(SystemConfig).all()
return {
"configs": [
{
"key": c.key,
"value": c.value,
"description": c.description
}
for c in configs
]
}
@app.get("/api/admin/ai-config")
async def get_ai_config(db: Session = Depends(get_db)):
"""获取AI配置"""
configs = {c.key: c.value for c in db.query(SystemConfig).filter(SystemConfig.key.startswith('ai_')).all()}
return {
"api_base": configs.get('ai_api_base', 'http://192.168.2.17:19007/v1'),
"api_key": configs.get('ai_api_key', 'xxxx'),
"model": configs.get('ai_model', 'auto'),
"use_mock": ai_service.use_mock
}
@app.post("/api/admin/ai-config")
async def update_ai_config(data: dict, db: Session = Depends(get_db)):
"""更新AI配置"""
api_base = data.get("api_base")
api_key = data.get("api_key")
model = data.get("model")
if api_base:
config = db.query(SystemConfig).filter(SystemConfig.key == 'ai_api_base').first()
if config:
config.value = api_base
else:
config = SystemConfig(key='ai_api_base', value=api_base, description='AI API地址')
db.add(config)
if api_key:
config = db.query(SystemConfig).filter(SystemConfig.key == 'ai_api_key').first()
if config:
config.value = api_key
else:
config = SystemConfig(key='ai_api_key', value=api_key, description='AI API密钥')
db.add(config)
if model:
config = db.query(SystemConfig).filter(SystemConfig.key == 'ai_model').first()
if config:
config.value = model
else:
config = SystemConfig(key='ai_model', value=model, description='AI模型名称')
db.add(config)
db.commit()
# 更新AI服务配置
configs = {c.key: c.value for c in db.query(SystemConfig).filter(SystemConfig.key.startswith('ai_')).all()}
ai_service.update_config(
configs.get('ai_api_base', 'http://192.168.2.17:19007/v1'),
configs.get('ai_api_key', 'xxxx'),
configs.get('ai_model', 'auto')
)
return {"success": True, "message": "AI配置已更新"}
@app.get("/api/admin/models")
async def get_available_models(db: Session = Depends(get_db)):
"""获取可用模型列表"""
import httpx
# 从数据库读取最新配置
configs = {c.key: c.value for c in db.query(SystemConfig).filter(SystemConfig.key.startswith('ai_')).all()}
api_base = configs.get('ai_api_base', '')
api_key = configs.get('ai_api_key', 'xxxx')
if not api_base:
# 返回默认模型列表
return {
"models": [
{"id": "auto", "name": "auto (自动选择)", "owned_by": "system"},
{"id": "qwen3.5-4b", "name": "qwen3.5-4b", "owned_by": "local"},
{"id": "dsv32", "name": "dsv32", "owned_by": "deepseek"},
{"id": "glm-4", "name": "glm-4", "owned_by": "zhipu"},
{"id": "gpt-4o", "name": "gpt-4o", "owned_by": "openai"},
{"id": "claude-3-opus", "name": "claude-3-opus", "owned_by": "anthropic"}
],
"success": False,
"message": "请先配置API地址"
}
try:
# 从当前配置的API地址获取模型列表
url = f"{api_base}/models"
headers = {"Authorization": f"Bearer {api_key}"}
logger.info(f"获取模型列表: url={url}")
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
models = []
for m in data.get('data', []):
model_id = m.get('id', '')
if model_id:
models.append({
"id": model_id,
"name": m.get('name', model_id),
"owned_by": m.get('owned_by', 'unknown')
})
return {"models": models, "success": True}
except Exception as e:
logger.error(f"获取模型列表失败: {e}")
# 返回默认模型列表
return {
"models": [
{"id": "auto", "name": "auto (自动选择)", "owned_by": "system"},
{"id": "qwen3.5-4b", "name": "qwen3.5-4b", "owned_by": "local"},
{"id": "dsv32", "name": "dsv32", "owned_by": "deepseek"},
{"id": "glm-4", "name": "glm-4", "owned_by": "zhipu"},
{"id": "gpt-4o", "name": "gpt-4o", "owned_by": "openai"},
{"id": "claude-3-opus", "name": "claude-3-opus", "owned_by": "anthropic"}
],
"success": False,
"message": "无法从API获取模型列表显示默认列表"
}
@app.post("/api/admin/test-ai")
async def test_ai_connection(db: Session = Depends(get_db)):
"""测试AI连接"""
import httpx
# 从数据库读取最新配置,如果没有则使用默认值
configs = {c.key: c.value for c in db.query(SystemConfig).filter(SystemConfig.key.startswith('ai_')).all()}
# 使用数据库值或默认值
api_base = configs.get('ai_api_base') or 'http://192.168.2.17:19007/v1'
api_key = configs.get('ai_api_key') or 'xxxx'
model = configs.get('ai_model') or 'auto'
# 判断是否使用默认值
using_defaults = not configs.get('ai_api_base')
try:
url = f"{api_base}/chat/completions"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": [{"role": "user", "content": "测试连接"}],
"max_tokens": 50
}
logger.info(f"测试AI连接: url={url}, model={model}, using_defaults={using_defaults}")
async with httpx.AsyncClient(timeout=15.0) as client:
response = await client.post(url, headers=headers, json=payload)
if response.status_code == 200:
data = response.json()
content = data['choices'][0]['message']['content']
result = {
"success": True,
"message": f"连接成功!模型响应: {content[:100]}",
"model": model,
"api_base": api_base
}
if using_defaults:
result["message"] += "\n(使用默认配置,点击「保存配置」可持久化)"
return result
else:
error_text = response.text[:200] if response.text else ""
return {
"success": False,
"message": f"连接失败: HTTP {response.status_code} - {error_text}",
"model": model,
"api_base": api_base
}
except httpx.ConnectError as e:
return {
"success": False,
"message": f"无法连接到API地址: {api_base}",
"model": model,
"api_base": api_base
}
except httpx.TimeoutException:
return {
"success": False,
"message": f"连接超时15秒",
"model": model,
"api_base": api_base
}
except Exception as e:
return {
"success": False,
"message": f"连接失败: {str(e)}",
"model": model,
"api_base": api_base
}
@app.post("/api/admin/config")
async def update_config(data: dict, db: Session = Depends(get_db)):
"""更新系统配置"""
key = data.get("key")
value = data.get("value")
description = data.get("description", "")
if not key:
raise HTTPException(status_code=400, detail="key不能为空")
config = db.query(SystemConfig).filter(SystemConfig.key == key).first()
if config:
config.value = value
if description:
config.description = description
else:
config = SystemConfig(key=key, value=value, description=description)
db.add(config)
db.commit()
# 根据配置类型执行相应操作
if key.startswith('ai_'):
# 更新AI服务配置
configs = {c.key: c.value for c in db.query(SystemConfig).all()}
api_base = configs.get('ai_api_base', 'http://192.168.2.17:19007/v1')
api_key = configs.get('ai_api_key', 'xxxx')
model = configs.get('ai_model', 'auto')
ai_service.update_config(api_base, api_key, model)
if key.startswith('matrix_') and matrix_bot.is_running:
await matrix_bot.disconnect()
await matrix_bot.init_from_config()
return {"success": True, "config": {"key": key, "value": value}}
# ==================== Matrix消息处理回调 ====================
async def handle_matrix_message(action: str, conversation_id: str = None, user_message: str = None, room_id: str = None, message_id: int = None, sender: str = None):
"""处理从Matrix收到的消息"""
if action == "new_conversation":
# 创建新会话 - 通知WebSocket客户端
await manager.send_to_user(MAIN_USER_ID, {
"type": "new_conversation",
"conversation_id": conversation_id
})
return
if action == "user_message":
# 用户消息 - 同步到网页端
db = SessionLocal()
try:
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if conversation:
# 发送用户消息通知到网页端
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "user",
"content": user_message,
"source": "matrix",
"sender": sender,
"created_at": datetime.now().isoformat()
}
})
finally:
db.close()
return
if action == "generate_ai_response":
# 生成AI回复并同步
db = SessionLocal()
try:
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if conversation:
# 发送"正在输入"状态
try:
await matrix_bot.client.room_typing(room_id, typing_state=True)
except:
pass
# 获取会话历史
history = conv_service.get_conversation_history(conversation_id, limit=20)
# 调用AI
ai_response = await ai_service.chat(history)
# 保存AI回复
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=ai_response,
source='matrix'
)
# 发送到Matrix
await matrix_bot.send_message(room_id, ai_response)
# 关闭"正在输入"状态
try:
await matrix_bot.client.room_typing(room_id, typing_state=False)
except:
pass
# 同步到网页端WebSocket
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": ai_response,
"source": "matrix",
"created_at": assistant_msg.created_at.isoformat()
}
})
logger.info(f"Matrix AI回复已发送: {ai_response[:50]}")
except Exception as e:
logger.error(f"生成AI回复失败: {e}")
await matrix_bot.send_message(room_id, f"处理消息时出错: {str(e)}")
finally:
db.close()
return
if action == "chat":
db = SessionLocal()
try:
conv_service = ConversationService(db)
# 获取会话历史
history = conv_service.get_conversation_history(conversation_id, limit=20)
# 调用AI
ai_response = await ai_service.chat(history)
# 保存AI回复
conversation = conv_service.get_conversation(conversation_id)
if conversation:
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=ai_response,
source='matrix'
)
# 发送到Matrix
await matrix_bot.send_message(room_id, ai_response)
# 同步到网页端WebSocket
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": ai_response,
"source": "matrix",
"created_at": assistant_msg.created_at.isoformat()
}
})
except Exception as e:
logger.error(f"处理Matrix消息失败: {e}")
await matrix_bot.send_message(room_id, f"处理消息时出错: {str(e)}")
finally:
db.close()
# ==================== 启动和关闭 ====================
@app.on_event("startup")
async def startup():
"""应用启动时初始化"""
init_db()
logger.info("数据库初始化完成")
# 从数据库加载AI配置
db = SessionLocal()
try:
configs = {c.key: c.value for c in db.query(SystemConfig).all()}
api_base = configs.get('ai_api_base', 'http://192.168.2.17:19007/v1')
api_key = configs.get('ai_api_key', 'xxxx')
model = configs.get('ai_model', 'auto')
ai_service.update_config(api_base, api_key, model)
logger.info(f"AI配置已加载: {api_base}, model={model}")
finally:
db.close()
# 初始化Matrix Bot
await matrix_bot.init_from_config()
if matrix_bot.is_running:
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
logger.info("Matrix Bot已启动")
@app.on_event("shutdown")
async def shutdown():
"""应用关闭时清理"""
await matrix_bot.disconnect()
logger.info("应用已关闭")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19020)