更好的Prompt管理系统,增加用户自定义Prompt与覆盖功能

pull/1496/head
UnCLAS-Prommer 2026-02-02 20:53:42 +08:00
parent 0d0f5a9cdb
commit b793a3d62b
No known key found for this signature in database
3 changed files with 402 additions and 46 deletions

View File

@ -125,8 +125,28 @@ version 0.3.0 - 2026-01-11
- [x] 使用C模块库提升相似度计算效率 - [x] 使用C模块库提升相似度计算效率
- [ ] 移除了定时表情包完整性检查,改为启动时检查(依然保留为独立方法,以防之后恢复定时检查系统) - [ ] 移除了定时表情包完整性检查,改为启动时检查(依然保留为独立方法,以防之后恢复定时检查系统)
## Prompt 管理系统
- [ ] 官方Prompt全部独立
- [x] 用户自定义Prompt系统
- [x] 用户可以创建删除自己的Prompt
- [x] 用户可以覆盖官方Prompt
- [x] Prompt构建系统
- [x] Prompt文件交互
- [x] 读取Prompt文件
- [x] 读取官方Prompt文件
- [x] 读取用户Prompt文件
- [x] 用户Prompt覆盖官方Prompt
- [x] 保存Prompt文件
- [x] Prompt管理方法
- [x] Prompt添加
- [x] Prompt删除
- [x] **只保存被标记为需要保存的Prompt其他的Prompt文件全部删除**
## 一些细枝末节的东西 ## 一些细枝末节的东西
- [ ] 将`stream_id`和`chat_id`统一命名为`session_id` - [ ] 将`stream_id`和`chat_id`统一命名为`session_id`
- [ ] 映射表 - [ ] 映射表
- [ ] `platform_group_user_session_id_map` `平台_群组_用户`-`会话ID` 映射表 - [ ] `platform_group_user_session_id_map` `平台_群组_用户`-`会话ID` 映射表
- [ ] 将大部分的数据模型均以`Mai`开头命名 - [ ] 将大部分的数据模型均以`Mai`开头命名
### 细节说明
1. Prompt管理系统中保存用户自定义Prompt的时候会只保存被标记为需要保存的Prompt其他的Prompt文件会全部删除以防止用户删除Prompt后文件依然存在的问题。因此如果想在运行时通过修改文件的方式来添加Prompt需要确保通过对应方法标记该Prompt为需要保存否则在下一次保存时会被删除。

View File

