Files
ai-chat-system/services/conversation_service.py

170 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
会话管理服务 - v2.0 兼容版
支持思考内容、Agent追踪等新字段
"""
from typing import List, Optional
from datetime import datetime
from sqlalchemy.orm import Session
import uuid
# 兼容v1和v2模型
try:
from models_v2 import User, Conversation, Message, Agent
USE_V2 = True
except ImportError:
from models import User, Conversation, Message
USE_V2 = False
class ConversationService:
def __init__(self, db: Session):
self.db = db
self.use_v2 = USE_V2
def get_or_create_user(self, user_id: str, display_name: str = None, user_type: str = 'web', matrix_user_id: str = None) -> User:
"""获取或创建用户"""
user = self.db.query(User).filter(User.user_id == user_id).first()
if not user:
user = User(
user_id=user_id,
display_name=display_name or user_id,
user_type=user_type,
matrix_user_id=matrix_user_id
)
self.db.add(user)
self.db.commit()
self.db.refresh(user)
else:
user.last_active_at = datetime.utcnow()
self.db.commit()
return user
def create_conversation(self, user_id: int, title: str = None, channel_id: int = None) -> Conversation:
"""创建新会话"""
conversation_id = f"conv_{uuid.uuid4().hex[:12]}"
if self.use_v2:
conversation = Conversation(
conversation_id=conversation_id,
user_id=user_id,
title=title,
channel_id=channel_id
)
else:
conversation = Conversation(
conversation_id=conversation_id,
user_id=user_id,
title=title
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
return conversation
def get_conversation(self, conversation_id: str) -> Optional[Conversation]:
"""获取会话"""
return self.db.query(Conversation).filter(
Conversation.conversation_id == conversation_id
).first()
def get_user_conversations(self, user_id: int) -> List[Conversation]:
"""获取用户的所有会话"""
return self.db.query(Conversation).filter(
Conversation.user_id == user_id,
Conversation.is_active == True
).order_by(Conversation.updated_at.desc()).all()
def add_message(
self,
conversation_id: int,
role: str,
content: str,
source: str = 'web',
extra_data: dict = None,
thinking_content: str = None, # v2新增
agent_id: int = None, # v2新增
model_used: str = None # v2新增
) -> Message:
"""添加消息"""
if self.use_v2:
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
source=source,
extra_data=extra_data,
thinking_content=thinking_content,
agent_id=agent_id,
model_used=model_used
)
else:
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
source=source,
extra_data=extra_data
)
self.db.add(message)
# 更新会话时间
conversation = self.db.query(Conversation).filter(Conversation.id == conversation_id).first()
if conversation:
conversation.updated_at = datetime.utcnow()
# 如果没有标题,用第一条用户消息作为标题
if not conversation.title and role == 'user':
conversation.title = content[:50] + ('...' if len(content) > 50 else '')
# v2: 更新当前Agent
if self.use_v2 and agent_id:
conversation.current_agent_id = agent_id
self.db.commit()
self.db.refresh(message)
return message
def get_messages(self, conversation_id: int, limit: int = 50) -> List[Message]:
"""获取会话消息"""
return self.db.query(Message).filter(
Message.conversation_id == conversation_id
).order_by(Message.created_at.asc()).limit(limit).all()
def get_conversation_history(self, conversation_id: str, limit: int = 20) -> List[dict]:
"""获取会话历史用于AI上下文"""
conversation = self.get_conversation(conversation_id)
if not conversation:
return []
messages = self.db.query(Message).filter(
Message.conversation_id == conversation.id,
Message.role.in_(['user', 'assistant', 'system']) # 排除thinking
).order_by(Message.created_at.desc()).limit(limit).all()
# 反转顺序,最早的在前
messages.reverse()
return [{"role": m.role, "content": m.content} for m in messages]
def delete_conversation(self, conversation_id: str):
"""删除会话(软删除)"""
conversation = self.get_conversation(conversation_id)
if conversation:
conversation.is_active = False
self.db.commit()
return True
return False
def set_conversation_agent(self, conversation_id: str, agent_id: int):
"""设置会话使用的Agentv2"""
if not self.use_v2:
return False
conversation = self.get_conversation(conversation_id)
if conversation:
conversation.current_agent_id = agent_id
self.db.commit()
return True
return False