319 lines
14 KiB
Python
319 lines
14 KiB
Python
"""
|
||
LLM服务 - 大模型池管理,支持思考功能
|
||
"""
|
||
import httpx
|
||
from typing import List, Dict, AsyncGenerator, Optional, Tuple
|
||
import json
|
||
import logging
|
||
import re
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class LLMService:
|
||
"""大模型调用服务,支持思考功能"""
|
||
|
||
def __init__(self):
|
||
self.providers_cache = {} # 缓存Provider配置
|
||
|
||
def load_provider(self, provider_config: dict):
|
||
"""加载Provider配置"""
|
||
provider_id = provider_config.get('id')
|
||
self.providers_cache[provider_id] = {
|
||
'api_base': provider_config.get('api_base'),
|
||
'api_key': provider_config.get('api_key'),
|
||
'supports_thinking': provider_config.get('supports_thinking', False),
|
||
'thinking_model': provider_config.get('thinking_model'),
|
||
'default_model': provider_config.get('default_model'),
|
||
'max_tokens': provider_config.get('max_tokens', 4096),
|
||
'temperature': provider_config.get('temperature', 0.7)
|
||
}
|
||
|
||
async def get_available_models(self, api_base: str, api_key: str) -> List[dict]:
|
||
"""从API获取可用模型列表"""
|
||
if not api_base:
|
||
return []
|
||
|
||
try:
|
||
url = f"{api_base}/models"
|
||
headers = {"Authorization": f"Bearer {api_key}"}
|
||
|
||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||
response = await client.get(url, headers=headers)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
models = []
|
||
for m in data.get('data', []):
|
||
model_id = m.get('id', '')
|
||
if model_id:
|
||
models.append({
|
||
"id": model_id,
|
||
"name": m.get('name', model_id),
|
||
"owned_by": m.get('owned_by', 'unknown')
|
||
})
|
||
return models
|
||
except Exception as e:
|
||
logger.warning(f"获取模型列表失败: {e}")
|
||
|
||
return []
|
||
|
||
async def test_connection(self, api_base: str, api_key: str, model: str) -> dict:
|
||
"""测试API连接"""
|
||
try:
|
||
url = f"{api_base}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": [{"role": "user", "content": "测试连接"}],
|
||
"max_tokens": 50
|
||
}
|
||
|
||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
content = data['choices'][0]['message']['content']
|
||
return {
|
||
"success": True,
|
||
"message": f"连接成功!模型响应: {content[:100]}",
|
||
"model": model
|
||
}
|
||
else:
|
||
return {
|
||
"success": False,
|
||
"message": f"连接失败: HTTP {response.status_code}"
|
||
}
|
||
except httpx.ConnectError:
|
||
return {"success": False, "message": f"无法连接到API地址"}
|
||
except httpx.TimeoutException:
|
||
return {"success": False, "message": "连接超时"}
|
||
except Exception as e:
|
||
return {"success": False, "message": f"连接失败: {str(e)}"}
|
||
|
||
async def chat(
|
||
self,
|
||
messages: List[Dict],
|
||
provider_config: dict,
|
||
agent_config: dict,
|
||
enable_thinking: bool = True
|
||
) -> Tuple[str, Optional[str]]:
|
||
"""
|
||
调用AI模型进行对话
|
||
|
||
Returns:
|
||
Tuple[str, Optional[str]]: (回复内容, 思考过程)
|
||
"""
|
||
api_base = provider_config.get('api_base')
|
||
api_key = provider_config.get('api_key')
|
||
model = agent_config.get('model_override') or provider_config.get('default_model', 'auto')
|
||
supports_thinking = provider_config.get('supports_thinking', False)
|
||
thinking_model = provider_config.get('thinking_model')
|
||
|
||
max_tokens = provider_config.get('max_tokens', 4096)
|
||
temperature = agent_config.get('temperature_override') or provider_config.get('temperature', 0.7)
|
||
|
||
# 构建消息
|
||
final_messages = messages.copy()
|
||
|
||
# 添加系统提示
|
||
system_prompt = agent_config.get('system_prompt', '你是一个有用的AI助手。')
|
||
if final_messages and final_messages[0]['role'] != 'system':
|
||
final_messages.insert(0, {"role": "system", "content": system_prompt})
|
||
|
||
thinking_content = None
|
||
|
||
# 处理思考功能
|
||
if enable_thinking and agent_config.get('enable_thinking', True):
|
||
thinking_prompt = agent_config.get('thinking_prompt')
|
||
|
||
if supports_thinking and thinking_model:
|
||
# 使用专门的思考模型
|
||
thinking_messages = final_messages.copy()
|
||
if thinking_prompt:
|
||
thinking_messages.append({"role": "system", "content": thinking_prompt})
|
||
|
||
try:
|
||
thinking_result = await self._call_api(
|
||
api_base, api_key, thinking_model, thinking_messages,
|
||
max_tokens=min(max_tokens, 1000),
|
||
temperature=0.3 # 思考时降低温度
|
||
)
|
||
thinking_content = thinking_result
|
||
except Exception as e:
|
||
logger.warning(f"思考模型调用失败: {e}")
|
||
|
||
elif supports_thinking:
|
||
# Provider支持思考但无单独模型,尝试在回复中获取思考部分
|
||
pass # 在回复解析时处理
|
||
|
||
elif thinking_prompt:
|
||
# Provider不支持思考,但Agent配置了思考提示词
|
||
# 将思考提示词添加到系统提示
|
||
enhanced_system = f"{system_prompt}\n\n在回答之前,请先思考问题。思考过程请用{agent_config.get('thinking_prefix', '')}和{agent_config.get('thinking_suffix', '')}包裹,然后再给出正式回答。"
|
||
if thinking_prompt:
|
||
enhanced_system += f"\n思考指导:{thinking_prompt}"
|
||
final_messages[0] = {"role": "system", "content": enhanced_system}
|
||
|
||
# 调用主模型
|
||
try:
|
||
response = await self._call_api(
|
||
api_base, api_key, model, final_messages,
|
||
max_tokens=max_tokens,
|
||
temperature=temperature
|
||
)
|
||
|
||
# 尝试从回复中提取思考内容
|
||
if enable_thinking and not supports_thinking:
|
||
thinking_prefix = agent_config.get('thinking_prefix', '')
|
||
thinking_suffix = agent_config.get('thinking_suffix', '')
|
||
|
||
if thinking_prefix and thinking_suffix:
|
||
# 提取思考部分
|
||
pattern = f"{re.escape(thinking_prefix)}(.*?)?{re.escape(thinking_suffix)}"
|
||
match = re.search(pattern, response, re.DOTALL)
|
||
if match:
|
||
thinking_content = match.group(1).strip()
|
||
# 移除思考部分,只保留回复
|
||
response = re.sub(pattern, '', response, flags=re.DOTALL).strip()
|
||
|
||
return response, thinking_content
|
||
|
||
except Exception as e:
|
||
logger.error(f"LLM调用失败: {e}")
|
||
raise
|
||
|
||
async def _call_api(
|
||
self,
|
||
api_base: str,
|
||
api_key: str,
|
||
model: str,
|
||
messages: List[Dict],
|
||
max_tokens: int = 4096,
|
||
temperature: float = 0.7
|
||
) -> str:
|
||
"""调用API"""
|
||
url = f"{api_base}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
logger.info(f"调用LLM: url={url}, model={model}")
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
response = await client.post(url, headers=headers, json=payload)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return data['choices'][0]['message']['content']
|
||
|
||
async def chat_stream(
|
||
self,
|
||
messages: List[Dict],
|
||
provider_config: dict,
|
||
agent_config: dict,
|
||
enable_thinking: bool = True
|
||
) -> AsyncGenerator[dict, None]:
|
||
"""
|
||
流式调用AI模型
|
||
|
||
Yields:
|
||
dict: {"type": "thinking"|"content", "text": "..."}
|
||
"""
|
||
api_base = provider_config.get('api_base')
|
||
api_key = provider_config.get('api_key')
|
||
model = agent_config.get('model_override') or provider_config.get('default_model', 'auto')
|
||
max_tokens = provider_config.get('max_tokens', 4096)
|
||
temperature = agent_config.get('temperature_override') or provider_config.get('temperature', 0.7)
|
||
|
||
# 构建消息
|
||
final_messages = messages.copy()
|
||
system_prompt = agent_config.get('system_prompt', '你是一个有用的AI助手。')
|
||
if final_messages and final_messages[0]['role'] != 'system':
|
||
final_messages.insert(0, {"role": "system", "content": system_prompt})
|
||
|
||
# 如果启用思考但模型不支持
|
||
if enable_thinking and agent_config.get('enable_thinking', True):
|
||
supports_thinking = provider_config.get('supports_thinking', False)
|
||
thinking_prompt = agent_config.get('thinking_prompt')
|
||
|
||
if not supports_thinking and thinking_prompt:
|
||
thinking_prefix = agent_config.get('thinking_prefix', '')
|
||
thinking_suffix = agent_config.get('thinking_suffix', '')
|
||
enhanced_system = f"{system_prompt}\n\n在回答之前,请先思考问题。思考过程请用{thinking_prefix}和{thinking_suffix}包裹,然后再给出正式回答。"
|
||
if thinking_prompt:
|
||
enhanced_system += f"\n思考指导:{thinking_prompt}"
|
||
final_messages[0] = {"role": "system", "content": enhanced_system}
|
||
|
||
url = f"{api_base}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {api_key}",
|
||
"Content-Type": "application/json"
|
||
}
|
||
payload = {
|
||
"model": model,
|
||
"messages": final_messages,
|
||
"stream": True,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens
|
||
}
|
||
|
||
thinking_prefix = agent_config.get('thinking_prefix', '')
|
||
thinking_suffix = agent_config.get('thinking_suffix', '')
|
||
in_thinking = False
|
||
thinking_buffer = ""
|
||
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||
async for line in response.aiter_lines():
|
||
if line.startswith("data: "):
|
||
data_str = line[6:]
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
data = json.loads(data_str)
|
||
if 'choices' in data and len(data['choices']) > 0:
|
||
delta = data['choices'][0].get('delta', {})
|
||
if 'content' in delta:
|
||
text = delta['content']
|
||
|
||
# 检测思考部分
|
||
if thinking_prefix and thinking_suffix:
|
||
for char in text:
|
||
if in_thinking:
|
||
thinking_buffer += char
|
||
# 检查是否结束思考
|
||
if thinking_buffer.endswith(thinking_suffix):
|
||
thinking_content = thinking_buffer[:-len(thinking_suffix)]
|
||
yield {"type": "thinking", "text": thinking_content}
|
||
in_thinking = False
|
||
thinking_buffer = ""
|
||
else:
|
||
# 检查是否接近结束
|
||
suffix_len = len(thinking_suffix)
|
||
if len(thinking_buffer) >= suffix_len:
|
||
yield {"type": "thinking", "text": thinking_buffer[-suffix_len:]}
|
||
else:
|
||
if char == thinking_prefix[0]:
|
||
# 可能开始思考
|
||
thinking_buffer = char
|
||
if len(thinking_prefix) == 1:
|
||
in_thinking = True
|
||
else:
|
||
yield {"type": "content", "text": char}
|
||
else:
|
||
yield {"type": "content", "text": text}
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
|
||
# 全局实例
|
||
llm_service = LLMService() |