diff --git a/agent_core/tool_registry.py b/agent_core/tool_registry.py index 270737f..89fe15a 100644 --- a/agent_core/tool_registry.py +++ b/agent_core/tool_registry.py @@ -1,40 +1,70 @@ +from collections.abc import Callable + from .tools.builtin_tools import BUILTIN_TOOLS -def run_declared_tools(tool_names: list[str], user_input: str) -> list[dict]: - results = [] - for tool_name in tool_names: - tool = BUILTIN_TOOLS.get(tool_name) +class ToolRegistry: + """ + 统一管理工具注册、查询和执行。 + + 设计目标: + - 让 Orchestrator 只关心“声明了哪些工具”,不关心工具如何存放。 + - 固化统一的工具调用结果结构,便于页面展示和审计日志保存。 + - 后续新增业务工具时,只需要注册函数,不必改调用协议。 + """ + + def __init__(self, initial_tools: dict[str, Callable] | None = None): + self._tools: dict[str, Callable] = dict(initial_tools or {}) + + def register(self, tool_name: str, tool_func: Callable) -> None: + """注册一个可通过名称调用的工具函数。""" + self._tools[tool_name] = tool_func + + def get(self, tool_name: str) -> Callable | None: + """按名称返回工具函数;未注册时返回 None。""" + return self._tools.get(tool_name) + + def run(self, tool_name: str, **kwargs) -> dict: + """ + 执行单个工具,并返回统一结果结构。 + + 统一返回值是审计日志、页面展示和后续 Agent 编排共享的协议。 + 即使工具不存在或执行失败,也返回可消费的失败结果,而不是抛异常。 + """ + tool = self.get(tool_name) if tool is None: - results.append( - { - "tool_name": tool_name, - "success": False, - "arguments": {"user_input": user_input}, - "result": {}, - "error": "工具未注册", - } - ) - continue + return { + "tool_name": tool_name, + "success": False, + "arguments": kwargs, + "result": {}, + "error": "工具未注册", + } try: - result = tool(user_input=user_input) - results.append( - { - "tool_name": tool_name, - "success": True, - "arguments": {"user_input": user_input}, - "result": result, - "error": "", - } - ) + return { + "tool_name": tool_name, + "success": True, + "arguments": kwargs, + "result": tool(**kwargs), + "error": "", + } except Exception as exc: - results.append( - { - "tool_name": tool_name, - "success": False, - "arguments": {"user_input": user_input}, - "result": {}, - "error": str(exc), - } - ) - return results + return { + "tool_name": tool_name, + "success": False, + "arguments": kwargs, + "result": {}, + "error": str(exc), + } + + +# 默认注册表承载项目内置工具,便于当前 V1 直接复用。 +DEFAULT_TOOL_REGISTRY = ToolRegistry(BUILTIN_TOOLS) + + +def run_declared_tools(tool_names: list[str], user_input: str) -> list[dict]: + """按场景声明顺序执行工具,保证结果顺序与配置顺序一致。""" + return [ + DEFAULT_TOOL_REGISTRY.run(tool_name, user_input=user_input) + for tool_name in tool_names + ] diff --git a/agent_core/tools/builtin_tools.py b/agent_core/tools/builtin_tools.py index 6c7464d..038bab3 100644 --- a/agent_core/tools/builtin_tools.py +++ b/agent_core/tools/builtin_tools.py @@ -1,8 +1,47 @@ +import re + + def calculate_rate(user_input: str) -> dict: - return {"rate": 1.0, "note": "模拟比例计算结果"} + """ + 从自然语言中提取两个数值并计算比例。 + + V1 目标不是构建复杂公式引擎,而是提供一个可演示的“业务工具”示例: + 只要输入中出现两个数字,就将其解释为“已完成值 / 总数”。 + """ + numbers = [float(item) for item in re.findall(r"\d+(?:\.\d+)?", user_input)] + if len(numbers) < 2: + return { + "success": False, + "rate": 0.0, + "numerator": 0.0, + "denominator": 0.0, + "note": "未能从输入中提取两个数字,无法计算比例。", + } + numerator, denominator = numbers[0], numbers[1] + if denominator == 0: + return { + "success": False, + "rate": 0.0, + "numerator": numerator, + "denominator": denominator, + "note": "分母为 0,无法计算比例。", + } + return { + "success": True, + "numerator": numerator, + "denominator": denominator, + "rate": round(numerator / denominator, 4), + "note": "已按输入中的前两个数字完成比例计算。", + } def query_demo_records(user_input: str) -> dict: + """ + 查询示例业务记录。 + + 该工具依赖 Audit 模块中的 DemoBusinessRecord 演示表,用于证明 + “场景 + 结构化数据 + 工具调用”可以组成更可信的业务 Agent。 + """ try: from apps.audit.models import DemoBusinessRecord except Exception as exc: @@ -32,11 +71,49 @@ def query_demo_records(user_input: str) -> dict: def check_required_fields(user_input: str) -> dict: - return {"missing_fields": [], "note": "模拟必填项检查结果"} + """ + 检查输入中声明的必填项是否全部出现。 + + 约定格式示例: + “请检查必填项:合同编号、供应商、金额。当前只提供了合同编号和金额。” + """ + required_match = re.search(r"必填项[::](.+?)(?:。|\.)", user_input) + provided_match = re.search(r"(?:当前|已|仅)?提供了(.+?)(?:。|\.)", user_input) + required_fields = _split_cn_items(required_match.group(1) if required_match else "") + provided_fields = set(_split_cn_items(provided_match.group(1) if provided_match else "")) + missing_fields = [field for field in required_fields if field not in provided_fields] + return { + "required_fields": required_fields, + "provided_fields": list(provided_fields), + "missing_fields": missing_fields, + "note": "已根据输入中的“必填项/提供了”描述完成检查。", + } def generate_action_items(user_input: str) -> dict: - return {"items": [f"围绕问题继续核实:{user_input}"]} + """ + 生成最小可执行行动项。 + + 该工具主要用于演示“模型回答之外,还可以得到结构化待办建议”。 + """ + return { + "items": [ + "先确认问题背景和适用场景。", + f"围绕当前问题继续核实:{user_input}", + "根据知识库和审计结果安排下一步处理动作。", + ] + } + + +def _split_cn_items(raw_text: str) -> list[str]: + """将中文顿号、逗号和连接词分隔的字段串切分为列表。""" + normalized = ( + raw_text.replace("和", "、") + .replace("以及", "、") + .replace(",", "、") + .replace(",", "、") + ) + return [item.strip(" 。.") for item in normalized.split("、") if item.strip(" 。.")] BUILTIN_TOOLS = { diff --git a/tests/test_tool_registry.py b/tests/test_tool_registry.py new file mode 100644 index 0000000..7acd436 --- /dev/null +++ b/tests/test_tool_registry.py @@ -0,0 +1,54 @@ +from agent_core.tool_registry import ToolRegistry, run_declared_tools +from agent_core.tools.builtin_tools import calculate_rate, check_required_fields + + +def test_tool_registry_register_get_and_run(): + registry = ToolRegistry() + + def hello_tool(user_input: str) -> dict: + return {"echo": user_input} + + registry.register("hello", hello_tool) + + assert registry.get("hello") is hello_tool + assert registry.run("hello", user_input="demo") == { + "tool_name": "hello", + "success": True, + "arguments": {"user_input": "demo"}, + "result": {"echo": "demo"}, + "error": "", + } + + +def test_tool_registry_returns_failed_result_for_missing_tool(): + registry = ToolRegistry() + + result = registry.run("missing", user_input="demo") + + assert result["tool_name"] == "missing" + assert result["success"] is False + assert result["error"] == "工具未注册" + + +def test_run_declared_tools_executes_multiple_tools_in_order(): + results = run_declared_tools(["generate_action_items", "missing_tool"], "请生成行动项") + + assert [item["tool_name"] for item in results] == ["generate_action_items", "missing_tool"] + assert results[0]["success"] is True + assert results[1]["success"] is False + + +def test_calculate_rate_extracts_fraction_like_numbers(): + result = calculate_rate("产线合格率,已完成 18 件,总数 24 件") + + assert result["success"] is True + assert result["numerator"] == 18.0 + assert result["denominator"] == 24.0 + assert result["rate"] == 0.75 + + +def test_check_required_fields_reports_missing_fields(): + result = check_required_fields("请检查必填项:合同编号、供应商、金额。当前只提供了合同编号和金额。") + + assert "供应商" in result["missing_fields"] + assert "合同编号" not in result["missing_fields"]