@ -1,4 +1,4 @@
# File: tests/test_prompt_manager.py # File: pytests/prompt_test/test_prompt_manager.py
import asyncio import asyncio
import inspect import inspect
@ -12,7 +12,15 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config")) sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prompt_manager # noqa from src.prompt.prompt_manager import ( # noqa
SUFFIX_PROMPT,
Prompt,
PromptManager,
prompt_manager,
)
# ========= Prompt 基础行为 =========
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -20,7 +28,11 @@ from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prom
[ [
pytest.param("simple", "Hello {name}", id="simple-template-with-field"), pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
pytest.param("no-fields", "Just a static template", id="template-without-fields"), pytest.param("no-fields", "Just a static template", id="template-without-fields"),
pytest.param("brace-escaping", "Use {{ and }} around {field}", id="template-with-escaped-braces"), pytest.param(
"brace-escaping",
"Use {{ and }} around {field}",
id="template-with-escaped-braces",
),
], ],
) )
def test_prompt_init_happy_paths(prompt_name: str, template: str): def test_prompt_init_happy_paths(prompt_name: str, template: str):
@ -53,7 +65,12 @@ def test_prompt_init_happy_paths(prompt_name: str, template: str):
), ),
], ],
) )
def test_prompt_init_error_cases(prompt_name, template, expected_exception, expected_msg_substring): def test_prompt_init_error_cases(
prompt_name,
template,
expected_exception,
expected_msg_substring,
):
# Act / Assert # Act / Assert
with pytest.raises(expected_exception) as exc_info: with pytest.raises(expected_exception) as exc_info:
Prompt(prompt_name=prompt_name, template=template) Prompt(prompt_name=prompt_name, template=template)
@ -123,6 +140,25 @@ def test_prompt_add_context(
assert result == expected_value assert result == expected_value
def test_prompt_clone_independent_instance():
# Arrange
prompt = Prompt(prompt_name="p", template="T {x}")
prompt.add_context("x", "X")
# Act
cloned = prompt.clone()
# Assert
assert cloned is not prompt
assert cloned.prompt_name == prompt.prompt_name
assert cloned.template == prompt.template
# 当前实现 clone 不复制 context
assert cloned.prompt_render_context == {}
# ========= PromptManager添加/获取/删除/替换 =========
def test_prompt_manager_add_prompt_happy_and_error(): def test_prompt_manager_add_prompt_happy_and_error():
# Arrange # Arrange
manager = PromptManager() manager = PromptManager()
@ -147,6 +183,59 @@ def test_prompt_manager_add_prompt_happy_and_error():
# Assert # Assert
assert "Prompt name 'p1' 已存在" in str(exc_info.value) assert "Prompt name 'p1' 已存在" in str(exc_info.value)
def test_prompt_manager_remove_prompt_happy_and_error():
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p1", template="T")
manager.add_prompt(p1, need_save=True)
# Act
manager.remove_prompt("p1")
# Assert
assert "p1" not in manager.prompts
assert "p1" not in manager._prompt_to_save
# Act / Assert
with pytest.raises(KeyError) as exc_info:
manager.remove_prompt("no_such")
assert "Prompt name 'no_such' 不存在" in str(exc_info.value)
def test_prompt_manager_replace_prompt_happy_and_error():
# sourcery skip: extract-duplicate-method
# Arrange
manager = PromptManager()
p1 = Prompt(prompt_name="p", template="Old")
manager.add_prompt(p1, need_save=True)
p_new = Prompt(prompt_name="p", template="New")
# Act: 替换且保持 need_save
manager.replace_prompt(p_new, need_save=True)
# Assert
assert manager.prompts["p"].template == "New"
assert "p" in manager._prompt_to_save
# Act: 再次替换,且不需要保存
p_new2 = Prompt(prompt_name="p", template="New2")
manager.replace_prompt(p_new2, need_save=False)
# Assert
assert manager.prompts["p"].template == "New2"
assert "p" not in manager._prompt_to_save
# Error: 不存在的 prompt
p_unknown = Prompt(prompt_name="unknown", template="T")
with pytest.raises(KeyError) as exc_info:
manager.replace_prompt(p_unknown)
assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value)
def test_prompt_manager_get_prompt_is_copy(): def test_prompt_manager_get_prompt_is_copy():
# Arrange # Arrange
manager = PromptManager() manager = PromptManager()
@ -162,6 +251,7 @@ def test_prompt_manager_get_prompt_is_copy():
assert retrieved_prompt.template == prompt.template assert retrieved_prompt.template == prompt.template
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
def test_prompt_manager_add_prompt_conflict_with_context_name(): def test_prompt_manager_add_prompt_conflict_with_context_name():
# Arrange # Arrange
manager = PromptManager() manager = PromptManager()
@ -230,6 +320,9 @@ def test_prompt_manager_get_prompt_not_exist():
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value) assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
# ========= 渲染逻辑 =========
@pytest.mark.parametrize( @pytest.mark.parametrize(
"template, inner_context, global_context, expected, case_id", "template, inner_context, global_context, expected, case_id",
[ [
@ -264,7 +357,13 @@ def test_prompt_manager_get_prompt_not_exist():
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prompt_manager_render_contexts(template, inner_context, global_context, expected, case_id): async def test_prompt_manager_render_contexts(
template,
inner_context,
global_context,
expected,
case_id,
):
# Arrange # Arrange
manager = PromptManager() manager = PromptManager()
tmp_prompt = Prompt(prompt_name="main", template=template) tmp_prompt = Prompt(prompt_name="main", template=template)
@ -274,7 +373,6 @@ async def test_prompt_manager_render_contexts(template, inner_context, global_co
prompt.add_context(name, fn) prompt.add_context(name, fn)
for name, fn in global_context.items(): for name, fn in global_context.items():
manager.add_context_construct_function(name, fn) manager.add_context_construct_function(name, fn)
# Act # Act
rendered = await manager.render_prompt(prompt) rendered = await manager.render_prompt(prompt)
@ -396,6 +494,20 @@ async def test_prompt_manager_render_with_coroutine_global_context_function():
assert rendered == "g-main" assert rendered == "g-main"
@pytest.mark.asyncio
async def test_prompt_manager_render_only_cloned_instance():
# Arrange
manager = PromptManager()
p = Prompt(prompt_name="p", template="T")
manager.add_prompt(p)
# Act / Assert: 直接用原始 p 渲染会报错
with pytest.raises(ValueError) as exc_info:
await manager.render_prompt(p)
assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"is_prompt_context, use_coroutine, case_id", "is_prompt_context, use_coroutine, case_id",
[ [
@ -406,7 +518,12 @@ async def test_prompt_manager_render_with_coroutine_global_context_function():
], ],
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_prompt_context, use_coroutine, case_id): async def test_prompt_manager_get_function_result_error_logging(
monkeypatch,
is_prompt_context,
use_coroutine,
case_id,
):
# Arrange # Arrange
manager = PromptManager() manager = PromptManager()
@ -449,6 +566,9 @@ async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
# ========= add_context_construct_function 边界 =========
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch): def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
# Arrange # Arrange
manager = PromptManager() manager = PromptManager()
@ -496,50 +616,68 @@ def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monk
monkeypatch.setattr("inspect.currentframe", real_currentframe) monkeypatch.setattr("inspect.currentframe", real_currentframe)
def test_prompt_manager_save_and_load_prompts(tmp_path, monkeypatch): # ========= save/load & 目录逻辑 =========
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch):
"""
save_prompts 现在的逻辑
1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件
2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR
这里模拟删除已有自定义 prompt 文件时发生 IO 错误
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 先在自定义目录写入一个 prompt 文件,触发 unlink 路径
old_file = custom_dir / f"old{SUFFIX_PROMPT}"
old_file.write_text("old", encoding="utf-8")
manager = PromptManager() manager = PromptManager()
p1 = Prompt(prompt_name="save_me", template="Template {x}") p1 = Prompt(prompt_name="save_error", template="T")
p1.add_context("x", "X")
manager.add_prompt(p1, need_save=True) manager.add_prompt(p1, need_save=True)
# Act # 打桩 Path.unlink使删除文件时报错
manager.save_prompts() def fake_unlink(self):
raise OSError("disk unlink error")
monkeypatch.setattr("pathlib.Path.unlink", fake_unlink)
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.save_prompts()
# Assert # Assert
saved_file = test_dir / f"save_me{SUFFIX_PROMPT}" assert "disk unlink error" in str(exc_info.value)
assert saved_file.exists()
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch):
"""
模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误
"""
# Arrange # Arrange
new_manager = PromptManager() prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
# Act monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
new_manager.load_prompts() monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# Assert
loaded = new_manager.get_prompt("save_me")
assert loaded.template == "Template {x}"
assert "save_me" in new_manager._prompt_to_save
def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
# Arrange
test_dir = tmp_path / "prompts_dir"
test_dir.mkdir()
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
manager = PromptManager() manager = PromptManager()
p1 = Prompt(prompt_name="save_error", template="T") p1 = Prompt(prompt_name="save_error", template="T")
manager.add_prompt(p1, need_save=True) manager.add_prompt(p1, need_save=True)
class FakeFile: class FakeFile:
def __enter__(self): def __enter__(self):
raise OSError("disk error") raise OSError("disk write error")
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
@ -554,15 +692,23 @@ def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
manager.save_prompts() manager.save_prompts()
# Assert # Assert
assert "disk error" in str(exc_info.value) assert "disk write error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch): def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch):
"""
模拟从 PROMPTS_DIR 读取 prompt 时发生 IO 错误
"""
# Arrange # Arrange
test_dir = tmp_path / "prompts_dir" prompts_dir = tmp_path / "prompts"
test_dir.mkdir() custom_dir = tmp_path / "data" / "custom_prompts"
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False) prompts_dir.mkdir(parents=True)
prompt_file = test_dir / f"bad{SUFFIX_PROMPT}" custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
prompt_file = prompts_dir / f"bad{SUFFIX_PROMPT}"
prompt_file.write_text("content", encoding="utf-8") prompt_file.write_text("content", encoding="utf-8")
class FakeFile: class FakeFile:
@ -572,8 +718,12 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
def __exit__(self, exc_type, exc, tb): def __exit__(self, exc_type, exc, tb):
return False return False
def fake_open(*_args, **_kwargs): def fake_open(*args, **kwargs):
return FakeFile() # 只对 default 目录下的文件触发错误,其余正常(如果有)
file_path = Path(args[0])
if file_path == prompt_file:
return FakeFile()
return open(*args, **kwargs)
monkeypatch.setattr("builtins.open", fake_open) monkeypatch.setattr("builtins.open", fake_open)
manager = PromptManager() manager = PromptManager()
@ -586,6 +736,151 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
assert "read error" in str(exc_info.value) assert "read error" in str(exc_info.value)
def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch):
"""
模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误
包含两种路径
1. default custom 同名load_prompts 会优先读取 custom
2. custom 有文件 default 无同名文件
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# default 与 custom 同名的文件
same_name = f"same{SUFFIX_PROMPT}"
base_file = prompts_dir / same_name
base_file.write_text("base", encoding="utf-8")
custom_file_same = custom_dir / same_name
custom_file_same.write_text("custom", encoding="utf-8")
# 仅 custom 下存在的文件
only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}"
only_custom_file.write_text("only", encoding="utf-8")
class FakeFile:
def __enter__(self):
raise OSError("custom read error")
def __exit__(self, exc_type, exc, tb):
return False
def fake_open(*args, **kwargs):
file_path = Path(args[0])
# 对 custom 目录下的 prompt 文件统一触发错误
if file_path.parent == custom_dir:
return FakeFile()
return open(*args, **kwargs)
monkeypatch.setattr("builtins.open", fake_open)
manager = PromptManager()
# Act / Assert
with pytest.raises(OSError) as exc_info:
manager.load_prompts()
# Assert
assert "custom read error" in str(exc_info.value)
def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch):
"""
load_prompts 逻辑
- 遍历 PROMPTS_DIR/*.prompt
- 如果 CUSTOM_PROMPTS_DIR 下存在同名文件则优先使用自定义目录
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 默认目录 prompt
base_file = prompts_dir / f"testp{SUFFIX_PROMPT}"
base_file.write_text("BaseTemplate {x}", encoding="utf-8")
# 自定义目录同名 prompt应当覆盖默认
custom_file = custom_dir / base_file.name
custom_file.write_text("CustomTemplate {x}", encoding="utf-8")
manager = PromptManager()
# Act
manager.load_prompts()
# Assert
p = manager.get_prompt("testp")
assert p.template == "CustomTemplate {x}"
# 从自定义目录加载的 prompt 应标记为 need_save加入 _prompt_to_save
assert "testp" in manager._prompt_to_save
def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch):
"""
PROMPTS_DIR 加载且没有同名自定义 prompt need_save 应为 False不进入 _prompt_to_save
"""
# Arrange
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
# 仅默认目录有 prompt自定义目录中无同名文件
base_file = prompts_dir / f"only_default{SUFFIX_PROMPT}"
base_file.write_text("DefaultTemplate {x}", encoding="utf-8")
manager = PromptManager()
# Act
manager.load_prompts()
# Assert
p = manager.get_prompt("only_default")
assert p.template == "DefaultTemplate {x}"
# 从默认目录加载的 prompt 不应标记为 need_save
assert "only_default" not in manager._prompt_to_save
def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch):
"""
save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存
"""
prompts_dir = tmp_path / "prompts"
custom_dir = tmp_path / "data" / "custom_prompts"
prompts_dir.mkdir(parents=True)
custom_dir.mkdir(parents=True)
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
manager = PromptManager()
p1 = Prompt(prompt_name="save_me", template="Template {x}")
p1.add_context("x", "X")
manager.add_prompt(p1, need_save=True)
# Act
manager.save_prompts()
# Assert: 文件应保存在 custom_dir 中
saved_file = custom_dir / f"save_me{SUFFIX_PROMPT}"
assert saved_file.exists()
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
# ========= 其它 =========
def test_prompt_manager_global_instance_access(): def test_prompt_manager_global_instance_access():
# Act # Act
pm = prompt_manager pm = prompt_manager

View File

@ -14,7 +14,10 @@ _RIGHT_BRACE = "\ufdea"
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve() PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
PROMPTS_DIR = PROJECT_ROOT / "prompts" PROMPTS_DIR = PROJECT_ROOT / "prompts"
DATA_DIR = PROJECT_ROOT / "data"
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
PROMPTS_DIR.mkdir(parents=True, exist_ok=True) PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
CUSTOM_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
SUFFIX_PROMPT = ".prompt" SUFFIX_PROMPT = ".prompt"
@ -54,7 +57,6 @@ class Prompt:
class PromptManager: class PromptManager:
def __init__(self) -> None: def __init__(self) -> None:
PROMPTS_DIR.mkdir(parents=True, exist_ok=True) # 确保提示词目录存在
self.prompts: dict[str, Prompt] = {} self.prompts: dict[str, Prompt] = {}
"""存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果""" """存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果"""
self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {} self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {}
@ -72,6 +74,22 @@ class PromptManager:
if need_save: if need_save:
self._prompt_to_save.add(prompt.prompt_name) self._prompt_to_save.add(prompt.prompt_name)
def remove_prompt(self, prompt_name: str) -> None:
if prompt_name not in self.prompts:
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
del self.prompts[prompt_name]
if prompt_name in self._prompt_to_save:
self._prompt_to_save.remove(prompt_name)
def replace_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
if prompt.prompt_name not in self.prompts:
raise KeyError(f"Prompt name '{prompt.prompt_name}' 不存在,无法替换")
self.prompts[prompt.prompt_name] = prompt
if need_save:
self._prompt_to_save.add(prompt.prompt_name)
elif prompt.prompt_name in self._prompt_to_save:
self._prompt_to_save.remove(prompt.prompt_name)
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None: def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
if name in self._context_construct_functions or name in self.prompts: if name in self._context_construct_functions or name in self.prompts:
raise KeyError(f"Construct function name '{name}' 已存在") raise KeyError(f"Construct function name '{name}' 已存在")
@ -159,9 +177,16 @@ class PromptManager:
return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}") return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}")
def save_prompts(self) -> None: def save_prompts(self) -> None:
# 先清空自定义目录下的所有 Prompt 文件
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
try:
prompt_file.unlink()
except Exception as e:
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
raise e
for prompt_name in self._prompt_to_save: for prompt_name in self._prompt_to_save:
prompt = self.prompts[prompt_name] prompt = self.prompts[prompt_name]
file_path = PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}" file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
try: try:
with open(file_path, "w", encoding="utf-8") as f: with open(file_path, "w", encoding="utf-8") as f:
f.write(prompt.template) f.write(prompt.template)
@ -171,12 +196,28 @@ class PromptManager:
def load_prompts(self) -> None: def load_prompts(self) -> None:
for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"): for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
try:
prompt_to_load = prompt_file
need_save = False
if (CUSTOM_PROMPTS_DIR / prompt_file.name).exists():
# 优先加载自定义目录下的 Prompt 文件
prompt_to_load = CUSTOM_PROMPTS_DIR / prompt_file.name
need_save = True
with open(prompt_to_load, "r", encoding="utf-8") as f:
template = f.read()
self.add_prompt(Prompt(prompt_name=prompt_to_load.stem, template=template), need_save=need_save)
except Exception as e:
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
raise e
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
if (PROMPTS_DIR / prompt_file.name).exists():
continue # 已经加载过了,跳过
try: try:
with open(prompt_file, "r", encoding="utf-8") as f: with open(prompt_file, "r", encoding="utf-8") as f:
template = f.read() template = f.read()
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True) self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
except Exception as e: except Exception as e:
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}") logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
raise e raise e
async def _get_function_result( async def _get_function_result(