diff --git a/main_v2.py b/main_v2.py index 1eb6fde..4bb09dd 100644 --- a/main_v2.py +++ b/main_v2.py @@ -3,7 +3,7 @@ AI对话系统 v2.0.0 - 主应用 支持:大模型池、Agent管理、渠道独立绑定、思考功能开关 """ from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends, HTTPException, Request -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlalchemy.orm import Session @@ -13,6 +13,9 @@ import json import logging from datetime import datetime import os +import base64 +import uuid +import time # 使用新的数据模型 from models_v2 import ( @@ -34,7 +37,14 @@ app = FastAPI(title="AI对话系统 v2.0", version="2.0.0") # 静态文件和模板 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +UPLOADS_DIR = os.path.join(BASE_DIR, "uploads", "images") + +# 确保上传目录存在 +os.makedirs(UPLOADS_DIR, exist_ok=True) + +# 静态文件服务 app.mount("/static", StaticFiles(directory=os.path.join(BASE_DIR, "static")), name="static") +app.mount("/uploads", StaticFiles(directory=os.path.join(BASE_DIR, "uploads")), name="uploads") templates = Jinja2Templates(directory=os.path.join(BASE_DIR, "templates")) # WebSocket连接管理 @@ -575,6 +585,56 @@ async def perform_search(data: dict, db: Session = Depends(get_db)): return {"success": False, "message": "不支持的搜索提供商"} +# ==================== 图片上传 API ==================== + +@app.post("/api/v2/upload-image") +async def upload_image(data: dict): + """上传图片到服务器,返回文件路径""" + try: + image_data = data.get('image') + file_name = data.get('name', 'image.png') + + if not image_data: + return {"success": False, "message": "缺少图片数据"} + + # 解析 base64 数据 + if image_data.startswith('data:image/'): + # 提取格式和base64内容 + header, base64_content = image_data.split(',', 1) + # 从header中提取图片格式 + format_match = header.split(':')[1].split(';')[0] # 如 'image/png' + ext = format_match.split('/')[1] if '/' in format_match else 'png' + else: + base64_content = image_data + ext = 'png' + + # 生成唯一文件名 + timestamp = int(time.time()) + unique_id = uuid.uuid4().hex[:8] + safe_name = f"{timestamp}_{unique_id}.{ext}" + + # 保存文件 + file_path = os.path.join(UPLOADS_DIR, safe_name) + image_bytes = base64.b64decode(base64_content) + + # 检查文件大小(限制10MB) + if len(image_bytes) > 10 * 1024 * 1024: + return {"success": False, "message": "图片大小超过10MB限制"} + + with open(file_path, 'wb') as f: + f.write(image_bytes) + + # 返回可访问的URL路径 + url_path = f"/uploads/images/{safe_name}" + logger.info(f"图片已保存: {file_path}, URL: {url_path}") + + return {"success": True, "path": url_path, "name": safe_name} + + except Exception as e: + logger.error(f"图片上传失败: {e}") + return {"success": False, "message": str(e)} + + # ==================== 对话 API(保留原有) ==================== @app.get("/api/conversations") @@ -787,6 +847,7 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str): # 处理文件内容,添加到消息 image_contents = [] # 图片内容(用于视觉模型) text_contents = [] # 文本文件内容 + image_paths = [] # 图片服务器路径(用于历史记录显示) if files: for f in files: if f.get('type') and f['type'].startswith('image/'): @@ -796,6 +857,13 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str): 'type': f['type'], 'data': f.get('content', '') # base64 数据 }) + # 记录服务器路径(用于历史记录) + if f.get('serverPath'): + image_paths.append({ + 'name': f['name'], + 'type': f['type'], + 'url': f['serverPath'] # 服务器文件路径 + }) # 不添加文件名文本,图片信息保存在 extra_data 中 elif f.get('content'): # 文本文件:直接添加内容,不带文件名前缀 @@ -808,10 +876,16 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str): for content in text_contents: message += f"\n\n{content}" - # 保存图片信息到 extra_data(用于历史记录) + # 保存图片和文件信息到 extra_data(用于历史记录) extra_data_for_msg = None - if image_contents: - # 只保存图片 URL(不保存完整 base64) + if image_paths: + # 图片保存服务器路径URL,历史记录可以显示 + extra_data_for_msg = { + 'images': image_paths, + 'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')] + } + elif image_contents: + # 没有服务器路径但有问题(可能上传失败) extra_data_for_msg = { 'images': [{'name': i['name'], 'type': i['type']} for i in image_contents], 'files': [{'name': f['name'], 'type': f['type']} for f in files if not f['type'].startswith('image/')] diff --git a/templates/index.html b/templates/index.html index f776a65..e366ada 100644 --- a/templates/index.html +++ b/templates/index.html @@ -228,6 +228,28 @@ 放大图片 + + +