MaiBot/pytests/prompt_test/test_prompt_manager.py

608 lines
17 KiB
Python

# File: tests/test_prompt_manager.py
import asyncio
import inspect
from pathlib import Path
from typing import Any
import sys
import pytest
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prompt_manager # noqa
@pytest.mark.parametrize(
"prompt_name, template",
[
pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
pytest.param("no-fields", "Just a static template", id="template-without-fields"),
pytest.param("brace-escaping", "Use {{ and }} around {field}", id="template-with-escaped-braces"),
],
)
def test_prompt_init_happy_paths(prompt_name: str, template: str):
# Act
prompt = Prompt(prompt_name=prompt_name, template=template)
# Assert
assert prompt.prompt_name == prompt_name
assert prompt.template == template
@pytest.mark.parametrize(
"prompt_name, template, expected_exception, expected_msg_substring",
[
pytest.param("", "Hello {name}", ValueError, "prompt_name 不能为空", id="empty-prompt-name"),
pytest.param("valid-name", "", ValueError, "template 不能为空", id="empty-template"),
pytest.param(
"unnamed-placeholder",
"Hello {}",
ValueError,
"模板中不允许使用未命名的占位符",
id="unnamed-placeholder-not-allowed",
),
pytest.param(
"unnamed-placeholder-with-escaped-brace",
"Value {{}} and {}",
ValueError,
"模板中不允许使用未命名的占位符",
id="unnamed-placeholder-mixed-with-escaped",
),
],
)
def test_prompt_init_error_cases(prompt_name, template, expected_exception, expected_msg_substring):
# Act / Assert
with pytest.raises(expected_exception) as exc_info:
Prompt(prompt_name=prompt_name, template=template)
# Assert
assert expected_msg_substring in str(exc_info.value)
@pytest.mark.parametrize(
"initial_context, name, func, expected_value, expected_exception, expected_msg_substring, case_id",
[
(
{},
"const_str",
"constant",
"constant",
None,
None,
"add-context-from-string-creates-wrapper",
),
(
{},
"callable_str",
lambda prompt_name: f"hello-{prompt_name}",
"hello-my_prompt",
None,
None,
"add-context-from-callable",
),
(
{"dup": lambda _: "x"},
"dup",
"y",
None,
KeyError,
"Context function name 'dup' 已存在于 Prompt 'my_prompt'",
"add-context-duplicate-key-error",
),
],
)
def test_prompt_add_context(
initial_context,
name,
func,
expected_value,
expected_exception,
expected_msg_substring,
case_id,
):
# Arrange
prompt = Prompt(prompt_name="my_prompt", template="template")
prompt.prompt_render_context = dict(initial_context)
# Act
if expected_exception:
with pytest.raises(expected_exception) as exc_info:
prompt.add_context(name, func)
# Assert
assert expected_msg_substring in str(exc_info.value)
else:
prompt.add_context(name, func)
# Assert
assert name in prompt.prompt_render_context
result = prompt.prompt_render_context[name]("my_prompt")
assert result == expected_value
def test_prompt_manager_add_prompt_happy_and_error():
# Arrange
manager = PromptManager()
prompt1 = Prompt(prompt_name="p1", template="T1")
manager.add_prompt(prompt1, need_save=True)
# Act
prompt2 = Prompt(prompt_name="p2", template="T2")
manager.add_prompt(prompt2, need_save=False)
# Assert
assert "p1" in manager._prompt_to_save
assert "p2" not in manager._prompt_to_save
# Arrange
prompt_dup = Prompt(prompt_name="p1", template="T-dup")
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.add_prompt(prompt_dup)
# Assert
assert "Prompt name 'p1' 已存在" in str(exc_info.value)
def test_prompt_manager_get_prompt_is_copy():
# Arrange
manager = PromptManager()
prompt = Prompt(prompt_name="original", template="T")
manager.add_prompt(prompt)
# Act
retrieved_prompt = manager.get_prompt("original")
# Assert
assert retrieved_prompt is not prompt
assert retrieved_prompt.prompt_name == prompt.prompt_name
assert retrieved_prompt.template == prompt.template
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
def test_prompt_manager_add_prompt_conflict_with_context_name():
# Arrange
manager = PromptManager()
manager.add_context_construct_function("ctx_name", lambda _: "value")
prompt_conflict = Prompt(prompt_name="ctx_name", template="T")
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.add_prompt(prompt_conflict)
# Assert
assert "Prompt name 'ctx_name' 已存在" in str(exc_info.value)
def test_prompt_manager_add_context_construct_function_happy():
# Arrange
manager = PromptManager()
def ctx_func(prompt_name: str) -> str:
return f"ctx-{prompt_name}"
# Act
manager.add_context_construct_function("ctx", ctx_func)
# Assert
assert "ctx" in manager._context_construct_functions
stored_func, module = manager._context_construct_functions["ctx"]
assert stored_func is ctx_func
assert module == __name__
def test_prompt_manager_add_context_construct_function_duplicate():
# Arrange
manager = PromptManager()
def f(_):
return "x"
manager.add_context_construct_function("dup", f)
manager.add_prompt(Prompt(prompt_name="dup_prompt", template="T"))
# Act / Assert
with pytest.raises(KeyError) as exc_info1:
manager.add_context_construct_function("dup", f)
# Assert
assert "Construct function name 'dup' 已存在" in str(exc_info1.value)
# Act / Assert
with pytest.raises(KeyError) as exc_info2:
manager.add_context_construct_function("dup_prompt", f)
# Assert
assert "Construct function name 'dup_prompt' 已存在" in str(exc_info2.value)
def test_prompt_manager_get_prompt_not_exist():
# Arrange
manager = PromptManager()
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.get_prompt("no_such_prompt")
# Assert
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
@pytest.mark.parametrize(
"template, inner_context, global_context, expected, case_id",
[
pytest.param(
"Hello {name}",
{"name": lambda p: f"name-for-{p}"},
{},
"Hello name-for-main",
"render-with-inner-context",
),
pytest.param(
"Global {block}",
{},
{"block": lambda p: f"block-{p}"},
"Global block-main",
"render-with-global-context",
),
pytest.param(
"Mix {inner} and {global}",
{"inner": lambda p: f"inner-{p}"},
{"global": lambda p: f"global-{p}"},
"Mix inner-main and global-main",
"render-with-inner-and-global-context",
),
pytest.param(
"Escaped {{ and }} and {field}",
{"field": lambda _: "X"},
{},
"Escaped { and } and X",
"render-with-escaped-braces",
),
],
)
@pytest.mark.asyncio
async def test_prompt_manager_render_contexts(template, inner_context, global_context, expected, case_id):
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template=template)
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
for name, fn in inner_context.items():
prompt.add_context(name, fn)
for name, fn in global_context.items():
manager.add_context_construct_function(name, fn)
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == expected
@pytest.mark.asyncio
async def test_prompt_manager_render_nested_prompts():
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p1", template="P1-{x}")
p2 = Prompt(prompt_name="p2", template="P2-{p1}")
p3_tmp = Prompt(prompt_name="p3", template="{p2}-end")
manager.add_prompt(p1)
manager.add_prompt(p2)
manager.add_prompt(p3_tmp)
p3 = manager.get_prompt("p3")
p3.add_context("x", lambda _: "X")
# Act
rendered = await manager.render_prompt(p3)
# Assert
assert rendered == "P2-P1-X-end"
@pytest.mark.asyncio
async def test_prompt_manager_render_recursive_limit():
# Arrange
manager = PromptManager()
p1_tmp = Prompt(prompt_name="p1", template="{p2}")
p2_tmp = Prompt(prompt_name="p2", template="{p1}")
manager.add_prompt(p1_tmp)
manager.add_prompt(p2_tmp)
p1 = manager.get_prompt("p1")
# Act / Assert
with pytest.raises(RecursionError) as exc_info:
await manager.render_prompt(p1)
# Assert
assert "递归层级过深" in str(exc_info.value)
@pytest.mark.asyncio
async def test_prompt_manager_render_missing_field_error():
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template="Hello {missing}")
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
# Act / Assert
with pytest.raises(KeyError) as exc_info:
await manager.render_prompt(prompt)
# Assert
assert "Prompt 'main' 中缺少必要的内容块或构建函数: 'missing'" in str(exc_info.value)
@pytest.mark.asyncio
async def test_prompt_manager_render_prefers_inner_context_over_global():
# Arrange
manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template="{field}")
manager.add_context_construct_function("field", lambda _: "global")
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
prompt.add_context("field", lambda _: "inner")
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == "inner"
@pytest.mark.asyncio
async def test_prompt_manager_render_with_coroutine_context_function():
# Arrange
manager = PromptManager()
async def async_inner(prompt_name: str) -> str:
await asyncio.sleep(0)
return f"async-{prompt_name}"
tmp_prompt = Prompt(prompt_name="main", template="{inner}")
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
prompt.add_context("inner", async_inner)
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == "async-main"
@pytest.mark.asyncio
async def test_prompt_manager_render_with_coroutine_global_context_function():
# Arrange
manager = PromptManager()
async def async_global(prompt_name: str) -> str:
await asyncio.sleep(0)
return f"g-{prompt_name}"
tmp_prompt = Prompt(prompt_name="main", template="{g}")
manager.add_context_construct_function("g", async_global)
manager.add_prompt(tmp_prompt)
prompt = manager.get_prompt("main")
# Act
rendered = await manager.render_prompt(prompt)
# Assert
assert rendered == "g-main"
@pytest.mark.parametrize(
"is_prompt_context, use_coroutine, case_id",
[
pytest.param(True, False, "prompt-context-sync-error"),
pytest.param(False, False, "global-context-sync-error"),
pytest.param(True, True, "prompt-context-async-error"),
pytest.param(False, True, "global-context-async-error"),
],
)
@pytest.mark.asyncio
async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_prompt_context, use_coroutine, case_id):
# Arrange
manager = PromptManager()
class DummyError(Exception):
pass
def sync_func(_name: str) -> str:
raise DummyError("sync-error")
async def async_func(_name: str) -> str:
await asyncio.sleep(0)
raise DummyError("async-error")
func = async_func if use_coroutine else sync_func
logged_messages: list[str] = []
def fake_error(msg: Any) -> None:
logged_messages.append(str(msg))
fake_logger = type("FakeLogger", (), {"error": staticmethod(fake_error)})
monkeypatch.setattr("src.prompt.prompt_manager.logger", fake_logger)
# Act / Assert
with pytest.raises(DummyError):
await manager._get_function_result(
func=func,
prompt_name="P",
field_name="field",
is_prompt_context=is_prompt_context,
module="mod",
)
# Assert
assert logged_messages
log = logged_messages[0]
if is_prompt_context:
assert "调用 Prompt 'P' 内部上下文构造函数 'field' 时出错" in log
else:
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
# Arrange
manager = PromptManager()
def fake_currentframe() -> None:
return None
monkeypatch.setattr("inspect.currentframe", fake_currentframe)
def f(_):
return "x"
# Act / Assert
with pytest.raises(RuntimeError) as exc_info:
manager.add_context_construct_function("x", f)
# Assert
assert "无法获取调用栈" in str(exc_info.value)
def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monkeypatch):
# Arrange
manager = PromptManager()
real_currentframe = inspect.currentframe
class FakeFrame:
f_back = None
def fake_currentframe():
return FakeFrame()
monkeypatch.setattr("inspect.currentframe", fake_currentframe)
def f(_):
return "x"
# Act / Assert
with pytest.raises(RuntimeError) as exc_info:
manager.add_context_construct_function("x", f)
# Assert
assert "无法获取调用栈的上一级" in str(exc_info.value)
# Cleanup
monkeypatch.setattr("inspect.currentframe", real_currentframe)
def test_prompt_manager_save_and_load_prompts(tmp_path, monkeypatch):
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_me", template="Template {x}")
p1.add_context("x", "X")
manager.add_prompt(p1, need_save=True)
# Act
manager.save_prompts()
# Assert
saved_file = test_dir / f"save_me{SUFFIX_PROMPT}"
assert saved_file.exists()
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
# Arrange
new_manager = PromptManager()
# Act
new_manager.load_prompts()
# Assert
loaded = new_manager.get_prompt("save_me")
assert loaded.template == "Template {x}"
assert "save_me" in new_manager._prompt_to_save
def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_error", template="T")
manager.add_prompt(p1, need_save=True)
class FakeFile:
def __enter__(self):
raise OSError("disk error")
def __exit__(self, exc_type, exc, tb):
return False
def fake_open(*_args, **_kwargs):
return FakeFile()
monkeypatch.setattr("builtins.open", fake_open)
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.save_prompts()
# Assert
assert "disk error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
prompt_file = test_dir / f"bad{SUFFIX_PROMPT}"
prompt_file.write_text("content", encoding="utf-8")
class FakeFile:
def __enter__(self):
raise OSError("read error")
def __exit__(self, exc_type, exc, tb):
return False
def fake_open(*_args, **_kwargs):
return FakeFile()
monkeypatch.setattr("builtins.open", fake_open)
manager = PromptManager()
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.load_prompts()
# Assert
assert "read error" in str(exc_info.value)
def test_prompt_manager_global_instance_access():
# Act
pm = prompt_manager
# Assert
assert isinstance(pm, PromptManager)
def test_formatter_parsing_named_fields_only():
# Arrange
manager = PromptManager()
prompt = Prompt(prompt_name="main", template="A {x} B {y} C")
manager.add_prompt(prompt)
# Act
fields = {field_name for _, field_name, _, _ in manager._formatter.parse(prompt.template) if field_name}
# Assert
assert fields == {"x", "y"}