- scan_skills() 支持两种格式: .yaml 文件 和 目录/SKILL.md - 解析 SKILL.md 的 YAML frontmatter 提取 name/description - 自动扫描 scripts/ 子目录的 .py 脚本 - skill_exec 支持 asyncio 子进程执行脚本 - 新增示例 OpenClaw 技能: skills/time-tool/
740 lines
28 KiB
Python
740 lines
28 KiB
Python
"""
|
||
黄庄三号 Agent v2.0 - 配置驱动版
|
||
==================================
|
||
所有工具、技能、MCP服务器、路由关键词均从配置加载
|
||
新增:丢文件到 tools/ 或 skills/ 即可,不改源码
|
||
|
||
运行方式:
|
||
python3 agent.py --test 自动测试
|
||
python3 agent.py --mcp --test 带MCP测试
|
||
python3 agent.py --mcp 交互模式(带MCP)
|
||
python3 agent.py 交互模式(不带MCP)
|
||
"""
|
||
import os
|
||
import sys
|
||
import re
|
||
import asyncio
|
||
import argparse
|
||
import importlib.util
|
||
from typing import Annotated
|
||
from typing_extensions import TypedDict
|
||
from pydantic import BaseModel, Field
|
||
from contextlib import AsyncExitStack
|
||
from pathlib import Path
|
||
|
||
import yaml
|
||
from langchain_openai import ChatOpenAI
|
||
from langchain_core.tools import tool
|
||
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage
|
||
from langgraph.graph import StateGraph, START, END
|
||
from langgraph.graph.message import add_messages
|
||
from langgraph.prebuilt import ToolNode
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 基础路径
|
||
# ════════════════════════════════════════════
|
||
BASE_DIR = Path(__file__).parent.resolve()
|
||
CONFIG_PATH = BASE_DIR / "config.yaml"
|
||
TOOLS_DIR = BASE_DIR / "tools"
|
||
SKILLS_DIR = BASE_DIR / "skills"
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 配置加载
|
||
# ════════════════════════════════════════════
|
||
def load_config() -> dict:
|
||
"""加载 config.yaml"""
|
||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 工具自动扫描注册
|
||
# ════════════════════════════════════════════
|
||
def scan_tools() -> list:
|
||
"""
|
||
扫描 tools/ 目录下所有 .py 文件
|
||
每个文件必须暴露 TOOLS 列表
|
||
"""
|
||
all_tools = []
|
||
if not TOOLS_DIR.exists():
|
||
print(" [工具] tools/ 目录不存在,跳过")
|
||
return all_tools
|
||
|
||
for py_file in sorted(TOOLS_DIR.glob("*.py")):
|
||
if py_file.name.startswith("_"):
|
||
continue
|
||
try:
|
||
spec = importlib.util.spec_from_file_location(
|
||
f"tools.{py_file.stem}", str(py_file)
|
||
)
|
||
mod = importlib.util.module_from_spec(spec)
|
||
spec.loader.exec_module(mod)
|
||
|
||
if hasattr(mod, "TOOLS"):
|
||
tool_list = mod.TOOLS
|
||
all_tools.extend(tool_list)
|
||
print(f" [工具] {py_file.name}: 加载 {len(tool_list)} 个 -> {[t.name for t in tool_list]}")
|
||
else:
|
||
print(f" [工具] {py_file.name}: 无 TOOLS 变量,跳过")
|
||
except Exception as e:
|
||
print(f" [工具] {py_file.name}: 加载失败 -> {e}")
|
||
|
||
return all_tools
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 技能自动扫描注册(兼容 OpenClaw 格式)
|
||
# ════════════════════════════════════════════
|
||
class SkillDef(BaseModel):
|
||
name: str
|
||
description: str
|
||
prompt: str
|
||
tools: list[str] = []
|
||
# OpenClaw 扩展字段
|
||
skill_dir: str = "" # OpenClaw skill 目录路径
|
||
scripts: list[str] = [] # scripts/ 下的脚本列表
|
||
is_openclaw: bool = False # 是否 OpenClaw 格式
|
||
|
||
class SkillRegistry:
|
||
def __init__(self):
|
||
self._skills: dict[str, SkillDef] = {}
|
||
|
||
def register(self, skill: SkillDef):
|
||
self._skills[skill.name] = skill
|
||
|
||
def get(self, name: str) -> SkillDef | None:
|
||
return self._skills.get(name)
|
||
|
||
def list_skills(self) -> list[SkillDef]:
|
||
return list(self._skills.values())
|
||
|
||
def format_list(self) -> str:
|
||
lines = []
|
||
for s in self._skills.values():
|
||
tag = " [OpenClaw]" if s.is_openclaw else ""
|
||
scripts_info = f", scripts: {s.scripts}" if s.scripts else ""
|
||
lines.append(f" - {s.name}: {s.description}{tag}{scripts_info}")
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _parse_skill_md_frontmatter(content: str) -> dict:
|
||
"""解析 SKILL.md 的 YAML frontmatter"""
|
||
if not content.startswith("---"):
|
||
return {}
|
||
end = content.find("---", 3)
|
||
if end == -1:
|
||
return {}
|
||
fm = content[3:end].strip()
|
||
return yaml.safe_load(fm) or {}
|
||
|
||
|
||
def scan_skills() -> SkillRegistry:
|
||
"""
|
||
扫描 skills/ 目录,支持两种格式:
|
||
1. .yaml 文件 - 简洁技能定义
|
||
2. 目录/SKILL.md - OpenClaw 格式(含 scripts/ 子目录)
|
||
"""
|
||
registry = SkillRegistry()
|
||
if not SKILLS_DIR.exists():
|
||
print(" [技能] skills/ 目录不存在,跳过")
|
||
return registry
|
||
|
||
# 1) 扫描 .yaml 文件
|
||
for yaml_file in sorted(SKILLS_DIR.glob("*.yaml")):
|
||
try:
|
||
with open(yaml_file, "r", encoding="utf-8") as f:
|
||
data = yaml.safe_load(f)
|
||
|
||
skill = SkillDef(
|
||
name=data["name"],
|
||
description=data.get("description", ""),
|
||
prompt=data.get("prompt", ""),
|
||
tools=data.get("tools", []),
|
||
)
|
||
registry.register(skill)
|
||
print(f" [技能] {yaml_file.name}: {skill.name}")
|
||
except Exception as e:
|
||
print(f" [技能] {yaml_file.name}: 加载失败 -> {e}")
|
||
|
||
# 2) 扫描 OpenClaw 格式(目录/SKILL.md)
|
||
for skill_dir in sorted(SKILLS_DIR.iterdir()):
|
||
if not skill_dir.is_dir():
|
||
continue
|
||
skill_md = skill_dir / "SKILL.md"
|
||
if not skill_md.exists():
|
||
continue
|
||
try:
|
||
with open(skill_md, "r", encoding="utf-8") as f:
|
||
content = f.read()
|
||
|
||
fm = _parse_skill_md_frontmatter(content)
|
||
if not fm:
|
||
print(f" [技能] {skill_dir.name}/SKILL.md: 无 frontmatter,跳过")
|
||
continue
|
||
|
||
name = fm.get("name", skill_dir.name)
|
||
description = fm.get("description", "")
|
||
|
||
# 把 SKILL.md 的 markdown body 作为 prompt
|
||
body_start = content.find("---", 3)
|
||
prompt = content[body_start + 3:].strip() if body_start != -1 else ""
|
||
|
||
# 扫描 scripts/ 子目录
|
||
scripts_dir = skill_dir / "scripts"
|
||
scripts = []
|
||
if scripts_dir.exists():
|
||
for script_file in sorted(scripts_dir.glob("*.py")):
|
||
if not script_file.name.startswith("_"):
|
||
scripts.append(str(script_file))
|
||
|
||
# 从 prompt 中提取 tools 引用(### scripts/xxx.py 段落)
|
||
# OpenClaw skill 通常在 body 中提到脚本名
|
||
tools = []
|
||
for script_path in scripts:
|
||
# 脚本文件名(不含扩展名)作为工具名
|
||
tool_name = Path(script_path).stem
|
||
tools.append(tool_name)
|
||
|
||
skill = SkillDef(
|
||
name=name,
|
||
description=description,
|
||
prompt=prompt,
|
||
tools=tools,
|
||
skill_dir=str(skill_dir),
|
||
scripts=scripts,
|
||
is_openclaw=True,
|
||
)
|
||
registry.register(skill)
|
||
scripts_str = f", scripts: {len(scripts)}" if scripts else ""
|
||
print(f" [技能] {skill_dir.name}/ (OpenClaw): {name}{scripts_str}")
|
||
|
||
except Exception as e:
|
||
print(f" [技能] {skill_dir.name}/SKILL.md: 加载失败 -> {e}")
|
||
|
||
return registry
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# MCP 管理器(配置驱动,支持多服务器)
|
||
# ════════════════════════════════════════════
|
||
class MCPManager:
|
||
"""管理多个MCP服务器连接"""
|
||
|
||
def __init__(self):
|
||
self.exit_stack = AsyncExitStack()
|
||
self.sessions: dict[str, object] = {} # server_name -> session
|
||
self.route_map: dict[str, tuple] = {} # keyword -> (session, tool_name)
|
||
self.mcp_tools: list = []
|
||
|
||
async def connect_all(self, mcp_configs: list[dict]):
|
||
"""根据配置连接所有MCP服务器"""
|
||
global all_tools
|
||
|
||
from langchain_mcp_adapters.tools import load_mcp_tools
|
||
from mcp.client.stdio import stdio_client, StdioServerParameters
|
||
from mcp.client.session import ClientSession
|
||
|
||
for srv in mcp_configs:
|
||
name = srv["name"]
|
||
try:
|
||
server_params = StdioServerParameters(
|
||
command=srv["command"],
|
||
args=srv.get("args", []),
|
||
)
|
||
|
||
read, write = await self.exit_stack.enter_async_context(
|
||
stdio_client(server_params)
|
||
)
|
||
session = await self.exit_stack.enter_async_context(
|
||
ClientSession(read, write)
|
||
)
|
||
await session.initialize()
|
||
|
||
# 加载工具
|
||
mcp_tools = await load_mcp_tools(session)
|
||
self.mcp_tools.extend(mcp_tools)
|
||
self.sessions[name] = session
|
||
|
||
# 注册路由关键词
|
||
route_kw = srv.get("route_keywords", {})
|
||
for tool_name, keywords in route_kw.items():
|
||
for kw in keywords:
|
||
self.route_map[kw] = (session, tool_name)
|
||
|
||
print(f" [MCP] {name}: 已连接,加载 {len(mcp_tools)} 个工具")
|
||
for t in mcp_tools:
|
||
print(f" - {t.name}: {t.description[:50]}")
|
||
if route_kw:
|
||
print(f" 路由关键词: {list(route_kw.keys())}")
|
||
|
||
except Exception as e:
|
||
print(f" [MCP] {name}: 连接失败 -> {e}")
|
||
|
||
# 把MCP工具加入全局列表(供LLM bind_tools用)
|
||
all_tools.extend(self.mcp_tools)
|
||
|
||
def match_route(self, user_input: str) -> tuple | None:
|
||
"""关键词匹配MCP路由,返回 (session, tool_name) 或 None"""
|
||
for keyword, (session, tool_name) in self.route_map.items():
|
||
if keyword in user_input:
|
||
return (session, tool_name)
|
||
return None
|
||
|
||
async def call_tool(self, session, tool_name: str, user_input: str) -> str:
|
||
"""通过MCP session调用工具"""
|
||
try:
|
||
# 简单参数解析
|
||
args = {}
|
||
|
||
# 从MCP工具列表中查找工具信息
|
||
for t in self.mcp_tools:
|
||
if t.name == tool_name:
|
||
# MCP工具的 args_schema 可能是 dict 或 Pydantic model
|
||
schema = getattr(t, "args_schema", None)
|
||
if schema:
|
||
if isinstance(schema, dict):
|
||
schema_fields = schema.get("properties", {})
|
||
elif hasattr(schema, "model_json_schema"):
|
||
schema_fields = schema.model_json_schema().get("properties", {})
|
||
else:
|
||
schema_fields = {}
|
||
args = _parse_tool_args(tool_name, schema_fields, user_input)
|
||
break
|
||
|
||
result = await session.call_tool(tool_name, args)
|
||
if result.content:
|
||
texts = [c.text for c in result.content if hasattr(c, "text")]
|
||
return "\n".join(texts) if texts else str(result)
|
||
return str(result)
|
||
except Exception as e:
|
||
return f"[MCP工具{tool_name}调用错误] {e}"
|
||
|
||
async def close(self):
|
||
await self.exit_stack.aclose()
|
||
self.sessions.clear()
|
||
print(" [MCP] 所有连接已关闭")
|
||
|
||
|
||
def _parse_tool_args(tool_name: str, schema_fields: dict, user_input: str) -> dict:
|
||
"""根据工具参数schema从用户输入中解析参数"""
|
||
args = {}
|
||
for field_name, field_info in schema_fields.items():
|
||
if field_name in ("timezone",):
|
||
args[field_name] = "Asia/Shanghai"
|
||
elif field_name in ("text",):
|
||
# 提取引号内的文本
|
||
match = re.search(r"['\"\u201c\u201d](.+?)['\"\u201c\u201d]", user_input)
|
||
args[field_name] = match.group(1) if match else user_input
|
||
elif field_name in ("city",):
|
||
args[field_name] = user_input
|
||
return args
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# Agent 状态
|
||
# ════════════════════════════════════════════
|
||
class AgentState(TypedDict):
|
||
messages: Annotated[list, add_messages]
|
||
thinking: str
|
||
active_skill: str | None
|
||
skill_output: str | None
|
||
iteration: int
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# LangGraph 节点
|
||
# ════════════════════════════════════════════
|
||
|
||
# --- 思考节点 ---
|
||
async def make_think_node(config, skills_reg, tools_list):
|
||
llm_cfg = config["llm"]
|
||
agent_cfg = config.get("agent", {})
|
||
temp = agent_cfg.get("think_temperature", 0.3)
|
||
|
||
async def think_node(state: AgentState) -> dict:
|
||
iteration = state.get("iteration", 0) + 1
|
||
if iteration > 3:
|
||
return {"iteration": iteration, "thinking": "(快速模式)"}
|
||
|
||
conv = []
|
||
for msg in state["messages"][-4:]:
|
||
role = "用户" if isinstance(msg, HumanMessage) else "AI"
|
||
conv.append(f"{role}: {msg.content[:150]}")
|
||
|
||
tool_names = [t.name for t in tools_list]
|
||
think_llm = ChatOpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
model=llm_cfg["model"],
|
||
temperature=temp,
|
||
)
|
||
resp = await think_llm.ainvoke([
|
||
SystemMessage(content="你是思考模块。简洁输出:用户意图、需要的工具/技能、注意事项。不要说没有工具。"),
|
||
HumanMessage(content=f"对话:\n{chr(10).join(conv)}\n\n可用技能:\n{skills_reg.format_list()}\n\n可用工具: {', '.join(tool_names)}"),
|
||
])
|
||
return {"iteration": iteration, "thinking": resp.content}
|
||
|
||
return think_node
|
||
|
||
|
||
# --- 技能路由节点 ---
|
||
async def make_skill_route_node(config, skills_reg, mcp_mgr):
|
||
skill_keywords = config.get("skill_keywords", {})
|
||
|
||
async def skill_route_node(state: AgentState) -> dict:
|
||
user_input = ""
|
||
for msg in reversed(state["messages"]):
|
||
if isinstance(msg, HumanMessage):
|
||
user_input = msg.content
|
||
break
|
||
|
||
# 1. MCP确定性路由(优先)
|
||
if mcp_mgr:
|
||
route = mcp_mgr.match_route(user_input)
|
||
if route:
|
||
session, tool_name = route
|
||
mcp_result = await mcp_mgr.call_tool(session, tool_name, user_input)
|
||
return {"active_skill": None, "skill_output": mcp_result}
|
||
|
||
# 2. Skill关键词路由
|
||
for sname, keywords in skill_keywords.items():
|
||
if any(kw in user_input for kw in keywords):
|
||
if skills_reg.get(sname):
|
||
return {"active_skill": sname, "skill_output": None}
|
||
|
||
return {"active_skill": None, "skill_output": None}
|
||
|
||
return skill_route_node
|
||
|
||
|
||
# --- 技能执行节点 ---
|
||
async def make_skill_exec_node(config, skills_reg, tools_list):
|
||
llm_cfg = config["llm"]
|
||
agent_cfg = config.get("agent", {})
|
||
temp = agent_cfg.get("skill_temperature", 0.7)
|
||
|
||
async def skill_execute_node(state: AgentState) -> dict:
|
||
sname = state.get("active_skill")
|
||
if not sname:
|
||
return {"skill_output": None}
|
||
sk = skills_reg.get(sname)
|
||
if not sk:
|
||
return {"skill_output": None}
|
||
|
||
user_input = ""
|
||
for msg in reversed(state["messages"]):
|
||
if isinstance(msg, HumanMessage):
|
||
user_input = msg.content
|
||
break
|
||
|
||
tool_info = ""
|
||
|
||
# ---- OpenClaw 技能:执行 scripts/ 下的脚本 ----
|
||
if sk.is_openclaw and sk.scripts:
|
||
for script_path in sk.scripts:
|
||
try:
|
||
proc = await asyncio.create_subprocess_exec(
|
||
"python3", script_path, user_input,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=30)
|
||
output = stdout.decode("utf-8", errors="replace").strip()
|
||
err = stderr.decode("utf-8", errors="replace").strip()
|
||
script_name = Path(script_path).name
|
||
if output:
|
||
tool_info += f"\n[脚本{script_name}输出]\n{output[:2000]}"
|
||
if err and "warning" not in err.lower():
|
||
tool_info += f"\n[脚本{script_name}错误]\n{err[:500]}"
|
||
except asyncio.TimeoutError:
|
||
tool_info += f"\n[脚本{Path(script_path).name}] 执行超时"
|
||
except Exception as e:
|
||
tool_info += f"\n[脚本{Path(script_path).name}] 执行错误: {e}"
|
||
|
||
# ---- 通用技能:执行依赖的本地工具 ----
|
||
for tname in sk.tools:
|
||
# 跳过 OpenClaw 已通过脚本执行的工具
|
||
if sk.is_openclaw:
|
||
continue
|
||
for t in tools_list:
|
||
if t.name == tname:
|
||
try:
|
||
if tname == "get_weather":
|
||
cities = ["北京", "上海", "深圳", "黄庄"]
|
||
city = next((c for c in cities if c in user_input), "北京")
|
||
r = await t.ainvoke({"city": city})
|
||
elif tname == "calculate":
|
||
expr = re.findall(r'[\d+\-*/(). ]+', user_input)
|
||
r = await t.ainvoke({"expression": expr[0].strip() if expr else "1+1"})
|
||
else:
|
||
r = await t.ainvoke({"query": user_input})
|
||
tool_info += f"\n工具{tname}结果: {r}"
|
||
except Exception as e:
|
||
tool_info += f"\n工具{tname}错误: {e}"
|
||
|
||
# 构造提示词
|
||
if sk.is_openclaw:
|
||
# OpenClaw 技能:用 SKILL.md body 作为指导 + 脚本输出
|
||
prompt = f"""你是技能"{sk.name}"的执行者。
|
||
|
||
技能说明:
|
||
{sk.prompt[:2000]}
|
||
|
||
脚本执行结果:
|
||
{tool_info if tool_info else "(无脚本输出)"}
|
||
|
||
用户请求:{user_input}
|
||
|
||
请基于技能说明和脚本输出回答用户。"""
|
||
else:
|
||
# 通用技能:用 prompt 模板
|
||
prompt = sk.prompt.format(input=user_input) + tool_info
|
||
|
||
sk_llm = ChatOpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
model=llm_cfg["model"],
|
||
temperature=temp,
|
||
)
|
||
resp = await sk_llm.ainvoke([
|
||
SystemMessage(content=prompt),
|
||
HumanMessage(content="请基于以上信息回答。"),
|
||
])
|
||
return {"skill_output": resp.content}
|
||
|
||
return skill_execute_node
|
||
|
||
|
||
# --- Agent主节点 ---
|
||
async def make_agent_node(config, skills_reg, tools_list):
|
||
llm_cfg = config["llm"]
|
||
agent_cfg = config.get("agent", {})
|
||
max_iter = agent_cfg.get("max_iterations", 5)
|
||
|
||
SYSTEM_PROMPT = """你是黄庄三号,严肃、认真、听话、聪明的AI助手。你的名字是"黄庄三号",你不是Claude,不是ChatGPT。
|
||
|
||
你具备四种能力:
|
||
1. 工具调用(FC) - 调用内置工具获取信息
|
||
2. MCP集成 - 通过MCP协议连接外部服务
|
||
3. 思考模式 - 回答前进行深度思考
|
||
4. 技能系统(Skill) - 调用注册技能完成复杂任务
|
||
|
||
可用技能:
|
||
{skill_list}
|
||
|
||
重要规则(必须严格遵守):
|
||
- 当被问"你是谁",必须回答"我是黄庄三号"
|
||
- 对于工具能提供的数据,必须调用工具获取,不要自己猜测"""
|
||
|
||
async def agent_node(state: AgentState) -> dict:
|
||
iteration = state.get("iteration", 0)
|
||
|
||
if state.get("skill_output"):
|
||
return {"messages": [AIMessage(content=state["skill_output"])]}
|
||
|
||
system_content = SYSTEM_PROMPT.format(skill_list=skills_reg.format_list())
|
||
if state.get("thinking"):
|
||
thinking = state["thinking"][:300]
|
||
# 如果思考中提到了工具名,强调必须调用
|
||
tool_hints = [t.name for t in tools_list if t.name in thinking]
|
||
if tool_hints:
|
||
thinking += f"\n\n[重要:必须调用 {', '.join(tool_hints)} 工具来回答]"
|
||
system_content += f"\n\n[内部思考]\n{thinking}"
|
||
|
||
messages = [SystemMessage(content=system_content)]
|
||
messages.extend(state["messages"])
|
||
|
||
llm = ChatOpenAI(
|
||
base_url=llm_cfg["base_url"],
|
||
api_key=llm_cfg["api_key"],
|
||
model=llm_cfg["model"],
|
||
)
|
||
llm_with_tools = llm.bind_tools(tools_list)
|
||
resp = await llm_with_tools.ainvoke(messages)
|
||
|
||
# 迭代保护
|
||
if iteration > max_iter and hasattr(resp, "tool_calls") and resp.tool_calls:
|
||
resp = AIMessage(content=resp.content or "任务完成(已达最大迭代次数)")
|
||
|
||
return {"messages": [resp], "iteration": iteration}
|
||
|
||
return agent_node
|
||
|
||
|
||
# --- 路由函数 ---
|
||
def route_from_agent(state: AgentState) -> str:
|
||
if state.get("skill_output"):
|
||
return "end"
|
||
for msg in reversed(state["messages"]):
|
||
if isinstance(msg, AIMessage):
|
||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||
return "tools"
|
||
break
|
||
return "end"
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 构建图
|
||
# ════════════════════════════════════════════
|
||
async def build_graph(config, skills_reg, mcp_mgr, tools_list):
|
||
think_node = await make_think_node(config, skills_reg, tools_list)
|
||
skill_route_node = await make_skill_route_node(config, skills_reg, mcp_mgr)
|
||
skill_exec_node = await make_skill_exec_node(config, skills_reg, tools_list)
|
||
agent_node = await make_agent_node(config, skills_reg, tools_list)
|
||
|
||
g = StateGraph(AgentState)
|
||
|
||
g.add_node("think", think_node)
|
||
g.add_node("skill_route", skill_route_node)
|
||
g.add_node("skill_exec", skill_exec_node)
|
||
g.add_node("agent", agent_node)
|
||
g.add_node("tools", ToolNode(tools_list))
|
||
|
||
g.add_edge(START, "think")
|
||
g.add_edge("think", "skill_route")
|
||
g.add_conditional_edges("skill_route",
|
||
lambda s: "skill_exec" if s.get("active_skill") else "agent",
|
||
{"skill_exec": "skill_exec", "agent": "agent"})
|
||
g.add_edge("skill_exec", "agent")
|
||
g.add_conditional_edges("agent", route_from_agent, {"tools": "tools", "end": END})
|
||
g.add_edge("tools", "agent")
|
||
|
||
return g.compile()
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 运行入口
|
||
# ════════════════════════════════════════════
|
||
async def run_agent(user_input: str, graph):
|
||
result = await graph.ainvoke({
|
||
"messages": [HumanMessage(content=user_input)],
|
||
"thinking": "", "active_skill": None, "skill_output": None, "iteration": 0,
|
||
})
|
||
last = result["messages"][-1]
|
||
return {
|
||
"reply": last.content if hasattr(last, "content") else str(last),
|
||
"thinking": result.get("thinking", ""),
|
||
"skill": result.get("active_skill"),
|
||
}
|
||
|
||
async def interactive_mode(graph):
|
||
print("=" * 60)
|
||
print(" 黄庄三号 Agent v2.0 - 配置驱动版")
|
||
print(" FC | MCP | 思考模式 | Skill")
|
||
print("=" * 60)
|
||
print(" 技能:", [s.name for s in skills_registry.list_skills()])
|
||
print(" 工具:", [t.name for t in all_tools])
|
||
print(" 输入 quit 退出")
|
||
print("=" * 60)
|
||
|
||
while True:
|
||
try:
|
||
user_input = input("\n你> ").strip()
|
||
except (EOFError, KeyboardInterrupt):
|
||
break
|
||
if not user_input:
|
||
continue
|
||
if user_input.lower() in ("quit", "exit", "q"):
|
||
break
|
||
|
||
result = await run_agent(user_input, graph)
|
||
if result["thinking"]:
|
||
print(f"\n[思考] {result['thinking'][:150]}...")
|
||
if result["skill"]:
|
||
print(f"[技能] {result['skill']}")
|
||
print(f"\n黄庄三号> {result['reply']}")
|
||
|
||
|
||
# ════════════════════════════════════════════
|
||
# 全局变量(由 main 初始化)
|
||
# ════════════════════════════════════════════
|
||
all_tools = []
|
||
skills_registry = SkillRegistry()
|
||
mcp_manager = None
|
||
|
||
|
||
async def main():
|
||
global all_tools, skills_registry, mcp_manager
|
||
|
||
parser = argparse.ArgumentParser(description="黄庄三号 Agent v2.0")
|
||
parser.add_argument("--mcp", action="store_true", help="启用MCP")
|
||
parser.add_argument("--test", action="store_true", help="自动测试")
|
||
args = parser.parse_args()
|
||
|
||
print("=" * 60)
|
||
print(" 黄庄三号 Agent v2.0 - 配置驱动版")
|
||
print("=" * 60)
|
||
|
||
# ── 加载配置 ──
|
||
print("\n[配置] 加载 config.yaml ...")
|
||
config = load_config()
|
||
print(f" 模型: {config['llm']['model']}")
|
||
print(f" MCP服务器: {len(config.get('mcp_servers', []))} 个")
|
||
print(f" 技能关键词: {len(config.get('skill_keywords', {}))} 个")
|
||
|
||
# ── 扫描工具 ──
|
||
print("\n[工具] 扫描 tools/ 目录 ...")
|
||
all_tools = scan_tools()
|
||
print(f" 工具总数: {len(all_tools)}")
|
||
|
||
# ── 扫描技能 ──
|
||
print("\n[技能] 扫描 skills/ 目录 ...")
|
||
skills_registry = scan_skills()
|
||
print(f" 技能总数: {len(skills_registry.list_skills())}")
|
||
|
||
# ── 连接MCP ──
|
||
if args.mcp and config.get("mcp_servers"):
|
||
print("\n[MCP] 连接服务器 ...")
|
||
mcp_manager = MCPManager()
|
||
await mcp_manager.connect_all(config["mcp_servers"])
|
||
print(f" MCP工具总数: {len(mcp_manager.mcp_tools)}")
|
||
|
||
print(f"\n 全部工具总数: {len(all_tools)}")
|
||
|
||
# ── 构建图 ──
|
||
graph = await build_graph(config, skills_registry, mcp_manager, all_tools)
|
||
|
||
if args.test:
|
||
# 自动测试
|
||
tests = [
|
||
("FC+思考+Skill", "黄庄天气怎么样?"),
|
||
("FC+Skill", "算一下 99*88+77"),
|
||
("知识搜索", "MCP是什么?"),
|
||
("身份", "你好你是谁?"),
|
||
]
|
||
|
||
if args.mcp and mcp_manager:
|
||
tests.extend([
|
||
("MCP:时间", "现在几点了?"),
|
||
("MCP:字符统计", "统计'黄庄三号是AI助手'的字符数"),
|
||
("MCP:UUID", "生成一个UUID"),
|
||
])
|
||
|
||
for label, query in tests:
|
||
print(f"\n{'─'*55}")
|
||
print(f"[测试:{label}] {query}")
|
||
r = await run_agent(query, graph)
|
||
print(f" 思考: {r['thinking'][:80]}...")
|
||
print(f" 技能: {r['skill']}")
|
||
print(f" 回复: {r['reply'][:150]}...")
|
||
|
||
print(f"\n{'='*60}")
|
||
print(" 验证完成!")
|
||
caps = ["FC", "思考", "Skill"]
|
||
if args.mcp:
|
||
caps.append("MCP")
|
||
print(" " + " ✅ | ".join(caps) + " ✅")
|
||
print("=" * 60)
|
||
else:
|
||
await interactive_mode(graph)
|
||
|
||
# ── 清理 ──
|
||
if mcp_manager:
|
||
await mcp_manager.close()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|