From c15b77907ed5111b537cf58a0807d79083991347 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Tue, 20 Jan 2026 14:00:15 +0800 Subject: [PATCH] =?UTF-8?q?Prompt=E7=AE=A1=E7=90=86=E5=99=A8=E4=B8=8E?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=EF=BC=9B=E8=B0=83=E6=95=B4=E9=83=A8=E5=88=86?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pytests/prompt_test/test_prompt_manager.py | 607 +++++++++++++++++++++ src/config/config.py | 8 +- src/prompt/prompt_manager.py | 157 ++++++ 3 files changed, 768 insertions(+), 4 deletions(-) create mode 100644 pytests/prompt_test/test_prompt_manager.py create mode 100644 src/prompt/prompt_manager.py diff --git a/pytests/prompt_test/test_prompt_manager.py b/pytests/prompt_test/test_prompt_manager.py new file mode 100644 index 00000000..e4852936 --- /dev/null +++ b/pytests/prompt_test/test_prompt_manager.py @@ -0,0 +1,607 @@ +import asyncio +import sys +from collections.abc import Callable +from pathlib import Path +from typing import Optional + +import pytest + +PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve() +sys.path.insert(0, str(PROJECT_ROOT)) +sys.path.insert(0, str(PROJECT_ROOT / "src" / "config")) + +from src.prompt.prompt_manager import PromptManager + +# --- Minimal stubs / constants matching the production module --- + +# These imports/definitions are here only to make the tests self‑contained +# In the real project, they already exist in `prompt_manager.py`'s module. +# We mirror them here to control behavior via monkeypatch. + + +class Prompt: + def __init__(self, prompt_name: str, template: str, prompt_render_context: Optional[dict[str, Callable]] = None): + self.prompt_name = prompt_name + self.template = template + self.prompt_render_context = prompt_render_context or {} + + +class DummyLogger: + def __init__(self): + self.errors: list[str] = [] + self.warnings: list[str] = [] + + def error(self, msg: str) -> None: + self.errors.append(msg) + + def warning(self, msg: str) -> None: + self.warnings.append(msg) + + +# --- Fixtures to patch module-level objects in prompt_manager --- + + +@pytest.fixture +def dummy_logger(monkeypatch): + from src.prompt import prompt_manager as pm + + logger = DummyLogger() + monkeypatch.setattr(pm, "logger", logger, raising=False) + return logger + + +@pytest.fixture +def temp_prompts_dir(tmp_path, monkeypatch): + from src.prompt import prompt_manager as pm + + prompts_dir = tmp_path / "prompts" + monkeypatch.setattr(pm, "PROMPTS_DIR", prompts_dir, raising=False) + return prompts_dir + + +@pytest.fixture +def brace_constants(monkeypatch): + from src.prompt import prompt_manager as pm + + # emulate the placeholders used in the manager + monkeypatch.setattr(pm, "_LEFT_BRACE", "__LEFT__", raising=False) + monkeypatch.setattr(pm, "_RIGHT_BRACE", "__RIGHT__", raising=False) + + +@pytest.fixture +def suffix_prompt(monkeypatch): + from src.prompt import prompt_manager as pm + + monkeypatch.setattr(pm, "SUFFIX_PROMPT", ".prompt", raising=False) + + +@pytest.fixture +def manager(temp_prompts_dir, brace_constants, suffix_prompt): + # PromptManager.__init__ uses patched PROMPTS_DIR + return PromptManager() + + +# --- Helper to run async methods in tests (for non-async pytest) --- + + +def run(coro): + return asyncio.get_event_loop().run_until_complete(coro) + + +# --- add_prompt tests -------------------------------------------------------- + + +@pytest.mark.parametrize( + "existing_prompts, existing_funcs, name_to_add, need_save, expect_in_save", + [ + pytest.param( + {}, + {}, + "greeting", + False, + False, + id="add_prompt_simple_not_saved", + ), + pytest.param( + {}, + {}, + "system", + True, + True, + id="add_prompt_marked_for_save", + ), + pytest.param( + {"existing": Prompt("existing", "tmpl")}, + {}, + "new", + True, + True, + id="add_prompt_with_existing_other_prompt", + ), + ], +) +def test_add_prompt_happy_path(manager, existing_prompts, existing_funcs, name_to_add, need_save, expect_in_save): + # Arrange + + manager.prompts.update(existing_prompts) + manager._context_construct_functions.update(existing_funcs) + prompt = Prompt(name_to_add, "template") + + # Act + + manager.add_prompt(prompt, need_save=need_save) + + # Assert + + assert manager.prompts[name_to_add] is prompt + assert (name_to_add in manager._prompt_to_save) is expect_in_save + + +@pytest.mark.parametrize( + "existing_prompts, existing_funcs, new_name, conflict_type", + [ + pytest.param( + {"dup": Prompt("dup", "tmpl")}, + {}, + "dup", + "prompt_conflict", + id="add_prompt_conflict_with_existing_prompt", + ), + pytest.param( + {}, + {"dup": (lambda x: x, "mod")}, + "dup", + "func_conflict", + id="add_prompt_conflict_with_existing_context_function", + ), + ], +) +def test_add_prompt_conflict_raises_key_error(manager, existing_prompts, existing_funcs, new_name, conflict_type): + # Arrange + + manager.prompts.update(existing_prompts) + manager._context_construct_functions.update(existing_funcs) + prompt = Prompt(new_name, "template") + + # Act / Assert + + with pytest.raises(KeyError) as exc: + manager.add_prompt(prompt) + + assert new_name in str(exc.value) + + +# --- add_context_construct_function tests ----------------------------------- + + +def test_add_context_construct_function_happy_path(manager): + # Arrange + + def builder(prompt_name: str) -> str: + return f"ctx_for_{prompt_name}" + + # Act + + manager.add_context_construct_function("ctx", builder) + + # Assert + + assert "ctx" in manager._context_construct_functions + stored_func, module = manager._context_construct_functions["ctx"] + assert stored_func is builder + # module is caller's module name + assert isinstance(module, str) + assert module != "" + + +def test_add_context_construct_function_logs_unknown_module(manager, dummy_logger, monkeypatch): + # Arrange + + def builder(prompt_name: str) -> str: + return f"v_{prompt_name}" + + def fake_currentframe(): + class FakeCallerFrame: + f_globals = {"__name__": "unknown"} + + class FakeFrame: + f_back = FakeCallerFrame() + + return FakeFrame() + + from src.prompt import prompt_manager as pm + + monkeypatch.setattr(pm.inspect, "currentframe", fake_currentframe) + + # Act + manager.add_context_construct_function("unknown_ctx", builder) + + # Assert + + assert any("无法获取调用函数的模块名" in msg for msg in dummy_logger.warnings) + assert "unknown_ctx" in manager._context_construct_functions + + +@pytest.mark.parametrize( + "existing_prompts, existing_funcs, name_to_add", + [ + pytest.param( + {"p": Prompt("p", "tmpl")}, + {}, + "p", + id="add_context_construct_function_conflict_with_prompt", + ), + pytest.param( + {}, + {"f": (lambda x: x, "mod")}, + "f", + id="add_context_construct_function_conflict_with_existing_func", + ), + ], +) +def test_add_context_construct_function_conflict_raises_key_error( + manager, existing_prompts, existing_funcs, name_to_add +): + # Arrange + + manager.prompts.update(existing_prompts) + manager._context_construct_functions.update(existing_funcs) + + def func(prompt_name: str) -> str: + return "x" + + # Act / Assert + + with pytest.raises(KeyError) as exc: + manager.add_context_construct_function(name_to_add, func) + + assert name_to_add in str(exc.value) + + +def test_add_context_construct_function_no_frame_raises_runtime_error(manager, monkeypatch): + # Arrange + + from src.prompt import prompt_manager as pm + + monkeypatch.setattr(pm.inspect, "currentframe", lambda: None) + + def func(prompt_name: str) -> str: + return "x" + + # Act / Assert + + with pytest.raises(RuntimeError) as exc: + manager.add_context_construct_function("ctx", func) + + assert "无法获取调用栈" in str(exc.value) + + +def test_add_context_construct_function_no_caller_frame_raises_runtime_error(manager, monkeypatch): + # Arrange + + from src.prompt import prompt_manager as pm + + class FakeFrame: + f_back = None + + monkeypatch.setattr(pm.inspect, "currentframe", lambda: FakeFrame()) + + def func(prompt_name: str) -> str: + return "x" + + # Act / Assert + + with pytest.raises(RuntimeError) as exc: + manager.add_context_construct_function("ctx", func) + + assert "无法获取调用栈的上一级" in str(exc.value) + + +# --- get_prompt tests -------------------------------------------------------- + + +@pytest.mark.parametrize( + "existing_name, requested_name, should_raise", + [ + pytest.param("p1", "p1", False, id="get_existing_prompt"), + pytest.param("p1", "missing", True, id="get_missing_prompt_raises"), + ], +) +def test_get_prompt(manager, existing_name, requested_name, should_raise): + # Arrange + + manager.prompts[existing_name] = Prompt(existing_name, "tmpl") + + # Act / Assert + + if should_raise: + with pytest.raises(KeyError) as exc: + manager.get_prompt(requested_name) + assert requested_name in str(exc.value) + else: + prompt = manager.get_prompt(requested_name) + assert prompt.prompt_name == existing_name + + +# --- render_prompt and _render tests ---------------------------------------- + + +@pytest.mark.parametrize( + "template, prompts_setup, ctx_funcs_setup, prompt_ctx, expected", + [ + pytest.param( + "Hello {name}", + {}, + {}, + {"name": lambda p: "World"}, + "Hello World", + id="render_with_prompt_context_sync", + ), + pytest.param( + "Hello {name}", + {}, + {}, + { + "name": lambda p: asyncio.sleep(0, result=f"Async-{p}"), + }, + "Hello Async-main", + id="render_with_prompt_context_async", + ), + pytest.param( + "Outer {inner}", + { + "inner": Prompt("inner", "Inner {value}", {"value": lambda p: "42"}), + }, + {}, + {}, + "Outer Inner 42", + id="render_with_nested_prompt_reference", + ), + pytest.param( + "Module says {ext}", + {}, + { + "ext": (lambda p: f"external-{p}", "test_module"), + }, + {}, + "Module says external-main", + id="render_with_external_context_function_sync", + ), + pytest.param( + "Module async {ext}", + {}, + { + "ext": (lambda p: asyncio.sleep(0, result=f"ext_async-{p}"), "test_module"), + }, + {}, + "Module async ext_async-main", + id="render_with_external_context_function_async", + ), + pytest.param( + "Escaped {{ and }} literal plus {value}", + {}, + {}, + {"value": lambda p: "X"}, + "Escaped { and } literal plus X", + id="render_with_escaped_braces", + ), + ], +) +def test_render_prompt_happy_path( + manager, + template, + prompts_setup, + ctx_funcs_setup, + prompt_ctx, + expected, +): + # Arrange + + main_prompt = Prompt("main", template, prompt_ctx) + manager.add_prompt(main_prompt) + for name, prompt in prompts_setup.items(): + manager.add_prompt(prompt) + manager._context_construct_functions.update(ctx_funcs_setup) + + # Act + + rendered = run(manager.render_prompt(main_prompt)) + + # Assert + + assert rendered == expected + + +def test_render_prompt_missing_field_raises_key_error(manager): + # Arrange + + prompt = Prompt("main", "Hello {missing}") + manager.add_prompt(prompt) + + # Act / Assert + + with pytest.raises(KeyError) as exc: + run(manager.render_prompt(prompt)) + + assert "缺少必要的内容块或构建函数" in str(exc.value) + assert "missing" in str(exc.value) + + +def test_render_prompt_recursion_limit_exceeded(manager): + # Arrange + + # Create mutual recursion between two prompts + p1 = Prompt("p1", "P1 uses {p2}") + p2 = Prompt("p2", "P2 uses {p1}") + manager.add_prompt(p1) + manager.add_prompt(p2) + + # Act / Assert + + with pytest.raises(RecursionError): + run(manager.render_prompt(p1)) + + +# --- _get_function_result tests --------------------------------------------- + + +@pytest.mark.parametrize( + "func, is_prompt_context, expect_async", + [ + pytest.param( + lambda p: f"sync_{p}", + True, + False, + id="get_function_result_sync_prompt_context", + ), + pytest.param( + lambda p: asyncio.sleep(0, result=f"async_{p}"), + False, + True, + id="get_function_result_async_external_context", + ), + ], +) +def test_get_function_result_happy_path(manager, dummy_logger, func, is_prompt_context, expect_async): + # Act + + res = run( + manager._get_function_result( + func=func, + prompt_name="prompt", + field_name="f", + is_prompt_context=is_prompt_context, + module="mod", + ) + ) + + # Assert + + assert res in {"sync_prompt", "async_prompt"} + + +@pytest.mark.parametrize( + "is_prompt_context, expected_message_part", + [ + pytest.param(True, "内部上下文构造函数", id="get_function_result_error_prompt_context_logs_internal_msg"), + pytest.param(False, "上下文构造函数", id="get_function_result_error_external_logs_external_msg"), + ], +) +def test_get_function_result_error_logging(manager, dummy_logger, is_prompt_context, expected_message_part): + # Arrange + + def bad_func(prompt_name: str) -> str: + raise ValueError("bad") + + # Act / Assert + + with pytest.raises(ValueError): + run( + manager._get_function_result( + func=bad_func, + prompt_name="promptX", + field_name="fieldX", + is_prompt_context=is_prompt_context, + module="modX", + ) + ) + + # Assert + + assert any(expected_message_part in msg for msg in dummy_logger.errors) + assert any("promptX" in msg for msg in dummy_logger.errors) ^ (not is_prompt_context) + assert any("modX" in msg for msg in dummy_logger.errors) ^ is_prompt_context + assert any("fieldX" in msg for msg in dummy_logger.errors) + + +# --- save_prompts tests ------------------------------------------------------ + + +def test_save_prompts_happy_path(manager, temp_prompts_dir): + # Arrange + + p1 = Prompt("p1", "Hello {{name}}") + p2 = Prompt("p2", "Bye {{value}}") + manager.add_prompt(p1, need_save=True) + manager.add_prompt(p2, need_save=True) + + # Act + + manager.save_prompts() + + # Assert + + files = sorted(temp_prompts_dir.glob("*.prompt")) + assert len(files) == 2 + contents = {f.stem: f.read_text(encoding="utf-8") for f in files} + assert contents["p1"] == "Hello {{name}}" + assert contents["p2"] == "Bye {{value}}" + + +def test_save_prompts_io_error(manager, temp_prompts_dir, dummy_logger, monkeypatch): + # Arrange + + prompt = Prompt("p1", "Hi") + manager.add_prompt(prompt, need_save=True) + + def bad_open(*args, **kwargs): + raise OSError("disk full") + + monkeypatch.setattr("builtins.open", bad_open) + + # Act / Assert + + with pytest.raises(OSError): + manager.save_prompts() + + # Assert + + assert any("保存 Prompt 'p1' 时出错" in msg for msg in dummy_logger.errors) + + +# --- load_prompts tests ------------------------------------------------------ + + +def test_load_prompts_happy_path(manager, temp_prompts_dir): + # Arrange + + file1 = temp_prompts_dir / "greet.prompt" + file2 = temp_prompts_dir / "farewell.prompt" + temp_prompts_dir.mkdir(parents=True, exist_ok=True) + file1.write_text("Hello {{name}}", encoding="utf-8") + file2.write_text("Bye {{name}}", encoding="utf-8") + + # Act + + manager.load_prompts() + + # Assert + + assert "greet" in manager.prompts + assert "farewell" in manager.prompts + assert "greet" in manager._prompt_to_save + assert "farewell" in manager._prompt_to_save + assert manager.prompts["greet"].template == "Hello {{name}}" + assert manager.prompts["farewell"].template == "Bye {{name}}" + + +def test_load_prompts_error(manager, temp_prompts_dir, dummy_logger, monkeypatch): + # Arrange + + file1 = temp_prompts_dir / "broken.prompt" + temp_prompts_dir.mkdir(parents=True, exist_ok=True) + file1.write_text("whatever", encoding="utf-8") + + def bad_open(*args, **kwargs): + raise OSError("cannot read") + + monkeypatch.setattr("builtins.open", bad_open) + + # Act / Assert + + with pytest.raises(OSError): + manager.load_prompts() + + # Assert + + assert any("加载 Prompt 文件" in msg for msg in dummy_logger.errors) diff --git a/src/config/config.py b/src/config/config.py index 24f58bcb..0a519beb 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -173,17 +173,17 @@ class ConfigManager: def initialize(self): logger.info(f"MaiCore当前版本: {MMC_VERSION}") logger.info("正在品鉴配置文件...") - self.global_config: Config = self._load_global_config() - self.model_config: ModelConfig = self._load_model_config() + self.global_config: Config = self.load_global_config() + self.model_config: ModelConfig = self.load_model_config() logger.info("非常的新鲜,非常的美味!") - def _load_global_config(self) -> Config: + def load_global_config(self) -> Config: config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION) if updated: sys.exit(0) # 先直接退出 return config - def _load_model_config(self) -> ModelConfig: + def load_model_config(self) -> ModelConfig: config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True) if updated: sys.exit(0) # 先直接退出 diff --git a/src/prompt/prompt_manager.py b/src/prompt/prompt_manager.py new file mode 100644 index 00000000..cef2114e --- /dev/null +++ b/src/prompt/prompt_manager.py @@ -0,0 +1,157 @@ +from collections.abc import Callable, Coroutine +from typing import Any, Optional +from string import Formatter +from pathlib import Path +import inspect + +from src.common.logger import get_logger + + +logger = get_logger("Prompt") + +_LEFT_BRACE = "\ufde9" +_RIGHT_BRACE = "\ufdea" + +PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve() +PROMPTS_DIR = PROJECT_ROOT / "prompts" +PROMPTS_DIR.mkdir(parents=True, exist_ok=True) +SUFFIX_PROMPT = ".prompt" + + +class Prompt: + prompt_name: str + template: str + prompt_render_context: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {} + + def __init__(self, prompt_name: str, template: str) -> None: + self.prompt_name = prompt_name + self.template = template + self.__post_init__() + + def add_context(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None: + if name in self.prompt_render_context: + raise KeyError(f"Context function name '{name}' 已存在于 Prompt '{self.prompt_name}' 中") + self.prompt_render_context[name] = func + + def __post_init__(self): + if not self.prompt_name: + raise ValueError("prompt_name 不能为空") + if not self.template: + raise ValueError("template 不能为空") + tmp = self.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE) + if "{}" in tmp: + raise ValueError(r"模板中不允许使用未命名的占位符 '{}'") + + +class PromptManager: + def __init__(self) -> None: + PROMPTS_DIR.mkdir(parents=True, exist_ok=True) # 确保提示词目录存在 + self.prompts: dict[str, Prompt] = {} + """存储 Prompt 实例""" + self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {} + """存储上下文构造函数及其所属模块""" + self._formatter = Formatter() # 仅用来解析模板 + """模板解析器""" + self._prompt_to_save: set[str] = set() + """需要保存的 Prompt 名称集合""" + + def add_prompt(self, prompt: Prompt, need_save: bool = False) -> None: + if prompt.prompt_name in self.prompts or prompt.prompt_name in self._context_construct_functions: + raise KeyError(f"Prompt name '{prompt.prompt_name}' 已存在") + self.prompts[prompt.prompt_name] = prompt + if need_save: + self._prompt_to_save.add(prompt.prompt_name) + + def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None: + if name in self._context_construct_functions or name in self.prompts: + raise KeyError(f"Construct function name '{name}' 已存在") + # 获取调用栈 + frame = inspect.currentframe() + if not frame: + # 不应该出现的情况 + raise RuntimeError("无法获取调用栈") + caller_frame = frame.f_back + if not caller_frame: + # 不应该出现的情况 + raise RuntimeError("无法获取调用栈的上一级") + caller_module = caller_frame.f_globals.get("__name__", "unknown") + if caller_module == "unknown": + logger.warning("无法获取调用函数的模块名,使用 'unknown' 作为默认值") + + self._context_construct_functions[name] = func, caller_module + + def get_prompt(self, prompt_name: str) -> Prompt: + if prompt_name not in self.prompts: + raise KeyError(f"Prompt name '{prompt_name}' 不存在") + return self.prompts[prompt_name] + + async def render_prompt(self, prompt: Prompt) -> str: + return await self._render(prompt) + + async def _render(self, prompt: Prompt, recursive_level: int = 0) -> str: + prompt.template = prompt.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE) + if recursive_level > 10: + raise RecursionError("递归层级过深,可能存在循环引用") + field_block = {field_name for _, field_name, _, _ in self._formatter.parse(prompt.template) if field_name} + rendered_fields: dict[str, str] = {} + for field_name in field_block: + if field_name in self.prompts: + rendered_fields[field_name] = await self._render(self.prompts[field_name], recursive_level + 1) + elif field_name in prompt.prompt_render_context: + func = prompt.prompt_render_context[field_name] + rendered_fields[field_name] = await self._get_function_result( + func, prompt.prompt_name, field_name, is_prompt_context=True + ) + elif field_name in self._context_construct_functions: + func, module = self._context_construct_functions[field_name] + rendered_fields[field_name] = await self._get_function_result( + func, prompt.prompt_name, field_name, is_prompt_context=False, module=module + ) + else: + raise KeyError(f"Prompt '{prompt.prompt_name}' 中缺少必要的内容块或构建函数: '{field_name}'") + rendered_template = prompt.template.format(**rendered_fields) + return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}") + + def save_prompts(self) -> None: + for prompt_name in self._prompt_to_save: + prompt = self.prompts[prompt_name] + file_path = PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}" + try: + with open(file_path, "w", encoding="utf-8") as f: + f.write(prompt.template) + except Exception as e: + logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {e}") + raise e + + def load_prompts(self) -> None: + for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): + try: + with open(prompt_file, "r", encoding="utf-8") as f: + template = f.read() + self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True) + except Exception as e: + logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") + raise e + + async def _get_function_result( + self, + func: Callable[[str], str | Coroutine[Any, Any, str]], + prompt_name: str, + field_name: str, + is_prompt_context: bool, + module: Optional[str] = None, + ) -> str: + try: + res = func(prompt_name) + if isinstance(res, Coroutine): + res = await res + return res + except Exception as e: + if is_prompt_context: + logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {e}") + else: + logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {e}") + raise e + + +prompt_manager = PromptManager()