架构级重构: - Supervisor节点:分析任务、分解子任务、智能调度Agent - Worker节点:各专业Agent(subgraph)独立执行子任务 - Aggregator节点:并行结果自动聚合 - Send API并行:多Agent同时处理不同子任务 - Agent注册表:AgentRegistry管理5个Agent - weather_agent: 天气专家 - math_agent: 数学专家 - knowledge_agent: 知识专家 - mcp_agent: MCP工具调用 - general_agent: 通用助手(兜底) - 共享State:messages/subtasks/results/final_answer - Supervisor输出JSON格式任务计划(parallel/single/direct)
822 lines
30 KiB
Python
822 lines
30 KiB
Python
"""
|
||
黄庄三号 Agent v3.0 - 多Agent交互版
|
||
====================================
|
||
架构: Supervisor + Worker(Agent) + Aggregator
|
||
- Supervisor: 分析任务,分解子任务,决定分给哪个Agent
|
||
- Worker: 各专业Agent(subgraph),独立执行子任务
|
||
- Aggregator: 聚合多Agent结果,生成最终回复
|
||
- 支持: 并行分发(Send API)、串行交接(Command handoff)
|
||
- Agent间通信: 共享State + Command + 消息总线
|
||
|
||
运行方式:
|
||
python3 agent.py --test 自动测试
|
||
python3 agent.py --mcp --test 带MCP测试
|
||
python3 agent.py --mcp 交互模式(带MCP)
|
||
python3 agent.py 交互模式(不带MCP)
|
||
"""
|
||
import os
|
||
import re
|
||
import json
|
||
import asyncio
|
||
import argparse
|
||
import importlib.util
|
||
from typing import Annotated, Literal
|
||
from typing_extensions import TypedDict
|
||
from pydantic import BaseModel, Field
|
||
from contextlib import AsyncExitStack
|
||
from pathlib import Path
|
||
from operator import add
|
||
|
||
import yaml
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_core.tools import tool
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||
from langgraph.graph import StateGraph, START, END
|
||
from langgraph.graph.message import add_messages
|
||
from langgraph.prebuilt import ToolNode
|
||
from langgraph.types import Send, Command
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 基础路径
|
||
# ════════════════════════════════════════════
|
||
BASE_DIR = Path(__file__).parent.resolve()
|
||
CONFIG_PATH = BASE_DIR / "config.yaml"
|
||
TOOLS_DIR = BASE_DIR / "tools"
|
||
SKILLS_DIR = BASE_DIR / "skills"
|
||
AGENTS_DIR = BASE_DIR / "agents"
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 配置加载
|
||
# ════════════════════════════════════════════
|
||
def load_config() -> dict:
|
||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 工具自动扫描注册
|
||
# ════════════════════════════════════════════
|
||
def scan_tools() -> list:
|
||
all_tools = []
|
||
if not TOOLS_DIR.exists():
|
||
return all_tools
|
||
for py_file in sorted(TOOLS_DIR.glob("*.py")):
|
||
if py_file.name.startswith("_"):
|
||
continue
|
||
try:
|
||
spec = importlib.util.spec_from_file_location(f"tools.{py_file.stem}", str(py_file))
|
||
mod = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(mod)
|
||
if hasattr(mod, "TOOLS"):
|
||
tool_list = mod.TOOLS
|
||
all_tools.extend(tool_list)
|
||
print(f" [工具] {py_file.name}: {[t.name for t in tool_list]}")
|
||
except Exception as e:
|
||
print(f" [工具] {py_file.name}: 加载失败 -> {e}")
|
||
return all_tools
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 技能自动扫描注册(兼容 OpenClaw)
|
||
# ════════════════════════════════════════════
|
||
class SkillDef(BaseModel):
|
||
name: str
|
||
description: str
|
||
prompt: str
|
||
tools: list[str] = []
|
||
skill_dir: str = ""
|
||
scripts: list[str] = []
|
||
is_openclaw: bool = False
|
||
|
||
class SkillRegistry:
|
||
def __init__(self):
|
||
self._skills: dict[str, SkillDef] = {}
|
||
|
||
def register(self, skill: SkillDef):
|
||
self._skills[skill.name] = skill
|
||
|
||
def get(self, name: str) -> SkillDef | None:
|
||
return self._skills.get(name)
|
||
|
||
def list_skills(self) -> list[SkillDef]:
|
||
return list(self._skills.values())
|
||
|
||
def format_list(self) -> str:
|
||
lines = []
|
||
for s in self._skills.values():
|
||
tag = " [OpenClaw]" if s.is_openclaw else ""
|
||
lines.append(f" - {s.name}: {s.description}{tag}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _parse_skill_md_frontmatter(content: str) -> dict:
|
||
if not content.startswith("---"):
|
||
return {}
|
||
end = content.find("---", 3)
|
||
if end == -1:
|
||
return {}
|
||
return yaml.safe_load(content[3:end].strip()) or {}
|
||
|
||
|
||
def scan_skills() -> SkillRegistry:
|
||
registry = SkillRegistry()
|
||
if not SKILLS_DIR.exists():
|
||
return registry
|
||
|
||
for yaml_file in sorted(SKILLS_DIR.glob("*.yaml")):
|
||
try:
|
||
with open(yaml_file, "r", encoding="utf-8") as f:
|
||
data = yaml.safe_load(f)
|
||
registry.register(SkillDef(
|
||
name=data["name"], description=data.get("description", ""),
|
||
prompt=data.get("prompt", ""), tools=data.get("tools", []),
|
||
))
|
||
print(f" [技能] {yaml_file.name}: {data['name']}")
|
||
except Exception as e:
|
||
print(f" [技能] {yaml_file.name}: 加载失败 -> {e}")
|
||
|
||
for skill_dir in sorted(SKILLS_DIR.iterdir()):
|
||
if not skill_dir.is_dir():
|
||
continue
|
||
skill_md = skill_dir / "SKILL.md"
|
||
if not skill_md.exists():
|
||
continue
|
||
try:
|
||
with open(skill_md, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
fm = _parse_skill_md_frontmatter(content)
|
||
if not fm:
|
||
continue
|
||
name = fm.get("name", skill_dir.name)
|
||
description = fm.get("description", "")
|
||
body_start = content.find("---", 3)
|
||
prompt = content[body_start + 3:].strip() if body_start != -1 else ""
|
||
scripts_dir = skill_dir / "scripts"
|
||
scripts, tools = [], []
|
||
if scripts_dir.exists():
|
||
for sf in sorted(scripts_dir.glob("*.py")):
|
||
if not sf.name.startswith("_"):
|
||
scripts.append(str(sf))
|
||
tools.append(sf.stem)
|
||
registry.register(SkillDef(
|
||
name=name, description=description, prompt=prompt, tools=tools,
|
||
skill_dir=str(skill_dir), scripts=scripts, is_openclaw=True,
|
||
))
|
||
print(f" [技能] {skill_dir.name}/ (OpenClaw): {name}")
|
||
except Exception as e:
|
||
print(f" [技能] {skill_dir.name}/SKILL.md: 加载失败 -> {e}")
|
||
|
||
return registry
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# MCP 管理器
|
||
# ════════════════════════════════════════════
|
||
class MCPManager:
|
||
def __init__(self):
|
||
self.exit_stack = AsyncExitStack()
|
||
self.sessions: dict[str, object] = {}
|
||
self.route_map: dict[str, tuple] = {}
|
||
self.mcp_tools: list = []
|
||
|
||
async def connect_all(self, mcp_configs: list[dict]):
|
||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||
from mcp.client.stdio import stdio_client, StdioServerParameters
|
||
from mcp.client.session import ClientSession
|
||
|
||
for srv in mcp_configs:
|
||
name = srv["name"]
|
||
try:
|
||
sp = StdioServerParameters(command=srv["command"], args=srv.get("args", []))
|
||
r, w = await self.exit_stack.enter_async_context(stdio_client(sp))
|
||
session = await self.exit_stack.enter_async_context(ClientSession(r, w))
|
||
await session.initialize()
|
||
mcp_tools = await load_mcp_tools(session)
|
||
self.mcp_tools.extend(mcp_tools)
|
||
self.sessions[name] = session
|
||
route_kw = srv.get("route_keywords", {})
|
||
for tool_name, keywords in route_kw.items():
|
||
for kw in keywords:
|
||
self.route_map[kw] = (session, tool_name)
|
||
print(f" [MCP] {name}: 已连接,{len(mcp_tools)} 个工具")
|
||
except Exception as e:
|
||
print(f" [MCP] {name}: 连接失败 -> {e}")
|
||
|
||
def match_route(self, user_input: str) -> tuple | None:
|
||
for keyword, (session, tool_name) in self.route_map.items():
|
||
if keyword in user_input:
|
||
return (session, tool_name)
|
||
return None
|
||
|
||
async def call_tool(self, session, tool_name: str, user_input: str) -> str:
|
||
try:
|
||
args = {}
|
||
for t in self.mcp_tools:
|
||
if t.name == tool_name:
|
||
schema = getattr(t, "args_schema", None)
|
||
if schema:
|
||
if isinstance(schema, dict):
|
||
sf = schema.get("properties", {})
|
||
elif hasattr(schema, "model_json_schema"):
|
||
sf = schema.model_json_schema().get("properties", {})
|
||
else:
|
||
sf = {}
|
||
args = _parse_tool_args(tool_name, sf, user_input)
|
||
break
|
||
result = await session.call_tool(tool_name, args)
|
||
if result.content:
|
||
texts = [c.text for c in result.content if hasattr(c, "text")]
|
||
return "\n".join(texts) if texts else str(result)
|
||
return str(result)
|
||
except Exception as e:
|
||
return f"[MCP工具{tool_name}错误] {e}"
|
||
|
||
async def close(self):
|
||
await self.exit_stack.aclose()
|
||
self.sessions.clear()
|
||
|
||
|
||
def _parse_tool_args(tool_name: str, schema_fields: dict, user_input: str) -> dict:
|
||
args = {}
|
||
for field_name, field_info in schema_fields.items():
|
||
if field_name in ("timezone",):
|
||
args[field_name] = "Asia/Shanghai"
|
||
elif field_name in ("text",):
|
||
match = re.search(r"['\"\u201c\u201d](.+?)['\"\u201c\u201d]", user_input)
|
||
args[field_name] = match.group(1) if match else user_input
|
||
elif field_name in ("city",):
|
||
args[field_name] = user_input
|
||
return args
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# ★ 核心:多 Agent 架构
|
||
# ════════════════════════════════════════════
|
||
|
||
# --- 共享状态 ---
|
||
class SharedState(TypedDict):
|
||
"""多Agent共享状态"""
|
||
messages: Annotated[list, add_messages] # 对话消息(累加)
|
||
subtasks: list[dict] # 分解的子任务
|
||
results: Annotated[list[str], add] # Agent执行结果(累加聚合)
|
||
active_agent: str # 当前活跃Agent
|
||
final_answer: str # 最终回复
|
||
|
||
|
||
class SubTaskState(TypedDict):
|
||
"""子任务状态(Send API用)"""
|
||
task: dict
|
||
messages: Annotated[list, add_messages]
|
||
|
||
|
||
# --- Agent 定义 ---
|
||
class AgentDef(BaseModel):
|
||
"""Agent定义"""
|
||
name: str
|
||
description: str
|
||
system_prompt: str
|
||
tools: list[str] = [] # 依赖的工具名
|
||
skill: str = "" # 依赖的技能名
|
||
|
||
|
||
class AgentRegistry:
|
||
"""Agent注册表"""
|
||
def __init__(self):
|
||
self._agents: dict[str, AgentDef] = {}
|
||
|
||
def register(self, agent: AgentDef):
|
||
self._agents[agent.name] = agent
|
||
|
||
def get(self, name: str) -> AgentDef | None:
|
||
return self._agents.get(name)
|
||
|
||
def list_agents(self) -> list[AgentDef]:
|
||
return list(self._agents.values())
|
||
|
||
def format_list(self) -> str:
|
||
return "\n".join(f" - {a.name}: {a.description}" for a in self._agents.values())
|
||
|
||
def match(self, task_type: str) -> AgentDef | None:
|
||
"""根据任务类型匹配Agent"""
|
||
# 精确匹配
|
||
if task_type in self._agents:
|
||
return self._agents[task_type]
|
||
# 关键词匹配
|
||
for agent in self._agents.values():
|
||
if agent.name in task_type or task_type in agent.name:
|
||
return agent
|
||
return None
|
||
|
||
|
||
def init_agents(skills_reg: SkillRegistry, tools_list: list) -> AgentRegistry:
|
||
"""初始化内置Agent"""
|
||
registry = AgentRegistry()
|
||
|
||
# 天气专家
|
||
registry.register(AgentDef(
|
||
name="weather_agent",
|
||
description="天气专家 - 查询天气、出行建议",
|
||
system_prompt="你是天气专家。根据天气数据给出专业的出行建议,包括穿衣、活动安排等。",
|
||
tools=["get_weather"],
|
||
skill="weather_analyst",
|
||
))
|
||
|
||
# 数学专家
|
||
registry.register(AgentDef(
|
||
name="math_agent",
|
||
description="数学专家 - 计算、数学问题解答",
|
||
system_prompt="你是数学专家。解答数学问题,给出计算过程和原理解释。",
|
||
tools=["calculate"],
|
||
skill="math_tutor",
|
||
))
|
||
|
||
# 知识专家
|
||
registry.register(AgentDef(
|
||
name="knowledge_agent",
|
||
description="知识专家 - 搜索知识、深入解释概念",
|
||
system_prompt="你是知识专家。搜索知识库,给出深入浅出的解释,结构化呈现信息。",
|
||
tools=["search_knowledge"],
|
||
skill="knowledge_explorer",
|
||
))
|
||
|
||
# 通用Agent(兜底)
|
||
registry.register(AgentDef(
|
||
name="general_agent",
|
||
description="通用助手 - 处理一般对话和简单问题",
|
||
system_prompt="你是黄庄三号通用助手。处理一般对话、问候和简单问题。",
|
||
tools=[],
|
||
skill="",
|
||
))
|
||
|
||
# MCP Agent
|
||
registry.register(AgentDef(
|
||
name="mcp_agent",
|
||
description="MCP工具调用 - 时间查询、字符统计、UUID生成等",
|
||
system_prompt="你是MCP工具调用专家。通过MCP协议调用外部工具获取实时数据。",
|
||
tools=[],
|
||
skill="",
|
||
))
|
||
|
||
return registry
|
||
|
||
|
||
# --- Supervisor 节点 ---
|
||
def make_supervisor_node(config, agent_registry, skills_reg, tools_list, mcp_mgr):
|
||
llm_cfg = config["llm"]
|
||
|
||
SUPERVISOR_PROMPT = """你是黄庄三号的任务协调者(Supervisor)。你的职责是:
|
||
|
||
1. 分析用户请求,判断是否需要分解为多个子任务
|
||
2. 如果需要多个Agent协作,输出JSON格式的子任务列表
|
||
3. 如果只需单个Agent处理,输出单个任务
|
||
4. 如果是简单对话,直接回复
|
||
|
||
可用Agent:
|
||
{agent_list}
|
||
|
||
MCP工具(当用户请求匹配这些关键词时,分配给对应Agent或在subtasks中指定agent为"mcp_agent"):
|
||
- 时间/几点/当前时间 → mcp_agent (get_current_time)
|
||
- 统计字符/字符数 → mcp_agent (count_chars)
|
||
- 生成UUID → mcp_agent (generate_uuid)
|
||
|
||
重要规则:
|
||
- 你的名字是"黄庄三号"
|
||
- 对于简单问候,直接回复,不要分配任务
|
||
- 输出格式必须是严格的JSON:
|
||
- 多任务: {{"mode": "parallel", "subtasks": [{{"agent": "agent名", "query": "具体任务"}}]}}
|
||
- 单任务: {{"mode": "single", "agent": "agent名", "query": "具体任务"}}
|
||
- 直接回复: {{"mode": "direct", "answer": "你的回复"}}
|
||
- 只输出JSON,不要其他内容"""
|
||
|
||
async def supervisor_node(state: SharedState) -> dict:
|
||
# 获取用户最新消息
|
||
user_msg = ""
|
||
for msg in reversed(state["messages"]):
|
||
if isinstance(msg, HumanMessage):
|
||
user_msg = msg.content
|
||
break
|
||
|
||
llm = ChatOpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
model=llm_cfg["model"],
|
||
temperature=0.1,
|
||
)
|
||
|
||
prompt = SUPERVISOR_PROMPT.format(agent_list=agent_registry.format_list())
|
||
messages = [SystemMessage(content=prompt), *state["messages"][-6:]]
|
||
|
||
resp = await llm.ainvoke(messages)
|
||
|
||
# 解析 Supervisor 的JSON输出
|
||
content = resp.content.strip()
|
||
# 清理可能的markdown代码块
|
||
if content.startswith("```"):
|
||
content = re.sub(r'^```\w*\n?', '', content)
|
||
content = re.sub(r'\n?```$', '', content)
|
||
content = content.strip()
|
||
|
||
try:
|
||
plan = json.loads(content)
|
||
except json.JSONDecodeError:
|
||
# JSON解析失败,当作直接回复
|
||
plan = {"mode": "direct", "answer": resp.content}
|
||
|
||
mode = plan.get("mode", "direct")
|
||
|
||
if mode == "direct":
|
||
return {
|
||
"messages": [AIMessage(content=plan.get("answer", resp.content))],
|
||
"final_answer": plan.get("answer", resp.content),
|
||
"subtasks": [],
|
||
}
|
||
elif mode == "single":
|
||
return {
|
||
"subtasks": [{"agent": plan["agent"], "query": plan.get("query", user_msg)}],
|
||
"active_agent": plan["agent"],
|
||
}
|
||
elif mode == "parallel":
|
||
return {
|
||
"subtasks": plan.get("subtasks", []),
|
||
}
|
||
else:
|
||
return {"subtasks": []}
|
||
|
||
return supervisor_node
|
||
|
||
|
||
# --- 任务分发路由(Send API并行) ---
|
||
def route_subtasks(state: SharedState):
|
||
"""根据subtasks决定路由:并行分发 or 结束"""
|
||
subtasks = state.get("subtasks", [])
|
||
if not subtasks:
|
||
return "end"
|
||
if len(subtasks) == 1:
|
||
# 单任务直接发给worker
|
||
return "worker_single"
|
||
# 多任务并行分发
|
||
return "worker_parallel"
|
||
|
||
|
||
def dispatch_parallel(state: SharedState):
|
||
"""并行分发:为每个子任务创建一个 Send"""
|
||
return [
|
||
Send("worker_node", {"task": t, "messages": state["messages"][-4:]})
|
||
for t in state.get("subtasks", [])
|
||
]
|
||
|
||
|
||
# --- Worker 节点(执行子任务) ---
|
||
def make_worker_node(config, agent_registry, skills_reg, tools_list, mcp_mgr):
|
||
llm_cfg = config["llm"]
|
||
agent_cfg = config.get("agent", {})
|
||
temp = agent_cfg.get("skill_temperature", 0.7)
|
||
|
||
async def worker_node(state: SubTaskState) -> dict:
|
||
task = state["task"]
|
||
agent_name = task.get("agent", "general_agent")
|
||
query = task.get("query", "")
|
||
|
||
agent_def = agent_registry.get(agent_name)
|
||
if not agent_def:
|
||
agent_def = agent_registry.get("general_agent")
|
||
|
||
# 执行工具(如果有)
|
||
tool_info = ""
|
||
for tname in agent_def.tools:
|
||
for t in tools_list:
|
||
if t.name == tname:
|
||
try:
|
||
if tname == "get_weather":
|
||
cities = ["北京", "上海", "深圳", "黄庄"]
|
||
city = next((c for c in cities if c in query), "北京")
|
||
r = await t.ainvoke({"city": city})
|
||
elif tname == "calculate":
|
||
expr = re.findall(r'[\d+\-*/(). ]+', query)
|
||
r = await t.ainvoke({"expression": expr[0].strip() if expr else "1+1"})
|
||
else:
|
||
r = await t.ainvoke({"query": query})
|
||
tool_info += f"\n工具{tname}结果: {r}"
|
||
except Exception as e:
|
||
tool_info += f"\n工具{tname}错误: {e}"
|
||
|
||
# MCP工具(如果有)
|
||
if mcp_mgr and query:
|
||
route = mcp_mgr.match_route(query)
|
||
if route:
|
||
session, tool_name = route
|
||
mcp_result = await mcp_mgr.call_tool(session, tool_name, query)
|
||
tool_info += f"\nMCP工具{tool_name}结果: {mcp_result}"
|
||
|
||
# 执行技能脚本(如果是OpenClaw技能)
|
||
if agent_def.skill:
|
||
sk = skills_reg.get(agent_def.skill)
|
||
if sk and sk.is_openclaw and sk.scripts:
|
||
for script_path in sk.scripts:
|
||
try:
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"python3", script_path, query,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=30)
|
||
output = stdout.decode("utf-8", errors="replace").strip()
|
||
if output:
|
||
tool_info += f"\n[脚本输出]\n{output[:2000]}"
|
||
except Exception as e:
|
||
tool_info += f"\n[脚本错误] {e}"
|
||
|
||
# 构造提示词
|
||
prompt = agent_def.system_prompt
|
||
if agent_def.skill:
|
||
sk = skills_reg.get(agent_def.skill)
|
||
if sk and not sk.is_openclaw:
|
||
prompt = sk.prompt.format(input=query)
|
||
|
||
worker_llm = ChatOpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
model=llm_cfg["model"],
|
||
temperature=temp,
|
||
)
|
||
|
||
messages = [
|
||
SystemMessage(content=prompt),
|
||
HumanMessage(content=query),
|
||
]
|
||
if tool_info:
|
||
messages.append(SystemMessage(content=f"工具/脚本提供的数据:{tool_info}"))
|
||
|
||
resp = await worker_llm.ainvoke(messages)
|
||
|
||
result_tag = f"[{agent_def.name}]"
|
||
return {"results": [f"{result_tag} {resp.content}"]}
|
||
|
||
return worker_node
|
||
|
||
|
||
# --- 单任务Worker(不走Send,直接串行) ---
|
||
def make_worker_single_node(config, agent_registry, skills_reg, tools_list, mcp_mgr):
|
||
"""和 worker_node 逻辑相同,但用于单任务串行执行"""
|
||
worker = make_worker_node(config, agent_registry, skills_reg, tools_list, mcp_mgr)
|
||
|
||
async def worker_single_node(state: SharedState) -> dict:
|
||
subtasks = state.get("subtasks", [])
|
||
if not subtasks:
|
||
return {"results": [], "final_answer": "没有可执行的任务"}
|
||
|
||
task = subtasks[0]
|
||
sub_state = SubTaskState(task=task, messages=state["messages"][-4:])
|
||
result = await worker(sub_state)
|
||
|
||
# 单任务直接作为最终回复
|
||
content = result["results"][0] if result.get("results") else "执行完成"
|
||
# 去掉 agent tag 前缀(用户不需要看到)
|
||
clean = re.sub(r'^\[.+?\]\s*', '', content)
|
||
return {
|
||
"results": result.get("results", []),
|
||
"final_answer": clean,
|
||
"messages": [AIMessage(content=clean)],
|
||
}
|
||
|
||
return worker_single_node
|
||
|
||
|
||
# --- 聚合节点 ---
|
||
def make_aggregator_node(config):
|
||
llm_cfg = config["llm"]
|
||
|
||
async def aggregator_node(state: SharedState) -> dict:
|
||
results = state.get("results", [])
|
||
if not results:
|
||
return {"final_answer": "所有Agent已完成,但无结果"}
|
||
|
||
# 如果只有一个结果,直接使用
|
||
if len(results) == 1:
|
||
content = re.sub(r'^\[.+?\]\s*', '', results[0])
|
||
return {
|
||
"final_answer": content,
|
||
"messages": [AIMessage(content=content)],
|
||
}
|
||
|
||
# 多结果需要聚合
|
||
combined = "\n\n---\n\n".join(results)
|
||
agg_llm = ChatOpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
model=llm_cfg["model"],
|
||
temperature=0.5,
|
||
)
|
||
|
||
resp = await agg_llm.ainvoke([
|
||
SystemMessage(content="你是黄庄三号。请将以下多个专业Agent的结果整合为一个连贯、结构化的回答。保留关键信息,去除冗余。"),
|
||
HumanMessage(content=f"多个Agent的结果:\n\n{combined}"),
|
||
])
|
||
|
||
return {
|
||
"final_answer": resp.content,
|
||
"messages": [AIMessage(content=resp.content)],
|
||
}
|
||
|
||
return aggregator_node
|
||
|
||
|
||
# --- Agent间 Handoff 交接 ---
|
||
def make_handoff_tool(agent_registry):
|
||
"""创建Agent间交接工具"""
|
||
from langchain_core.tools import tool as lc_tool
|
||
|
||
agent_names = [a.name for a in agent_registry.list_agents()]
|
||
|
||
@lc_tool
|
||
def handoff_to_agent(target_agent: str, task_description: str) -> str:
|
||
"""将任务交接给另一个Agent处理。
|
||
target_agent: 目标Agent名称,可选: """ + ", ".join(agent_names) + """
|
||
task_description: 需要交接的任务描述
|
||
"""
|
||
if target_agent not in agent_names:
|
||
return f"Agent {target_agent} 不存在,可用Agent: {agent_names}"
|
||
return f"任务已交接给 {target_agent}: {task_description}"
|
||
|
||
return handoff_tool
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 构建多Agent图
|
||
# ════════════════════════════════════════════
|
||
async def build_graph(config, skills_reg, mcp_mgr, tools_list):
|
||
agent_registry = init_agents(skills_reg, tools_list)
|
||
|
||
supervisor = make_supervisor_node(config, agent_registry, skills_reg, tools_list, mcp_mgr)
|
||
worker = make_worker_node(config, agent_registry, skills_reg, tools_list, mcp_mgr)
|
||
worker_single = make_worker_single_node(config, agent_registry, skills_reg, tools_list, mcp_mgr)
|
||
aggregator = make_aggregator_node(config)
|
||
|
||
g = StateGraph(SharedState)
|
||
|
||
# 添加节点
|
||
g.add_node("supervisor", supervisor)
|
||
g.add_node("worker_node", worker) # 并行worker (Send API分发到这里)
|
||
g.add_node("worker_single", worker_single) # 串行worker
|
||
g.add_node("aggregator", aggregator) # 聚合器
|
||
|
||
# 添加边
|
||
g.add_edge(START, "supervisor")
|
||
|
||
# Supervisor路由:单条 conditional_edges,用 Send 实现并行
|
||
def route_from_supervisor(state: SharedState):
|
||
subtasks = state.get("subtasks", [])
|
||
final_answer = state.get("final_answer", "")
|
||
# 直接回复
|
||
if final_answer or not subtasks:
|
||
return "end"
|
||
# 单任务串行
|
||
if len(subtasks) == 1:
|
||
return "worker_single"
|
||
# 多任务并行分发 (通过 Send API)
|
||
return [
|
||
Send("worker_node", {"task": t, "messages": state["messages"][-4:]})
|
||
for t in subtasks
|
||
]
|
||
|
||
g.add_conditional_edges("supervisor", route_from_supervisor, {
|
||
"end": END,
|
||
"worker_single": "worker_single",
|
||
"worker_node": "worker_node", # Send API 的目标节点
|
||
})
|
||
|
||
# Worker完成后的流向
|
||
g.add_edge("worker_node", "aggregator") # 并行结果聚合
|
||
g.add_edge("worker_single", END) # 单任务直接结束
|
||
g.add_edge("aggregator", END) # 聚合后结束
|
||
|
||
return g.compile()
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 运行入口
|
||
# ════════════════════════════════════════════
|
||
all_tools = []
|
||
skills_registry = SkillRegistry()
|
||
mcp_manager = None
|
||
|
||
|
||
async def run_agent(user_input: str, graph):
|
||
result = await graph.ainvoke({
|
||
"messages": [HumanMessage(content=user_input)],
|
||
"subtasks": [], "results": [], "active_agent": "", "final_answer": "",
|
||
})
|
||
last = result["messages"][-1]
|
||
return {
|
||
"reply": last.content if hasattr(last, "content") else str(last),
|
||
"subtasks": result.get("subtasks", []),
|
||
"results_count": len(result.get("results", [])),
|
||
"final_answer": result.get("final_answer", ""),
|
||
}
|
||
|
||
|
||
async def interactive_mode(graph):
|
||
print("=" * 60)
|
||
print(" 黄庄三号 Agent v3.0 - 多Agent交互版")
|
||
print(" Supervisor + Worker(Agent) + Aggregator")
|
||
print("=" * 60)
|
||
print(" 输入 quit 退出")
|
||
print("=" * 60)
|
||
|
||
while True:
|
||
try:
|
||
user_input = input("\n你> ").strip()
|
||
except (EOFError, KeyboardInterrupt):
|
||
break
|
||
if not user_input or user_input.lower() in ("quit", "exit", "q"):
|
||
break
|
||
|
||
result = await run_agent(user_input, graph)
|
||
if result["subtasks"]:
|
||
agents = [t.get("agent", "?") for t in result["subtasks"]]
|
||
print(f"\n[调度] 分配给: {', '.join(agents)}")
|
||
if result["results_count"] > 1:
|
||
print(f"[聚合] 合并 {result['results_count']} 个Agent结果")
|
||
print(f"\n黄庄三号> {result['reply']}")
|
||
|
||
|
||
async def main():
|
||
global all_tools, skills_registry, mcp_manager
|
||
|
||
parser = argparse.ArgumentParser(description="黄庄三号 Agent v3.0")
|
||
parser.add_argument("--mcp", action="store_true", help="启用MCP")
|
||
parser.add_argument("--test", action="store_true", help="自动测试")
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 60)
|
||
print(" 黄庄三号 Agent v3.0 - 多Agent交互版")
|
||
print("=" * 60)
|
||
|
||
# 加载配置
|
||
print("\n[配置] 加载 config.yaml ...")
|
||
config = load_config()
|
||
|
||
# 扫描工具
|
||
print("\n[工具] 扫描 tools/ ...")
|
||
all_tools = scan_tools()
|
||
print(f" 工具总数: {len(all_tools)}")
|
||
|
||
# 扫描技能
|
||
print("\n[技能] 扫描 skills/ ...")
|
||
skills_registry = scan_skills()
|
||
print(f" 技能总数: {len(skills_registry.list_skills())}")
|
||
|
||
# 连接MCP
|
||
if args.mcp and config.get("mcp_servers"):
|
||
print("\n[MCP] 连接服务器 ...")
|
||
mcp_manager = MCPManager()
|
||
await mcp_manager.connect_all(config["mcp_servers"])
|
||
|
||
# 初始化Agent
|
||
agent_registry = init_agents(skills_registry, all_tools)
|
||
print(f"\n[Agent] 已注册 {len(agent_registry.list_agents())} 个Agent:")
|
||
for a in agent_registry.list_agents():
|
||
print(f" - {a.name}: {a.description}")
|
||
|
||
# 构建图
|
||
print("\n[构建] 多Agent图 ...")
|
||
graph = await build_graph(config, skills_registry, mcp_manager, all_tools)
|
||
|
||
if args.test:
|
||
tests = [
|
||
("单Agent:天气", "黄庄天气怎么样?"),
|
||
("单Agent:数学", "算一下 99*88+77"),
|
||
("单Agent:知识", "MCP是什么?"),
|
||
("直接回复", "你好你是谁?"),
|
||
("多Agent并行", "帮我查一下北京和上海的天气,再算一下123+456"),
|
||
]
|
||
|
||
if args.mcp and mcp_manager:
|
||
tests.append(("MCP:时间", "现在几点了?"))
|
||
|
||
for label, query in tests:
|
||
print(f"\n{'─'*55}")
|
||
print(f"[测试:{label}] {query}")
|
||
r = await run_agent(query, graph)
|
||
if r["subtasks"]:
|
||
agents = [t.get("agent", "?") for t in r["subtasks"]]
|
||
print(f" 调度: {', '.join(agents)}")
|
||
if r["results_count"] > 1:
|
||
print(f" 聚合: {r['results_count']} 个Agent结果")
|
||
print(f" 回复: {r['reply'][:150]}...")
|
||
|
||
print(f"\n{'='*60}")
|
||
print(" 多Agent测试完成!")
|
||
print("=" * 60)
|
||
else:
|
||
await interactive_mode(graph)
|
||
|
||
if mcp_manager:
|
||
await mcp_manager.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|