Files
ai-chat-system/main_v2.py

819 lines
28 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI对话系统 v2.0.0 - 主应用
支持大模型池、Agent管理、渠道独立绑定、思考功能开关
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Dict, Set, Optional
import asyncio
import json
import logging
from datetime import datetime
import os
# 使用新的数据模型
from models_v2 import (
init_db, get_db, SessionLocal,
User, Conversation, Message, SystemConfig,
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping,
init_default_data
)
from services.llm_service import llm_service
from services.agent_service import AgentService, LLMProviderService, ChannelService
from services.conversation_service import ConversationService
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建应用
app = FastAPI(title="AI对话系统 v2.0", version="2.0.0")
# 静态文件和模板
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
# WebSocket连接管理
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.info(f"WebSocket连接: {user_id}")
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: str, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
except:
pass
manager = ConnectionManager()
# 固定主用户ID
MAIN_USER_ID = "main_user"
# ==================== 页面路由 ====================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页 - 聊天界面"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/admin", response_class=HTMLResponse)
async def admin(request: Request):
"""后台管理 - v2.0版本"""
return templates.TemplateResponse("admin_v2/index.html", {"request": request})
# ==================== v2 API路由 ====================
# ===== 大模型池 API =====
@app.get("/api/v2/providers")
async def get_providers(db: Session = Depends(get_db)):
"""获取所有大模型池"""
service = LLMProviderService(db)
providers = service.get_all_providers()
return {
"providers": [
{
"id": p.id,
"name": p.name,
"api_base": p.api_base,
"api_key": p.api_key,
"models": p.models,
"default_model": p.default_model,
"supports_thinking": p.supports_thinking,
"thinking_model": p.thinking_model,
"max_tokens": p.max_tokens,
"temperature": p.temperature,
"is_active": p.is_active,
"priority": p.priority,
"description": p.description
}
for p in providers
]
}
@app.post("/api/v2/providers")
async def create_provider(data: dict, db: Session = Depends(get_db)):
"""创建大模型池"""
service = LLMProviderService(db)
# 检查名称是否已存在
if service.get_provider_by_name(data.get('name')):
return {"success": False, "message": "名称已存在"}
provider = service.create_provider(data)
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.put("/api/v2/providers/{provider_id}")
async def update_provider(provider_id: int, data: dict, db: Session = Depends(get_db)):
"""更新大模型池"""
service = LLMProviderService(db)
provider = service.update_provider(provider_id, data)
if not provider:
return {"success": False, "message": "Provider不存在"}
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.delete("/api/v2/providers/{provider_id}")
async def delete_provider(provider_id: int, db: Session = Depends(get_db)):
"""删除大模型池"""
service = LLMProviderService(db)
success = service.delete_provider(provider_id)
if not success:
return {"success": False, "message": "无法删除可能有Agent在使用"}
return {"success": True}
@app.post("/api/v2/providers/models")
async def fetch_models(data: dict):
"""从API获取模型列表"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
models = await llm_service.get_available_models(api_base, api_key)
return {"models": models}
@app.post("/api/v2/providers/test")
async def test_provider(data: dict):
"""测试大模型连接"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
model = data.get('model', 'auto')
result = await llm_service.test_connection(api_base, api_key, model)
return result
# ===== Agent API =====
@app.get("/api/v2/agents")
async def get_agents(db: Session = Depends(get_db)):
"""获取所有Agent"""
service = AgentService(db)
agents = service.get_all_agents()
# 获取Provider名称
providers = {p.id: p.name for p in db.query(LLMProvider).all()}
return {
"agents": [
{
"id": a.id,
"name": a.name,
"display_name": a.display_name,
"llm_provider_id": a.llm_provider_id,
"llm_provider_name": providers.get(a.llm_provider_id),
"model_override": a.model_override,
"system_prompt": a.system_prompt,
"enable_thinking": a.enable_thinking,
"thinking_prompt": a.thinking_prompt,
"thinking_prefix": a.thinking_prefix,
"thinking_suffix": a.thinking_suffix,
"max_history": a.max_history,
"temperature_override": a.temperature_override,
"is_default": a.is_default,
"is_active": a.is_active,
"description": a.description
}
for a in agents
]
}
@app.post("/api/v2/agents")
async def create_agent(data: dict, db: Session = Depends(get_db)):
"""创建Agent"""
service = AgentService(db)
# 检查名称是否已存在
if service.get_agent_by_name(data.get('name')):
return {"success": False, "message": "Agent名称已存在"}
# 检查Provider是否存在
provider_service = LLMProviderService(db)
if not provider_service.get_provider(data.get('llm_provider_id')):
return {"success": False, "message": "选择的LLM Provider不存在"}
agent = service.create_agent(data)
# 如果设为默认,更新其他
if data.get('is_default'):
service.set_default_agent(agent.id)
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.put("/api/v2/agents/{agent_id}")
async def update_agent(agent_id: int, data: dict, db: Session = Depends(get_db)):
"""更新Agent"""
service = AgentService(db)
# 如果设为默认,先更新
if data.get('is_default'):
service.set_default_agent(agent_id)
agent = service.update_agent(agent_id, data)
if not agent:
return {"success": False, "message": "Agent不存在"}
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.delete("/api/v2/agents/{agent_id}")
async def delete_agent(agent_id: int, db: Session = Depends(get_db)):
"""删除Agent"""
service = AgentService(db)
success = service.delete_agent(agent_id)
if not success:
return {"success": False, "message": "无法删除默认Agent"}
return {"success": True}
@app.post("/api/v2/agents/{agent_id}/default")
async def set_agent_default(agent_id: int, db: Session = Depends(get_db)):
"""设置默认Agent"""
service = AgentService(db)
success = service.set_default_agent(agent_id)
return {"success": success}
@app.get("/api/v2/agents/{agent_id}/config")
async def get_agent_config(agent_id: int, db: Session = Depends(get_db)):
"""获取Agent完整配置含LLM Provider"""
service = AgentService(db)
config = service.get_agent_config(agent_id)
return {"config": config}
# ===== 渠道 API =====
@app.get("/api/v2/channels")
async def get_channels(db: Session = Depends(get_db)):
"""获取所有渠道及绑定关系"""
service = ChannelService(db)
channels = service.get_all_channels()
# 获取所有绑定关系
mappings = db.query(ChannelAgentMapping).order_by(ChannelAgentMapping.priority).all()
agents = {a.id: a for a in db.query(Agent).all()}
return {
"channels": [
{
"id": c.id,
"channel_type": c.channel_type,
"name": c.name,
"config": c.config,
"is_active": c.is_active,
"is_primary": c.is_primary,
"agent_mappings": [
{
"id": m.id,
"agent": {
"id": agents[m.agent_id].id if m.agent_id in agents else None,
"name": agents[m.agent_id].name if m.agent_id in agents else None,
"display_name": agents[m.agent_id].display_name if m.agent_id in agents else None,
"is_active": agents[m.agent_id].is_active if m.agent_id in agents else None
},
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings if m.channel_id == c.id
]
}
for c in channels
],
"mappings": [
{
"id": m.id,
"channel_id": m.channel_id,
"channel_type": channels[0].channel_type if channels else None, # 需要从Channel获取
"channel_name": next((c.name for c in channels if c.id == m.channel_id), ""),
"agent_id": m.agent_id,
"agent_name": agents[m.agent_id].display_name if m.agent_id in agents else "",
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings
]
}
@app.post("/api/v2/channels")
async def create_channel(data: dict, db: Session = Depends(get_db)):
"""创建渠道"""
service = ChannelService(db)
channel = service.create_channel(data)
return {"success": True, "channel": {"id": channel.id, "name": channel.name}}
@app.put("/api/v2/channels/{channel_id}")
async def update_channel(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道"""
service = ChannelService(db)
channel = service.update_channel(channel_id, data)
if not channel:
return {"success": False, "message": "渠道不存在"}
return {"success": True}
@app.put("/api/v2/channels/{channel_id}/config")
async def update_channel_config(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道配置如Matrix配置"""
channel = db.query(Channel).filter(Channel.id == channel_id).first()
if not channel:
return {"success": False, "message": "渠道不存在"}
channel.config = data.get('config', {})
db.commit()
# 如果是Matrix渠道重新初始化Matrix Bot
if channel.channel_type == 'matrix':
# TODO: 重启Matrix Bot
pass
return {"success": True}
@app.post("/api/v2/channels/bind")
async def bind_agent_to_channel(data: dict, db: Session = Depends(get_db)):
"""绑定Agent到渠道"""
service = ChannelService(db)
channel_id = data.get('channel_id')
agent_id = data.get('agent_id')
# 检查渠道和Agent是否存在
if not service.get_channel(channel_id):
return {"success": False, "message": "渠道不存在"}
agent_service = AgentService(db)
if not agent_service.get_agent(agent_id):
return {"success": False, "message": "Agent不存在"}
mapping = service.bind_agent(
channel_id=channel_id,
agent_id=agent_id,
priority=data.get('priority', 0),
mode=data.get('mode', 'single'),
conditions=data.get('conditions')
)
return {"success": True, "mapping": {"id": mapping.id}}
@app.delete("/api/v2/channels/unbind/{mapping_id}")
async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)):
"""解绑Agent"""
service = ChannelService(db)
success = service.unbind_agent(mapping_id)
return {"success": success}
# ==================== 对话 API保留原有 ====================
@app.get("/api/conversations")
async def get_conversations(db: Session = Depends(get_db)):
"""获取会话列表"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
return {
"conversations": [
{
"id": c.conversation_id,
"title": c.title or "新对话",
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat()
}
for c in conversations
]
}
@app.get("/api/conversations/latest")
async def get_latest_conversation(db: Session = Depends(get_db)):
"""获取最新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
if conversations:
latest = conversations[0]
return {
"conversation": {
"id": latest.conversation_id,
"title": latest.title or "新对话",
"updated_at": latest.updated_at.isoformat()
}
}
return {"conversation": None}
@app.post("/api/conversations")
async def create_conversation(db: Session = Depends(get_db)):
"""创建新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversation = conv_service.create_conversation(user.id)
return {
"id": conversation.conversation_id,
"title": conversation.title or "新对话",
"created_at": conversation.created_at.isoformat()
}
@app.get("/api/conversations/{conversation_id}/messages")
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
"""获取会话消息"""
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
messages = conv_service.get_messages(conversation.id)
return {
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content, # v2新增
"source": m.source,
"agent_id": m.agent_id, # v2新增
"model_used": m.model_used, # v2新增
"created_at": m.created_at.isoformat()
}
for m in messages
]
}
@app.delete("/api/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
"""删除会话"""
conv_service = ConversationService(db)
success = conv_service.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"success": True}
# ==================== WebSocket路由 ====================
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
"""WebSocket连接 - 实时对话"""
actual_user_id = MAIN_USER_ID
await manager.connect(websocket, actual_user_id)
# 初始化时获取默认Agent ID
db = SessionLocal()
try:
agent_service = AgentService(db)
default_agent = agent_service.get_default_agent()
default_agent_id = default_agent.id if default_agent else None
finally:
db.close()
current_conversation_id = None
current_agent_id = default_agent_id
try:
while True:
data = await websocket.receive_json()
action = data.get("action")
# 每次消息处理时创建新的数据库会话,处理完后关闭
try:
db = SessionLocal()
conv_service = ConversationService(db)
agent_service = AgentService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
if action == "select_conversation":
current_conversation_id = data.get("conversation_id")
conversation = conv_service.get_conversation(current_conversation_id)
if conversation:
messages = conv_service.get_messages(conversation.id)
await websocket.send_json({
"type": "history",
"conversation_id": current_conversation_id,
"messages": [
{
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content,
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
})
elif action == "switch_agent":
# 切换Agent
new_agent_id = data.get("agent_id")
agent = agent_service.get_agent(new_agent_id)
if agent and agent.is_active:
current_agent_id = new_agent_id
await websocket.send_json({
"type": "agent_switched",
"agent_id": current_agent_id,
"agent_name": agent.display_name or agent.name
})
elif action == "chat":
message = data.get("message", "")
conversation_id = data.get("conversation_id")
enable_thinking = data.get("enable_thinking", True)
agent_id_override = data.get("agent_id")
if agent_id_override:
agent = agent_service.get_agent(agent_id_override)
if agent and agent.is_active:
current_agent_id = agent_id_override
if not message.strip():
continue
# 获取或创建会话
if conversation_id:
conversation = conv_service.get_conversation(conversation_id)
else:
conversation = conv_service.create_conversation(user.id)
conversation_id = conversation.conversation_id
await websocket.send_json({
"type": "conversation_created",
"conversation_id": conversation_id
})
# 保存用户消息
user_msg = conv_service.add_message(
conversation_id=conversation.id,
role='user',
content=message,
source='web'
)
# 广播用户消息
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": user_msg.id,
"role": "user",
"content": message,
"source": "web",
"created_at": user_msg.created_at.isoformat()
}
})
# 获取Agent配置
agent_config = agent_service.get_agent_config(current_agent_id)
if not agent_config or not agent_config.get('provider'):
await websocket.send_json({
"type": "error",
"message": "Agent配置不完整"
})
continue
# 获取对话历史
history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20))
# 调用LLM
try:
if agent_config['agent'].get('enable_thinking') and enable_thinking:
await websocket.send_json({
"type": "thinking_start",
"conversation_id": conversation_id
})
response, thinking_content = await llm_service.chat(
messages=history,
provider_config=agent_config['provider'],
agent_config=agent_config['agent'],
enable_thinking=enable_thinking
)
if thinking_content:
await websocket.send_json({
"type": "thinking_content",
"conversation_id": conversation_id,
"content": thinking_content
})
await websocket.send_json({
"type": "thinking_end",
"conversation_id": conversation_id
})
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=response,
source='web',
thinking_content=thinking_content,
agent_id=current_agent_id,
model_used=agent_config['provider'].get('default_model')
)
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": response,
"thinking_content": thinking_content,
"source": "web",
"agent_id": current_agent_id,
"agent_name": agent_config['agent'].get('display_name'),
"created_at": assistant_msg.created_at.isoformat()
}
})
except Exception as e:
logger.error(f"LLM调用失败: {e}")
await websocket.send_json({
"type": "error",
"message": f"AI服务暂时不可用: {str(e)}"
})
finally:
db.close()
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
manager.disconnect(websocket, user_id)
# ==================== 后台管理API ====================
@app.get("/api/admin/stats")
async def get_stats(db: Session = Depends(get_db)):
"""获取统计数据"""
total_users = db.query(User).count()
total_conversations = db.query(Conversation).count()
total_messages = db.query(Message).count()
return {
"total_users": total_users,
"total_conversations": total_conversations,
"total_messages": total_messages
}
# ==================== Matrix消息处理回调 ====================
async def handle_matrix_message(
action: str,
conversation_id: str = None,
user_message: str = None,
room_id: str = None,
message_id: int = None,
sender: str = None,
content: str = None,
thinking_content: str = None,
agent_id: int = None,
agent_name: str = None
):
"""处理从Matrix收到的消息"""
if action == "new_conversation":
await manager.send_to_user(MAIN_USER_ID, {
"type": "new_conversation",
"conversation_id": conversation_id
})
return
if action == "user_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "user",
"content": user_message,
"source": "matrix",
"sender": sender,
"created_at": datetime.now().isoformat()
}
})
return
if action == "assistant_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "assistant",
"content": content,
"thinking_content": thinking_content,
"source": "matrix",
"agent_id": agent_id,
"agent_name": agent_name,
"created_at": datetime.now().isoformat()
}
})
return
# ==================== 启动和关闭 ====================
# 导入Matrix Bot
try:
from services.matrix_service_v2 import matrix_bot
MATRIX_V2 = True
except ImportError:
from services.matrix_service import matrix_bot
MATRIX_V2 = False
@app.on_event("startup")
async def startup():
"""应用启动"""
init_db()
logger.info("数据库初始化完成")
# 初始化默认数据
init_default_data()
# 初始化Matrix Bot
try:
success = await matrix_bot.init_from_config()
if success and matrix_bot.is_running:
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
logger.info("Matrix Bot v2 已启动")
else:
logger.info("Matrix Bot 未配置或未启用")
except Exception as e:
logger.warning(f"Matrix Bot初始化失败: {e}")
@app.on_event("shutdown")
async def shutdown():
"""应用关闭"""
try:
await matrix_bot.disconnect()
except:
pass
logger.info("应用已关闭")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19020)