107 lines
3.8 KiB
Python
107 lines
3.8 KiB
Python
"""
|
||
AI服务 - 调用大模型API
|
||
"""
|
||
import httpx
|
||
from typing import List, Dict, AsyncGenerator
|
||
import json
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AIService:
|
||
def __init__(self):
|
||
self.api_base = ""
|
||
self.api_key = ""
|
||
self.model = ""
|
||
self.use_mock = True
|
||
|
||
def update_config(self, api_base: str, api_key: str, model: str):
|
||
"""更新配置"""
|
||
self.api_base = api_base
|
||
self.api_key = api_key
|
||
self.model = model
|
||
# 如果配置完整则使用真实API,否则使用mock
|
||
self.use_mock = not (api_base and model)
|
||
logger.info(f"AI配置已更新: api_base={api_base}, model={model}, use_mock={self.use_mock}")
|
||
|
||
async def chat(self, messages: List[Dict]) -> str:
|
||
"""
|
||
调用AI模型进行对话
|
||
"""
|
||
# 如果使用mock模式,返回模拟回复
|
||
if self.use_mock:
|
||
logger.info("使用Mock模式回复")
|
||
last_msg = messages[-1]['content'] if messages else "你好"
|
||
return f"这是一个测试回复。您说的是:{last_msg}\n\n请配置有效的AI服务地址和模型,才能获得真正的AI回复。"
|
||
|
||
# 调用真实API
|
||
url = f"{self.api_base}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"temperature": 0.7,
|
||
"max_tokens": 2000
|
||
}
|
||
|
||
logger.info(f"调用AI API: {url}, model={self.model}")
|
||
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return data['choices'][0]['message']['content']
|
||
except Exception as e:
|
||
logger.error(f"AI API调用失败: {e}")
|
||
# API失败时返回模拟回复
|
||
last_msg = messages[-1]['content'] if messages else "你好"
|
||
return f"AI服务暂时不可用(错误:{str(e)})。您说的是:{last_msg}"
|
||
|
||
async def chat_stream(self, messages: List[Dict]) -> AsyncGenerator[str, None]:
|
||
"""
|
||
流式调用AI模型
|
||
"""
|
||
if self.use_mock:
|
||
last_msg = messages[-1]['content'] if messages else "你好"
|
||
reply = f"这是一个测试回复。您说的是:{last_msg}"
|
||
for char in reply:
|
||
yield char
|
||
return
|
||
|
||
url = f"{self.api_base}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": self.model,
|
||
"messages": messages,
|
||
"stream": True,
|
||
"temperature": 0.7,
|
||
"max_tokens": 2000
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||
async for line in response.aiter_lines():
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
data = json.loads(data_str)
|
||
if 'choices' in data and len(data['choices']) > 0:
|
||
delta = data['choices'][0].get('delta', {})
|
||
if 'content' in delta:
|
||
yield delta['content']
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
|
||
# 全局实例
|
||
ai_service = AIService() |