Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ae08e01e55 | |||
| 9048d94e33 | |||
| 291de733a4 | |||
| 10f67a807a |
@@ -33,7 +33,7 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 创建应用
|
||||
app = FastAPI(title="AI对话系统 v2.0", version="2.0.0")
|
||||
app = FastAPI(title="AI对话系统 v3.0", version="3.0.1")
|
||||
|
||||
# 静态文件和模板
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -1121,7 +1121,6 @@ async def websocket_endpoint(websocket: WebSocket, user_id: str):
|
||||
messages=history_with_tools,
|
||||
provider_config=agent_config['provider'],
|
||||
agent_config=agent_config['agent'],
|
||||
tool_results=tool_results,
|
||||
enable_thinking=enable_thinking
|
||||
)
|
||||
|
||||
|
||||
@@ -137,6 +137,9 @@ class AgentService:
|
||||
'api_key': provider.api_key if provider else None,
|
||||
'supports_thinking': provider.supports_thinking if provider else False,
|
||||
'thinking_model': provider.thinking_model if provider else None,
|
||||
'supports_vision': provider.supports_vision if provider else False,
|
||||
'vision_model': provider.vision_model if provider else None,
|
||||
'supports_function_calling': provider.supports_function_calling if provider else False,
|
||||
'default_model': provider.default_model if provider else 'auto',
|
||||
'max_tokens': provider.max_tokens if provider else 4096,
|
||||
'temperature': provider.temperature if provider else 0.7,
|
||||
|
||||
@@ -514,17 +514,15 @@ class LLMService:
|
||||
messages: List[Dict],
|
||||
provider_config: dict,
|
||||
agent_config: dict,
|
||||
tool_results: List[Dict],
|
||||
enable_thinking: bool = True
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""
|
||||
第二阶段调用:将工具执行结果返回给LLM
|
||||
第二阶段调用:使用包含工具调用和结果的完整消息历史
|
||||
|
||||
Args:
|
||||
messages: 对话历史(包含工具调用和结果)
|
||||
messages: 已包含assistant tool_calls和tool结果的完整消息历史
|
||||
provider_config: LLM Provider配置
|
||||
agent_config: Agent配置
|
||||
tool_results: 工具执行结果 [{"tool_call_id": "xxx", "content": "..."}]
|
||||
|
||||
Returns:
|
||||
Tuple[str, Optional[str]]: (回复内容, 思考过程)
|
||||
@@ -535,14 +533,17 @@ class LLMService:
|
||||
max_tokens = provider_config.get('max_tokens', 4096)
|
||||
temperature = agent_config.get('temperature_override') or provider_config.get('temperature', 0.7)
|
||||
|
||||
# 将工具结果添加到消息历史
|
||||
# 消息历史已经包含了assistant的tool_calls和tool结果,直接使用
|
||||
final_messages = messages.copy()
|
||||
for result in tool_results:
|
||||
final_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": result['tool_call_id'],
|
||||
"content": result['content']
|
||||
})
|
||||
|
||||
# 添加提示:告诉模型直接根据工具结果回答,不要再调用工具
|
||||
# 添加一个系统级别的提示
|
||||
tool_hint = {
|
||||
"role": "system",
|
||||
"content": "请根据工具返回的结果直接回答用户的问题,不要再调用任何工具或搜索。如果结果不足以回答问题,请根据现有信息给出最好的回答,并说明信息的局限性。"
|
||||
}
|
||||
# 在工具结果之后添加提示
|
||||
final_messages.append(tool_hint)
|
||||
|
||||
# 调用LLM生成最终回复
|
||||
url = f"{api_base.rstrip('/')}/chat/completions"
|
||||
@@ -557,19 +558,50 @@ class LLMService:
|
||||
"max_tokens": max_tokens
|
||||
}
|
||||
|
||||
logger.info(f"工具结果返回LLM: url={url}, model={model}")
|
||||
logger.info(f"工具结果返回LLM: url={url}, model={model}, 消息数={len(final_messages)}")
|
||||
# 打印消息内容(调试)
|
||||
for i, msg in enumerate(final_messages):
|
||||
role = msg.get('role')
|
||||
content_preview = str(msg.get('content', ''))[:100] if msg.get('content') else 'None'
|
||||
if role == 'tool':
|
||||
logger.info(f"消息[{i}] role={role}, tool_call_id={msg.get('tool_call_id')}, content长度={len(msg.get('content',''))}")
|
||||
elif role == 'assistant' and msg.get('tool_calls'):
|
||||
logger.info(f"消息[{i}] role={role}, tool_calls={len(msg['tool_calls'])}")
|
||||
else:
|
||||
logger.info(f"消息[{i}] role={role}, content={content_preview}...")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"API返回错误: status={response.status_code}")
|
||||
logger.error(f"API返回错误: status={response.status_code}, body={response.text[:500]}")
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
content = data['choices'][0]['message']['content']
|
||||
|
||||
# 过滤掉伪工具调用格式(某些模型如Kimi会输出这种内部格式)
|
||||
# 模式:<|tool_calls_section_begin|>...<|tool_calls_section_end|>
|
||||
import re
|
||||
tool_pattern = r'<\|tool_calls_section_begin\|>.*?<\|tool_calls_section_end\|>'
|
||||
content = re.sub(tool_pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
# 也过滤单个 tool_call 格式
|
||||
tool_call_pattern = r'<\|tool_call_begin\|>.*?<\|tool_call_end\|>'
|
||||
content = re.sub(tool_call_pattern, '', content, flags=re.DOTALL)
|
||||
|
||||
# 清理可能残留的格式标记
|
||||
content = content.replace('<|tool_calls_section_begin|>', '')
|
||||
content = content.replace('<|tool_calls_section_end|>', '')
|
||||
content = content.replace('<|tool_call_begin|>', '')
|
||||
content = content.replace('<|tool_call_end|>', '')
|
||||
content = content.replace('<|tool_call_argument_begin|>', '')
|
||||
content = content.replace('<|tool_call_argument_end|>', '')
|
||||
|
||||
# 清理多余空行
|
||||
content = re.sub(r'\n{3,}', '\n\n', content).strip()
|
||||
|
||||
return content, None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user