修复Kimi-K2.5模型在第二轮调用时输出伪工具调用格式的问题: - 添加系统提示告诉模型直接根据工具结果回答 - 过滤 <|tool_calls_section_begin|> 等内部格式标记 - 清理多余空行 版本: v3.0.1
1311 lines
53 KiB
Python
1311 lines
53 KiB
Python
"""
|
||
AI对话系统 v3.0.0 - 主应用
|
||
支持:大模型池、Agent管理、渠道独立绑定、思考功能开关、Function Calling(LLM自主调用工具)
|
||
"""
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request
|
||
from fastapi.responses import HTMLResponse, JSONResponse
|
||
from fastapi.staticfiles import StaticFiles
|
||
from fastapi.templating import Jinja2Templates
|
||
from sqlalchemy.orm import Session
|
||
from typing import Dict, Set, Optional
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
from datetime import datetime
|
||
import os
|
||
import base64
|
||
import uuid
|
||
import time
|
||
|
||
# 使用新的数据模型
|
||
from models_v2 import (
|
||
init_db, get_db, SessionLocal,
|
||
User, Conversation, Message, SystemConfig,
|
||
LLMProvider, Agent, Channel, ChannelAgentMapping, MatrixRoomMapping, ToolConfig, ToolUsageLog,
|
||
init_default_data
|
||
)
|
||
from services.llm_service import llm_service
|
||
from services.agent_service import AgentService, LLMProviderService, ChannelService, ToolService
|
||
from services.conversation_service import ConversationService
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建应用
|
||
app = FastAPI(title="AI对话系统 v3.0", version="3.0.1")
|
||
|
||
# 静态文件和模板
|
||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
UPLOADS_DIR = os.path.join(BASE_DIR, "uploads", "images")
|
||
|
||
# 确保上传目录存在
|
||
os.makedirs(UPLOADS_DIR, exist_ok=True)
|
||
|
||
# 静态文件服务
|
||
app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static")
|
||
app.mount("/uploads", StaticFiles(directory=os.path.join(BASE_DIR, "uploads")), name="uploads")
|
||
templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates"))
|
||
|
||
# WebSocket连接管理
|
||
class ConnectionManager:
|
||
def __init__(self):
|
||
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
||
|
||
async def connect(self, websocket: WebSocket, user_id: str):
|
||
await websocket.accept()
|
||
if user_id not in self.active_connections:
|
||
self.active_connections[user_id] = set()
|
||
self.active_connections[user_id].add(websocket)
|
||
logger.info(f"WebSocket连接: {user_id}")
|
||
|
||
def disconnect(self, websocket: WebSocket, user_id: str):
|
||
if user_id in self.active_connections:
|
||
self.active_connections[user_id].discard(websocket)
|
||
if not self.active_connections[user_id]:
|
||
del self.active_connections[user_id]
|
||
|
||
async def send_to_user(self, user_id: str, message: dict):
|
||
if user_id in self.active_connections:
|
||
for connection in self.active_connections[user_id]:
|
||
try:
|
||
await connection.send_json(message)
|
||
logger.info(f"发送消息到用户 {user_id}: {message.get('type', 'unknown')}")
|
||
except Exception as e:
|
||
logger.error(f"发送消息失败: {e}")
|
||
|
||
async def ping(self, user_id: str):
|
||
"""发送心跳ping"""
|
||
await self.send_to_user(user_id, {"type": "ping"})
|
||
|
||
|
||
manager = ConnectionManager()
|
||
|
||
# 固定主用户ID
|
||
MAIN_USER_ID = "main_user"
|
||
|
||
|
||
# ==================== 页面路由 ====================
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def index(request: Request):
|
||
"""主页 - 聊天界面"""
|
||
return templates.TemplateResponse("index.html", {"request": request})
|
||
|
||
|
||
@app.get("/admin", response_class=HTMLResponse)
|
||
async def admin(request: Request):
|
||
"""后台管理 - v2.0版本"""
|
||
return templates.TemplateResponse("admin_v2/index.html", {"request": request})
|
||
|
||
|
||
# ==================== v2 API路由 ====================
|
||
|
||
# ===== 大模型池 API =====
|
||
|
||
@app.get("/api/v2/providers")
|
||
async def get_providers(db: Session = Depends(get_db)):
|
||
"""获取所有大模型池"""
|
||
service = LLMProviderService(db)
|
||
providers = service.get_all_providers()
|
||
|
||
return {
|
||
"providers": [
|
||
{
|
||
"id": p.id,
|
||
"name": p.name,
|
||
"api_base": p.api_base,
|
||
"api_key": p.api_key,
|
||
"models": p.models,
|
||
"default_model": p.default_model,
|
||
"supports_thinking": p.supports_thinking,
|
||
"thinking_model": p.thinking_model,
|
||
"supports_vision": p.supports_vision,
|
||
"vision_model": p.vision_model,
|
||
"supports_function_calling": p.supports_function_calling,
|
||
"max_tokens": p.max_tokens,
|
||
"temperature": p.temperature,
|
||
"is_active": p.is_active,
|
||
"priority": p.priority,
|
||
"description": p.description
|
||
}
|
||
for p in providers
|
||
]
|
||
}
|
||
|
||
|
||
@app.post("/api/v2/providers")
|
||
async def create_provider(data: dict, db: Session = Depends(get_db)):
|
||
"""创建大模型池"""
|
||
service = LLMProviderService(db)
|
||
|
||
# 检查名称是否已存在
|
||
if service.get_provider_by_name(data.get('name')):
|
||
return {"success": False, "message": "名称已存在"}
|
||
|
||
provider = service.create_provider(data)
|
||
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
|
||
|
||
|
||
@app.put("/api/v2/providers/{provider_id}")
|
||
async def update_provider(provider_id: int, data: dict, db: Session = Depends(get_db)):
|
||
"""更新大模型池"""
|
||
service = LLMProviderService(db)
|
||
provider = service.update_provider(provider_id, data)
|
||
|
||
if not provider:
|
||
return {"success": False, "message": "Provider不存在"}
|
||
|
||
return {"success": True, "provider": {"id": provider.id, "name": provider.name}}
|
||
|
||
|
||
@app.delete("/api/v2/providers/{provider_id}")
|
||
async def delete_provider(provider_id: int, db: Session = Depends(get_db)):
|
||
"""删除大模型池"""
|
||
service = LLMProviderService(db)
|
||
success = service.delete_provider(provider_id)
|
||
|
||
if not success:
|
||
return {"success": False, "message": "无法删除(可能有Agent在使用)"}
|
||
|
||
return {"success": True}
|
||
|
||
|
||
@app.post("/api/v2/providers/models")
|
||
async def fetch_models(data: dict):
|
||
"""从API获取模型列表"""
|
||
api_base = data.get('api_base')
|
||
api_key = data.get('api_key', '')
|
||
|
||
models = await llm_service.get_available_models(api_base, api_key)
|
||
return {"models": models}
|
||
|
||
|
||
@app.post("/api/v2/providers/test")
|
||
async def test_provider(data: dict):
|
||
"""测试大模型连接"""
|
||
api_base = data.get('api_base')
|
||
api_key = data.get('api_key', '')
|
||
model = data.get('model', 'auto')
|
||
|
||
result = await llm_service.test_connection(api_base, api_key, model)
|
||
return result
|
||
|
||
|
||
# ===== Agent API =====
|
||
|
||
@app.get("/api/v2/agents")
|
||
async def get_agents(db: Session = Depends(get_db)):
|
||
"""获取所有Agent"""
|
||
service = AgentService(db)
|
||
agents = service.get_all_agents()
|
||
|
||
# 获取Provider名称
|
||
providers = {p.id: p.name for p in db.query(LLMProvider).all()}
|
||
|
||
return {
|
||
"agents": [
|
||
{
|
||
"id": a.id,
|
||
"name": a.name,
|
||
"display_name": a.display_name,
|
||
"llm_provider_id": a.llm_provider_id,
|
||
"llm_provider_name": providers.get(a.llm_provider_id),
|
||
"model_override": a.model_override,
|
||
"system_prompt": a.system_prompt,
|
||
"enable_thinking": a.enable_thinking,
|
||
"thinking_prompt": a.thinking_prompt,
|
||
"thinking_prefix": a.thinking_prefix,
|
||
"thinking_suffix": a.thinking_suffix,
|
||
"tools": a.tools or [], # 工具列表
|
||
"max_history": a.max_history,
|
||
"temperature_override": a.temperature_override,
|
||
"is_default": a.is_default,
|
||
"is_active": a.is_active,
|
||
"description": a.description
|
||
}
|
||
for a in agents
|
||
]
|
||
}
|
||
|
||
|
||
@app.post("/api/v2/agents")
|
||
async def create_agent(data: dict, db: Session = Depends(get_db)):
|
||
"""创建Agent"""
|
||
service = AgentService(db)
|
||
|
||
# 检查名称是否已存在
|
||
if service.get_agent_by_name(data.get('name')):
|
||
return {"success": False, "message": "Agent名称已存在"}
|
||
|
||
# 检查Provider是否存在
|
||
provider_service = LLMProviderService(db)
|
||
if not provider_service.get_provider(data.get('llm_provider_id')):
|
||
return {"success": False, "message": "选择的LLM Provider不存在"}
|
||
|
||
agent = service.create_agent(data)
|
||
|
||
# 如果设为默认,更新其他
|
||
if data.get('is_default'):
|
||
service.set_default_agent(agent.id)
|
||
|
||
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
|
||
|
||
|
||
@app.put("/api/v2/agents/{agent_id}")
|
||
async def update_agent(agent_id: int, data: dict, db: Session = Depends(get_db)):
|
||
"""更新Agent"""
|
||
service = AgentService(db)
|
||
|
||
# 如果设为默认,先更新
|
||
if data.get('is_default'):
|
||
service.set_default_agent(agent_id)
|
||
|
||
agent = service.update_agent(agent_id, data)
|
||
|
||
if not agent:
|
||
return {"success": False, "message": "Agent不存在"}
|
||
|
||
return {"success": True, "agent": {"id": agent.id, "name": agent.name}}
|
||
|
||
|
||
@app.delete("/api/v2/agents/{agent_id}")
|
||
async def delete_agent(agent_id: int, db: Session = Depends(get_db)):
|
||
"""删除Agent"""
|
||
service = AgentService(db)
|
||
success = service.delete_agent(agent_id)
|
||
|
||
if not success:
|
||
return {"success": False, "message": "无法删除默认Agent"}
|
||
|
||
return {"success": True}
|
||
|
||
|
||
@app.post("/api/v2/agents/{agent_id}/default")
|
||
async def set_agent_default(agent_id: int, db: Session = Depends(get_db)):
|
||
"""设置默认Agent"""
|
||
service = AgentService(db)
|
||
success = service.set_default_agent(agent_id)
|
||
|
||
return {"success": success}
|
||
|
||
|
||
@app.get("/api/v2/agents/{agent_id}/config")
|
||
async def get_agent_config(agent_id: int, db: Session = Depends(get_db)):
|
||
"""获取Agent完整配置(含LLM Provider)"""
|
||
service = AgentService(db)
|
||
config = service.get_agent_config(agent_id)
|
||
|
||
return {"config": config}
|
||
|
||
|
||
# ===== 渠道 API =====
|
||
|
||
@app.get("/api/v2/channels")
|
||
async def get_channels(db: Session = Depends(get_db)):
|
||
"""获取所有渠道及绑定关系"""
|
||
service = ChannelService(db)
|
||
channels = service.get_all_channels()
|
||
|
||
# 获取所有绑定关系
|
||
mappings = db.query(ChannelAgentMapping).order_by(ChannelAgentMapping.priority).all()
|
||
|
||
agents = {a.id: a for a in db.query(Agent).all()}
|
||
|
||
return {
|
||
"channels": [
|
||
{
|
||
"id": c.id,
|
||
"channel_type": c.channel_type,
|
||
"name": c.name,
|
||
"config": c.config,
|
||
"is_active": c.is_active,
|
||
"is_primary": c.is_primary,
|
||
"agent_mappings": [
|
||
{
|
||
"id": m.id,
|
||
"agent": {
|
||
"id": agents[m.agent_id].id if m.agent_id in agents else None,
|
||
"name": agents[m.agent_id].name if m.agent_id in agents else None,
|
||
"display_name": agents[m.agent_id].display_name if m.agent_id in agents else None,
|
||
"is_active": agents[m.agent_id].is_active if m.agent_id in agents else None
|
||
},
|
||
"priority": m.priority,
|
||
"mode": m.mode,
|
||
"conditions": m.conditions
|
||
}
|
||
for m in mappings if m.channel_id == c.id
|
||
]
|
||
}
|
||
for c in channels
|
||
],
|
||
"mappings": [
|
||
{
|
||
"id": m.id,
|
||
"channel_id": m.channel_id,
|
||
"channel_type": channels[0].channel_type if channels else None, # 需要从Channel获取
|
||
"channel_name": next((c.name for c in channels if c.id == m.channel_id), ""),
|
||
"agent_id": m.agent_id,
|
||
"agent_name": agents[m.agent_id].display_name if m.agent_id in agents else "",
|
||
"priority": m.priority,
|
||
"mode": m.mode,
|
||
"conditions": m.conditions
|
||
}
|
||
for m in mappings
|
||
]
|
||
}
|
||
|
||
|
||
@app.post("/api/v2/channels")
|
||
async def create_channel(data: dict, db: Session = Depends(get_db)):
|
||
"""创建渠道"""
|
||
service = ChannelService(db)
|
||
channel = service.create_channel(data)
|
||
|
||
return {"success": True, "channel": {"id": channel.id, "name": channel.name}}
|
||
|
||
|
||
@app.put("/api/v2/channels/{channel_id}")
|
||
async def update_channel(channel_id: int, data: dict, db: Session = Depends(get_db)):
|
||
"""更新渠道"""
|
||
service = ChannelService(db)
|
||
channel = service.update_channel(channel_id, data)
|
||
|
||
if not channel:
|
||
return {"success": False, "message": "渠道不存在"}
|
||
|
||
return {"success": True}
|
||
|
||
|
||
@app.put("/api/v2/channels/{channel_id}/config")
|
||
async def update_channel_config(channel_id: int, data: dict, db: Session = Depends(get_db)):
|
||
"""更新渠道配置(如Matrix配置)"""
|
||
channel = db.query(Channel).filter(Channel.id == channel_id).first()
|
||
|
||
if not channel:
|
||
return {"success": False, "message": "渠道不存在"}
|
||
|
||
channel.config = data.get('config', {})
|
||
db.commit()
|
||
|
||
# 如果是Matrix渠道,重新初始化Matrix Bot
|
||
if channel.channel_type == 'matrix':
|
||
# TODO: 重启Matrix Bot
|
||
pass
|
||
|
||
return {"success": True}
|
||
|
||
|
||
@app.post("/api/v2/channels/bind")
|
||
async def bind_agent_to_channel(data: dict, db: Session = Depends(get_db)):
|
||
"""绑定Agent到渠道"""
|
||
service = ChannelService(db)
|
||
|
||
channel_id = data.get('channel_id')
|
||
agent_id = data.get('agent_id')
|
||
|
||
# 检查渠道和Agent是否存在
|
||
if not service.get_channel(channel_id):
|
||
return {"success": False, "message": "渠道不存在"}
|
||
|
||
agent_service = AgentService(db)
|
||
if not agent_service.get_agent(agent_id):
|
||
return {"success": False, "message": "Agent不存在"}
|
||
|
||
mapping = service.bind_agent(
|
||
channel_id=channel_id,
|
||
agent_id=agent_id,
|
||
priority=data.get('priority', 0),
|
||
mode=data.get('mode', 'single'),
|
||
conditions=data.get('conditions')
|
||
)
|
||
|
||
return {"success": True, "mapping": {"id": mapping.id}}
|
||
|
||
|
||
@app.delete("/api/v2/channels/unbind/{mapping_id}")
|
||
async def unbind_agent(mapping_id: int, db: Session = Depends(get_db)):
|
||
"""解绑Agent"""
|
||
service = ChannelService(db)
|
||
success = service.unbind_agent(mapping_id)
|
||
|
||
return {"success": success}
|
||
|
||
|
||
# ==================== 工具配置 API ====================
|
||
|
||
@app.get("/api/v2/tools")
|
||
async def get_tools(tool_type: str = None, db: Session = Depends(get_db)):
|
||
"""获取所有工具配置"""
|
||
service = ToolService(db)
|
||
if tool_type:
|
||
tools = service.get_tools_by_type(tool_type)
|
||
else:
|
||
tools = service.get_all_tools()
|
||
|
||
return {
|
||
"tools": [
|
||
{
|
||
"id": t.id,
|
||
"name": t.name,
|
||
"tool_type": t.tool_type,
|
||
"provider": t.provider,
|
||
"config": t.config,
|
||
"is_active": t.is_active,
|
||
"is_default": t.is_default,
|
||
"total_calls": t.total_calls,
|
||
"success_calls": t.success_calls,
|
||
"failed_calls": t.failed_calls
|
||
}
|
||
for t in tools
|
||
]
|
||
}
|
||
|
||
|
||
@app.post("/api/v2/tools")
|
||
async def create_tool(data: dict, db: Session = Depends(get_db)):
|
||
"""创建工具配置"""
|
||
service = ToolService(db)
|
||
tool = service.create_tool(data)
|
||
return {"success": True, "tool": {"id": tool.id, "name": tool.name, "tool_type": tool.tool_type}}
|
||
|
||
|
||
@app.put("/api/v2/tools/{tool_id}")
|
||
async def update_tool(tool_id: int, data: dict, db: Session = Depends(get_db)):
|
||
"""更新工具配置"""
|
||
service = ToolService(db)
|
||
|
||
if data.get('is_default'):
|
||
service.set_default_tool(tool_id)
|
||
|
||
tool = service.update_tool(tool_id, data)
|
||
|
||
if not tool:
|
||
return {"success": False, "message": "工具不存在"}
|
||
|
||
return {"success": True, "tool": {"id": tool.id, "name": tool.name}}
|
||
|
||
|
||
@app.delete("/api/v2/tools/{tool_id}")
|
||
async def delete_tool(tool_id: int, db: Session = Depends(get_db)):
|
||
"""删除工具配置"""
|
||
service = ToolService(db)
|
||
success = service.delete_tool(tool_id)
|
||
|
||
return {"success": success}
|
||
|
||
|
||
@app.post("/api/v2/tools/{tool_id}/default")
|
||
async def set_tool_default(tool_id: int, db: Session = Depends(get_db)):
|
||
"""设置默认工具"""
|
||
service = ToolService(db)
|
||
success = service.set_default_tool(tool_id)
|
||
|
||
return {"success": success}
|
||
|
||
|
||
@app.get("/api/v2/tools/stats")
|
||
async def get_tool_stats(days: int = 7, db: Session = Depends(get_db)):
|
||
"""获取工具使用统计"""
|
||
service = ToolService(db)
|
||
stats = service.get_usage_stats(days=days)
|
||
return stats
|
||
|
||
|
||
@app.post("/api/v2/tools/search")
|
||
async def perform_search(data: dict, db: Session = Depends(get_db)):
|
||
"""执行搜索(供前端或Agent调用)"""
|
||
import httpx
|
||
import time
|
||
|
||
query = data.get('query')
|
||
if not query:
|
||
return {"success": False, "message": "缺少搜索关键词"}
|
||
|
||
service = ToolService(db)
|
||
tool_id = data.get('tool_id')
|
||
|
||
if tool_id:
|
||
tool = service.get_tool(tool_id)
|
||
else:
|
||
tool = service.get_default_tool('search')
|
||
|
||
if not tool or not tool.config.get('api_key'):
|
||
return {"success": False, "message": "未配置搜索工具"}
|
||
|
||
# Tavily Search API
|
||
if tool.provider == 'tavily' or tool.tool_type == 'search':
|
||
start_time = time.time()
|
||
try:
|
||
tavily_url = "https://api.tavily.com/search"
|
||
config = tool.config
|
||
payload = {
|
||
"api_key": config.get('api_key'),
|
||
"query": query,
|
||
"max_results": config.get('max_results', 5),
|
||
"include_raw_content": config.get('include_raw_content', False),
|
||
"search_depth": config.get('search_depth', 'basic')
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=30) as client:
|
||
response = await client.post(tavily_url, json=payload)
|
||
result = response.json()
|
||
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
# 更新统计和日志
|
||
service.increment_stats(tool.id, True)
|
||
service.log_usage({
|
||
'tool_id': tool.id,
|
||
'tool_type': 'search',
|
||
'query': query,
|
||
'success': True,
|
||
'result_summary': f'{len(result.get("results", []))} results',
|
||
'conversation_id': data.get('conversation_id'),
|
||
'agent_id': data.get('agent_id'),
|
||
'duration_ms': duration_ms
|
||
})
|
||
|
||
return {
|
||
"success": True,
|
||
"results": result.get("results", []),
|
||
"query": query
|
||
}
|
||
except Exception as e:
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
service.increment_stats(tool.id, False)
|
||
service.log_usage({
|
||
'tool_id': tool.id,
|
||
'tool_type': 'search',
|
||
'query': query,
|
||
'success': False,
|
||
'error_message': str(e),
|
||
'conversation_id': data.get('conversation_id'),
|
||
'duration_ms': duration_ms
|
||
})
|
||
return {"success": False, "message": str(e)}
|
||
|
||
return {"success": False, "message": "不支持的搜索提供商"}
|
||
|
||
|
||
# ==================== 图片上传 API ====================
|
||
|
||
@app.post("/api/v2/upload-image")
|
||
async def upload_image(data: dict):
|
||
"""上传图片到服务器,返回文件路径"""
|
||
try:
|
||
image_data = data.get('image')
|
||
file_name = data.get('name', 'image.png')
|
||
|
||
if not image_data:
|
||
return {"success": False, "message": "缺少图片数据"}
|
||
|
||
# 解析 base64 数据
|
||
if image_data.startswith('data:image/'):
|
||
# 提取格式和base64内容
|
||
header, base64_content = image_data.split(',', 1)
|
||
# 从header中提取图片格式
|
||
format_match = header.split(':')[1].split(';')[0] # 如 'image/png'
|
||
ext = format_match.split('/')[1] if '/' in format_match else 'png'
|
||
else:
|
||
base64_content = image_data
|
||
ext = 'png'
|
||
|
||
# 生成唯一文件名
|
||
timestamp = int(time.time())
|
||
unique_id = uuid.uuid4().hex[:8]
|
||
safe_name = f"{timestamp}_{unique_id}.{ext}"
|
||
|
||
# 保存文件
|
||
file_path = os.path.join(UPLOADS_DIR, safe_name)
|
||
image_bytes = base64.b64decode(base64_content)
|
||
|
||
# 检查文件大小(限制10MB)
|
||
if len(image_bytes) > 10 * 1024 * 1024:
|
||
return {"success": False, "message": "图片大小超过10MB限制"}
|
||
|
||
with open(file_path, 'wb') as f:
|
||
f.write(image_bytes)
|
||
|
||
# 返回可访问的URL路径
|
||
url_path = f"/uploads/images/{safe_name}"
|
||
logger.info(f"图片已保存: {file_path}, URL: {url_path}")
|
||
|
||
return {"success": True, "path": url_path, "name": safe_name}
|
||
|
||
except Exception as e:
|
||
logger.error(f"图片上传失败: {e}")
|
||
return {"success": False, "message": str(e)}
|
||
|
||
|
||
# ==================== 对话 API(保留原有) ====================
|
||
|
||
@app.get("/api/conversations")
|
||
async def get_conversations(db: Session = Depends(get_db)):
|
||
"""获取会话列表"""
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
conversations = conv_service.get_user_conversations(user.id)
|
||
|
||
return {
|
||
"conversations": [
|
||
{
|
||
"id": c.conversation_id,
|
||
"title": c.title or "新对话",
|
||
"created_at": c.created_at.isoformat(),
|
||
"updated_at": c.updated_at.isoformat()
|
||
}
|
||
for c in conversations
|
||
]
|
||
}
|
||
|
||
|
||
@app.get("/api/conversations/latest")
|
||
async def get_latest_conversation(db: Session = Depends(get_db)):
|
||
"""获取最新会话"""
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
conversations = conv_service.get_user_conversations(user.id)
|
||
|
||
if conversations:
|
||
latest = conversations[0]
|
||
return {
|
||
"conversation": {
|
||
"id": latest.conversation_id,
|
||
"title": latest.title or "新对话",
|
||
"updated_at": latest.updated_at.isoformat()
|
||
}
|
||
}
|
||
return {"conversation": None}
|
||
|
||
|
||
@app.post("/api/conversations")
|
||
async def create_conversation(db: Session = Depends(get_db)):
|
||
"""创建新会话"""
|
||
conv_service = ConversationService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
conversation = conv_service.create_conversation(user.id)
|
||
|
||
return {
|
||
"id": conversation.conversation_id,
|
||
"title": conversation.title or "新对话",
|
||
"created_at": conversation.created_at.isoformat()
|
||
}
|
||
|
||
|
||
@app.get("/api/conversations/{conversation_id}/messages")
|
||
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
|
||
"""获取会话消息"""
|
||
conv_service = ConversationService(db)
|
||
conversation = conv_service.get_conversation(conversation_id)
|
||
|
||
if not conversation:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
messages = conv_service.get_messages(conversation.id)
|
||
|
||
return {
|
||
"messages": [
|
||
{
|
||
"id": m.id,
|
||
"role": m.role,
|
||
"content": m.content,
|
||
"thinking_content": m.thinking_content, # v2新增
|
||
"extra_data": m.extra_data, # 包含搜索结果等
|
||
"source": m.source,
|
||
"agent_id": m.agent_id, # v2新增
|
||
"model_used": m.model_used, # v2新增
|
||
"created_at": m.created_at.isoformat()
|
||
}
|
||
for m in messages
|
||
]
|
||
}
|
||
|
||
|
||
@app.delete("/api/conversations/{conversation_id}")
|
||
async def delete_conversation(conversation_id: str, db: Session = Depends(get_db)):
|
||
"""删除会话"""
|
||
conv_service = ConversationService(db)
|
||
success = conv_service.delete_conversation(conversation_id)
|
||
|
||
if not success:
|
||
raise HTTPException(status_code=404, detail="会话不存在")
|
||
|
||
return {"success": True}
|
||
|
||
|
||
# ==================== WebSocket路由 ====================
|
||
|
||
@app.websocket("/ws/{user_id}")
|
||
async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
||
"""WebSocket连接 - 实时对话"""
|
||
actual_user_id = MAIN_USER_ID
|
||
await manager.connect(websocket, actual_user_id)
|
||
|
||
# 初始化时获取默认Agent ID
|
||
db = SessionLocal()
|
||
try:
|
||
agent_service = AgentService(db)
|
||
default_agent = agent_service.get_default_agent()
|
||
default_agent_id = default_agent.id if default_agent else None
|
||
finally:
|
||
db.close()
|
||
|
||
current_conversation_id = None
|
||
current_agent_id = default_agent_id
|
||
|
||
try:
|
||
while True:
|
||
try:
|
||
data = await websocket.receive_json()
|
||
except Exception as json_err:
|
||
logger.error(f"JSON解析错误: {json_err}")
|
||
# 如果连接已断开,退出循环
|
||
if "disconnect" in str(json_err).lower() or "closed" in str(json_err).lower():
|
||
logger.info("WebSocket已断开,退出循环")
|
||
break
|
||
try:
|
||
text_data = await websocket.receive_text()
|
||
if text_data.strip():
|
||
data = json.loads(text_data)
|
||
else:
|
||
continue
|
||
except Exception as text_err:
|
||
logger.error(f"文本消息解析错误: {text_err}")
|
||
if "disconnect" in str(text_err).lower() or "closed" in str(text_err).lower():
|
||
logger.info("WebSocket已断开,退出循环")
|
||
break
|
||
continue
|
||
|
||
action = data.get("action")
|
||
logger.info(f"WebSocket收到消息: action={action}")
|
||
|
||
# 每次消息处理时创建新的数据库会话,处理完后关闭
|
||
try:
|
||
db = SessionLocal()
|
||
conv_service = ConversationService(db)
|
||
agent_service = AgentService(db)
|
||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||
|
||
if action == "select_conversation":
|
||
current_conversation_id = data.get("conversation_id")
|
||
conversation = conv_service.get_conversation(current_conversation_id)
|
||
|
||
if conversation:
|
||
messages = conv_service.get_messages(conversation.id)
|
||
|
||
# 获取对话使用的Agent ID
|
||
conv_agent_id = conversation.current_agent_id
|
||
|
||
await websocket.send_json({
|
||
"type": "history",
|
||
"conversation_id": current_conversation_id,
|
||
"agent_id": conv_agent_id, # 返回对话的Agent ID
|
||
"messages": [
|
||
{
|
||
"role": m.role,
|
||
"content": m.content,
|
||
"thinking_content": m.thinking_content,
|
||
"extra_data": m.extra_data, # 包含搜索结果等
|
||
"agent_id": m.agent_id, # 每条消息的Agent ID
|
||
"source": m.source,
|
||
"created_at": m.created_at.isoformat()
|
||
}
|
||
for m in messages
|
||
]
|
||
})
|
||
|
||
elif action == "switch_agent":
|
||
# 切换Agent
|
||
new_agent_id = data.get("agent_id")
|
||
agent = agent_service.get_agent(new_agent_id)
|
||
if agent and agent.is_active:
|
||
current_agent_id = new_agent_id
|
||
await websocket.send_json({
|
||
"type": "agent_switched",
|
||
"agent_id": current_agent_id,
|
||
"agent_name": agent.display_name or agent.name
|
||
})
|
||
|
||
elif action == "chat":
|
||
message = data.get("message", "")
|
||
files = data.get("files", []) # 上传的文件
|
||
conversation_id = data.get("conversation_id")
|
||
enable_thinking = data.get("enable_thinking", True)
|
||
agent_id_override = data.get("agent_id")
|
||
# v3.0: 移除 disabled_tools,由LLM自主决定
|
||
|
||
if agent_id_override:
|
||
agent = agent_service.get_agent(agent_id_override)
|
||
if agent and agent.is_active:
|
||
current_agent_id = agent_id_override
|
||
|
||
# 如果没有消息但有文件,构造消息
|
||
if not message.strip() and files:
|
||
message = "[上传文件]"
|
||
|
||
if not message.strip() and not files:
|
||
continue
|
||
|
||
# 处理文件内容
|
||
image_contents = []
|
||
text_contents = []
|
||
image_paths = []
|
||
if files:
|
||
for f in files:
|
||
if f.get('type') and f['type'].startswith('image/'):
|
||
image_contents.append({
|
||
'name': f['name'],
|
||
'type': f['type'],
|
||
'data': f.get('content', '')
|
||
})
|
||
if f.get('serverPath'):
|
||
image_paths.append({
|
||
'name': f['name'],
|
||
'type': f['type'],
|
||
'url': f['serverPath']
|
||
})
|
||
elif f.get('content'):
|
||
text_contents.append(f['content'][:3000])
|
||
if len(f['content']) > 3000:
|
||
text_contents[-1] += "...(内容过长已截断)"
|
||
|
||
if text_contents:
|
||
for content in text_contents:
|
||
message += f"\n\n{content}"
|
||
|
||
# 保存文件信息到 extra_data
|
||
extra_data_for_msg = None
|
||
if image_paths:
|
||
extra_data_for_msg = {
|
||
'images': image_paths,
|
||
'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')]
|
||
}
|
||
elif image_contents:
|
||
extra_data_for_msg = {
|
||
'images': [{'name': i['name'], 'type': i['type']} for i in image_contents],
|
||
'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')]
|
||
}
|
||
|
||
# 1. 获取Agent配置
|
||
agent_config = agent_service.get_agent_config(current_agent_id)
|
||
agent_tools = agent_config.get('agent', {}).get('tools', [])
|
||
supports_function_calling = agent_config.get('provider', {}).get('supports_function_calling', False)
|
||
|
||
# 2. 获取或创建会话
|
||
if conversation_id:
|
||
conversation = conv_service.get_conversation(conversation_id)
|
||
else:
|
||
conversation = conv_service.create_conversation(user.id)
|
||
conversation_id = conversation.conversation_id
|
||
await websocket.send_json({
|
||
"type": "conversation_created",
|
||
"conversation_id": conversation_id
|
||
})
|
||
|
||
# 3. 广播用户消息
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "user_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": None,
|
||
"role": "user",
|
||
"content": message,
|
||
"source": "web",
|
||
"created_at": datetime.utcnow().isoformat()
|
||
}
|
||
})
|
||
|
||
# 4. 保存用户消息
|
||
user_msg = conv_service.add_message(
|
||
conversation_id=conversation.id,
|
||
role='user',
|
||
content=message,
|
||
source='web',
|
||
extra_data=extra_data_for_msg
|
||
)
|
||
|
||
# 5. 获取对话历史
|
||
history = conv_service.get_conversation_history(conversation_id, limit=agent_config['agent'].get('max_history', 20))
|
||
|
||
# 6. 构建工具 schema(Function Calling)
|
||
tools_schema = []
|
||
if supports_function_calling and agent_tools:
|
||
# 搜索工具
|
||
if 'search' in agent_tools:
|
||
tool_service = ToolService(db)
|
||
search_tool = tool_service.get_default_tool('search')
|
||
if search_tool and search_tool.config.get('api_key'):
|
||
tools_schema.append({
|
||
"type": "function",
|
||
"function": {
|
||
"name": "web_search",
|
||
"description": "搜索互联网获取实时信息、新闻、数据等。当用户询问需要最新信息的问题时使用此工具。",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"query": {
|
||
"type": "string",
|
||
"description": "搜索关键词或问题"
|
||
}
|
||
},
|
||
"required": ["query"]
|
||
}
|
||
}
|
||
})
|
||
|
||
# 7. 调用LLM(Function Calling模式)
|
||
if not agent_config or not agent_config.get('provider'):
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"message": "Agent配置不完整"
|
||
})
|
||
continue
|
||
|
||
try:
|
||
response = None
|
||
thinking_content = None
|
||
tool_calls_record = []
|
||
|
||
# 第一阶段:让LLM决定是否调用工具
|
||
if tools_schema:
|
||
response, thinking_content, tool_calls = await llm_service.chat_with_tools(
|
||
messages=history,
|
||
provider_config=agent_config['provider'],
|
||
agent_config=agent_config['agent'],
|
||
tools=tools_schema,
|
||
enable_thinking=enable_thinking,
|
||
images=image_contents
|
||
)
|
||
|
||
# 如果LLM请求调用工具
|
||
if tool_calls:
|
||
logger.info(f"LLM请求调用工具: {tool_calls}")
|
||
|
||
# 发送工具调用通知给前端
|
||
await websocket.send_json({
|
||
"type": "tool_calls",
|
||
"conversation_id": conversation_id,
|
||
"tool_calls": [
|
||
{"name": tc['name'], "arguments": tc['arguments']}
|
||
for tc in tool_calls
|
||
]
|
||
})
|
||
|
||
# 执行工具调用
|
||
tool_results = []
|
||
tool_service = ToolService(db)
|
||
search_tool = tool_service.get_default_tool('search')
|
||
|
||
for tc in tool_calls:
|
||
if tc['name'] == 'web_search':
|
||
query = tc['arguments'].get('query', message)
|
||
logger.info(f"执行搜索: query={query}")
|
||
|
||
import httpx
|
||
import time
|
||
start_time = time.time()
|
||
|
||
try:
|
||
tavily_url = "https://api.tavily.com/search"
|
||
config = search_tool.config
|
||
payload = {
|
||
"api_key": config.get('api_key'),
|
||
"query": query,
|
||
"max_results": config.get('max_results', 5),
|
||
"search_depth": config.get('search_depth', 'basic')
|
||
}
|
||
|
||
with httpx.Client(timeout=30) as client:
|
||
resp = client.post(tavily_url, json=payload)
|
||
search_result = resp.json()
|
||
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
||
if search_result.get("results"):
|
||
# 构建搜索结果
|
||
search_content = []
|
||
for i, r in enumerate(search_result["results"][:5], 1):
|
||
search_content.append({
|
||
"title": r.get('title', 'N/A'),
|
||
"content": r.get('content', r.get('snippet', ''))[:300],
|
||
"url": r.get('url', 'N/A')
|
||
})
|
||
|
||
tool_results.append({
|
||
"tool_call_id": tc['id'],
|
||
"content": json.dumps(search_content)
|
||
})
|
||
|
||
# 发送搜索结果给前端
|
||
await websocket.send_json({
|
||
"type": "search_results",
|
||
"conversation_id": conversation_id,
|
||
"results": [
|
||
{"title": r.get('title'), "snippet": r.get('content', '')[:150], "url": r.get('url')}
|
||
for r in search_result["results"][:5]
|
||
],
|
||
"query": query
|
||
})
|
||
|
||
# 记录日志
|
||
tool_service.increment_stats(search_tool.id, True)
|
||
tool_service.log_usage({
|
||
'tool_id': search_tool.id,
|
||
'tool_type': 'search',
|
||
'query': query,
|
||
'success': True,
|
||
'result_summary': f'{len(search_result["results"])} results',
|
||
'conversation_id': conversation_id,
|
||
'agent_id': current_agent_id,
|
||
'duration_ms': duration_ms
|
||
})
|
||
|
||
tool_calls_record.append({
|
||
"name": "web_search",
|
||
"query": query,
|
||
"results_count": len(search_result["results"])
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"搜索失败: {e}")
|
||
duration_ms = int((time.time() - start_time) * 1000)
|
||
tool_service.increment_stats(search_tool.id, False)
|
||
tool_service.log_usage({
|
||
'tool_id': search_tool.id,
|
||
'tool_type': 'search',
|
||
'query': query,
|
||
'success': False,
|
||
'error_message': str(e),
|
||
'conversation_id': conversation_id,
|
||
'duration_ms': duration_ms
|
||
})
|
||
tool_results.append({
|
||
"tool_call_id": tc['id'],
|
||
"content": json.dumps({"error": str(e)})
|
||
})
|
||
|
||
# 将工具调用消息添加到历史
|
||
# 注意:这里需要将 assistant 的 tool_calls 消息添加到历史
|
||
# 但我们用的是简化的历史格式,需要重新构建
|
||
|
||
# 第二阶段:将工具结果返回给LLM
|
||
if tool_results:
|
||
# 重新获取完整历史(包含工具调用)
|
||
history_with_tools = history.copy()
|
||
# 添加 assistant 的 tool_calls 消息
|
||
history_with_tools.append({
|
||
"role": "assistant",
|
||
"content": None,
|
||
"tool_calls": [
|
||
{
|
||
"id": tc['id'],
|
||
"type": "function",
|
||
"function": {
|
||
"name": tc['name'],
|
||
"arguments": json.dumps(tc['arguments'])
|
||
}
|
||
}
|
||
for tc in tool_calls
|
||
]
|
||
})
|
||
# 添加工具结果
|
||
for tr in tool_results:
|
||
history_with_tools.append({
|
||
"role": "tool",
|
||
"tool_call_id": tr['tool_call_id'],
|
||
"content": tr['content']
|
||
})
|
||
|
||
response, thinking_content = await llm_service.chat_with_tool_results(
|
||
messages=history_with_tools,
|
||
provider_config=agent_config['provider'],
|
||
agent_config=agent_config['agent'],
|
||
enable_thinking=enable_thinking
|
||
)
|
||
|
||
# 如果不支持 Function Calling 或没有工具,直接调用普通 chat
|
||
if response is None:
|
||
response, thinking_content = await llm_service.chat(
|
||
messages=history,
|
||
provider_config=agent_config['provider'],
|
||
agent_config=agent_config['agent'],
|
||
enable_thinking=enable_thinking,
|
||
images=image_contents
|
||
)
|
||
|
||
logger.info(f"LLM响应: response长度={len(response)}, thinking长度={len(thinking_content) if thinking_content else 0}")
|
||
|
||
# 保存AI回复
|
||
extra_data_to_save = None
|
||
if tool_calls_record:
|
||
extra_data_to_save = {'tool_calls': tool_calls_record}
|
||
|
||
assistant_msg = conv_service.add_message(
|
||
conversation_id=conversation.id,
|
||
role='assistant',
|
||
content=response,
|
||
source='web',
|
||
thinking_content=thinking_content if thinking_content else None,
|
||
agent_id=current_agent_id,
|
||
model_used=agent_config['provider'].get('default_model'),
|
||
extra_data=extra_data_to_save
|
||
)
|
||
|
||
# 发送AI回复
|
||
await websocket.send_json({
|
||
"type": "assistant_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": assistant_msg.id,
|
||
"role": "assistant",
|
||
"content": response,
|
||
"thinking_content": thinking_content if thinking_content else None,
|
||
"source": "web",
|
||
"agent_id": current_agent_id,
|
||
"agent_name": agent_config['agent'].get('display_name'),
|
||
"tool_calls": tool_calls_record, # v3.0: 返回工具调用记录
|
||
"created_at": assistant_msg.created_at.isoformat()
|
||
}
|
||
})
|
||
|
||
logger.info(f"AI回复已发送: conversation_id={conversation_id}")
|
||
|
||
await websocket.send_json({
|
||
"type": "stream_end",
|
||
"conversation_id": conversation_id
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM调用失败: {e}")
|
||
await websocket.send_json({
|
||
"type": "error",
|
||
"message": f"AI服务暂时不可用: {str(e)}"
|
||
})
|
||
|
||
finally:
|
||
db.close()
|
||
|
||
except WebSocketDisconnect:
|
||
manager.disconnect(websocket, user_id)
|
||
except Exception as e:
|
||
logger.error(f"WebSocket错误: {e}")
|
||
manager.disconnect(websocket, user_id)
|
||
|
||
|
||
# ==================== 后台管理API ====================
|
||
|
||
@app.get("/api/admin/stats")
|
||
async def get_stats(db: Session = Depends(get_db)):
|
||
"""获取统计数据"""
|
||
total_users = db.query(User).count()
|
||
total_conversations = db.query(Conversation).count()
|
||
total_messages = db.query(Message).count()
|
||
|
||
return {
|
||
"total_users": total_users,
|
||
"total_conversations": total_conversations,
|
||
"total_messages": total_messages
|
||
}
|
||
|
||
|
||
# ==================== Matrix消息处理回调 ====================
|
||
|
||
async def handle_matrix_message(
|
||
action: str,
|
||
conversation_id: str = None,
|
||
user_message: str = None,
|
||
room_id: str = None,
|
||
message_id: int = None,
|
||
sender: str = None,
|
||
content: str = None,
|
||
thinking_content: str = None,
|
||
agent_id: int = None,
|
||
agent_name: str = None
|
||
):
|
||
"""处理从Matrix收到的消息"""
|
||
|
||
if action == "new_conversation":
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "new_conversation",
|
||
"conversation_id": conversation_id
|
||
})
|
||
return
|
||
|
||
if action == "user_message":
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "user_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": message_id,
|
||
"role": "user",
|
||
"content": user_message,
|
||
"source": "matrix",
|
||
"sender": sender,
|
||
"created_at": datetime.now().isoformat()
|
||
}
|
||
})
|
||
return
|
||
|
||
if action == "assistant_message":
|
||
await manager.send_to_user(MAIN_USER_ID, {
|
||
"type": "assistant_message",
|
||
"conversation_id": conversation_id,
|
||
"message": {
|
||
"id": message_id,
|
||
"role": "assistant",
|
||
"content": content,
|
||
"thinking_content": thinking_content,
|
||
"source": "matrix",
|
||
"agent_id": agent_id,
|
||
"agent_name": agent_name,
|
||
"created_at": datetime.now().isoformat()
|
||
}
|
||
})
|
||
return
|
||
|
||
|
||
# ==================== 启动和关闭 ====================
|
||
|
||
# 导入Matrix Bot
|
||
try:
|
||
from services.matrix_service_v2 import matrix_bot
|
||
MATRIX_V2 = True
|
||
except ImportError:
|
||
from services.matrix_service import matrix_bot
|
||
MATRIX_V2 = False
|
||
|
||
@app.on_event("startup")
|
||
async def startup():
|
||
"""应用启动"""
|
||
init_db()
|
||
logger.info("数据库初始化完成")
|
||
|
||
# 初始化默认数据
|
||
init_default_data()
|
||
|
||
# 初始化Matrix Bot
|
||
try:
|
||
success = await matrix_bot.init_from_config()
|
||
if success and matrix_bot.is_running:
|
||
asyncio.create_task(matrix_bot.start_sync(handle_matrix_message))
|
||
logger.info("Matrix Bot v2 已启动")
|
||
else:
|
||
logger.info("Matrix Bot 未配置或未启用")
|
||
except Exception as e:
|
||
logger.warning(f"Matrix Bot初始化失败: {e}")
|
||
|
||
|
||
@app.on_event("shutdown")
|
||
async def shutdown():
|
||
"""应用关闭"""
|
||
try:
|
||
await matrix_bot.disconnect()
|
||
except:
|
||
pass
|
||
logger.info("应用已关闭")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
import uvicorn
|
||
uvicorn.run(app, host="0.0.0.0", port=19020) |