Files
ai-chat-system/main_v2.py

846 lines
30 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI对话系统 v2.0.0 - 主应用
支持大模型池、Agent管理、渠道独立绑定、思考功能开关
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Dict, Set, Optional
import asyncio
import json
import logging
from datetime import datetime
import os
# 使用新的数据模型
from models_v2 import (
init_db, get_db, SessionLocal,
User, Conversation, Message, SystemConfig,
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping,
init_default_data
)
from services.llm_service import llm_service
from services.agent_service import AgentService, LLMProviderService, ChannelService
from services.conversation_service import ConversationService
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建应用
app = FastAPI(title="AI对话系统 v2.0", version="2.0.0")
# 静态文件和模板
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
# WebSocket连接管理
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.info(f"WebSocket连接: {user_id}")
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: str, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
logger.info(f"发送消息到用户 {user_id}: {message.get('type', 'unknown')}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def ping(self, user_id: str):
"""发送心跳ping"""
await self.send_to_user(user_id, {"type": "ping"})
manager = ConnectionManager()
# 固定主用户ID
MAIN_USER_ID = "main_user"
# ==================== 页面路由 ====================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页 - 聊天界面"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/admin", response_class=HTMLResponse)
async def admin(request: Request):
"""后台管理 - v2.0版本"""
return templates.TemplateResponse("admin_v2/index.html", {"request": request})
# ==================== v2 API路由 ====================
# ===== 大模型池 API =====
@app.get("/api/v2/providers")
async def get_providers(db: Session = Depends(get_db)):
"""获取所有大模型池"""
service = LLMProviderService(db)
providers = service.get_all_providers()
return {
"providers": [
{
"id": p.id,
"name": p.name,
"api_base": p.api_base,
"api_key": p.api_key,
"models": p.models,
"default_model": p.default_model,
"supports_thinking": p.supports_thinking,
"thinking_model": p.thinking_model,
"max_tokens": p.max_tokens,
"temperature": p.temperature,
"is_active": p.is_active,
"priority": p.priority,
"description": p.description
}
for p in providers
]
}
@app.post("/api/v2/providers")
async def create_provider(data: dict, db: Session = Depends(get_db)):
"""创建大模型池"""
service = LLMProviderService(db)
# 检查名称是否已存在
if service.get_provider_by_name(data.get('name')):
return {"success": False, "message": "名称已存在"}
provider = service.create_provider(data)
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.put("/api/v2/providers/{provider_id}")
async def update_provider(provider_id: int, data: dict, db: Session = Depends(get_db)):
"""更新大模型池"""
service = LLMProviderService(db)
provider = service.update_provider(provider_id, data)
if not provider:
return {"success": False, "message": "Provider不存在"}
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.delete("/api/v2/providers/{provider_id}")
async def delete_provider(provider_id: int, db: Session = Depends(get_db)):
"""删除大模型池"""
service = LLMProviderService(db)
success = service.delete_provider(provider_id)
if not success:
return {"success": False, "message": "无法删除可能有Agent在使用"}
return {"success": True}
@app.post("/api/v2/providers/models")
async def fetch_models(data: dict):
"""从API获取模型列表"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
models = await llm_service.get_available_models(api_base, api_key)
return {"models": models}
@app.post("/api/v2/providers/test")
async def test_provider(data: dict):
"""测试大模型连接"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
model = data.get('model', 'auto')
result = await llm_service.test_connection(api_base, api_key, model)
return result
# ===== Agent API =====
@app.get("/api/v2/agents")
async def get_agents(db: Session = Depends(get_db)):
"""获取所有Agent"""
service = AgentService(db)
agents = service.get_all_agents()
# 获取Provider名称
providers = {p.id: p.name for p in db.query(LLMProvider).all()}
return {
"agents": [
{
"id": a.id,
"name": a.name,
"display_name": a.display_name,
"llm_provider_id": a.llm_provider_id,
"llm_provider_name": providers.get(a.llm_provider_id),
"model_override": a.model_override,
"system_prompt": a.system_prompt,
"enable_thinking": a.enable_thinking,
"thinking_prompt": a.thinking_prompt,
"thinking_prefix": a.thinking_prefix,
"thinking_suffix": a.thinking_suffix,
"max_history": a.max_history,
"temperature_override": a.temperature_override,
"is_default": a.is_default,
"is_active": a.is_active,
"description": a.description
}
for a in agents
]
}
@app.post("/api/v2/agents")
async def create_agent(data: dict, db: Session = Depends(get_db)):
"""创建Agent"""
service = AgentService(db)
# 检查名称是否已存在
if service.get_agent_by_name(data.get('name')):
return {"success": False, "message": "Agent名称已存在"}
# 检查Provider是否存在
provider_service = LLMProviderService(db)
if not provider_service.get_provider(data.get('llm_provider_id')):
return {"success": False, "message": "选择的LLM Provider不存在"}
agent = service.create_agent(data)
# 如果设为默认,更新其他
if data.get('is_default'):
service.set_default_agent(agent.id)
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.put("/api/v2/agents/{agent_id}")
async def update_agent(agent_id: int, data: dict, db: Session = Depends(get_db)):
"""更新Agent"""
service = AgentService(db)
# 如果设为默认,先更新
if data.get('is_default'):
service.set_default_agent(agent_id)
agent = service.update_agent(agent_id, data)
if not agent:
return {"success": False, "message": "Agent不存在"}
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.delete("/api/v2/agents/{agent_id}")
async def delete_agent(agent_id: int, db: Session = Depends(get_db)):
"""删除Agent"""
service = AgentService(db)
success = service.delete_agent(agent_id)
if not success:
return {"success": False, "message": "无法删除默认Agent"}
return {"success": True}
@app.post("/api/v2/agents/{agent_id}/default")
async def set_agent_default(agent_id: int, db: Session = Depends(get_db)):
"""设置默认Agent"""
service = AgentService(db)
success = service.set_default_agent(agent_id)
return {"success": success}
@app.get("/api/v2/agents/{agent_id}/config")
async def get_agent_config(agent_id: int, db: Session = Depends(get_db)):
"""获取Agent完整配置含LLM Provider"""
service = AgentService(db)
config = service.get_agent_config(agent_id)
return {"config": config}
# ===== 渠道 API =====
@app.get("/api/v2/channels")
async def get_channels(db: Session = Depends(get_db)):
"""获取所有渠道及绑定关系"""
service = ChannelService(db)
channels = service.get_all_channels()
# 获取所有绑定关系
mappings = db.query(ChannelAgentMapping).order_by(ChannelAgentMapping.priority).all()
agents = {a.id: a for a in db.query(Agent).all()}
return {
"channels": [
{
"id": c.id,
"channel_type": c.channel_type,
"name": c.name,
"config": c.config,
"is_active": c.is_active,
"is_primary": c.is_primary,
"agent_mappings": [
{
"id": m.id,
"agent": {
"id": agents[m.agent_id].id if m.agent_id in agents else None,
"name": agents[m.agent_id].name if m.agent_id in agents else None,
"display_name": agents[m.agent_id].display_name if m.agent_id in agents else None,
"is_active": agents[m.agent_id].is_active if m.agent_id in agents else None
},
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings if m.channel_id == c.id
]
}
for c in channels
],
"mappings": [
{
"id": m.id,
"channel_id": m.channel_id,
"channel_type": channels[0].channel_type if channels else None, # 需要从Channel获取
"channel_name": next((c.name for c in channels if c.id == m.channel_id), ""),
"agent_id": m.agent_id,
"agent_name": agents[m.agent_id].display_name if m.agent_id in agents else "",
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings
]
}
@app.post("/api/v2/channels")
async def create_channel(data: dict, db: Session = Depends(get_db)):
"""创建渠道"""
service = ChannelService(db)
channel = service.create_channel(data)
return {"success": True, "channel": {"id": channel.id, "name": channel.name}}
@app.put("/api/v2/channels/{channel_id}")
async def update_channel(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道"""
service = ChannelService(db)
channel = service.update_channel(channel_id, data)
if not channel:
return {"success": False, "message": "渠道不存在"}
return {"success": True}
@app.put("/api/v2/channels/{channel_id}/config")
async def update_channel_config(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道配置如Matrix配置"""
channel = db.query(Channel).filter(Channel.id == channel_id).first()
if not channel:
return {"success": False, "message": "渠道不存在"}
channel.config = data.get('config', {})
db.commit()
# 如果是Matrix渠道重新初始化Matrix Bot
if channel.channel_type == 'matrix':
# TODO: 重启Matrix Bot
pass
return {"success": True}
@app.post("/api/v2/channels/bind")
async def bind_agent_to_channel(data: dict, db: Session = Depends(get_db)):
"""绑定Agent到渠道"""
service = ChannelService(db)
channel_id = data.get('channel_id')
agent_id = data.get('agent_id')
# 检查渠道和Agent是否存在
if not service.get_channel(channel_id):
return {"success": False, "message": "渠道不存在"}
agent_service = AgentService(db)
if not agent_service.get_agent(agent_id):
return {"success": False, "message": "Agent不存在"}
mapping = service.bind_agent(
channel_id=channel_id,
agent_id=agent_id,
priority=data.get('priority', 0),
mode=data.get('mode', 'single'),
conditions=data.get('conditions')
)
return {"success": True, "mapping": {"id": mapping.id}}
@app.delete("/api/v2/channels/unbind/{mapping_id}")
async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)):
"""解绑Agent"""
service = ChannelService(db)
success = service.unbind_agent(mapping_id)
return {"success": success}
# ==================== 对话 API保留原有 ====================
@app.get("/api/conversations")
async def get_conversations(db: Session = Depends(get_db)):
"""获取会话列表"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
return {
"conversations": [
{
"id": c.conversation_id,
"title": c.title or "新对话",
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat()
}
for c in conversations
]
}
@app.get("/api/conversations/latest")
async def get_latest_conversation(db: Session = Depends(get_db)):
"""获取最新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
if conversations:
latest = conversations[0]
return {
"conversation": {
"id": latest.conversation_id,
"title": latest.title or "新对话",
"updated_at": latest.updated_at.isoformat()
}
}
return {"conversation": None}
@app.post("/api/conversations")
async def create_conversation(db: Session = Depends(get_db)):
"""创建新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversation = conv_service.create_conversation(user.id)
return {
"id": conversation.conversation_id,
"title": conversation.title or "新对话",
"created_at": conversation.created_at.isoformat()
}
@app.get("/api/conversations/{conversation_id}/messages")
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
"""获取会话消息"""
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
messages = conv_service.get_messages(conversation.id)
return {
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content, # v2新增
"source": m.source,
"agent_id": m.agent_id, # v2新增
"model_used": m.model_used, # v2新增
"created_at": m.created_at.isoformat()
}
for m in messages
]
}
@app.delete("/api/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
"""删除会话"""
conv_service = ConversationService(db)
success = conv_service.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"success": True}
# ==================== WebSocket路由 ====================
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
"""WebSocket连接 - 实时对话"""
actual_user_id = MAIN_USER_ID
await manager.connect(websocket, actual_user_id)
# 初始化时获取默认Agent ID
db = SessionLocal()
try:
agent_service = AgentService(db)
default_agent = agent_service.get_default_agent()
default_agent_id = default_agent.id if default_agent else None
finally:
db.close()
current_conversation_id = None
current_agent_id = default_agent_id
try:
while True:
try:
data = await websocket.receive_json()
except Exception as json_err:
logger.error(f"JSON解析错误: {json_err}")
# 如果连接已断开,退出循环
if "disconnect" in str(json_err).lower() or "closed" in str(json_err).lower():
logger.info("WebSocket已断开退出循环")
break
try:
text_data = await websocket.receive_text()
if text_data.strip():
data = json.loads(text_data)
else:
continue
except Exception as text_err:
logger.error(f"文本消息解析错误: {text_err}")
if "disconnect" in str(text_err).lower() or "closed" in str(text_err).lower():
logger.info("WebSocket已断开退出循环")
break
continue
action = data.get("action")
logger.info(f"WebSocket收到消息: action={action}")
# 每次消息处理时创建新的数据库会话,处理完后关闭
try:
db = SessionLocal()
conv_service = ConversationService(db)
agent_service = AgentService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
if action == "select_conversation":
current_conversation_id = data.get("conversation_id")
conversation = conv_service.get_conversation(current_conversation_id)
if conversation:
messages = conv_service.get_messages(conversation.id)
# 获取对话使用的Agent ID
conv_agent_id = conversation.current_agent_id
await websocket.send_json({
"type": "history",
"conversation_id": current_conversation_id,
"agent_id": conv_agent_id, # 返回对话的Agent ID
"messages": [
{
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content,
"agent_id": m.agent_id, # 每条消息的Agent ID
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
})
elif action == "switch_agent":
# 切换Agent
new_agent_id = data.get("agent_id")
agent = agent_service.get_agent(new_agent_id)
if agent and agent.is_active:
current_agent_id = new_agent_id
await websocket.send_json({
"type": "agent_switched",
"agent_id": current_agent_id,
"agent_name": agent.display_name or agent.name
})
elif action == "chat":
message = data.get("message", "")
conversation_id = data.get("conversation_id")
enable_thinking = data.get("enable_thinking", True)
agent_id_override = data.get("agent_id")
if agent_id_override:
agent = agent_service.get_agent(agent_id_override)
if agent and agent.is_active:
current_agent_id = agent_id_override
if not message.strip():
continue
# 获取或创建会话
if conversation_id:
conversation = conv_service.get_conversation(conversation_id)
else:
conversation = conv_service.create_conversation(user.id)
conversation_id = conversation.conversation_id
await websocket.send_json({
"type": "conversation_created",
"conversation_id": conversation_id
})
# 保存用户消息
user_msg = conv_service.add_message(
conversation_id=conversation.id,
role='user',
content=message,
source='web'
)
# 广播用户消息
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": user_msg.id,
"role": "user",
"content": message,
"source": "web",
"created_at": user_msg.created_at.isoformat()
}
})
# 获取Agent配置
agent_config = agent_service.get_agent_config(current_agent_id)
if not agent_config or not agent_config.get('provider'):
await websocket.send_json({
"type": "error",
"message": "Agent配置不完整"
})
continue
# 获取对话历史
history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20))
# 使用非流式调用LLM简化版本确保稳定
try:
# 调用LLM非流式
response, thinking_content = await llm_service.chat(
messages=history,
provider_config=agent_config['provider'],
agent_config=agent_config['agent'],
enable_thinking=enable_thinking
)
logger.info(f"LLM响应: response长度={len(response)}, thinking长度={len(thinking_content) if thinking_content else 0}")
# 保存AI回复
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=response,
source='web',
thinking_content=thinking_content if thinking_content else None,
agent_id=current_agent_id,
model_used=agent_config['provider'].get('default_model')
)
# 发送完整回复(包含思考内容)
await websocket.send_json({
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": response,
"thinking_content": thinking_content if thinking_content else None,
"source": "web",
"agent_id": current_agent_id,
"agent_name": agent_config['agent'].get('display_name'),
"created_at": assistant_msg.created_at.isoformat()
}
})
logger.info(f"AI回复已发送: conversation_id={conversation_id}")
# 启用发送按钮
await websocket.send_json({
"type": "stream_end",
"conversation_id": conversation_id
})
except Exception as e:
logger.error(f"LLM调用失败: {e}")
await websocket.send_json({
"type": "error",
"message": f"AI服务暂时不可用: {str(e)}"
})
finally:
db.close()
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
manager.disconnect(websocket, user_id)
# ==================== 后台管理API ====================
@app.get("/api/admin/stats")
async def get_stats(db: Session = Depends(get_db)):
"""获取统计数据"""
total_users = db.query(User).count()
total_conversations = db.query(Conversation).count()
total_messages = db.query(Message).count()
return {
"total_users": total_users,
"total_conversations": total_conversations,
"total_messages": total_messages
}
# ==================== Matrix消息处理回调 ====================
async def handle_matrix_message(
action: str,
conversation_id: str = None,
user_message: str = None,
room_id: str = None,
message_id: int = None,
sender: str = None,
content: str = None,
thinking_content: str = None,
agent_id: int = None,
agent_name: str = None
):
"""处理从Matrix收到的消息"""
if action == "new_conversation":
await manager.send_to_user(MAIN_USER_ID, {
"type": "new_conversation",
"conversation_id": conversation_id
})
return
if action == "user_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "user",
"content": user_message,
"source": "matrix",
"sender": sender,
"created_at": datetime.now().isoformat()
}
})
return
if action == "assistant_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "assistant",
"content": content,
"thinking_content": thinking_content,
"source": "matrix",
"agent_id": agent_id,
"agent_name": agent_name,
"created_at": datetime.now().isoformat()
}
})
return
# ==================== 启动和关闭 ====================
# 导入Matrix Bot
try:
from services.matrix_service_v2 import matrix_bot
MATRIX_V2 = True
except ImportError:
from services.matrix_service import matrix_bot
MATRIX_V2 = False
@app.on_event("startup")
async def startup():
"""应用启动"""
init_db()
logger.info("数据库初始化完成")
# 初始化默认数据
init_default_data()
# 初始化Matrix Bot
try:
success = await matrix_bot.init_from_config()
if success and matrix_bot.is_running:
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
logger.info("Matrix Bot v2 已启动")
else:
logger.info("Matrix Bot 未配置或未启用")
except Exception as e:
logger.warning(f"Matrix Bot初始化失败: {e}")
@app.on_event("shutdown")
async def shutdown():
"""应用关闭"""
try:
await matrix_bot.disconnect()
except:
pass
logger.info("应用已关闭")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19020)