feat(tools): 增强工具注册表与内置工具能力
This commit is contained in:
@@ -1,40 +1,70 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from .tools.builtin_tools import BUILTIN_TOOLS
|
from .tools.builtin_tools import BUILTIN_TOOLS
|
||||||
|
|
||||||
|
|
||||||
def run_declared_tools(tool_names: list[str], user_input: str) -> list[dict]:
|
class ToolRegistry:
|
||||||
results = []
|
"""
|
||||||
for tool_name in tool_names:
|
统一管理工具注册、查询和执行。
|
||||||
tool = BUILTIN_TOOLS.get(tool_name)
|
|
||||||
|
设计目标:
|
||||||
|
- 让 Orchestrator 只关心“声明了哪些工具”,不关心工具如何存放。
|
||||||
|
- 固化统一的工具调用结果结构,便于页面展示和审计日志保存。
|
||||||
|
- 后续新增业务工具时,只需要注册函数,不必改调用协议。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, initial_tools: dict[str, Callable] | None = None):
|
||||||
|
self._tools: dict[str, Callable] = dict(initial_tools or {})
|
||||||
|
|
||||||
|
def register(self, tool_name: str, tool_func: Callable) -> None:
|
||||||
|
"""注册一个可通过名称调用的工具函数。"""
|
||||||
|
self._tools[tool_name] = tool_func
|
||||||
|
|
||||||
|
def get(self, tool_name: str) -> Callable | None:
|
||||||
|
"""按名称返回工具函数;未注册时返回 None。"""
|
||||||
|
return self._tools.get(tool_name)
|
||||||
|
|
||||||
|
def run(self, tool_name: str, **kwargs) -> dict:
|
||||||
|
"""
|
||||||
|
执行单个工具,并返回统一结果结构。
|
||||||
|
|
||||||
|
统一返回值是审计日志、页面展示和后续 Agent 编排共享的协议。
|
||||||
|
即使工具不存在或执行失败,也返回可消费的失败结果,而不是抛异常。
|
||||||
|
"""
|
||||||
|
tool = self.get(tool_name)
|
||||||
if tool is None:
|
if tool is None:
|
||||||
results.append(
|
return {
|
||||||
{
|
"tool_name": tool_name,
|
||||||
"tool_name": tool_name,
|
"success": False,
|
||||||
"success": False,
|
"arguments": kwargs,
|
||||||
"arguments": {"user_input": user_input},
|
"result": {},
|
||||||
"result": {},
|
"error": "工具未注册",
|
||||||
"error": "工具未注册",
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
try:
|
try:
|
||||||
result = tool(user_input=user_input)
|
return {
|
||||||
results.append(
|
"tool_name": tool_name,
|
||||||
{
|
"success": True,
|
||||||
"tool_name": tool_name,
|
"arguments": kwargs,
|
||||||
"success": True,
|
"result": tool(**kwargs),
|
||||||
"arguments": {"user_input": user_input},
|
"error": "",
|
||||||
"result": result,
|
}
|
||||||
"error": "",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
results.append(
|
return {
|
||||||
{
|
"tool_name": tool_name,
|
||||||
"tool_name": tool_name,
|
"success": False,
|
||||||
"success": False,
|
"arguments": kwargs,
|
||||||
"arguments": {"user_input": user_input},
|
"result": {},
|
||||||
"result": {},
|
"error": str(exc),
|
||||||
"error": str(exc),
|
}
|
||||||
}
|
|
||||||
)
|
|
||||||
return results
|
# 默认注册表承载项目内置工具,便于当前 V1 直接复用。
|
||||||
|
DEFAULT_TOOL_REGISTRY = ToolRegistry(BUILTIN_TOOLS)
|
||||||
|
|
||||||
|
|
||||||
|
def run_declared_tools(tool_names: list[str], user_input: str) -> list[dict]:
|
||||||
|
"""按场景声明顺序执行工具,保证结果顺序与配置顺序一致。"""
|
||||||
|
return [
|
||||||
|
DEFAULT_TOOL_REGISTRY.run(tool_name, user_input=user_input)
|
||||||
|
for tool_name in tool_names
|
||||||
|
]
|
||||||
|
|||||||
@@ -1,8 +1,47 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
def calculate_rate(user_input: str) -> dict:
|
def calculate_rate(user_input: str) -> dict:
|
||||||
return {"rate": 1.0, "note": "模拟比例计算结果"}
|
"""
|
||||||
|
从自然语言中提取两个数值并计算比例。
|
||||||
|
|
||||||
|
V1 目标不是构建复杂公式引擎,而是提供一个可演示的“业务工具”示例:
|
||||||
|
只要输入中出现两个数字,就将其解释为“已完成值 / 总数”。
|
||||||
|
"""
|
||||||
|
numbers = [float(item) for item in re.findall(r"\d+(?:\.\d+)?", user_input)]
|
||||||
|
if len(numbers) < 2:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"rate": 0.0,
|
||||||
|
"numerator": 0.0,
|
||||||
|
"denominator": 0.0,
|
||||||
|
"note": "未能从输入中提取两个数字,无法计算比例。",
|
||||||
|
}
|
||||||
|
numerator, denominator = numbers[0], numbers[1]
|
||||||
|
if denominator == 0:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"rate": 0.0,
|
||||||
|
"numerator": numerator,
|
||||||
|
"denominator": denominator,
|
||||||
|
"note": "分母为 0,无法计算比例。",
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"numerator": numerator,
|
||||||
|
"denominator": denominator,
|
||||||
|
"rate": round(numerator / denominator, 4),
|
||||||
|
"note": "已按输入中的前两个数字完成比例计算。",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def query_demo_records(user_input: str) -> dict:
|
def query_demo_records(user_input: str) -> dict:
|
||||||
|
"""
|
||||||
|
查询示例业务记录。
|
||||||
|
|
||||||
|
该工具依赖 Audit 模块中的 DemoBusinessRecord 演示表,用于证明
|
||||||
|
“场景 + 结构化数据 + 工具调用”可以组成更可信的业务 Agent。
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
from apps.audit.models import DemoBusinessRecord
|
from apps.audit.models import DemoBusinessRecord
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -32,11 +71,49 @@ def query_demo_records(user_input: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def check_required_fields(user_input: str) -> dict:
|
def check_required_fields(user_input: str) -> dict:
|
||||||
return {"missing_fields": [], "note": "模拟必填项检查结果"}
|
"""
|
||||||
|
检查输入中声明的必填项是否全部出现。
|
||||||
|
|
||||||
|
约定格式示例:
|
||||||
|
“请检查必填项:合同编号、供应商、金额。当前只提供了合同编号和金额。”
|
||||||
|
"""
|
||||||
|
required_match = re.search(r"必填项[::](.+?)(?:。|\.)", user_input)
|
||||||
|
provided_match = re.search(r"(?:当前|已|仅)?提供了(.+?)(?:。|\.)", user_input)
|
||||||
|
required_fields = _split_cn_items(required_match.group(1) if required_match else "")
|
||||||
|
provided_fields = set(_split_cn_items(provided_match.group(1) if provided_match else ""))
|
||||||
|
missing_fields = [field for field in required_fields if field not in provided_fields]
|
||||||
|
return {
|
||||||
|
"required_fields": required_fields,
|
||||||
|
"provided_fields": list(provided_fields),
|
||||||
|
"missing_fields": missing_fields,
|
||||||
|
"note": "已根据输入中的“必填项/提供了”描述完成检查。",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def generate_action_items(user_input: str) -> dict:
|
def generate_action_items(user_input: str) -> dict:
|
||||||
return {"items": [f"围绕问题继续核实:{user_input}"]}
|
"""
|
||||||
|
生成最小可执行行动项。
|
||||||
|
|
||||||
|
该工具主要用于演示“模型回答之外,还可以得到结构化待办建议”。
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"items": [
|
||||||
|
"先确认问题背景和适用场景。",
|
||||||
|
f"围绕当前问题继续核实:{user_input}",
|
||||||
|
"根据知识库和审计结果安排下一步处理动作。",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _split_cn_items(raw_text: str) -> list[str]:
|
||||||
|
"""将中文顿号、逗号和连接词分隔的字段串切分为列表。"""
|
||||||
|
normalized = (
|
||||||
|
raw_text.replace("和", "、")
|
||||||
|
.replace("以及", "、")
|
||||||
|
.replace(",", "、")
|
||||||
|
.replace(",", "、")
|
||||||
|
)
|
||||||
|
return [item.strip(" 。.") for item in normalized.split("、") if item.strip(" 。.")]
|
||||||
|
|
||||||
|
|
||||||
BUILTIN_TOOLS = {
|
BUILTIN_TOOLS = {
|
||||||
|
|||||||
54
tests/test_tool_registry.py
Normal file
54
tests/test_tool_registry.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from agent_core.tool_registry import ToolRegistry, run_declared_tools
|
||||||
|
from agent_core.tools.builtin_tools import calculate_rate, check_required_fields
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_registry_register_get_and_run():
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
def hello_tool(user_input: str) -> dict:
|
||||||
|
return {"echo": user_input}
|
||||||
|
|
||||||
|
registry.register("hello", hello_tool)
|
||||||
|
|
||||||
|
assert registry.get("hello") is hello_tool
|
||||||
|
assert registry.run("hello", user_input="demo") == {
|
||||||
|
"tool_name": "hello",
|
||||||
|
"success": True,
|
||||||
|
"arguments": {"user_input": "demo"},
|
||||||
|
"result": {"echo": "demo"},
|
||||||
|
"error": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_registry_returns_failed_result_for_missing_tool():
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
result = registry.run("missing", user_input="demo")
|
||||||
|
|
||||||
|
assert result["tool_name"] == "missing"
|
||||||
|
assert result["success"] is False
|
||||||
|
assert result["error"] == "工具未注册"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_declared_tools_executes_multiple_tools_in_order():
|
||||||
|
results = run_declared_tools(["generate_action_items", "missing_tool"], "请生成行动项")
|
||||||
|
|
||||||
|
assert [item["tool_name"] for item in results] == ["generate_action_items", "missing_tool"]
|
||||||
|
assert results[0]["success"] is True
|
||||||
|
assert results[1]["success"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_calculate_rate_extracts_fraction_like_numbers():
|
||||||
|
result = calculate_rate("产线合格率,已完成 18 件,总数 24 件")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["numerator"] == 18.0
|
||||||
|
assert result["denominator"] == 24.0
|
||||||
|
assert result["rate"] == 0.75
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_required_fields_reports_missing_fields():
|
||||||
|
result = check_required_fields("请检查必填项:合同编号、供应商、金额。当前只提供了合同编号和金额。")
|
||||||
|
|
||||||
|
assert "供应商" in result["missing_fields"]
|
||||||
|
assert "合同编号" not in result["missing_fields"]
|
||||||
Reference in New Issue
Block a user