Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 26a76b030d | |||
| daccc625c3 | |||
| a2a7fd46c3 | |||
| baf5913bfb | |||
| ae08e01e55 | |||
| 9048d94e33 | |||
| 291de733a4 |
45
main_v2.py
45
main_v2.py
@@ -647,17 +647,19 @@ async def get_conversations(db: Session = Depends(get_db)):
|
||||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||||
conversations = conv_service.get_user_conversations(user.id)
|
||||
|
||||
return {
|
||||
"conversations": [
|
||||
{
|
||||
"id": c.conversation_id,
|
||||
"title": c.title or "新对话",
|
||||
"created_at": c.created_at.isoformat(),
|
||||
"updated_at": c.updated_at.isoformat()
|
||||
}
|
||||
for c in conversations
|
||||
]
|
||||
}
|
||||
# 为每个对话计算消息数量
|
||||
result = []
|
||||
for c in conversations:
|
||||
msg_count = db.query(Message).filter(Message.conversation_id == c.id).count()
|
||||
result.append({
|
||||
"id": c.conversation_id,
|
||||
"title": c.title or "新对话",
|
||||
"created_at": c.created_at.isoformat(),
|
||||
"updated_at": c.updated_at.isoformat(),
|
||||
"message_count": msg_count
|
||||
})
|
||||
|
||||
return {"conversations": result}
|
||||
|
||||
|
||||
@app.get("/api/conversations/latest")
|
||||
@@ -693,6 +695,26 @@ async def create_conversation(db: Session = Depends(get_db)):
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/conversations/{conversation_id}")
|
||||
async def get_conversation(conversation_id: str, db: Session = Depends(get_db)):
|
||||
"""获取单个对话详情"""
|
||||
conv_service = ConversationService(db)
|
||||
conversation = conv_service.get_conversation(conversation_id)
|
||||
|
||||
if not conversation:
|
||||
raise HTTPException(status_code=404, detail="会话不存在")
|
||||
|
||||
msg_count = db.query(Message).filter(Message.conversation_id == conversation.id).count()
|
||||
|
||||
return {
|
||||
"id": conversation.conversation_id,
|
||||
"title": conversation.title or "新对话",
|
||||
"created_at": conversation.created_at.isoformat(),
|
||||
"updated_at": conversation.updated_at.isoformat(),
|
||||
"message_count": msg_count
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/conversations/{conversation_id}/messages")
|
||||
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
|
||||
"""获取会话消息"""
|
||||
@@ -1121,7 +1143,6 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
||||
messages=history_with_tools,
|
||||
provider_config=agent_config['provider'],
|
||||
agent_config=agent_config['agent'],
|
||||
tool_results=tool_results,
|
||||
enable_thinking=enable_thinking
|
||||
)
|
||||
|
||||
|
||||
@@ -514,17 +514,15 @@ class LLMService:
|
||||
messages: List[Dict],
|
||||
provider_config: dict,
|
||||
agent_config: dict,
|
||||
tool_results: List[Dict],
|
||||
enable_thinking: bool = True
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
第二阶段调用:将工具执行结果返回给LLM
|
||||
第二阶段调用:使用包含工具调用和结果的完整消息历史
|
||||
|
||||
Args:
|
||||
messages: 对话历史(包含工具调用和结果)
|
||||
messages: 已包含assistant tool_calls和tool结果的完整消息历史
|
||||
provider_config: LLM Provider配置
|
||||
agent_config: Agent配置
|
||||
tool_results: 工具执行结果 [{"tool_call_id": "xxx", "content": "..."}]
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[str]]: (回复内容, 思考过程)
|
||||
@@ -535,14 +533,8 @@ class LLMService:
|
||||
max_tokens = provider_config.get('max_tokens', 4096)
|
||||
temperature = agent_config.get('temperature_override') or provider_config.get('temperature', 0.7)
|
||||
|
||||
# 将工具结果添加到消息历史
|
||||
# 消息历史已经包含了assistant的tool_calls和tool结果,直接使用
|
||||
final_messages = messages.copy()
|
||||
for result in tool_results:
|
||||
final_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result['tool_call_id'],
|
||||
"content": result['content']
|
||||
})
|
||||
|
||||
# 调用LLM生成最终回复
|
||||
url = f"{api_base.rstrip('/')}/chat/completions"
|
||||
@@ -557,14 +549,14 @@ class LLMService:
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
logger.info(f"工具结果返回LLM: url={url}, model={model}")
|
||||
logger.info(f"工具结果返回LLM: url={url}, model={model}, 消息数={len(final_messages)}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"API返回错误: status={response.status_code}")
|
||||
logger.error(f"API返回错误: status={response.status_code}, body={response.text[:500]}")
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user