Compare commits
4 Commits
ae08e01e55
...
v3.1.0
| Author | SHA1 | Date | |
|---|---|---|---|
| 26a76b030d | |||
| daccc625c3 | |||
| a2a7fd46c3 | |||
| baf5913bfb |
46
main_v2.py
46
main_v2.py
@@ -33,7 +33,7 @@ logging.basicConfig(level=logging.INFO)
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# 创建应用
|
# 创建应用
|
||||||
app = FastAPI(title="AI对话系统 v3.0", version="3.0.1")
|
app = FastAPI(title="AI对话系统 v2.0", version="2.0.0")
|
||||||
|
|
||||||
# 静态文件和模板
|
# 静态文件和模板
|
||||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
@@ -647,17 +647,19 @@ async def get_conversations(db: Session = Depends(get_db)):
|
|||||||
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
user = conv_service.get_or_create_user(MAIN_USER_ID, display_name="主用户", user_type='web')
|
||||||
conversations = conv_service.get_user_conversations(user.id)
|
conversations = conv_service.get_user_conversations(user.id)
|
||||||
|
|
||||||
return {
|
# 为每个对话计算消息数量
|
||||||
"conversations": [
|
result = []
|
||||||
{
|
for c in conversations:
|
||||||
"id": c.conversation_id,
|
msg_count = db.query(Message).filter(Message.conversation_id == c.id).count()
|
||||||
"title": c.title or "新对话",
|
result.append({
|
||||||
"created_at": c.created_at.isoformat(),
|
"id": c.conversation_id,
|
||||||
"updated_at": c.updated_at.isoformat()
|
"title": c.title or "新对话",
|
||||||
}
|
"created_at": c.created_at.isoformat(),
|
||||||
for c in conversations
|
"updated_at": c.updated_at.isoformat(),
|
||||||
]
|
"message_count": msg_count
|
||||||
}
|
})
|
||||||
|
|
||||||
|
return {"conversations": result}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/conversations/latest")
|
@app.get("/api/conversations/latest")
|
||||||
@@ -693,6 +695,26 @@ async def create_conversation(db: Session = Depends(get_db)):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/api/conversations/{conversation_id}")
|
||||||
|
async def get_conversation(conversation_id: str, db: Session = Depends(get_db)):
|
||||||
|
"""获取单个对话详情"""
|
||||||
|
conv_service = ConversationService(db)
|
||||||
|
conversation = conv_service.get_conversation(conversation_id)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
raise HTTPException(status_code=404, detail="会话不存在")
|
||||||
|
|
||||||
|
msg_count = db.query(Message).filter(Message.conversation_id == conversation.id).count()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": conversation.conversation_id,
|
||||||
|
"title": conversation.title or "新对话",
|
||||||
|
"created_at": conversation.created_at.isoformat(),
|
||||||
|
"updated_at": conversation.updated_at.isoformat(),
|
||||||
|
"message_count": msg_count
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/conversations/{conversation_id}/messages")
|
@app.get("/api/conversations/{conversation_id}/messages")
|
||||||
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
|
async def get_messages(conversation_id: str, db: Session = Depends(get_db)):
|
||||||
"""获取会话消息"""
|
"""获取会话消息"""
|
||||||
|
|||||||
@@ -536,15 +536,6 @@ class LLMService:
|
|||||||
# 消息历史已经包含了assistant的tool_calls和tool结果,直接使用
|
# 消息历史已经包含了assistant的tool_calls和tool结果,直接使用
|
||||||
final_messages = messages.copy()
|
final_messages = messages.copy()
|
||||||
|
|
||||||
# 添加提示:告诉模型直接根据工具结果回答,不要再调用工具
|
|
||||||
# 添加一个系统级别的提示
|
|
||||||
tool_hint = {
|
|
||||||
"role": "system",
|
|
||||||
"content": "请根据工具返回的结果直接回答用户的问题,不要再调用任何工具或搜索。如果结果不足以回答问题,请根据现有信息给出最好的回答,并说明信息的局限性。"
|
|
||||||
}
|
|
||||||
# 在工具结果之后添加提示
|
|
||||||
final_messages.append(tool_hint)
|
|
||||||
|
|
||||||
# 调用LLM生成最终回复
|
# 调用LLM生成最终回复
|
||||||
url = f"{api_base.rstrip('/')}/chat/completions"
|
url = f"{api_base.rstrip('/')}/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
@@ -559,16 +550,6 @@ class LLMService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
logger.info(f"工具结果返回LLM: url={url}, model={model}, 消息数={len(final_messages)}")
|
logger.info(f"工具结果返回LLM: url={url}, model={model}, 消息数={len(final_messages)}")
|
||||||
# 打印消息内容(调试)
|
|
||||||
for i, msg in enumerate(final_messages):
|
|
||||||
role = msg.get('role')
|
|
||||||
content_preview = str(msg.get('content', ''))[:100] if msg.get('content') else 'None'
|
|
||||||
if role == 'tool':
|
|
||||||
logger.info(f"消息[{i}] role={role}, tool_call_id={msg.get('tool_call_id')}, content长度={len(msg.get('content',''))}")
|
|
||||||
elif role == 'assistant' and msg.get('tool_calls'):
|
|
||||||
logger.info(f"消息[{i}] role={role}, tool_calls={len(msg['tool_calls'])}")
|
|
||||||
else:
|
|
||||||
logger.info(f"消息[{i}] role={role}, content={content_preview}...")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||||
@@ -581,27 +562,6 @@ class LLMService:
|
|||||||
data = response.json()
|
data = response.json()
|
||||||
content = data['choices'][0]['message']['content']
|
content = data['choices'][0]['message']['content']
|
||||||
|
|
||||||
# 过滤掉伪工具调用格式(某些模型如Kimi会输出这种内部格式)
|
|
||||||
# 模式:<|tool_calls_section_begin|>...<|tool_calls_section_end|>
|
|
||||||
import re
|
|
||||||
tool_pattern = r'<\|tool_calls_section_begin\|>.*?<\|tool_calls_section_end\|>'
|
|
||||||
content = re.sub(tool_pattern, '', content, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# 也过滤单个 tool_call 格式
|
|
||||||
tool_call_pattern = r'<\|tool_call_begin\|>.*?<\|tool_call_end\|>'
|
|
||||||
content = re.sub(tool_call_pattern, '', content, flags=re.DOTALL)
|
|
||||||
|
|
||||||
# 清理可能残留的格式标记
|
|
||||||
content = content.replace('<|tool_calls_section_begin|>', '')
|
|
||||||
content = content.replace('<|tool_calls_section_end|>', '')
|
|
||||||
content = content.replace('<|tool_call_begin|>', '')
|
|
||||||
content = content.replace('<|tool_call_end|>', '')
|
|
||||||
content = content.replace('<|tool_call_argument_begin|>', '')
|
|
||||||
content = content.replace('<|tool_call_argument_end|>', '')
|
|
||||||
|
|
||||||
# 清理多余空行
|
|
||||||
content = re.sub(r'\n{3,}', '\n\n', content).strip()
|
|
||||||
|
|
||||||
return content, None
|
return content, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user