Files
ai-chat-system/main_v2.py
hubian ae08e01e55 fix: Kimi模型伪工具调用格式过滤
修复Kimi-K2.5模型在第二轮调用时输出伪工具调用格式的问题:
- 添加系统提示告诉模型直接根据工具结果回答
- 过滤 <|tool_calls_section_begin|> 等内部格式标记
- 清理多余空行

版本: v3.0.1
2026-04-15 09:45:08 +08:00

1311 lines
53 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI对话系统 v3.0.0 - 主应用
支持大模型池、Agent管理、渠道独立绑定、思考功能开关、Function CallingLLM自主调用工具
"""
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy.orm import Session
from typing import Dict, Set, Optional
import asyncio
import json
import logging
from datetime import datetime
import os
import base64
import uuid
import time
# 使用新的数据模型
from models_v2 import (
init_db, get_db, SessionLocal,
User, Conversation, Message, SystemConfig,
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping, ToolConfig, ToolUsageLog,
init_default_data
)
from services.llm_service import llm_service
from services.agent_service import AgentService, LLMProviderService, ChannelService, ToolService
from services.conversation_service import ConversationService
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建应用
app = FastAPI(title="AI对话系统 v3.0", version="3.0.1")
# 静态文件和模板
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
UPLOADS_DIR = os.path.join(BASE_DIR, "uploads", "images")
# 确保上传目录存在
os.makedirs(UPLOADS_DIR, exist_ok=True)
# 静态文件服务
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
app.mount("/uploads", StaticFiles(directory=os.path.join(BASE_DIR, "uploads")), name="uploads")
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
# WebSocket连接管理
class ConnectionManager:
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
async def connect(self, websocket: WebSocket, user_id: str):
await websocket.accept()
if user_id not in self.active_connections:
self.active_connections[user_id] = set()
self.active_connections[user_id].add(websocket)
logger.info(f"WebSocket连接: {user_id}")
def disconnect(self, websocket: WebSocket, user_id: str):
if user_id in self.active_connections:
self.active_connections[user_id].discard(websocket)
if not self.active_connections[user_id]:
del self.active_connections[user_id]
async def send_to_user(self, user_id: str, message: dict):
if user_id in self.active_connections:
for connection in self.active_connections[user_id]:
try:
await connection.send_json(message)
logger.info(f"发送消息到用户 {user_id}: {message.get('type', 'unknown')}")
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def ping(self, user_id: str):
"""发送心跳ping"""
await self.send_to_user(user_id, {"type": "ping"})
manager = ConnectionManager()
# 固定主用户ID
MAIN_USER_ID = "main_user"
# ==================== 页面路由 ====================
@app.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页 - 聊天界面"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/admin", response_class=HTMLResponse)
async def admin(request: Request):
"""后台管理 - v2.0版本"""
return templates.TemplateResponse("admin_v2/index.html", {"request": request})
# ==================== v2 API路由 ====================
# ===== 大模型池 API =====
@app.get("/api/v2/providers")
async def get_providers(db: Session = Depends(get_db)):
"""获取所有大模型池"""
service = LLMProviderService(db)
providers = service.get_all_providers()
return {
"providers": [
{
"id": p.id,
"name": p.name,
"api_base": p.api_base,
"api_key": p.api_key,
"models": p.models,
"default_model": p.default_model,
"supports_thinking": p.supports_thinking,
"thinking_model": p.thinking_model,
"supports_vision": p.supports_vision,
"vision_model": p.vision_model,
"supports_function_calling": p.supports_function_calling,
"max_tokens": p.max_tokens,
"temperature": p.temperature,
"is_active": p.is_active,
"priority": p.priority,
"description": p.description
}
for p in providers
]
}
@app.post("/api/v2/providers")
async def create_provider(data: dict, db: Session = Depends(get_db)):
"""创建大模型池"""
service = LLMProviderService(db)
# 检查名称是否已存在
if service.get_provider_by_name(data.get('name')):
return {"success": False, "message": "名称已存在"}
provider = service.create_provider(data)
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.put("/api/v2/providers/{provider_id}")
async def update_provider(provider_id: int, data: dict, db: Session = Depends(get_db)):
"""更新大模型池"""
service = LLMProviderService(db)
provider = service.update_provider(provider_id, data)
if not provider:
return {"success": False, "message": "Provider不存在"}
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
@app.delete("/api/v2/providers/{provider_id}")
async def delete_provider(provider_id: int, db: Session = Depends(get_db)):
"""删除大模型池"""
service = LLMProviderService(db)
success = service.delete_provider(provider_id)
if not success:
return {"success": False, "message": "无法删除可能有Agent在使用"}
return {"success": True}
@app.post("/api/v2/providers/models")
async def fetch_models(data: dict):
"""从API获取模型列表"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
models = await llm_service.get_available_models(api_base, api_key)
return {"models": models}
@app.post("/api/v2/providers/test")
async def test_provider(data: dict):
"""测试大模型连接"""
api_base = data.get('api_base')
api_key = data.get('api_key', '')
model = data.get('model', 'auto')
result = await llm_service.test_connection(api_base, api_key, model)
return result
# ===== Agent API =====
@app.get("/api/v2/agents")
async def get_agents(db: Session = Depends(get_db)):
"""获取所有Agent"""
service = AgentService(db)
agents = service.get_all_agents()
# 获取Provider名称
providers = {p.id: p.name for p in db.query(LLMProvider).all()}
return {
"agents": [
{
"id": a.id,
"name": a.name,
"display_name": a.display_name,
"llm_provider_id": a.llm_provider_id,
"llm_provider_name": providers.get(a.llm_provider_id),
"model_override": a.model_override,
"system_prompt": a.system_prompt,
"enable_thinking": a.enable_thinking,
"thinking_prompt": a.thinking_prompt,
"thinking_prefix": a.thinking_prefix,
"thinking_suffix": a.thinking_suffix,
"tools": a.tools or [], # 工具列表
"max_history": a.max_history,
"temperature_override": a.temperature_override,
"is_default": a.is_default,
"is_active": a.is_active,
"description": a.description
}
for a in agents
]
}
@app.post("/api/v2/agents")
async def create_agent(data: dict, db: Session = Depends(get_db)):
"""创建Agent"""
service = AgentService(db)
# 检查名称是否已存在
if service.get_agent_by_name(data.get('name')):
return {"success": False, "message": "Agent名称已存在"}
# 检查Provider是否存在
provider_service = LLMProviderService(db)
if not provider_service.get_provider(data.get('llm_provider_id')):
return {"success": False, "message": "选择的LLM Provider不存在"}
agent = service.create_agent(data)
# 如果设为默认,更新其他
if data.get('is_default'):
service.set_default_agent(agent.id)
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.put("/api/v2/agents/{agent_id}")
async def update_agent(agent_id: int, data: dict, db: Session = Depends(get_db)):
"""更新Agent"""
service = AgentService(db)
# 如果设为默认,先更新
if data.get('is_default'):
service.set_default_agent(agent_id)
agent = service.update_agent(agent_id, data)
if not agent:
return {"success": False, "message": "Agent不存在"}
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
@app.delete("/api/v2/agents/{agent_id}")
async def delete_agent(agent_id: int, db: Session = Depends(get_db)):
"""删除Agent"""
service = AgentService(db)
success = service.delete_agent(agent_id)
if not success:
return {"success": False, "message": "无法删除默认Agent"}
return {"success": True}
@app.post("/api/v2/agents/{agent_id}/default")
async def set_agent_default(agent_id: int, db: Session = Depends(get_db)):
"""设置默认Agent"""
service = AgentService(db)
success = service.set_default_agent(agent_id)
return {"success": success}
@app.get("/api/v2/agents/{agent_id}/config")
async def get_agent_config(agent_id: int, db: Session = Depends(get_db)):
"""获取Agent完整配置含LLM Provider"""
service = AgentService(db)
config = service.get_agent_config(agent_id)
return {"config": config}
# ===== 渠道 API =====
@app.get("/api/v2/channels")
async def get_channels(db: Session = Depends(get_db)):
"""获取所有渠道及绑定关系"""
service = ChannelService(db)
channels = service.get_all_channels()
# 获取所有绑定关系
mappings = db.query(ChannelAgentMapping).order_by(ChannelAgentMapping.priority).all()
agents = {a.id: a for a in db.query(Agent).all()}
return {
"channels": [
{
"id": c.id,
"channel_type": c.channel_type,
"name": c.name,
"config": c.config,
"is_active": c.is_active,
"is_primary": c.is_primary,
"agent_mappings": [
{
"id": m.id,
"agent": {
"id": agents[m.agent_id].id if m.agent_id in agents else None,
"name": agents[m.agent_id].name if m.agent_id in agents else None,
"display_name": agents[m.agent_id].display_name if m.agent_id in agents else None,
"is_active": agents[m.agent_id].is_active if m.agent_id in agents else None
},
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings if m.channel_id == c.id
]
}
for c in channels
],
"mappings": [
{
"id": m.id,
"channel_id": m.channel_id,
"channel_type": channels[0].channel_type if channels else None, # 需要从Channel获取
"channel_name": next((c.name for c in channels if c.id == m.channel_id), ""),
"agent_id": m.agent_id,
"agent_name": agents[m.agent_id].display_name if m.agent_id in agents else "",
"priority": m.priority,
"mode": m.mode,
"conditions": m.conditions
}
for m in mappings
]
}
@app.post("/api/v2/channels")
async def create_channel(data: dict, db: Session = Depends(get_db)):
"""创建渠道"""
service = ChannelService(db)
channel = service.create_channel(data)
return {"success": True, "channel": {"id": channel.id, "name": channel.name}}
@app.put("/api/v2/channels/{channel_id}")
async def update_channel(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道"""
service = ChannelService(db)
channel = service.update_channel(channel_id, data)
if not channel:
return {"success": False, "message": "渠道不存在"}
return {"success": True}
@app.put("/api/v2/channels/{channel_id}/config")
async def update_channel_config(channel_id: int, data: dict, db: Session = Depends(get_db)):
"""更新渠道配置如Matrix配置"""
channel = db.query(Channel).filter(Channel.id == channel_id).first()
if not channel:
return {"success": False, "message": "渠道不存在"}
channel.config = data.get('config', {})
db.commit()
# 如果是Matrix渠道重新初始化Matrix Bot
if channel.channel_type == 'matrix':
# TODO: 重启Matrix Bot
pass
return {"success": True}
@app.post("/api/v2/channels/bind")
async def bind_agent_to_channel(data: dict, db: Session = Depends(get_db)):
"""绑定Agent到渠道"""
service = ChannelService(db)
channel_id = data.get('channel_id')
agent_id = data.get('agent_id')
# 检查渠道和Agent是否存在
if not service.get_channel(channel_id):
return {"success": False, "message": "渠道不存在"}
agent_service = AgentService(db)
if not agent_service.get_agent(agent_id):
return {"success": False, "message": "Agent不存在"}
mapping = service.bind_agent(
channel_id=channel_id,
agent_id=agent_id,
priority=data.get('priority', 0),
mode=data.get('mode', 'single'),
conditions=data.get('conditions')
)
return {"success": True, "mapping": {"id": mapping.id}}
@app.delete("/api/v2/channels/unbind/{mapping_id}")
async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)):
"""解绑Agent"""
service = ChannelService(db)
success = service.unbind_agent(mapping_id)
return {"success": success}
# ==================== 工具配置 API ====================
@app.get("/api/v2/tools")
async def get_tools(tool_type: str = None, db: Session = Depends(get_db)):
"""获取所有工具配置"""
service = ToolService(db)
if tool_type:
tools = service.get_tools_by_type(tool_type)
else:
tools = service.get_all_tools()
return {
"tools": [
{
"id": t.id,
"name": t.name,
"tool_type": t.tool_type,
"provider": t.provider,
"config": t.config,
"is_active": t.is_active,
"is_default": t.is_default,
"total_calls": t.total_calls,
"success_calls": t.success_calls,
"failed_calls": t.failed_calls
}
for t in tools
]
}
@app.post("/api/v2/tools")
async def create_tool(data: dict, db: Session = Depends(get_db)):
"""创建工具配置"""
service = ToolService(db)
tool = service.create_tool(data)
return {"success": True, "tool": {"id": tool.id, "name": tool.name, "tool_type": tool.tool_type}}
@app.put("/api/v2/tools/{tool_id}")
async def update_tool(tool_id: int, data: dict, db: Session = Depends(get_db)):
"""更新工具配置"""
service = ToolService(db)
if data.get('is_default'):
service.set_default_tool(tool_id)
tool = service.update_tool(tool_id, data)
if not tool:
return {"success": False, "message": "工具不存在"}
return {"success": True, "tool": {"id": tool.id, "name": tool.name}}
@app.delete("/api/v2/tools/{tool_id}")
async def delete_tool(tool_id: int, db: Session = Depends(get_db)):
"""删除工具配置"""
service = ToolService(db)
success = service.delete_tool(tool_id)
return {"success": success}
@app.post("/api/v2/tools/{tool_id}/default")
async def set_tool_default(tool_id: int, db: Session = Depends(get_db)):
"""设置默认工具"""
service = ToolService(db)
success = service.set_default_tool(tool_id)
return {"success": success}
@app.get("/api/v2/tools/stats")
async def get_tool_stats(days: int = 7, db: Session = Depends(get_db)):
"""获取工具使用统计"""
service = ToolService(db)
stats = service.get_usage_stats(days=days)
return stats
@app.post("/api/v2/tools/search")
async def perform_search(data: dict, db: Session = Depends(get_db)):
"""执行搜索供前端或Agent调用"""
import httpx
import time
query = data.get('query')
if not query:
return {"success": False, "message": "缺少搜索关键词"}
service = ToolService(db)
tool_id = data.get('tool_id')
if tool_id:
tool = service.get_tool(tool_id)
else:
tool = service.get_default_tool('search')
if not tool or not tool.config.get('api_key'):
return {"success": False, "message": "未配置搜索工具"}
# Tavily Search API
if tool.provider == 'tavily' or tool.tool_type == 'search':
start_time = time.time()
try:
tavily_url = "https://api.tavily.com/search"
config = tool.config
payload = {
"api_key": config.get('api_key'),
"query": query,
"max_results": config.get('max_results', 5),
"include_raw_content": config.get('include_raw_content', False),
"search_depth": config.get('search_depth', 'basic')
}
async with httpx.AsyncClient(timeout=30) as client:
response = await client.post(tavily_url, json=payload)
result = response.json()
duration_ms = int((time.time() - start_time) * 1000)
# 更新统计和日志
service.increment_stats(tool.id, True)
service.log_usage({
'tool_id': tool.id,
'tool_type': 'search',
'query': query,
'success': True,
'result_summary': f'{len(result.get("results", []))} results',
'conversation_id': data.get('conversation_id'),
'agent_id': data.get('agent_id'),
'duration_ms': duration_ms
})
return {
"success": True,
"results": result.get("results", []),
"query": query
}
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
service.increment_stats(tool.id, False)
service.log_usage({
'tool_id': tool.id,
'tool_type': 'search',
'query': query,
'success': False,
'error_message': str(e),
'conversation_id': data.get('conversation_id'),
'duration_ms': duration_ms
})
return {"success": False, "message": str(e)}
return {"success": False, "message": "不支持的搜索提供商"}
# ==================== 图片上传 API ====================
@app.post("/api/v2/upload-image")
async def upload_image(data: dict):
"""上传图片到服务器,返回文件路径"""
try:
image_data = data.get('image')
file_name = data.get('name', 'image.png')
if not image_data:
return {"success": False, "message": "缺少图片数据"}
# 解析 base64 数据
if image_data.startswith('data:image/'):
# 提取格式和base64内容
header, base64_content = image_data.split(',', 1)
# 从header中提取图片格式
format_match = header.split(':')[1].split(';')[0] # 如 'image/png'
ext = format_match.split('/')[1] if '/' in format_match else 'png'
else:
base64_content = image_data
ext = 'png'
# 生成唯一文件名
timestamp = int(time.time())
unique_id = uuid.uuid4().hex[:8]
safe_name = f"{timestamp}_{unique_id}.{ext}"
# 保存文件
file_path = os.path.join(UPLOADS_DIR, safe_name)
image_bytes = base64.b64decode(base64_content)
# 检查文件大小限制10MB
if len(image_bytes) > 10 * 1024 * 1024:
return {"success": False, "message": "图片大小超过10MB限制"}
with open(file_path, 'wb') as f:
f.write(image_bytes)
# 返回可访问的URL路径
url_path = f"/uploads/images/{safe_name}"
logger.info(f"图片已保存: {file_path}, URL: {url_path}")
return {"success": True, "path": url_path, "name": safe_name}
except Exception as e:
logger.error(f"图片上传失败: {e}")
return {"success": False, "message": str(e)}
# ==================== 对话 API保留原有 ====================
@app.get("/api/conversations")
async def get_conversations(db: Session = Depends(get_db)):
"""获取会话列表"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
return {
"conversations": [
{
"id": c.conversation_id,
"title": c.title or "新对话",
"created_at": c.created_at.isoformat(),
"updated_at": c.updated_at.isoformat()
}
for c in conversations
]
}
@app.get("/api/conversations/latest")
async def get_latest_conversation(db: Session = Depends(get_db)):
"""获取最新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversations = conv_service.get_user_conversations(user.id)
if conversations:
latest = conversations[0]
return {
"conversation": {
"id": latest.conversation_id,
"title": latest.title or "新对话",
"updated_at": latest.updated_at.isoformat()
}
}
return {"conversation": None}
@app.post("/api/conversations")
async def create_conversation(db: Session = Depends(get_db)):
"""创建新会话"""
conv_service = ConversationService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
conversation = conv_service.create_conversation(user.id)
return {
"id": conversation.conversation_id,
"title": conversation.title or "新对话",
"created_at": conversation.created_at.isoformat()
}
@app.get("/api/conversations/{conversation_id}/messages")
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
"""获取会话消息"""
conv_service = ConversationService(db)
conversation = conv_service.get_conversation(conversation_id)
if not conversation:
raise HTTPException(status_code=404, detail="会话不存在")
messages = conv_service.get_messages(conversation.id)
return {
"messages": [
{
"id": m.id,
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content, # v2新增
"extra_data": m.extra_data, # 包含搜索结果等
"source": m.source,
"agent_id": m.agent_id, # v2新增
"model_used": m.model_used, # v2新增
"created_at": m.created_at.isoformat()
}
for m in messages
]
}
@app.delete("/api/conversations/{conversation_id}")
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
"""删除会话"""
conv_service = ConversationService(db)
success = conv_service.delete_conversation(conversation_id)
if not success:
raise HTTPException(status_code=404, detail="会话不存在")
return {"success": True}
# ==================== WebSocket路由 ====================
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
"""WebSocket连接 - 实时对话"""
actual_user_id = MAIN_USER_ID
await manager.connect(websocket, actual_user_id)
# 初始化时获取默认Agent ID
db = SessionLocal()
try:
agent_service = AgentService(db)
default_agent = agent_service.get_default_agent()
default_agent_id = default_agent.id if default_agent else None
finally:
db.close()
current_conversation_id = None
current_agent_id = default_agent_id
try:
while True:
try:
data = await websocket.receive_json()
except Exception as json_err:
logger.error(f"JSON解析错误: {json_err}")
# 如果连接已断开,退出循环
if "disconnect" in str(json_err).lower() or "closed" in str(json_err).lower():
logger.info("WebSocket已断开退出循环")
break
try:
text_data = await websocket.receive_text()
if text_data.strip():
data = json.loads(text_data)
else:
continue
except Exception as text_err:
logger.error(f"文本消息解析错误: {text_err}")
if "disconnect" in str(text_err).lower() or "closed" in str(text_err).lower():
logger.info("WebSocket已断开退出循环")
break
continue
action = data.get("action")
logger.info(f"WebSocket收到消息: action={action}")
# 每次消息处理时创建新的数据库会话,处理完后关闭
try:
db = SessionLocal()
conv_service = ConversationService(db)
agent_service = AgentService(db)
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
if action == "select_conversation":
current_conversation_id = data.get("conversation_id")
conversation = conv_service.get_conversation(current_conversation_id)
if conversation:
messages = conv_service.get_messages(conversation.id)
# 获取对话使用的Agent ID
conv_agent_id = conversation.current_agent_id
await websocket.send_json({
"type": "history",
"conversation_id": current_conversation_id,
"agent_id": conv_agent_id, # 返回对话的Agent ID
"messages": [
{
"role": m.role,
"content": m.content,
"thinking_content": m.thinking_content,
"extra_data": m.extra_data, # 包含搜索结果等
"agent_id": m.agent_id, # 每条消息的Agent ID
"source": m.source,
"created_at": m.created_at.isoformat()
}
for m in messages
]
})
elif action == "switch_agent":
# 切换Agent
new_agent_id = data.get("agent_id")
agent = agent_service.get_agent(new_agent_id)
if agent and agent.is_active:
current_agent_id = new_agent_id
await websocket.send_json({
"type": "agent_switched",
"agent_id": current_agent_id,
"agent_name": agent.display_name or agent.name
})
elif action == "chat":
message = data.get("message", "")
files = data.get("files", []) # 上传的文件
conversation_id = data.get("conversation_id")
enable_thinking = data.get("enable_thinking", True)
agent_id_override = data.get("agent_id")
# v3.0: 移除 disabled_tools由LLM自主决定
if agent_id_override:
agent = agent_service.get_agent(agent_id_override)
if agent and agent.is_active:
current_agent_id = agent_id_override
# 如果没有消息但有文件,构造消息
if not message.strip() and files:
message = "[上传文件]"
if not message.strip() and not files:
continue
# 处理文件内容
image_contents = []
text_contents = []
image_paths = []
if files:
for f in files:
if f.get('type') and f['type'].startswith('image/'):
image_contents.append({
'name': f['name'],
'type': f['type'],
'data': f.get('content', '')
})
if f.get('serverPath'):
image_paths.append({
'name': f['name'],
'type': f['type'],
'url': f['serverPath']
})
elif f.get('content'):
text_contents.append(f['content'][:3000])
if len(f['content']) > 3000:
text_contents[-1] += "...(内容过长已截断)"
if text_contents:
for content in text_contents:
message += f"\n\n{content}"
# 保存文件信息到 extra_data
extra_data_for_msg = None
if image_paths:
extra_data_for_msg = {
'images': image_paths,
'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')]
}
elif image_contents:
extra_data_for_msg = {
'images': [{'name': i['name'], 'type': i['type']} for i in image_contents],
'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')]
}
# 1. 获取Agent配置
agent_config = agent_service.get_agent_config(current_agent_id)
agent_tools = agent_config.get('agent', {}).get('tools', [])
supports_function_calling = agent_config.get('provider', {}).get('supports_function_calling', False)
# 2. 获取或创建会话
if conversation_id:
conversation = conv_service.get_conversation(conversation_id)
else:
conversation = conv_service.create_conversation(user.id)
conversation_id = conversation.conversation_id
await websocket.send_json({
"type": "conversation_created",
"conversation_id": conversation_id
})
# 3. 广播用户消息
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": None,
"role": "user",
"content": message,
"source": "web",
"created_at": datetime.utcnow().isoformat()
}
})
# 4. 保存用户消息
user_msg = conv_service.add_message(
conversation_id=conversation.id,
role='user',
content=message,
source='web',
extra_data=extra_data_for_msg
)
# 5. 获取对话历史
history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20))
# 6. 构建工具 schemaFunction Calling
tools_schema = []
if supports_function_calling and agent_tools:
# 搜索工具
if 'search' in agent_tools:
tool_service = ToolService(db)
search_tool = tool_service.get_default_tool('search')
if search_tool and search_tool.config.get('api_key'):
tools_schema.append({
"type": "function",
"function": {
"name": "web_search",
"description": "搜索互联网获取实时信息、新闻、数据等。当用户询问需要最新信息的问题时使用此工具。",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "搜索关键词或问题"
}
},
"required": ["query"]
}
}
})
# 7. 调用LLMFunction Calling模式
if not agent_config or not agent_config.get('provider'):
await websocket.send_json({
"type": "error",
"message": "Agent配置不完整"
})
continue
try:
response = None
thinking_content = None
tool_calls_record = []
# 第一阶段让LLM决定是否调用工具
if tools_schema:
response, thinking_content, tool_calls = await llm_service.chat_with_tools(
messages=history,
provider_config=agent_config['provider'],
agent_config=agent_config['agent'],
tools=tools_schema,
enable_thinking=enable_thinking,
images=image_contents
)
# 如果LLM请求调用工具
if tool_calls:
logger.info(f"LLM请求调用工具: {tool_calls}")
# 发送工具调用通知给前端
await websocket.send_json({
"type": "tool_calls",
"conversation_id": conversation_id,
"tool_calls": [
{"name": tc['name'], "arguments": tc['arguments']}
for tc in tool_calls
]
})
# 执行工具调用
tool_results = []
tool_service = ToolService(db)
search_tool = tool_service.get_default_tool('search')
for tc in tool_calls:
if tc['name'] == 'web_search':
query = tc['arguments'].get('query', message)
logger.info(f"执行搜索: query={query}")
import httpx
import time
start_time = time.time()
try:
tavily_url = "https://api.tavily.com/search"
config = search_tool.config
payload = {
"api_key": config.get('api_key'),
"query": query,
"max_results": config.get('max_results', 5),
"search_depth": config.get('search_depth', 'basic')
}
with httpx.Client(timeout=30) as client:
resp = client.post(tavily_url, json=payload)
search_result = resp.json()
duration_ms = int((time.time() - start_time) * 1000)
if search_result.get("results"):
# 构建搜索结果
search_content = []
for i, r in enumerate(search_result["results"][:5], 1):
search_content.append({
"title": r.get('title', 'N/A'),
"content": r.get('content', r.get('snippet', ''))[:300],
"url": r.get('url', 'N/A')
})
tool_results.append({
"tool_call_id": tc['id'],
"content": json.dumps(search_content)
})
# 发送搜索结果给前端
await websocket.send_json({
"type": "search_results",
"conversation_id": conversation_id,
"results": [
{"title": r.get('title'), "snippet": r.get('content', '')[:150], "url": r.get('url')}
for r in search_result["results"][:5]
],
"query": query
})
# 记录日志
tool_service.increment_stats(search_tool.id, True)
tool_service.log_usage({
'tool_id': search_tool.id,
'tool_type': 'search',
'query': query,
'success': True,
'result_summary': f'{len(search_result["results"])} results',
'conversation_id': conversation_id,
'agent_id': current_agent_id,
'duration_ms': duration_ms
})
tool_calls_record.append({
"name": "web_search",
"query": query,
"results_count": len(search_result["results"])
})
except Exception as e:
logger.error(f"搜索失败: {e}")
duration_ms = int((time.time() - start_time) * 1000)
tool_service.increment_stats(search_tool.id, False)
tool_service.log_usage({
'tool_id': search_tool.id,
'tool_type': 'search',
'query': query,
'success': False,
'error_message': str(e),
'conversation_id': conversation_id,
'duration_ms': duration_ms
})
tool_results.append({
"tool_call_id": tc['id'],
"content": json.dumps({"error": str(e)})
})
# 将工具调用消息添加到历史
# 注意:这里需要将 assistant 的 tool_calls 消息添加到历史
# 但我们用的是简化的历史格式,需要重新构建
# 第二阶段将工具结果返回给LLM
if tool_results:
# 重新获取完整历史(包含工具调用)
history_with_tools = history.copy()
# 添加 assistant 的 tool_calls 消息
history_with_tools.append({
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": tc['id'],
"type": "function",
"function": {
"name": tc['name'],
"arguments": json.dumps(tc['arguments'])
}
}
for tc in tool_calls
]
})
# 添加工具结果
for tr in tool_results:
history_with_tools.append({
"role": "tool",
"tool_call_id": tr['tool_call_id'],
"content": tr['content']
})
response, thinking_content = await llm_service.chat_with_tool_results(
messages=history_with_tools,
provider_config=agent_config['provider'],
agent_config=agent_config['agent'],
enable_thinking=enable_thinking
)
# 如果不支持 Function Calling 或没有工具,直接调用普通 chat
if response is None:
response, thinking_content = await llm_service.chat(
messages=history,
provider_config=agent_config['provider'],
agent_config=agent_config['agent'],
enable_thinking=enable_thinking,
images=image_contents
)
logger.info(f"LLM响应: response长度={len(response)}, thinking长度={len(thinking_content) if thinking_content else 0}")
# 保存AI回复
extra_data_to_save = None
if tool_calls_record:
extra_data_to_save = {'tool_calls': tool_calls_record}
assistant_msg = conv_service.add_message(
conversation_id=conversation.id,
role='assistant',
content=response,
source='web',
thinking_content=thinking_content if thinking_content else None,
agent_id=current_agent_id,
model_used=agent_config['provider'].get('default_model'),
extra_data=extra_data_to_save
)
# 发送AI回复
await websocket.send_json({
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": assistant_msg.id,
"role": "assistant",
"content": response,
"thinking_content": thinking_content if thinking_content else None,
"source": "web",
"agent_id": current_agent_id,
"agent_name": agent_config['agent'].get('display_name'),
"tool_calls": tool_calls_record, # v3.0: 返回工具调用记录
"created_at": assistant_msg.created_at.isoformat()
}
})
logger.info(f"AI回复已发送: conversation_id={conversation_id}")
await websocket.send_json({
"type": "stream_end",
"conversation_id": conversation_id
})
except Exception as e:
logger.error(f"LLM调用失败: {e}")
await websocket.send_json({
"type": "error",
"message": f"AI服务暂时不可用: {str(e)}"
})
finally:
db.close()
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
except Exception as e:
logger.error(f"WebSocket错误: {e}")
manager.disconnect(websocket, user_id)
# ==================== 后台管理API ====================
@app.get("/api/admin/stats")
async def get_stats(db: Session = Depends(get_db)):
"""获取统计数据"""
total_users = db.query(User).count()
total_conversations = db.query(Conversation).count()
total_messages = db.query(Message).count()
return {
"total_users": total_users,
"total_conversations": total_conversations,
"total_messages": total_messages
}
# ==================== Matrix消息处理回调 ====================
async def handle_matrix_message(
action: str,
conversation_id: str = None,
user_message: str = None,
room_id: str = None,
message_id: int = None,
sender: str = None,
content: str = None,
thinking_content: str = None,
agent_id: int = None,
agent_name: str = None
):
"""处理从Matrix收到的消息"""
if action == "new_conversation":
await manager.send_to_user(MAIN_USER_ID, {
"type": "new_conversation",
"conversation_id": conversation_id
})
return
if action == "user_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "user_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "user",
"content": user_message,
"source": "matrix",
"sender": sender,
"created_at": datetime.now().isoformat()
}
})
return
if action == "assistant_message":
await manager.send_to_user(MAIN_USER_ID, {
"type": "assistant_message",
"conversation_id": conversation_id,
"message": {
"id": message_id,
"role": "assistant",
"content": content,
"thinking_content": thinking_content,
"source": "matrix",
"agent_id": agent_id,
"agent_name": agent_name,
"created_at": datetime.now().isoformat()
}
})
return
# ==================== 启动和关闭 ====================
# 导入Matrix Bot
try:
from services.matrix_service_v2 import matrix_bot
MATRIX_V2 = True
except ImportError:
from services.matrix_service import matrix_bot
MATRIX_V2 = False
@app.on_event("startup")
async def startup():
"""应用启动"""
init_db()
logger.info("数据库初始化完成")
# 初始化默认数据
init_default_data()
# 初始化Matrix Bot
try:
success = await matrix_bot.init_from_config()
if success and matrix_bot.is_running:
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
logger.info("Matrix Bot v2 已启动")
else:
logger.info("Matrix Bot 未配置或未启用")
except Exception as e:
logger.warning(f"Matrix Bot初始化失败: {e}")
@app.on_event("shutdown")
async def shutdown():
"""应用关闭"""
try:
await matrix_bot.disconnect()
except:
pass
logger.info("应用已关闭")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=19020)