mirror of https://github.com/Mai-with-u/MaiBot.git
Prompt管理器与测试;调整部分方法名称
parent
3a66bfeac1
commit
c15b77907e
|
|
@ -0,0 +1,607 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
||||
|
||||
from src.prompt.prompt_manager import PromptManager
|
||||
|
||||
# --- Minimal stubs / constants matching the production module ---
|
||||
|
||||
# These imports/definitions are here only to make the tests self‑contained
|
||||
# In the real project, they already exist in `prompt_manager.py`'s module.
|
||||
# We mirror them here to control behavior via monkeypatch.
|
||||
|
||||
|
||||
class Prompt:
|
||||
def __init__(self, prompt_name: str, template: str, prompt_render_context: Optional[dict[str, Callable]] = None):
|
||||
self.prompt_name = prompt_name
|
||||
self.template = template
|
||||
self.prompt_render_context = prompt_render_context or {}
|
||||
|
||||
|
||||
class DummyLogger:
|
||||
def __init__(self):
|
||||
self.errors: list[str] = []
|
||||
self.warnings: list[str] = []
|
||||
|
||||
def error(self, msg: str) -> None:
|
||||
self.errors.append(msg)
|
||||
|
||||
def warning(self, msg: str) -> None:
|
||||
self.warnings.append(msg)
|
||||
|
||||
|
||||
# --- Fixtures to patch module-level objects in prompt_manager ---
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_logger(monkeypatch):
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
logger = DummyLogger()
|
||||
monkeypatch.setattr(pm, "logger", logger, raising=False)
|
||||
return logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_prompts_dir(tmp_path, monkeypatch):
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
prompts_dir = tmp_path / "prompts"
|
||||
monkeypatch.setattr(pm, "PROMPTS_DIR", prompts_dir, raising=False)
|
||||
return prompts_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def brace_constants(monkeypatch):
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
# emulate the placeholders used in the manager
|
||||
monkeypatch.setattr(pm, "_LEFT_BRACE", "__LEFT__", raising=False)
|
||||
monkeypatch.setattr(pm, "_RIGHT_BRACE", "__RIGHT__", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def suffix_prompt(monkeypatch):
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
monkeypatch.setattr(pm, "SUFFIX_PROMPT", ".prompt", raising=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager(temp_prompts_dir, brace_constants, suffix_prompt):
|
||||
# PromptManager.__init__ uses patched PROMPTS_DIR
|
||||
return PromptManager()
|
||||
|
||||
|
||||
# --- Helper to run async methods in tests (for non-async pytest) ---
|
||||
|
||||
|
||||
def run(coro):
|
||||
return asyncio.get_event_loop().run_until_complete(coro)
|
||||
|
||||
|
||||
# --- add_prompt tests --------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"existing_prompts, existing_funcs, name_to_add, need_save, expect_in_save",
|
||||
[
|
||||
pytest.param(
|
||||
{},
|
||||
{},
|
||||
"greeting",
|
||||
False,
|
||||
False,
|
||||
id="add_prompt_simple_not_saved",
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
{},
|
||||
"system",
|
||||
True,
|
||||
True,
|
||||
id="add_prompt_marked_for_save",
|
||||
),
|
||||
pytest.param(
|
||||
{"existing": Prompt("existing", "tmpl")},
|
||||
{},
|
||||
"new",
|
||||
True,
|
||||
True,
|
||||
id="add_prompt_with_existing_other_prompt",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_add_prompt_happy_path(manager, existing_prompts, existing_funcs, name_to_add, need_save, expect_in_save):
|
||||
# Arrange
|
||||
|
||||
manager.prompts.update(existing_prompts)
|
||||
manager._context_construct_functions.update(existing_funcs)
|
||||
prompt = Prompt(name_to_add, "template")
|
||||
|
||||
# Act
|
||||
|
||||
manager.add_prompt(prompt, need_save=need_save)
|
||||
|
||||
# Assert
|
||||
|
||||
assert manager.prompts[name_to_add] is prompt
|
||||
assert (name_to_add in manager._prompt_to_save) is expect_in_save
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"existing_prompts, existing_funcs, new_name, conflict_type",
|
||||
[
|
||||
pytest.param(
|
||||
{"dup": Prompt("dup", "tmpl")},
|
||||
{},
|
||||
"dup",
|
||||
"prompt_conflict",
|
||||
id="add_prompt_conflict_with_existing_prompt",
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
{"dup": (lambda x: x, "mod")},
|
||||
"dup",
|
||||
"func_conflict",
|
||||
id="add_prompt_conflict_with_existing_context_function",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_add_prompt_conflict_raises_key_error(manager, existing_prompts, existing_funcs, new_name, conflict_type):
|
||||
# Arrange
|
||||
|
||||
manager.prompts.update(existing_prompts)
|
||||
manager._context_construct_functions.update(existing_funcs)
|
||||
prompt = Prompt(new_name, "template")
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(KeyError) as exc:
|
||||
manager.add_prompt(prompt)
|
||||
|
||||
assert new_name in str(exc.value)
|
||||
|
||||
|
||||
# --- add_context_construct_function tests -----------------------------------
|
||||
|
||||
|
||||
def test_add_context_construct_function_happy_path(manager):
|
||||
# Arrange
|
||||
|
||||
def builder(prompt_name: str) -> str:
|
||||
return f"ctx_for_{prompt_name}"
|
||||
|
||||
# Act
|
||||
|
||||
manager.add_context_construct_function("ctx", builder)
|
||||
|
||||
# Assert
|
||||
|
||||
assert "ctx" in manager._context_construct_functions
|
||||
stored_func, module = manager._context_construct_functions["ctx"]
|
||||
assert stored_func is builder
|
||||
# module is caller's module name
|
||||
assert isinstance(module, str)
|
||||
assert module != ""
|
||||
|
||||
|
||||
def test_add_context_construct_function_logs_unknown_module(manager, dummy_logger, monkeypatch):
|
||||
# Arrange
|
||||
|
||||
def builder(prompt_name: str) -> str:
|
||||
return f"v_{prompt_name}"
|
||||
|
||||
def fake_currentframe():
|
||||
class FakeCallerFrame:
|
||||
f_globals = {"__name__": "unknown"}
|
||||
|
||||
class FakeFrame:
|
||||
f_back = FakeCallerFrame()
|
||||
|
||||
return FakeFrame()
|
||||
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
monkeypatch.setattr(pm.inspect, "currentframe", fake_currentframe)
|
||||
|
||||
# Act
|
||||
manager.add_context_construct_function("unknown_ctx", builder)
|
||||
|
||||
# Assert
|
||||
|
||||
assert any("无法获取调用函数的模块名" in msg for msg in dummy_logger.warnings)
|
||||
assert "unknown_ctx" in manager._context_construct_functions
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"existing_prompts, existing_funcs, name_to_add",
|
||||
[
|
||||
pytest.param(
|
||||
{"p": Prompt("p", "tmpl")},
|
||||
{},
|
||||
"p",
|
||||
id="add_context_construct_function_conflict_with_prompt",
|
||||
),
|
||||
pytest.param(
|
||||
{},
|
||||
{"f": (lambda x: x, "mod")},
|
||||
"f",
|
||||
id="add_context_construct_function_conflict_with_existing_func",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_add_context_construct_function_conflict_raises_key_error(
|
||||
manager, existing_prompts, existing_funcs, name_to_add
|
||||
):
|
||||
# Arrange
|
||||
|
||||
manager.prompts.update(existing_prompts)
|
||||
manager._context_construct_functions.update(existing_funcs)
|
||||
|
||||
def func(prompt_name: str) -> str:
|
||||
return "x"
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(KeyError) as exc:
|
||||
manager.add_context_construct_function(name_to_add, func)
|
||||
|
||||
assert name_to_add in str(exc.value)
|
||||
|
||||
|
||||
def test_add_context_construct_function_no_frame_raises_runtime_error(manager, monkeypatch):
|
||||
# Arrange
|
||||
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
monkeypatch.setattr(pm.inspect, "currentframe", lambda: None)
|
||||
|
||||
def func(prompt_name: str) -> str:
|
||||
return "x"
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
manager.add_context_construct_function("ctx", func)
|
||||
|
||||
assert "无法获取调用栈" in str(exc.value)
|
||||
|
||||
|
||||
def test_add_context_construct_function_no_caller_frame_raises_runtime_error(manager, monkeypatch):
|
||||
# Arrange
|
||||
|
||||
from src.prompt import prompt_manager as pm
|
||||
|
||||
class FakeFrame:
|
||||
f_back = None
|
||||
|
||||
monkeypatch.setattr(pm.inspect, "currentframe", lambda: FakeFrame())
|
||||
|
||||
def func(prompt_name: str) -> str:
|
||||
return "x"
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(RuntimeError) as exc:
|
||||
manager.add_context_construct_function("ctx", func)
|
||||
|
||||
assert "无法获取调用栈的上一级" in str(exc.value)
|
||||
|
||||
|
||||
# --- get_prompt tests --------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"existing_name, requested_name, should_raise",
|
||||
[
|
||||
pytest.param("p1", "p1", False, id="get_existing_prompt"),
|
||||
pytest.param("p1", "missing", True, id="get_missing_prompt_raises"),
|
||||
],
|
||||
)
|
||||
def test_get_prompt(manager, existing_name, requested_name, should_raise):
|
||||
# Arrange
|
||||
|
||||
manager.prompts[existing_name] = Prompt(existing_name, "tmpl")
|
||||
|
||||
# Act / Assert
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(KeyError) as exc:
|
||||
manager.get_prompt(requested_name)
|
||||
assert requested_name in str(exc.value)
|
||||
else:
|
||||
prompt = manager.get_prompt(requested_name)
|
||||
assert prompt.prompt_name == existing_name
|
||||
|
||||
|
||||
# --- render_prompt and _render tests ----------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"template, prompts_setup, ctx_funcs_setup, prompt_ctx, expected",
|
||||
[
|
||||
pytest.param(
|
||||
"Hello {name}",
|
||||
{},
|
||||
{},
|
||||
{"name": lambda p: "World"},
|
||||
"Hello World",
|
||||
id="render_with_prompt_context_sync",
|
||||
),
|
||||
pytest.param(
|
||||
"Hello {name}",
|
||||
{},
|
||||
{},
|
||||
{
|
||||
"name": lambda p: asyncio.sleep(0, result=f"Async-{p}"),
|
||||
},
|
||||
"Hello Async-main",
|
||||
id="render_with_prompt_context_async",
|
||||
),
|
||||
pytest.param(
|
||||
"Outer {inner}",
|
||||
{
|
||||
"inner": Prompt("inner", "Inner {value}", {"value": lambda p: "42"}),
|
||||
},
|
||||
{},
|
||||
{},
|
||||
"Outer Inner 42",
|
||||
id="render_with_nested_prompt_reference",
|
||||
),
|
||||
pytest.param(
|
||||
"Module says {ext}",
|
||||
{},
|
||||
{
|
||||
"ext": (lambda p: f"external-{p}", "test_module"),
|
||||
},
|
||||
{},
|
||||
"Module says external-main",
|
||||
id="render_with_external_context_function_sync",
|
||||
),
|
||||
pytest.param(
|
||||
"Module async {ext}",
|
||||
{},
|
||||
{
|
||||
"ext": (lambda p: asyncio.sleep(0, result=f"ext_async-{p}"), "test_module"),
|
||||
},
|
||||
{},
|
||||
"Module async ext_async-main",
|
||||
id="render_with_external_context_function_async",
|
||||
),
|
||||
pytest.param(
|
||||
"Escaped {{ and }} literal plus {value}",
|
||||
{},
|
||||
{},
|
||||
{"value": lambda p: "X"},
|
||||
"Escaped { and } literal plus X",
|
||||
id="render_with_escaped_braces",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_render_prompt_happy_path(
|
||||
manager,
|
||||
template,
|
||||
prompts_setup,
|
||||
ctx_funcs_setup,
|
||||
prompt_ctx,
|
||||
expected,
|
||||
):
|
||||
# Arrange
|
||||
|
||||
main_prompt = Prompt("main", template, prompt_ctx)
|
||||
manager.add_prompt(main_prompt)
|
||||
for name, prompt in prompts_setup.items():
|
||||
manager.add_prompt(prompt)
|
||||
manager._context_construct_functions.update(ctx_funcs_setup)
|
||||
|
||||
# Act
|
||||
|
||||
rendered = run(manager.render_prompt(main_prompt))
|
||||
|
||||
# Assert
|
||||
|
||||
assert rendered == expected
|
||||
|
||||
|
||||
def test_render_prompt_missing_field_raises_key_error(manager):
|
||||
# Arrange
|
||||
|
||||
prompt = Prompt("main", "Hello {missing}")
|
||||
manager.add_prompt(prompt)
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(KeyError) as exc:
|
||||
run(manager.render_prompt(prompt))
|
||||
|
||||
assert "缺少必要的内容块或构建函数" in str(exc.value)
|
||||
assert "missing" in str(exc.value)
|
||||
|
||||
|
||||
def test_render_prompt_recursion_limit_exceeded(manager):
|
||||
# Arrange
|
||||
|
||||
# Create mutual recursion between two prompts
|
||||
p1 = Prompt("p1", "P1 uses {p2}")
|
||||
p2 = Prompt("p2", "P2 uses {p1}")
|
||||
manager.add_prompt(p1)
|
||||
manager.add_prompt(p2)
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(RecursionError):
|
||||
run(manager.render_prompt(p1))
|
||||
|
||||
|
||||
# --- _get_function_result tests ---------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, is_prompt_context, expect_async",
|
||||
[
|
||||
pytest.param(
|
||||
lambda p: f"sync_{p}",
|
||||
True,
|
||||
False,
|
||||
id="get_function_result_sync_prompt_context",
|
||||
),
|
||||
pytest.param(
|
||||
lambda p: asyncio.sleep(0, result=f"async_{p}"),
|
||||
False,
|
||||
True,
|
||||
id="get_function_result_async_external_context",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_function_result_happy_path(manager, dummy_logger, func, is_prompt_context, expect_async):
|
||||
# Act
|
||||
|
||||
res = run(
|
||||
manager._get_function_result(
|
||||
func=func,
|
||||
prompt_name="prompt",
|
||||
field_name="f",
|
||||
is_prompt_context=is_prompt_context,
|
||||
module="mod",
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
||||
assert res in {"sync_prompt", "async_prompt"}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"is_prompt_context, expected_message_part",
|
||||
[
|
||||
pytest.param(True, "内部上下文构造函数", id="get_function_result_error_prompt_context_logs_internal_msg"),
|
||||
pytest.param(False, "上下文构造函数", id="get_function_result_error_external_logs_external_msg"),
|
||||
],
|
||||
)
|
||||
def test_get_function_result_error_logging(manager, dummy_logger, is_prompt_context, expected_message_part):
|
||||
# Arrange
|
||||
|
||||
def bad_func(prompt_name: str) -> str:
|
||||
raise ValueError("bad")
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
run(
|
||||
manager._get_function_result(
|
||||
func=bad_func,
|
||||
prompt_name="promptX",
|
||||
field_name="fieldX",
|
||||
is_prompt_context=is_prompt_context,
|
||||
module="modX",
|
||||
)
|
||||
)
|
||||
|
||||
# Assert
|
||||
|
||||
assert any(expected_message_part in msg for msg in dummy_logger.errors)
|
||||
assert any("promptX" in msg for msg in dummy_logger.errors) ^ (not is_prompt_context)
|
||||
assert any("modX" in msg for msg in dummy_logger.errors) ^ is_prompt_context
|
||||
assert any("fieldX" in msg for msg in dummy_logger.errors)
|
||||
|
||||
|
||||
# --- save_prompts tests ------------------------------------------------------
|
||||
|
||||
|
||||
def test_save_prompts_happy_path(manager, temp_prompts_dir):
|
||||
# Arrange
|
||||
|
||||
p1 = Prompt("p1", "Hello {{name}}")
|
||||
p2 = Prompt("p2", "Bye {{value}}")
|
||||
manager.add_prompt(p1, need_save=True)
|
||||
manager.add_prompt(p2, need_save=True)
|
||||
|
||||
# Act
|
||||
|
||||
manager.save_prompts()
|
||||
|
||||
# Assert
|
||||
|
||||
files = sorted(temp_prompts_dir.glob("*.prompt"))
|
||||
assert len(files) == 2
|
||||
contents = {f.stem: f.read_text(encoding="utf-8") for f in files}
|
||||
assert contents["p1"] == "Hello {{name}}"
|
||||
assert contents["p2"] == "Bye {{value}}"
|
||||
|
||||
|
||||
def test_save_prompts_io_error(manager, temp_prompts_dir, dummy_logger, monkeypatch):
|
||||
# Arrange
|
||||
|
||||
prompt = Prompt("p1", "Hi")
|
||||
manager.add_prompt(prompt, need_save=True)
|
||||
|
||||
def bad_open(*args, **kwargs):
|
||||
raise OSError("disk full")
|
||||
|
||||
monkeypatch.setattr("builtins.open", bad_open)
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(OSError):
|
||||
manager.save_prompts()
|
||||
|
||||
# Assert
|
||||
|
||||
assert any("保存 Prompt 'p1' 时出错" in msg for msg in dummy_logger.errors)
|
||||
|
||||
|
||||
# --- load_prompts tests ------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_prompts_happy_path(manager, temp_prompts_dir):
|
||||
# Arrange
|
||||
|
||||
file1 = temp_prompts_dir / "greet.prompt"
|
||||
file2 = temp_prompts_dir / "farewell.prompt"
|
||||
temp_prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
file1.write_text("Hello {{name}}", encoding="utf-8")
|
||||
file2.write_text("Bye {{name}}", encoding="utf-8")
|
||||
|
||||
# Act
|
||||
|
||||
manager.load_prompts()
|
||||
|
||||
# Assert
|
||||
|
||||
assert "greet" in manager.prompts
|
||||
assert "farewell" in manager.prompts
|
||||
assert "greet" in manager._prompt_to_save
|
||||
assert "farewell" in manager._prompt_to_save
|
||||
assert manager.prompts["greet"].template == "Hello {{name}}"
|
||||
assert manager.prompts["farewell"].template == "Bye {{name}}"
|
||||
|
||||
|
||||
def test_load_prompts_error(manager, temp_prompts_dir, dummy_logger, monkeypatch):
|
||||
# Arrange
|
||||
|
||||
file1 = temp_prompts_dir / "broken.prompt"
|
||||
temp_prompts_dir.mkdir(parents=True, exist_ok=True)
|
||||
file1.write_text("whatever", encoding="utf-8")
|
||||
|
||||
def bad_open(*args, **kwargs):
|
||||
raise OSError("cannot read")
|
||||
|
||||
monkeypatch.setattr("builtins.open", bad_open)
|
||||
|
||||
# Act / Assert
|
||||
|
||||
with pytest.raises(OSError):
|
||||
manager.load_prompts()
|
||||
|
||||
# Assert
|
||||
|
||||
assert any("加载 Prompt 文件" in msg for msg in dummy_logger.errors)
|
||||
|
|
@ -173,17 +173,17 @@ class ConfigManager:
|
|||
def initialize(self):
|
||||
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
|
||||
logger.info("正在品鉴配置文件...")
|
||||
self.global_config: Config = self._load_global_config()
|
||||
self.model_config: ModelConfig = self._load_model_config()
|
||||
self.global_config: Config = self.load_global_config()
|
||||
self.model_config: ModelConfig = self.load_model_config()
|
||||
logger.info("非常的新鲜,非常的美味!")
|
||||
|
||||
def _load_global_config(self) -> Config:
|
||||
def load_global_config(self) -> Config:
|
||||
config, updated = load_config_from_file(Config, self.bot_config_path, CONFIG_VERSION)
|
||||
if updated:
|
||||
sys.exit(0) # 先直接退出
|
||||
return config
|
||||
|
||||
def _load_model_config(self) -> ModelConfig:
|
||||
def load_model_config(self) -> ModelConfig:
|
||||
config, updated = load_config_from_file(ModelConfig, self.model_config_path, MODEL_CONFIG_VERSION, True)
|
||||
if updated:
|
||||
sys.exit(0) # 先直接退出
|
||||
|
|
|
|||
|
|
@ -0,0 +1,157 @@
|
|||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any, Optional
|
||||
from string import Formatter
|
||||
from pathlib import Path
|
||||
import inspect
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
|
||||
logger = get_logger("Prompt")
|
||||
|
||||
_LEFT_BRACE = "\ufde9"
|
||||
_RIGHT_BRACE = "\ufdea"
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||
PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SUFFIX_PROMPT = ".prompt"
|
||||
|
||||
|
||||
class Prompt:
|
||||
prompt_name: str
|
||||
template: str
|
||||
prompt_render_context: dict[str, Callable[[str], str | Coroutine[Any, Any, str]]] = {}
|
||||
|
||||
def __init__(self, prompt_name: str, template: str) -> None:
|
||||
self.prompt_name = prompt_name
|
||||
self.template = template
|
||||
self.__post_init__()
|
||||
|
||||
def add_context(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||
if name in self.prompt_render_context:
|
||||
raise KeyError(f"Context function name '{name}' 已存在于 Prompt '{self.prompt_name}' 中")
|
||||
self.prompt_render_context[name] = func
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.prompt_name:
|
||||
raise ValueError("prompt_name 不能为空")
|
||||
if not self.template:
|
||||
raise ValueError("template 不能为空")
|
||||
tmp = self.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE)
|
||||
if "{}" in tmp:
|
||||
raise ValueError(r"模板中不允许使用未命名的占位符 '{}'")
|
||||
|
||||
|
||||
class PromptManager:
|
||||
def __init__(self) -> None:
|
||||
PROMPTS_DIR.mkdir(parents=True, exist_ok=True) # 确保提示词目录存在
|
||||
self.prompts: dict[str, Prompt] = {}
|
||||
"""存储 Prompt 实例"""
|
||||
self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {}
|
||||
"""存储上下文构造函数及其所属模块"""
|
||||
self._formatter = Formatter() # 仅用来解析模板
|
||||
"""模板解析器"""
|
||||
self._prompt_to_save: set[str] = set()
|
||||
"""需要保存的 Prompt 名称集合"""
|
||||
|
||||
def add_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
|
||||
if prompt.prompt_name in self.prompts or prompt.prompt_name in self._context_construct_functions:
|
||||
raise KeyError(f"Prompt name '{prompt.prompt_name}' 已存在")
|
||||
self.prompts[prompt.prompt_name] = prompt
|
||||
if need_save:
|
||||
self._prompt_to_save.add(prompt.prompt_name)
|
||||
|
||||
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||
if name in self._context_construct_functions or name in self.prompts:
|
||||
raise KeyError(f"Construct function name '{name}' 已存在")
|
||||
# 获取调用栈
|
||||
frame = inspect.currentframe()
|
||||
if not frame:
|
||||
# 不应该出现的情况
|
||||
raise RuntimeError("无法获取调用栈")
|
||||
caller_frame = frame.f_back
|
||||
if not caller_frame:
|
||||
# 不应该出现的情况
|
||||
raise RuntimeError("无法获取调用栈的上一级")
|
||||
caller_module = caller_frame.f_globals.get("__name__", "unknown")
|
||||
if caller_module == "unknown":
|
||||
logger.warning("无法获取调用函数的模块名,使用 'unknown' 作为默认值")
|
||||
|
||||
self._context_construct_functions[name] = func, caller_module
|
||||
|
||||
def get_prompt(self, prompt_name: str) -> Prompt:
|
||||
if prompt_name not in self.prompts:
|
||||
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
|
||||
return self.prompts[prompt_name]
|
||||
|
||||
async def render_prompt(self, prompt: Prompt) -> str:
|
||||
return await self._render(prompt)
|
||||
|
||||
async def _render(self, prompt: Prompt, recursive_level: int = 0) -> str:
|
||||
prompt.template = prompt.template.replace("{{", _LEFT_BRACE).replace("}}", _RIGHT_BRACE)
|
||||
if recursive_level > 10:
|
||||
raise RecursionError("递归层级过深,可能存在循环引用")
|
||||
field_block = {field_name for _, field_name, _, _ in self._formatter.parse(prompt.template) if field_name}
|
||||
rendered_fields: dict[str, str] = {}
|
||||
for field_name in field_block:
|
||||
if field_name in self.prompts:
|
||||
rendered_fields[field_name] = await self._render(self.prompts[field_name], recursive_level + 1)
|
||||
elif field_name in prompt.prompt_render_context:
|
||||
func = prompt.prompt_render_context[field_name]
|
||||
rendered_fields[field_name] = await self._get_function_result(
|
||||
func, prompt.prompt_name, field_name, is_prompt_context=True
|
||||
)
|
||||
elif field_name in self._context_construct_functions:
|
||||
func, module = self._context_construct_functions[field_name]
|
||||
rendered_fields[field_name] = await self._get_function_result(
|
||||
func, prompt.prompt_name, field_name, is_prompt_context=False, module=module
|
||||
)
|
||||
else:
|
||||
raise KeyError(f"Prompt '{prompt.prompt_name}' 中缺少必要的内容块或构建函数: '{field_name}'")
|
||||
rendered_template = prompt.template.format(**rendered_fields)
|
||||
return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}")
|
||||
|
||||
def save_prompts(self) -> None:
|
||||
for prompt_name in self._prompt_to_save:
|
||||
prompt = self.prompts[prompt_name]
|
||||
file_path = PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||
try:
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(prompt.template)
|
||||
except Exception as e:
|
||||
logger.error(f"保存 Prompt '{prompt_name}' 时出错,文件路径: '{file_path}',错误信息: {e}")
|
||||
raise e
|
||||
|
||||
def load_prompts(self) -> None:
|
||||
for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||
try:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
template = f.read()
|
||||
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||
raise e
|
||||
|
||||
async def _get_function_result(
|
||||
self,
|
||||
func: Callable[[str], str | Coroutine[Any, Any, str]],
|
||||
prompt_name: str,
|
||||
field_name: str,
|
||||
is_prompt_context: bool,
|
||||
module: Optional[str] = None,
|
||||
) -> str:
|
||||
try:
|
||||
res = func(prompt_name)
|
||||
if isinstance(res, Coroutine):
|
||||
res = await res
|
||||
return res
|
||||
except Exception as e:
|
||||
if is_prompt_context:
|
||||
logger.error(f"调用 Prompt '{prompt_name}' 内部上下文构造函数 '{field_name}' 时出错,错误信息: {e}")
|
||||
else:
|
||||
logger.error(f"调用上下文构造函数 '{field_name}' 时出错,所属模块: '{module}',错误信息: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
Loading…
Reference in New Issue