170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
"""
|
||
会话管理服务 - v2.0 兼容版
|
||
支持思考内容、Agent追踪等新字段
|
||
"""
|
||
from typing import List, Optional
|
||
from datetime import datetime
|
||
from sqlalchemy.orm import Session
|
||
import uuid
|
||
|
||
# 兼容v1和v2模型
|
||
try:
|
||
from models_v2 import User, Conversation, Message, Agent
|
||
USE_V2 = True
|
||
except ImportError:
|
||
from models import User, Conversation, Message
|
||
USE_V2 = False
|
||
|
||
|
||
class ConversationService:
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
self.use_v2 = USE_V2
|
||
|
||
def get_or_create_user(self, user_id: str, display_name: str = None, user_type: str = 'web', matrix_user_id: str = None) -> User:
|
||
"""获取或创建用户"""
|
||
user = self.db.query(User).filter(User.user_id == user_id).first()
|
||
if not user:
|
||
user = User(
|
||
user_id=user_id,
|
||
display_name=display_name or user_id,
|
||
user_type=user_type,
|
||
matrix_user_id=matrix_user_id
|
||
)
|
||
self.db.add(user)
|
||
self.db.commit()
|
||
self.db.refresh(user)
|
||
else:
|
||
user.last_active_at = datetime.utcnow()
|
||
self.db.commit()
|
||
return user
|
||
|
||
def create_conversation(self, user_id: int, title: str = None, channel_id: int = None) -> Conversation:
|
||
"""创建新会话"""
|
||
conversation_id = f"conv_{uuid.uuid4().hex[:12]}"
|
||
|
||
if self.use_v2:
|
||
conversation = Conversation(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
title=title,
|
||
channel_id=channel_id
|
||
)
|
||
else:
|
||
conversation = Conversation(
|
||
conversation_id=conversation_id,
|
||
user_id=user_id,
|
||
title=title
|
||
)
|
||
|
||
self.db.add(conversation)
|
||
self.db.commit()
|
||
self.db.refresh(conversation)
|
||
return conversation
|
||
|
||
def get_conversation(self, conversation_id: str) -> Optional[Conversation]:
|
||
"""获取会话"""
|
||
return self.db.query(Conversation).filter(
|
||
Conversation.conversation_id == conversation_id
|
||
).first()
|
||
|
||
def get_user_conversations(self, user_id: int) -> List[Conversation]:
|
||
"""获取用户的所有会话"""
|
||
return self.db.query(Conversation).filter(
|
||
Conversation.user_id == user_id,
|
||
Conversation.is_active == True
|
||
).order_by(Conversation.updated_at.desc()).all()
|
||
|
||
def add_message(
|
||
self,
|
||
conversation_id: int,
|
||
role: str,
|
||
content: str,
|
||
source: str = 'web',
|
||
extra_data: dict = None,
|
||
thinking_content: str = None, # v2新增
|
||
agent_id: int = None, # v2新增
|
||
model_used: str = None # v2新增
|
||
) -> Message:
|
||
"""添加消息"""
|
||
|
||
if self.use_v2:
|
||
message = Message(
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
content=content,
|
||
source=source,
|
||
extra_data=extra_data,
|
||
thinking_content=thinking_content,
|
||
agent_id=agent_id,
|
||
model_used=model_used
|
||
)
|
||
else:
|
||
message = Message(
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
content=content,
|
||
source=source,
|
||
extra_data=extra_data
|
||
)
|
||
|
||
self.db.add(message)
|
||
|
||
# 更新会话时间
|
||
conversation = self.db.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||
if conversation:
|
||
conversation.updated_at = datetime.utcnow()
|
||
# 如果没有标题,用第一条用户消息作为标题
|
||
if not conversation.title and role == 'user':
|
||
conversation.title = content[:50] + ('...' if len(content) > 50 else '')
|
||
|
||
# v2: 更新当前Agent
|
||
if self.use_v2 and agent_id:
|
||
conversation.current_agent_id = agent_id
|
||
|
||
self.db.commit()
|
||
self.db.refresh(message)
|
||
return message
|
||
|
||
def get_messages(self, conversation_id: int, limit: int = 50) -> List[Message]:
|
||
"""获取会话消息"""
|
||
return self.db.query(Message).filter(
|
||
Message.conversation_id == conversation_id
|
||
).order_by(Message.created_at.asc()).limit(limit).all()
|
||
|
||
def get_conversation_history(self, conversation_id: str, limit: int = 20) -> List[dict]:
|
||
"""获取会话历史(用于AI上下文)"""
|
||
conversation = self.get_conversation(conversation_id)
|
||
if not conversation:
|
||
return []
|
||
|
||
messages = self.db.query(Message).filter(
|
||
Message.conversation_id == conversation.id,
|
||
Message.role.in_(['user', 'assistant', 'system']) # 排除thinking
|
||
).order_by(Message.created_at.desc()).limit(limit).all()
|
||
|
||
# 反转顺序,最早的在前
|
||
messages.reverse()
|
||
|
||
return [{"role": m.role, "content": m.content} for m in messages]
|
||
|
||
def delete_conversation(self, conversation_id: str):
|
||
"""删除会话(软删除)"""
|
||
conversation = self.get_conversation(conversation_id)
|
||
if conversation:
|
||
conversation.is_active = False
|
||
self.db.commit()
|
||
return True
|
||
return False
|
||
|
||
def set_conversation_agent(self, conversation_id: str, agent_id: int):
|
||
"""设置会话使用的Agent(v2)"""
|
||
if not self.use_v2:
|
||
return False
|
||
|
||
conversation = self.get_conversation(conversation_id)
|
||
if conversation:
|
||
conversation.current_agent_id = agent_id
|
||
self.db.commit()
|
||
return True
|
||
return False |