From b793a3d62b807127737361a0bdbba2902c395040 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Mon, 2 Feb 2026 20:53:42 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E5=A5=BD=E7=9A=84Prompt=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E7=B3=BB=E7=BB=9F=EF=BC=8C=E5=A2=9E=E5=8A=A0=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E8=87=AA=E5=AE=9A=E4=B9=89Prompt=E4=B8=8E=E8=A6=86?= =?UTF-8?q?=E7=9B=96=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- changelogs/mai_next_todo.md | 22 +- pytests/prompt_test/test_prompt_manager.py | 379 ++++++++++++++++++--- src/prompt/prompt_manager.py | 47 ++- 3 files changed, 402 insertions(+), 46 deletions(-) diff --git a/changelogs/mai_next_todo.md b/changelogs/mai_next_todo.md index 7fa73e17..40c7bdcb 100644 --- a/changelogs/mai_next_todo.md +++ b/changelogs/mai_next_todo.md @@ -125,8 +125,28 @@ version 0.3.0 - 2026-01-11 - [x] 使用C模块库提升相似度计算效率 - [ ] 移除了定时表情包完整性检查,改为启动时检查(依然保留为独立方法,以防之后恢复定时检查系统) +## Prompt 管理系统 +- [ ] 官方Prompt全部独立 +- [x] 用户自定义Prompt系统 + - [x] 用户可以创建,删除自己的Prompt + - [x] 用户可以覆盖官方Prompt +- [x] Prompt构建系统 +- [x] Prompt文件交互 + - [x] 读取Prompt文件 + - [x] 读取官方Prompt文件 + - [x] 读取用户Prompt文件 + - [x] 用户Prompt覆盖官方Prompt + - [x] 保存Prompt文件 +- [x] Prompt管理方法 + - [x] Prompt添加 + - [x] Prompt删除 + - [x] **只保存被标记为需要保存的Prompt,其他的Prompt文件全部删除** + ## 一些细枝末节的东西 - [ ] 将`stream_id`和`chat_id`统一命名为`session_id` - [ ] 映射表 - [ ] `platform_group_user_session_id_map` `平台_群组_用户`-`会话ID` 映射表 -- [ ] 将大部分的数据模型均以`Mai`开头命名 \ No newline at end of file +- [ ] 将大部分的数据模型均以`Mai`开头命名 + +### 细节说明 +1. Prompt管理系统中保存用户自定义Prompt的时候会只保存被标记为需要保存的Prompt,其他的Prompt文件会全部删除,以防止用户删除Prompt后文件依然存在的问题。因此,如果想在运行时通过修改文件的方式来添加Prompt,需要确保通过对应方法标记该Prompt为需要保存,否则在下一次保存时会被删除。 \ No newline at end of file diff --git a/pytests/prompt_test/test_prompt_manager.py b/pytests/prompt_test/test_prompt_manager.py index f16d3dfe..53d2c622 100644 --- a/pytests/prompt_test/test_prompt_manager.py +++ b/pytests/prompt_test/test_prompt_manager.py @@ -1,4 +1,4 @@ -# File: tests/test_prompt_manager.py +# File: pytests/prompt_test/test_prompt_manager.py import asyncio import inspect @@ -12,7 +12,15 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve() sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT / "src" / "config")) -from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prompt_manager # noqa +from src.prompt.prompt_manager import ( # noqa + SUFFIX_PROMPT, + Prompt, + PromptManager, + prompt_manager, +) + + +# ========= Prompt 基础行为 ========= @pytest.mark.parametrize( @@ -20,7 +28,11 @@ from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prom [ pytest.param("simple", "Hello {name}", id="simple-template-with-field"), pytest.param("no-fields", "Just a static template", id="template-without-fields"), - pytest.param("brace-escaping", "Use {{ and }} around {field}", id="template-with-escaped-braces"), + pytest.param( + "brace-escaping", + "Use {{ and }} around {field}", + id="template-with-escaped-braces", + ), ], ) def test_prompt_init_happy_paths(prompt_name: str, template: str): @@ -53,7 +65,12 @@ def test_prompt_init_happy_paths(prompt_name: str, template: str): ), ], ) -def test_prompt_init_error_cases(prompt_name, template, expected_exception, expected_msg_substring): +def test_prompt_init_error_cases( + prompt_name, + template, + expected_exception, + expected_msg_substring, +): # Act / Assert with pytest.raises(expected_exception) as exc_info: Prompt(prompt_name=prompt_name, template=template) @@ -123,6 +140,25 @@ def test_prompt_add_context( assert result == expected_value +def test_prompt_clone_independent_instance(): + # Arrange + prompt = Prompt(prompt_name="p", template="T {x}") + prompt.add_context("x", "X") + + # Act + cloned = prompt.clone() + + # Assert + assert cloned is not prompt + assert cloned.prompt_name == prompt.prompt_name + assert cloned.template == prompt.template + # 当前实现 clone 不复制 context + assert cloned.prompt_render_context == {} + + +# ========= PromptManager:添加/获取/删除/替换 ========= + + def test_prompt_manager_add_prompt_happy_and_error(): # Arrange manager = PromptManager() @@ -147,6 +183,59 @@ def test_prompt_manager_add_prompt_happy_and_error(): # Assert assert "Prompt name 'p1' 已存在" in str(exc_info.value) + +def test_prompt_manager_remove_prompt_happy_and_error(): + # Arrange + manager = PromptManager() + p1 = Prompt(prompt_name="p1", template="T") + manager.add_prompt(p1, need_save=True) + + # Act + manager.remove_prompt("p1") + + # Assert + assert "p1" not in manager.prompts + assert "p1" not in manager._prompt_to_save + + # Act / Assert + with pytest.raises(KeyError) as exc_info: + manager.remove_prompt("no_such") + + assert "Prompt name 'no_such' 不存在" in str(exc_info.value) + + +def test_prompt_manager_replace_prompt_happy_and_error(): + # sourcery skip: extract-duplicate-method + # Arrange + manager = PromptManager() + p1 = Prompt(prompt_name="p", template="Old") + manager.add_prompt(p1, need_save=True) + + p_new = Prompt(prompt_name="p", template="New") + + # Act: 替换且保持 need_save + manager.replace_prompt(p_new, need_save=True) + + # Assert + assert manager.prompts["p"].template == "New" + assert "p" in manager._prompt_to_save + + # Act: 再次替换,且不需要保存 + p_new2 = Prompt(prompt_name="p", template="New2") + manager.replace_prompt(p_new2, need_save=False) + + # Assert + assert manager.prompts["p"].template == "New2" + assert "p" not in manager._prompt_to_save + + # Error: 不存在的 prompt + p_unknown = Prompt(prompt_name="unknown", template="T") + with pytest.raises(KeyError) as exc_info: + manager.replace_prompt(p_unknown) + + assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value) + + def test_prompt_manager_get_prompt_is_copy(): # Arrange manager = PromptManager() @@ -162,6 +251,7 @@ def test_prompt_manager_get_prompt_is_copy(): assert retrieved_prompt.template == prompt.template assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context + def test_prompt_manager_add_prompt_conflict_with_context_name(): # Arrange manager = PromptManager() @@ -230,6 +320,9 @@ def test_prompt_manager_get_prompt_not_exist(): assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value) +# ========= 渲染逻辑 ========= + + @pytest.mark.parametrize( "template, inner_context, global_context, expected, case_id", [ @@ -264,7 +357,13 @@ def test_prompt_manager_get_prompt_not_exist(): ], ) @pytest.mark.asyncio -async def test_prompt_manager_render_contexts(template, inner_context, global_context, expected, case_id): +async def test_prompt_manager_render_contexts( + template, + inner_context, + global_context, + expected, + case_id, +): # Arrange manager = PromptManager() tmp_prompt = Prompt(prompt_name="main", template=template) @@ -274,7 +373,6 @@ async def test_prompt_manager_render_contexts(template, inner_context, global_co prompt.add_context(name, fn) for name, fn in global_context.items(): manager.add_context_construct_function(name, fn) - # Act rendered = await manager.render_prompt(prompt) @@ -396,6 +494,20 @@ async def test_prompt_manager_render_with_coroutine_global_context_function(): assert rendered == "g-main" +@pytest.mark.asyncio +async def test_prompt_manager_render_only_cloned_instance(): + # Arrange + manager = PromptManager() + p = Prompt(prompt_name="p", template="T") + manager.add_prompt(p) + + # Act / Assert: 直接用原始 p 渲染会报错 + with pytest.raises(ValueError) as exc_info: + await manager.render_prompt(p) + + assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value) + + @pytest.mark.parametrize( "is_prompt_context, use_coroutine, case_id", [ @@ -406,7 +518,12 @@ async def test_prompt_manager_render_with_coroutine_global_context_function(): ], ) @pytest.mark.asyncio -async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_prompt_context, use_coroutine, case_id): +async def test_prompt_manager_get_function_result_error_logging( + monkeypatch, + is_prompt_context, + use_coroutine, + case_id, +): # Arrange manager = PromptManager() @@ -449,6 +566,9 @@ async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_ assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log +# ========= add_context_construct_function 边界 ========= + + def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch): # Arrange manager = PromptManager() @@ -496,50 +616,68 @@ def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monk monkeypatch.setattr("inspect.currentframe", real_currentframe) -def test_prompt_manager_save_and_load_prompts(tmp_path, monkeypatch): - # Arrange - test_dir = tmp_path / "prompts_dir" - test_dir.mkdir() +# ========= save/load & 目录逻辑 ========= - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False) + +def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch): + """ + save_prompts 现在的逻辑: + 1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件; + 2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR。 + + 这里模拟删除已有自定义 prompt 文件时发生 IO 错误。 + """ + # Arrange + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) + + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) + + # 先在自定义目录写入一个 prompt 文件,触发 unlink 路径 + old_file = custom_dir / f"old{SUFFIX_PROMPT}" + old_file.write_text("old", encoding="utf-8") manager = PromptManager() - p1 = Prompt(prompt_name="save_me", template="Template {x}") - p1.add_context("x", "X") + p1 = Prompt(prompt_name="save_error", template="T") manager.add_prompt(p1, need_save=True) - # Act - manager.save_prompts() + # 打桩 Path.unlink,使删除文件时报错 + def fake_unlink(self): + raise OSError("disk unlink error") + + monkeypatch.setattr("pathlib.Path.unlink", fake_unlink) + + # Act / Assert + with pytest.raises(OSError) as exc_info: + manager.save_prompts() # Assert - saved_file = test_dir / f"save_me{SUFFIX_PROMPT}" - assert saved_file.exists() - assert saved_file.read_text(encoding="utf-8") == "Template {x}" + assert "disk unlink error" in str(exc_info.value) + +def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch): + """ + 模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误。 + """ # Arrange - new_manager = PromptManager() + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) - # Act - new_manager.load_prompts() + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) - # Assert - loaded = new_manager.get_prompt("save_me") - assert loaded.template == "Template {x}" - assert "save_me" in new_manager._prompt_to_save - - -def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch): - # Arrange - test_dir = tmp_path / "prompts_dir" - test_dir.mkdir() - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False) manager = PromptManager() p1 = Prompt(prompt_name="save_error", template="T") manager.add_prompt(p1, need_save=True) class FakeFile: def __enter__(self): - raise OSError("disk error") + raise OSError("disk write error") def __exit__(self, exc_type, exc, tb): return False @@ -554,15 +692,23 @@ def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch): manager.save_prompts() # Assert - assert "disk error" in str(exc_info.value) + assert "disk write error" in str(exc_info.value) -def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch): +def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch): + """ + 模拟从 PROMPTS_DIR 读取 prompt 时发生 IO 错误。 + """ # Arrange - test_dir = tmp_path / "prompts_dir" - test_dir.mkdir() - monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False) - prompt_file = test_dir / f"bad{SUFFIX_PROMPT}" + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) + + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) + + prompt_file = prompts_dir / f"bad{SUFFIX_PROMPT}" prompt_file.write_text("content", encoding="utf-8") class FakeFile: @@ -572,8 +718,12 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch): def __exit__(self, exc_type, exc, tb): return False - def fake_open(*_args, **_kwargs): - return FakeFile() + def fake_open(*args, **kwargs): + # 只对 default 目录下的文件触发错误,其余正常(如果有) + file_path = Path(args[0]) + if file_path == prompt_file: + return FakeFile() + return open(*args, **kwargs) monkeypatch.setattr("builtins.open", fake_open) manager = PromptManager() @@ -586,6 +736,151 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch): assert "read error" in str(exc_info.value) +def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch): + """ + 模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误。 + 包含两种路径: + 1. default 与 custom 同名,load_prompts 会优先读取 custom; + 2. 仅 custom 有文件,且 default 无同名文件。 + """ + # Arrange + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) + + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) + + # default 与 custom 同名的文件 + same_name = f"same{SUFFIX_PROMPT}" + base_file = prompts_dir / same_name + base_file.write_text("base", encoding="utf-8") + custom_file_same = custom_dir / same_name + custom_file_same.write_text("custom", encoding="utf-8") + + # 仅 custom 下存在的文件 + only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}" + only_custom_file.write_text("only", encoding="utf-8") + + class FakeFile: + def __enter__(self): + raise OSError("custom read error") + + def __exit__(self, exc_type, exc, tb): + return False + + def fake_open(*args, **kwargs): + file_path = Path(args[0]) + # 对 custom 目录下的 prompt 文件统一触发错误 + if file_path.parent == custom_dir: + return FakeFile() + return open(*args, **kwargs) + + monkeypatch.setattr("builtins.open", fake_open) + manager = PromptManager() + + # Act / Assert + with pytest.raises(OSError) as exc_info: + manager.load_prompts() + + # Assert + assert "custom read error" in str(exc_info.value) + + +def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch): + """ + load_prompts 逻辑: + - 遍历 PROMPTS_DIR/*.prompt + - 如果 CUSTOM_PROMPTS_DIR 下存在同名文件,则优先使用自定义目录 + """ + # Arrange + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) + + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) + + # 默认目录 prompt + base_file = prompts_dir / f"testp{SUFFIX_PROMPT}" + base_file.write_text("BaseTemplate {x}", encoding="utf-8") + + # 自定义目录同名 prompt,应当覆盖默认 + custom_file = custom_dir / base_file.name + custom_file.write_text("CustomTemplate {x}", encoding="utf-8") + + manager = PromptManager() + + # Act + manager.load_prompts() + + # Assert + p = manager.get_prompt("testp") + assert p.template == "CustomTemplate {x}" + # 从自定义目录加载的 prompt 应标记为 need_save(加入 _prompt_to_save) + assert "testp" in manager._prompt_to_save + + +def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch): + """ + 从 PROMPTS_DIR 加载、且没有同名自定义 prompt 时,need_save 应为 False(不进入 _prompt_to_save)。 + """ + # Arrange + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) + + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) + + # 仅默认目录有 prompt,自定义目录中无同名文件 + base_file = prompts_dir / f"only_default{SUFFIX_PROMPT}" + base_file.write_text("DefaultTemplate {x}", encoding="utf-8") + + manager = PromptManager() + + # Act + manager.load_prompts() + + # Assert + p = manager.get_prompt("only_default") + assert p.template == "DefaultTemplate {x}" + # 从默认目录加载的 prompt 不应标记为 need_save + assert "only_default" not in manager._prompt_to_save + + +def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch): + """ + save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存。 + """ + prompts_dir = tmp_path / "prompts" + custom_dir = tmp_path / "data" / "custom_prompts" + prompts_dir.mkdir(parents=True) + custom_dir.mkdir(parents=True) + + monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False) + monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False) + + manager = PromptManager() + p1 = Prompt(prompt_name="save_me", template="Template {x}") + p1.add_context("x", "X") + manager.add_prompt(p1, need_save=True) + + # Act + manager.save_prompts() + + # Assert: 文件应保存在 custom_dir 中 + saved_file = custom_dir / f"save_me{SUFFIX_PROMPT}" + assert saved_file.exists() + assert saved_file.read_text(encoding="utf-8") == "Template {x}" + + +# ========= 其它 ========= + + def test_prompt_manager_global_instance_access(): # Act pm = prompt_manager diff --git a/src/prompt/prompt_manager.py b/src/prompt/prompt_manager.py index a3d21f44..63a36eb5 100644 --- a/src/prompt/prompt_manager.py +++ b/src/prompt/prompt_manager.py @@ -14,7 +14,10 @@ _RIGHT_BRACE = "\ufdea" PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve() PROMPTS_DIR = PROJECT_ROOT / "prompts" +DATA_DIR = PROJECT_ROOT / "data" +CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts" PROMPTS_DIR.mkdir(parents=True, exist_ok=True) +CUSTOM_PROMPTS_DIR.mkdir(parents=True, exist_ok=True) SUFFIX_PROMPT = ".prompt" @@ -54,7 +57,6 @@ class Prompt: class PromptManager: def __init__(self) -> None: - PROMPTS_DIR.mkdir(parents=True, exist_ok=True) # 确保提示词目录存在 self.prompts: dict[str, Prompt] = {} """存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果""" self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {} @@ -72,6 +74,22 @@ class PromptManager: if need_save: self._prompt_to_save.add(prompt.prompt_name) + def remove_prompt(self, prompt_name: str) -> None: + if prompt_name not in self.prompts: + raise KeyError(f"Prompt name '{prompt_name}' 不存在") + del self.prompts[prompt_name] + if prompt_name in self._prompt_to_save: + self._prompt_to_save.remove(prompt_name) + + def replace_prompt(self, prompt: Prompt, need_save: bool = False) -> None: + if prompt.prompt_name not in self.prompts: + raise KeyError(f"Prompt name '{prompt.prompt_name}' 不存在,无法替换") + self.prompts[prompt.prompt_name] = prompt + if need_save: + self._prompt_to_save.add(prompt.prompt_name) + elif prompt.prompt_name in self._prompt_to_save: + self._prompt_to_save.remove(prompt.prompt_name) + def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None: if name in self._context_construct_functions or name in self.prompts: raise KeyError(f"Construct function name '{name}' 已存在") @@ -159,9 +177,16 @@ class PromptManager: return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}") def save_prompts(self) -> None: + # 先清空自定义目录下的所有 Prompt 文件 + for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): + try: + prompt_file.unlink() + except Exception as e: + logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") + raise e for prompt_name in self._prompt_to_save: prompt = self.prompts[prompt_name] - file_path = PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}" + file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}" try: with open(file_path, "w", encoding="utf-8") as f: f.write(prompt.template) @@ -171,12 +196,28 @@ class PromptManager: def load_prompts(self) -> None: for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): + try: + prompt_to_load = prompt_file + need_save = False + if (CUSTOM_PROMPTS_DIR / prompt_file.name).exists(): + # 优先加载自定义目录下的 Prompt 文件 + prompt_to_load = CUSTOM_PROMPTS_DIR / prompt_file.name + need_save = True + with open(prompt_to_load, "r", encoding="utf-8") as f: + template = f.read() + self.add_prompt(Prompt(prompt_name=prompt_to_load.stem, template=template), need_save=need_save) + except Exception as e: + logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") + raise e + for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): + if (PROMPTS_DIR / prompt_file.name).exists(): + continue # 已经加载过了,跳过 try: with open(prompt_file, "r", encoding="utf-8") as f: template = f.read() self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True) except Exception as e: - logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") + logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") raise e async def _get_function_result(