Files
ai-chat-system/main_v2.py

1220 lines
47 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI对话系统 v2.0.0 - 主应用
支持大模型池、Agent管理、渠道独立绑定、思考功能开关
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Dict, Set, Optional
import asyncio
import json
import logging
from datetime import datetime
import os
import base64
import uuid
import time
# 使用新的数据模型
from models_v2 import (
init_db, get_db, SessionLocal,
User, Conversation, Message, SystemConfig,
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping, ToolConfig, ToolUsageLog,
init_default_data
)
from services.llm_service import llm_service
from services.agent_service import AgentService, LLMProviderService, ChannelService, ToolService
from services.conversation_service import ConversationService
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建应用
app = FastAPI(title="AI对话系统 v2.0", version="2.0.0")
# 静态文件和模板
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
UPLOADS_DIR = os.path.join(BASE_DIR, "uploads", "images")
# 确保上传目录存在
os.makedirs(UPLOADS_DIR, exist_ok=True)
# 静态文件服务
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
app.mount("/uploads", StaticFiles(directory=os.path.join(BASE_DIR, "uploads")), name="uploads")
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
# WebSocket连接管理
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.info(f"WebSocket连接: {user_id}")
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: str, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
logger.info(f"发送消息到用户 {user_id}: {message.get('type', 'unknown')}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def ping(self, user_id: str):
"""发送心跳ping"""
await self.send_to_user(user_id, {"type": "ping"})
manager = ConnectionManager()
# 固定主用户ID
MAIN_USER_ID = "main_user"
# ==================== 页面路由 ====================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页 - 聊天界面"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/admin", response_class=HTMLResponse)
async def admin(request: Request):
"""后台管理 - v2.0版本"""
return templates.TemplateResponse("admin_v2/index.html", {"request": request})
# ==================== v2 API路由 ====================
# ===== 大模型池 API =====
@app.get("/api/v2/providers")
async def get_providers(db: Session = Depends(get_db)):
"""获取所有大模型池"""
service = LLMProviderService(db)
providers = service.get_all_providers()
return {
"providers": [
{
"id": p.id,
"name": p.name,
"api_base": p.api_base,
"api_key": p.api_key,
"models": p.models,
"default_model": p.default_model,
"supports_thinking": p.supports_thinking,
"thinking_model": p.thinking_model,
"supports_vision": p.supports_vision,
"vision_model": p.vision_model,
"max_tokens": p.max_tokens,
"temperature": p.temperature,
"is_active": p.is_active,
"priority": p.priority,
"description": p.description
}
for p in providers
]
}
@app.post("/api/v2/providers")
async def create_provider(data: dict, db: Session = Depends(get_db)):
"""创建大模型池"""
service = LLMProviderService(db)
# 检查名称是否已存在
if service.get_provider_by_name(data.get('name')):
return {"success": False, "message": "名称已存在"}
provider = service.create_provider(data)
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.put("/api/v2/providers/{provider_id}")
async def update_provider(provider_id: int, data: dict, db: Session = Depends(get_db)):
"""更新大模型池"""
service = LLMProviderService(db)
provider = service.update_provider(provider_id, data)
if not provider:
return {"success": False, "message": "Provider不存在"}
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.delete("/api/v2/providers/{provider_id}")
async def delete_provider(provider_id: int, db: Session = Depends(get_db)):
"""删除大模型池"""
service = LLMProviderService(db)
success = service.delete_provider(provider_id)
if not success:
return {"success": False, "message": "无法删除可能有Agent在使用"}
return {"success": True}
@app.post("/api/v2/providers/models")
async def fetch_models(data: dict):
"""从API获取模型列表"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
models = await llm_service.get_available_models(api_base, api_key)
return {"models": models}
@app.post("/api/v2/providers/test")
async def test_provider(data: dict):
"""测试大模型连接"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
model = data.get('model', 'auto')
result = await llm_service.test_connection(api_base, api_key, model)
return result
# ===== Agent API =====
@app.get("/api/v2/agents")
async def get_agents(db: Session = Depends(get_db)):
"""获取所有Agent"""
service = AgentService(db)
agents = service.get_all_agents()
# 获取Provider名称
providers = {p.id: p.name for p in db.query(LLMProvider).all()}
return {
"agents": [
{
"id": a.id,
"name": a.name,
"display_name": a.display_name,
"llm_provider_id": a.llm_provider_id,
"llm_provider_name": providers.get(a.llm_provider_id),
"model_override": a.model_override,
"system_prompt": a.system_prompt,
"enable_thinking": a.enable_thinking,
"thinking_prompt": a.thinking_prompt,
"thinking_prefix": a.thinking_prefix,
"thinking_suffix": a.thinking_suffix,
"tools": a.tools or [], # 工具列表
"max_history": a.max_history,
"temperature_override": a.temperature_override,
"is_default": a.is_default,
"is_active": a.is_active,
"description": a.description
}
for a in agents
]
}
@app.post("/api/v2/agents")
async def create_agent(data: dict, db: Session = Depends(get_db)):
"""创建Agent"""
service = AgentService(db)
# 检查名称是否已存在
if service.get_agent_by_name(data.get('name')):
return {"success": False, "message": "Agent名称已存在"}
# 检查Provider是否存在
provider_service = LLMProviderService(db)
if not provider_service.get_provider(data.get('llm_provider_id')):
return {"success": False, "message": "选择的LLM Provider不存在"}
agent = service.create_agent(data)
# 如果设为默认,更新其他
if data.get('is_default'):
service.set_default_agent(agent.id)
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.put("/api/v2/agents/{agent_id}")
async def update_agent(agent_id: int, data: dict, db: Session = Depends(get_db)):
"""更新Agent"""
service = AgentService(db)
# 如果设为默认,先更新
if data.get('is_default'):
service.set_default_agent(agent_id)
agent = service.update_agent(agent_id, data)
if not agent:
return {"success": False, "message": "Agent不存在"}
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.delete("/api/v2/agents/{agent_id}")
async def delete_agent(agent_id: int, db: Session = Depends(get_db)):
"""删除Agent"""
service = AgentService(db)
success = service.delete_agent(agent_id)
if not success:
return {"success": False, "message": "无法删除默认Agent"}
return {"success": True}
@app.post("/api/v2/agents/{agent_id}/default")
async def set_agent_default(agent_id: int, db: Session = Depends(get_db)):
"""设置默认Agent"""
service = AgentService(db)
success = service.set_default_agent(agent_id)
return {"success": success}
@app.get("/api/v2/agents/{agent_id}/config")
async def get_agent_config(agent_id: int, db: Session = Depends(get_db)):
"""获取Agent完整配置含LLM Provider"""
service = AgentService(db)
config = service.get_agent_config(agent_id)
return {"config": config}
# ===== 渠道 API =====
@app.get("/api/v2/channels")
async def get_channels(db: Session = Depends(get_db)):
"""获取所有渠道及绑定关系"""
service = ChannelService(db)
channels = service.get_all_channels()
# 获取所有绑定关系
mappings = db.query(ChannelAgentMapping).order_by(ChannelAgentMapping.priority).all()
agents = {a.id: a for a in db.query(Agent).all()}
return {
"channels": [
{
"id": c.id,
"channel_type": c.channel_type,
"name": c.name,
"config": c.config,
"is_active": c.is_active,
"is_primary": c.is_primary,
"agent_mappings": [
{
"id": m.id,
"agent": {
"id": agents[m.agent_id].id if m.agent_id in agents else None,
"name": agents[m.agent_id].name if m.agent_id in agents else None,
"display_name": agents[m.agent_id].display_name if m.agent_id in agents else None,
"is_active": agents[m.agent_id].is_active if m.agent_id in agents else None
},
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings if m.channel_id == c.id
]
}
for c in channels
],
"mappings": [
{
"id": m.id,
"channel_id": m.channel_id,
"channel_type": channels[0].channel_type if channels else None, # 需要从Channel获取
"channel_name": next((c.name for c in channels if c.id == m.channel_id), ""),
"agent_id": m.agent_id,
"agent_name": agents[m.agent_id].display_name if m.agent_id in agents else "",
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings
]
}
@app.post("/api/v2/channels")
async def create_channel(data: dict, db: Session = Depends(get_db)):
"""创建渠道"""
service = ChannelService(db)
channel = service.create_channel(data)
return {"success": True, "channel": {"id": channel.id, "name": channel.name}}
@app.put("/api/v2/channels/{channel_id}")
async def update_channel(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道"""
service = ChannelService(db)
channel = service.update_channel(channel_id, data)
if not channel:
return {"success": False, "message": "渠道不存在"}
return {"success": True}
@app.put("/api/v2/channels/{channel_id}/config")
async def update_channel_config(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道配置如Matrix配置"""
channel = db.query(Channel).filter(Channel.id == channel_id).first()
if not channel:
return {"success": False, "message": "渠道不存在"}
channel.config = data.get('config', {})
db.commit()
# 如果是Matrix渠道重新初始化Matrix Bot
if channel.channel_type == 'matrix':
# TODO: 重启Matrix Bot
pass
return {"success": True}
@app.post("/api/v2/channels/bind")
async def bind_agent_to_channel(data: dict, db: Session = Depends(get_db)):
"""绑定Agent到渠道"""
service = ChannelService(db)
channel_id = data.get('channel_id')
agent_id = data.get('agent_id')
# 检查渠道和Agent是否存在
if not service.get_channel(channel_id):
return {"success": False, "message": "渠道不存在"}
agent_service = AgentService(db)
if not agent_service.get_agent(agent_id):
return {"success": False, "message": "Agent不存在"}
mapping = service.bind_agent(
channel_id=channel_id,
agent_id=agent_id,
priority=data.get('priority', 0),
mode=data.get('mode', 'single'),
conditions=data.get('conditions')
)
return {"success": True, "mapping": {"id": mapping.id}}
@app.delete("/api/v2/channels/unbind/{mapping_id}")
async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)):
"""解绑Agent"""
service = ChannelService(db)
success = service.unbind_agent(mapping_id)
return {"success": success}
# ==================== 工具配置 API ====================
@app.get("/api/v2/tools")
async def get_tools(tool_type: str = None, db: Session = Depends(get_db)):
"""获取所有工具配置"""
service = ToolService(db)
if tool_type:
tools = service.get_tools_by_type(tool_type)
else:
tools = service.get_all_tools()
return {
"tools": [
{
"id": t.id,
"name": t.name,
"tool_type": t.tool_type,
"provider": t.provider,
"config": t.config,
"is_active": t.is_active,
"is_default": t.is_default,
"total_calls": t.total_calls,
"success_calls": t.success_calls,
"failed_calls": t.failed_calls
}
for t in tools
]
}
@app.post("/api/v2/tools")
async def create_tool(data: dict, db: Session = Depends(get_db)):
"""创建工具配置"""
service = ToolService(db)
tool = service.create_tool(data)
return {"success": True, "tool": {"id": tool.id, "name": tool.name, "tool_type": tool.tool_type}}
@app.put("/api/v2/tools/{tool_id}")
async def update_tool(tool_id: int, data: dict, db: Session = Depends(get_db)):
"""更新工具配置"""
service = ToolService(db)
if data.get('is_default'):
service.set_default_tool(tool_id)
tool = service.update_tool(tool_id, data)
if not tool:
return {"success": False, "message": "工具不存在"}
return {"success": True, "tool": {"id": tool.id, "name": tool.name}}
@app.delete("/api/v2/tools/{tool_id}")
async def delete_tool(tool_id: int, db: Session = Depends(get_db)):
"""删除工具配置"""
service = ToolService(db)
success = service.delete_tool(tool_id)
return {"success": success}
@app.post("/api/v2/tools/{tool_id}/default")
async def set_tool_default(tool_id: int, db: Session = Depends(get_db)):
"""设置默认工具"""
service = ToolService(db)
success = service.set_default_tool(tool_id)
return {"success": success}
@app.get("/api/v2/tools/stats")
async def get_tool_stats(days: int = 7, db: Session = Depends(get_db)):
"""获取工具使用统计"""
service = ToolService(db)
stats = service.get_usage_stats(days=days)
return stats
@app.post("/api/v2/tools/search")
async def perform_search(data: dict, db: Session = Depends(get_db)):
"""执行搜索供前端或Agent调用"""
import httpx
import time
query = data.get('query')
if not query:
return {"success": False, "message": "缺少搜索关键词"}
service = ToolService(db)
tool_id = data.get('tool_id')
if tool_id:
tool = service.get_tool(tool_id)
else:
tool = service.get_default_tool('search')
if not tool or not tool.config.get('api_key'):
return {"success": False, "message": "未配置搜索工具"}
# Tavily Search API
if tool.provider == 'tavily' or tool.tool_type == 'search':
start_time = time.time()
try:
tavily_url = "https://api.tavily.com/search"
config = tool.config
payload = {
"api_key": config.get('api_key'),
"query": query,
"max_results": config.get('max_results', 5),
"include_raw_content": config.get('include_raw_content', False),
"search_depth": config.get('search_depth', 'basic')
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(tavily_url, json=payload)
result = response.json()
duration_ms = int((time.time() - start_time) * 1000)
# 更新统计和日志
service.increment_stats(tool.id, True)
service.log_usage({
'tool_id': tool.id,
'tool_type': 'search',
'query': query,
'success': True,
'result_summary': f'{len(result.get("results", []))} results',
'conversation_id': data.get('conversation_id'),
'agent_id': data.get('agent_id'),
'duration_ms': duration_ms
})
return {
"success": True,
"results": result.get("results", []),
"query": query
}
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
service.increment_stats(tool.id, False)
service.log_usage({
'tool_id': tool.id,
'tool_type': 'search',
'query': query,
'success': False,
'error_message': str(e),
'conversation_id': data.get('conversation_id'),
'duration_ms': duration_ms
})
return {"success": False, "message": str(e)}
return {"success": False, "message": "不支持的搜索提供商"}
# ==================== 图片上传 API ====================
@app.post("/api/v2/upload-image")
async def upload_image(data: dict):
"""上传图片到服务器,返回文件路径"""
try:
image_data = data.get('image')
file_name = data.get('name', 'image.png')
if not image_data:
return {"success": False, "message": "缺少图片数据"}
# 解析 base64 数据
if image_data.startswith('data:image/'):
# 提取格式和base64内容
header, base64_content = image_data.split(',', 1)
# 从header中提取图片格式
format_match = header.split(':')[1].split(';')[0] # 如 'image/png'
ext = format_match.split('/')[1] if '/' in format_match else 'png'
else:
base64_content = image_data
ext = 'png'
# 生成唯一文件名
timestamp = int(time.time())
unique_id = uuid.uuid4().hex[:8]
safe_name = f"{timestamp}_{unique_id}.{ext}"
# 保存文件
file_path = os.path.join(UPLOADS_DIR, safe_name)
image_bytes = base64.b64decode(base64_content)
# 检查文件大小限制10MB
if len(image_bytes) > 10 * 1024 * 1024:
return {"success": False, "message": "图片大小超过10MB限制"}
with open(file_path, 'wb') as f:
f.write(image_bytes)
# 返回可访问的URL路径
url_path = f"/uploads/images/{safe_name}"
logger.info(f"图片已保存: {file_path}, URL: {url_path}")
return {"success": True, "path": url_path, "name": safe_name}
except Exception as e:
logger.error(f"图片上传失败: {e}")
return {"success": False, "message": str(e)}
# ==================== 对话 API保留原有 ====================
@app.get("/api/conversations")
async def get_conversations(db: Session = Depends(get_db)):
"""获取会话列表"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
return {
"conversations": [
{
"id": c.conversation_id,
"title": c.title or "新对话",
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat()
}
for c in conversations
]
}
@app.get("/api/conversations/latest")
async def get_latest_conversation(db: Session = Depends(get_db)):
"""获取最新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
if conversations:
latest = conversations[0]
return {
"conversation": {
"id": latest.conversation_id,
"title": latest.title or "新对话",
"updated_at": latest.updated_at.isoformat()
}
}
return {"conversation": None}
@app.post("/api/conversations")
async def create_conversation(db: Session = Depends(get_db)):
"""创建新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversation = conv_service.create_conversation(user.id)
return {
"id": conversation.conversation_id,
"title": conversation.title or "新对话",
"created_at": conversation.created_at.isoformat()
}
@app.get("/api/conversations/{conversation_id}/messages")
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
"""获取会话消息"""
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
messages = conv_service.get_messages(conversation.id)
return {
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content, # v2新增
"extra_data": m.extra_data, # 包含搜索结果等
"source": m.source,
"agent_id": m.agent_id, # v2新增
"model_used": m.model_used, # v2新增
"created_at": m.created_at.isoformat()
}
for m in messages
]
}
@app.delete("/api/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
"""删除会话"""
conv_service = ConversationService(db)
success = conv_service.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"success": True}
# ==================== WebSocket路由 ====================
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
"""WebSocket连接 - 实时对话"""
actual_user_id = MAIN_USER_ID
await manager.connect(websocket, actual_user_id)
# 初始化时获取默认Agent ID
db = SessionLocal()
try:
agent_service = AgentService(db)
default_agent = agent_service.get_default_agent()
default_agent_id = default_agent.id if default_agent else None
finally:
db.close()
current_conversation_id = None
current_agent_id = default_agent_id
try:
while True:
try:
data = await websocket.receive_json()
except Exception as json_err:
logger.error(f"JSON解析错误: {json_err}")
# 如果连接已断开,退出循环
if "disconnect" in str(json_err).lower() or "closed" in str(json_err).lower():
logger.info("WebSocket已断开退出循环")
break
try:
text_data = await websocket.receive_text()
if text_data.strip():
data = json.loads(text_data)
else:
continue
except Exception as text_err:
logger.error(f"文本消息解析错误: {text_err}")
if "disconnect" in str(text_err).lower() or "closed" in str(text_err).lower():
logger.info("WebSocket已断开退出循环")
break
continue
action = data.get("action")
logger.info(f"WebSocket收到消息: action={action}")
# 每次消息处理时创建新的数据库会话,处理完后关闭
try:
db = SessionLocal()
conv_service = ConversationService(db)
agent_service = AgentService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
if action == "select_conversation":
current_conversation_id = data.get("conversation_id")
conversation = conv_service.get_conversation(current_conversation_id)
if conversation:
messages = conv_service.get_messages(conversation.id)
# 获取对话使用的Agent ID
conv_agent_id = conversation.current_agent_id
await websocket.send_json({
"type": "history",
"conversation_id": current_conversation_id,
"agent_id": conv_agent_id, # 返回对话的Agent ID
"messages": [
{
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content,
"extra_data": m.extra_data, # 包含搜索结果等
"agent_id": m.agent_id, # 每条消息的Agent ID
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
})
elif action == "switch_agent":
# 切换Agent
new_agent_id = data.get("agent_id")
agent = agent_service.get_agent(new_agent_id)
if agent and agent.is_active:
current_agent_id = new_agent_id
await websocket.send_json({
"type": "agent_switched",
"agent_id": current_agent_id,
"agent_name": agent.display_name or agent.name
})
elif action == "chat":
message = data.get("message", "")
files = data.get("files", []) # 上传的文件
conversation_id = data.get("conversation_id")
enable_thinking = data.get("enable_thinking", True)
agent_id_override = data.get("agent_id")
disabled_tools = data.get("disabled_tools", []) # 禁用的工具列表
if agent_id_override:
agent = agent_service.get_agent(agent_id_override)
if agent and agent.is_active:
current_agent_id = agent_id_override
# 如果没有消息但有文件,构造消息
if not message.strip() and files:
message = "[上传文件]"
if not message.strip() and not files:
continue
# 处理文件内容,添加到消息
image_contents = [] # 图片内容(用于视觉模型)
text_contents = [] # 文本文件内容
image_paths = [] # 图片服务器路径(用于历史记录显示)
if files:
for f in files:
if f.get('type') and f['type'].startswith('image/'):
# 图片:记录 base64 数据,用于视觉模型
image_contents.append({
'name': f['name'],
'type': f['type'],
'data': f.get('content', '') # base64 数据
})
# 记录服务器路径(用于历史记录)
if f.get('serverPath'):
image_paths.append({
'name': f['name'],
'type': f['type'],
'url': f['serverPath'] # 服务器文件路径
})
# 不添加文件名文本,图片信息保存在 extra_data 中
elif f.get('content'):
# 文本文件:直接添加内容,不带文件名前缀
text_contents.append(f['content'][:3000])
if len(f['content']) > 3000:
text_contents[-1] += "...(内容过长已截断)"
# 如果有文本文件内容,追加到消息后面
if text_contents:
for content in text_contents:
message += f"\n\n{content}"
# 保存图片和文件信息到 extra_data用于历史记录
extra_data_for_msg = None
if image_paths:
# 图片保存服务器路径URL历史记录可以显示
extra_data_for_msg = {
'images': image_paths,
'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')]
}
elif image_contents:
# 没有服务器路径但有问题(可能上传失败)
extra_data_for_msg = {
'images': [{'name': i['name'], 'type': i['type']} for i in image_contents],
'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')]
}
# 1. 获取Agent配置
agent_config = agent_service.get_agent_config(current_agent_id)
agent_tools = agent_config.get('agent', {}).get('tools', [])
# 2. 获取或创建会话(先有 conversation_id
if conversation_id:
conversation = conv_service.get_conversation(conversation_id)
else:
conversation = conv_service.create_conversation(user.id)
conversation_id = conversation.conversation_id
await websocket.send_json({
"type": "conversation_created",
"conversation_id": conversation_id
})
# 3. 广播用户消息(前端立即看到)
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": None, # 临时,后面会保存
"role": "user",
"content": message,
"source": "web",
"created_at": datetime.utcnow().isoformat()
}
})
# 4. 执行搜索并发送搜索结果
search_context = None
search_results_for_client = None # 用于发送给前端和保存
logger.info(f"检查搜索条件: agent_tools={agent_tools}, disabled_tools={disabled_tools}")
if 'search' in agent_tools and 'search' not in disabled_tools:
logger.info("搜索条件满足,开始执行搜索")
tool_service = ToolService(db)
search_tool = tool_service.get_default_tool('search')
logger.info(f"获取到搜索工具: {search_tool.name if search_tool else 'None'}")
if search_tool and search_tool.config.get('api_key'):
import httpx
import time
start_time = time.time()
try:
logger.info(f"执行搜索: query={message}")
tavily_url = "https://api.tavily.com/search"
config = search_tool.config
payload = {
"api_key": config.get('api_key'),
"query": message,
"max_results": config.get('max_results', 5),
"search_depth": config.get('search_depth', 'basic')
}
with httpx.Client(timeout=30) as client:
resp = client.post(tavily_url, json=payload)
search_result = resp.json()
duration_ms = int((time.time() - start_time) * 1000)
if search_result.get("results"):
# 构建搜索上下文给LLM
max_for_llm = config.get('max_results', 5)
search_context = "\n\n【搜索结果】\n"
for i, r in enumerate(search_result["results"][:max_for_llm], 1):
search_context += f"{i}. {r.get('title', 'N/A')}\n {r.get('content', r.get('snippet', 'N/A'))[:200]}\n 来源: {r.get('url', 'N/A')}\n"
logger.info(f"搜索完成: {len(search_result['results'])} 条结果,使用 {min(len(search_result['results']), max_for_llm)}")
# 发送搜索结果给前端(按配置的数量)
max_display = config.get('max_results', 5)
search_results_for_client = [
{
"title": r.get('title', 'N/A'),
"snippet": r.get('content', r.get('snippet', ''))[:150],
"url": r.get('url', 'N/A')
}
for r in search_result["results"][:max_display]
]
await websocket.send_json({
"type": "search_results",
"conversation_id": conversation_id,
"results": search_results_for_client,
"query": message
})
# 更新统计和日志
tool_service.increment_stats(search_tool.id, True)
tool_service.log_usage({
'tool_id': search_tool.id,
'tool_type': 'search',
'query': message,
'success': True,
'result_summary': f'{len(search_result["results"])} results',
'conversation_id': conversation_id,
'agent_id': current_agent_id,
'duration_ms': duration_ms
})
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"搜索失败: {e}")
tool_service.increment_stats(search_tool.id, False)
tool_service.log_usage({
'tool_id': search_tool.id,
'tool_type': 'search',
'query': message,
'success': False,
'error_message': str(e),
'conversation_id': conversation_id,
'duration_ms': duration_ms
})
# 5. 保存用户消息到数据库
extra_data_to_save = None
if search_results_for_client:
extra_data_to_save = {'search_results': search_results_for_client, 'search_query': message}
if extra_data_for_msg:
if extra_data_to_save:
extra_data_to_save.update(extra_data_for_msg)
else:
extra_data_to_save = extra_data_for_msg
user_msg = conv_service.add_message(
conversation_id=conversation.id,
role='user',
content=message,
source='web',
extra_data=extra_data_to_save
)
# 6. 获取对话历史(包含刚保存的用户消息)
history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20))
# 7. 如果有搜索结果,添加到消息中
if search_context:
modified_system_prompt = agent_config['agent'].get('system_prompt', '') + "\n\n如果提供了搜索结果,请基于搜索结果回答用户问题,并注明信息来源。"
agent_config['agent']['system_prompt'] = modified_system_prompt
history.append({"role": "system", "content": f"以下是搜索到的相关信息,请参考这些内容回答用户问题:{search_context}"})
# 8. 调用LLM返回回复
if not agent_config or not agent_config.get('provider'):
await websocket.send_json({
"type": "error",
"message": "Agent配置不完整"
})
continue
try:
response, thinking_content = await llm_service.chat(
messages=history,
provider_config=agent_config['provider'],
agent_config=agent_config['agent'],
enable_thinking=enable_thinking,
images=image_contents # 传递图片数据给多模态模型
)
logger.info(f"LLM响应: response长度={len(response)}, thinking长度={len(thinking_content) if thinking_content else 0}")
# 保存AI回复
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=response,
source='web',
thinking_content=thinking_content if thinking_content else None,
agent_id=current_agent_id,
model_used=agent_config['provider'].get('default_model')
)
# 发送AI回复
await websocket.send_json({
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": response,
"thinking_content": thinking_content if thinking_content else None,
"source": "web",
"agent_id": current_agent_id,
"agent_name": agent_config['agent'].get('display_name'),
"created_at": assistant_msg.created_at.isoformat()
}
})
logger.info(f"AI回复已发送: conversation_id={conversation_id}")
await websocket.send_json({
"type": "stream_end",
"conversation_id": conversation_id
})
except Exception as e:
logger.error(f"LLM调用失败: {e}")
await websocket.send_json({
"type": "error",
"message": f"AI服务暂时不可用: {str(e)}"
})
finally:
db.close()
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
manager.disconnect(websocket, user_id)
# ==================== 后台管理API ====================
@app.get("/api/admin/stats")
async def get_stats(db: Session = Depends(get_db)):
"""获取统计数据"""
total_users = db.query(User).count()
total_conversations = db.query(Conversation).count()
total_messages = db.query(Message).count()
return {
"total_users": total_users,
"total_conversations": total_conversations,
"total_messages": total_messages
}
# ==================== Matrix消息处理回调 ====================
async def handle_matrix_message(
action: str,
conversation_id: str = None,
user_message: str = None,
room_id: str = None,
message_id: int = None,
sender: str = None,
content: str = None,
thinking_content: str = None,
agent_id: int = None,
agent_name: str = None
):
"""处理从Matrix收到的消息"""
if action == "new_conversation":
await manager.send_to_user(MAIN_USER_ID, {
"type": "new_conversation",
"conversation_id": conversation_id
})
return
if action == "user_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "user",
"content": user_message,
"source": "matrix",
"sender": sender,
"created_at": datetime.now().isoformat()
}
})
return
if action == "assistant_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "assistant",
"content": content,
"thinking_content": thinking_content,
"source": "matrix",
"agent_id": agent_id,
"agent_name": agent_name,
"created_at": datetime.now().isoformat()
}
})
return
# ==================== 启动和关闭 ====================
# 导入Matrix Bot
try:
from services.matrix_service_v2 import matrix_bot
MATRIX_V2 = True
except ImportError:
from services.matrix_service import matrix_bot
MATRIX_V2 = False
@app.on_event("startup")
async def startup():
"""应用启动"""
init_db()
logger.info("数据库初始化完成")
# 初始化默认数据
init_default_data()
# 初始化Matrix Bot
try:
success = await matrix_bot.init_from_config()
if success and matrix_bot.is_running:
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
logger.info("Matrix Bot v2 已启动")
else:
logger.info("Matrix Bot 未配置或未启用")
except Exception as e:
logger.warning(f"Matrix Bot初始化失败: {e}")
@app.on_event("shutdown")
async def shutdown():
"""应用关闭"""
try:
await matrix_bot.disconnect()
except:
pass
logger.info("应用已关闭")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19020)