mirror of https://github.com/Mai-with-u/MaiBot.git
更好的Prompt管理系统,增加用户自定义Prompt与覆盖功能
parent
0d0f5a9cdb
commit
b793a3d62b
|
|
@ -125,8 +125,28 @@ version 0.3.0 - 2026-01-11
|
||||||
- [x] 使用C模块库提升相似度计算效率
|
- [x] 使用C模块库提升相似度计算效率
|
||||||
- [ ] 移除了定时表情包完整性检查,改为启动时检查(依然保留为独立方法,以防之后恢复定时检查系统)
|
- [ ] 移除了定时表情包完整性检查,改为启动时检查(依然保留为独立方法,以防之后恢复定时检查系统)
|
||||||
|
|
||||||
|
## Prompt 管理系统
|
||||||
|
- [ ] 官方Prompt全部独立
|
||||||
|
- [x] 用户自定义Prompt系统
|
||||||
|
- [x] 用户可以创建,删除自己的Prompt
|
||||||
|
- [x] 用户可以覆盖官方Prompt
|
||||||
|
- [x] Prompt构建系统
|
||||||
|
- [x] Prompt文件交互
|
||||||
|
- [x] 读取Prompt文件
|
||||||
|
- [x] 读取官方Prompt文件
|
||||||
|
- [x] 读取用户Prompt文件
|
||||||
|
- [x] 用户Prompt覆盖官方Prompt
|
||||||
|
- [x] 保存Prompt文件
|
||||||
|
- [x] Prompt管理方法
|
||||||
|
- [x] Prompt添加
|
||||||
|
- [x] Prompt删除
|
||||||
|
- [x] **只保存被标记为需要保存的Prompt,其他的Prompt文件全部删除**
|
||||||
|
|
||||||
## 一些细枝末节的东西
|
## 一些细枝末节的东西
|
||||||
- [ ] 将`stream_id`和`chat_id`统一命名为`session_id`
|
- [ ] 将`stream_id`和`chat_id`统一命名为`session_id`
|
||||||
- [ ] 映射表
|
- [ ] 映射表
|
||||||
- [ ] `platform_group_user_session_id_map` `平台_群组_用户`-`会话ID` 映射表
|
- [ ] `platform_group_user_session_id_map` `平台_群组_用户`-`会话ID` 映射表
|
||||||
- [ ] 将大部分的数据模型均以`Mai`开头命名
|
- [ ] 将大部分的数据模型均以`Mai`开头命名
|
||||||
|
|
||||||
|
### 细节说明
|
||||||
|
1. Prompt管理系统中保存用户自定义Prompt的时候会只保存被标记为需要保存的Prompt,其他的Prompt文件会全部删除,以防止用户删除Prompt后文件依然存在的问题。因此,如果想在运行时通过修改文件的方式来添加Prompt,需要确保通过对应方法标记该Prompt为需要保存,否则在下一次保存时会被删除。
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
# File: tests/test_prompt_manager.py
|
# File: pytests/prompt_test/test_prompt_manager.py
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -12,7 +12,15 @@ PROJECT_ROOT: Path = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||||
sys.path.insert(0, str(PROJECT_ROOT))
|
sys.path.insert(0, str(PROJECT_ROOT))
|
||||||
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
sys.path.insert(0, str(PROJECT_ROOT / "src" / "config"))
|
||||||
|
|
||||||
from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prompt_manager # noqa
|
from src.prompt.prompt_manager import ( # noqa
|
||||||
|
SUFFIX_PROMPT,
|
||||||
|
Prompt,
|
||||||
|
PromptManager,
|
||||||
|
prompt_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ========= Prompt 基础行为 =========
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
|
@ -20,7 +28,11 @@ from src.prompt.prompt_manager import SUFFIX_PROMPT, Prompt, PromptManager, prom
|
||||||
[
|
[
|
||||||
pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
|
pytest.param("simple", "Hello {name}", id="simple-template-with-field"),
|
||||||
pytest.param("no-fields", "Just a static template", id="template-without-fields"),
|
pytest.param("no-fields", "Just a static template", id="template-without-fields"),
|
||||||
pytest.param("brace-escaping", "Use {{ and }} around {field}", id="template-with-escaped-braces"),
|
pytest.param(
|
||||||
|
"brace-escaping",
|
||||||
|
"Use {{ and }} around {field}",
|
||||||
|
id="template-with-escaped-braces",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_prompt_init_happy_paths(prompt_name: str, template: str):
|
def test_prompt_init_happy_paths(prompt_name: str, template: str):
|
||||||
|
|
@ -53,7 +65,12 @@ def test_prompt_init_happy_paths(prompt_name: str, template: str):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_prompt_init_error_cases(prompt_name, template, expected_exception, expected_msg_substring):
|
def test_prompt_init_error_cases(
|
||||||
|
prompt_name,
|
||||||
|
template,
|
||||||
|
expected_exception,
|
||||||
|
expected_msg_substring,
|
||||||
|
):
|
||||||
# Act / Assert
|
# Act / Assert
|
||||||
with pytest.raises(expected_exception) as exc_info:
|
with pytest.raises(expected_exception) as exc_info:
|
||||||
Prompt(prompt_name=prompt_name, template=template)
|
Prompt(prompt_name=prompt_name, template=template)
|
||||||
|
|
@ -123,6 +140,25 @@ def test_prompt_add_context(
|
||||||
assert result == expected_value
|
assert result == expected_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_clone_independent_instance():
|
||||||
|
# Arrange
|
||||||
|
prompt = Prompt(prompt_name="p", template="T {x}")
|
||||||
|
prompt.add_context("x", "X")
|
||||||
|
|
||||||
|
# Act
|
||||||
|
cloned = prompt.clone()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert cloned is not prompt
|
||||||
|
assert cloned.prompt_name == prompt.prompt_name
|
||||||
|
assert cloned.template == prompt.template
|
||||||
|
# 当前实现 clone 不复制 context
|
||||||
|
assert cloned.prompt_render_context == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ========= PromptManager:添加/获取/删除/替换 =========
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_prompt_happy_and_error():
|
def test_prompt_manager_add_prompt_happy_and_error():
|
||||||
# Arrange
|
# Arrange
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
|
|
@ -147,6 +183,59 @@ def test_prompt_manager_add_prompt_happy_and_error():
|
||||||
# Assert
|
# Assert
|
||||||
assert "Prompt name 'p1' 已存在" in str(exc_info.value)
|
assert "Prompt name 'p1' 已存在" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_remove_prompt_happy_and_error():
|
||||||
|
# Arrange
|
||||||
|
manager = PromptManager()
|
||||||
|
p1 = Prompt(prompt_name="p1", template="T")
|
||||||
|
manager.add_prompt(p1, need_save=True)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
manager.remove_prompt("p1")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "p1" not in manager.prompts
|
||||||
|
assert "p1" not in manager._prompt_to_save
|
||||||
|
|
||||||
|
# Act / Assert
|
||||||
|
with pytest.raises(KeyError) as exc_info:
|
||||||
|
manager.remove_prompt("no_such")
|
||||||
|
|
||||||
|
assert "Prompt name 'no_such' 不存在" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_replace_prompt_happy_and_error():
|
||||||
|
# sourcery skip: extract-duplicate-method
|
||||||
|
# Arrange
|
||||||
|
manager = PromptManager()
|
||||||
|
p1 = Prompt(prompt_name="p", template="Old")
|
||||||
|
manager.add_prompt(p1, need_save=True)
|
||||||
|
|
||||||
|
p_new = Prompt(prompt_name="p", template="New")
|
||||||
|
|
||||||
|
# Act: 替换且保持 need_save
|
||||||
|
manager.replace_prompt(p_new, need_save=True)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert manager.prompts["p"].template == "New"
|
||||||
|
assert "p" in manager._prompt_to_save
|
||||||
|
|
||||||
|
# Act: 再次替换,且不需要保存
|
||||||
|
p_new2 = Prompt(prompt_name="p", template="New2")
|
||||||
|
manager.replace_prompt(p_new2, need_save=False)
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert manager.prompts["p"].template == "New2"
|
||||||
|
assert "p" not in manager._prompt_to_save
|
||||||
|
|
||||||
|
# Error: 不存在的 prompt
|
||||||
|
p_unknown = Prompt(prompt_name="unknown", template="T")
|
||||||
|
with pytest.raises(KeyError) as exc_info:
|
||||||
|
manager.replace_prompt(p_unknown)
|
||||||
|
|
||||||
|
assert "Prompt name 'unknown' 不存在,无法替换" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_get_prompt_is_copy():
|
def test_prompt_manager_get_prompt_is_copy():
|
||||||
# Arrange
|
# Arrange
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
|
|
@ -162,6 +251,7 @@ def test_prompt_manager_get_prompt_is_copy():
|
||||||
assert retrieved_prompt.template == prompt.template
|
assert retrieved_prompt.template == prompt.template
|
||||||
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
|
assert retrieved_prompt.prompt_render_context == prompt.prompt_render_context
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_prompt_conflict_with_context_name():
|
def test_prompt_manager_add_prompt_conflict_with_context_name():
|
||||||
# Arrange
|
# Arrange
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
|
|
@ -230,6 +320,9 @@ def test_prompt_manager_get_prompt_not_exist():
|
||||||
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
|
assert "Prompt name 'no_such_prompt' 不存在" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 渲染逻辑 =========
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"template, inner_context, global_context, expected, case_id",
|
"template, inner_context, global_context, expected, case_id",
|
||||||
[
|
[
|
||||||
|
|
@ -264,7 +357,13 @@ def test_prompt_manager_get_prompt_not_exist():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prompt_manager_render_contexts(template, inner_context, global_context, expected, case_id):
|
async def test_prompt_manager_render_contexts(
|
||||||
|
template,
|
||||||
|
inner_context,
|
||||||
|
global_context,
|
||||||
|
expected,
|
||||||
|
case_id,
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
tmp_prompt = Prompt(prompt_name="main", template=template)
|
tmp_prompt = Prompt(prompt_name="main", template=template)
|
||||||
|
|
@ -274,7 +373,6 @@ async def test_prompt_manager_render_contexts(template, inner_context, global_co
|
||||||
prompt.add_context(name, fn)
|
prompt.add_context(name, fn)
|
||||||
for name, fn in global_context.items():
|
for name, fn in global_context.items():
|
||||||
manager.add_context_construct_function(name, fn)
|
manager.add_context_construct_function(name, fn)
|
||||||
|
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
rendered = await manager.render_prompt(prompt)
|
rendered = await manager.render_prompt(prompt)
|
||||||
|
|
@ -396,6 +494,20 @@ async def test_prompt_manager_render_with_coroutine_global_context_function():
|
||||||
assert rendered == "g-main"
|
assert rendered == "g-main"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_prompt_manager_render_only_cloned_instance():
|
||||||
|
# Arrange
|
||||||
|
manager = PromptManager()
|
||||||
|
p = Prompt(prompt_name="p", template="T")
|
||||||
|
manager.add_prompt(p)
|
||||||
|
|
||||||
|
# Act / Assert: 直接用原始 p 渲染会报错
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
await manager.render_prompt(p)
|
||||||
|
|
||||||
|
assert "只能渲染通过 PromptManager.get_prompt 方法获取的 Prompt 实例" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"is_prompt_context, use_coroutine, case_id",
|
"is_prompt_context, use_coroutine, case_id",
|
||||||
[
|
[
|
||||||
|
|
@ -406,7 +518,12 @@ async def test_prompt_manager_render_with_coroutine_global_context_function():
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_prompt_context, use_coroutine, case_id):
|
async def test_prompt_manager_get_function_result_error_logging(
|
||||||
|
monkeypatch,
|
||||||
|
is_prompt_context,
|
||||||
|
use_coroutine,
|
||||||
|
case_id,
|
||||||
|
):
|
||||||
# Arrange
|
# Arrange
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
|
|
||||||
|
|
@ -449,6 +566,9 @@ async def test_prompt_manager_get_function_result_error_logging(monkeypatch, is_
|
||||||
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
|
assert "调用上下文构造函数 'field' 时出错,所属模块: 'mod'" in log
|
||||||
|
|
||||||
|
|
||||||
|
# ========= add_context_construct_function 边界 =========
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
|
def test_prompt_manager_add_context_construct_function_unknown_frame(monkeypatch):
|
||||||
# Arrange
|
# Arrange
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
|
|
@ -496,50 +616,68 @@ def test_prompt_manager_add_context_construct_function_unknown_caller_frame(monk
|
||||||
monkeypatch.setattr("inspect.currentframe", real_currentframe)
|
monkeypatch.setattr("inspect.currentframe", real_currentframe)
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_save_and_load_prompts(tmp_path, monkeypatch):
|
# ========= save/load & 目录逻辑 =========
|
||||||
# Arrange
|
|
||||||
test_dir = tmp_path / "prompts_dir"
|
|
||||||
test_dir.mkdir()
|
|
||||||
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
|
|
||||||
|
def test_prompt_manager_save_prompts_io_error_on_unlink(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
save_prompts 现在的逻辑:
|
||||||
|
1. 先删除 CUSTOM_PROMPTS_DIR 下的所有 *.prompt 文件;
|
||||||
|
2. 再将 _prompt_to_save 中的 prompt 写入 CUSTOM_PROMPTS_DIR。
|
||||||
|
|
||||||
|
这里模拟删除已有自定义 prompt 文件时发生 IO 错误。
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
|
prompts_dir.mkdir(parents=True)
|
||||||
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
|
# 先在自定义目录写入一个 prompt 文件,触发 unlink 路径
|
||||||
|
old_file = custom_dir / f"old{SUFFIX_PROMPT}"
|
||||||
|
old_file.write_text("old", encoding="utf-8")
|
||||||
|
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
p1 = Prompt(prompt_name="save_me", template="Template {x}")
|
p1 = Prompt(prompt_name="save_error", template="T")
|
||||||
p1.add_context("x", "X")
|
|
||||||
manager.add_prompt(p1, need_save=True)
|
manager.add_prompt(p1, need_save=True)
|
||||||
|
|
||||||
# Act
|
# 打桩 Path.unlink,使删除文件时报错
|
||||||
manager.save_prompts()
|
def fake_unlink(self):
|
||||||
|
raise OSError("disk unlink error")
|
||||||
|
|
||||||
|
monkeypatch.setattr("pathlib.Path.unlink", fake_unlink)
|
||||||
|
|
||||||
|
# Act / Assert
|
||||||
|
with pytest.raises(OSError) as exc_info:
|
||||||
|
manager.save_prompts()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
saved_file = test_dir / f"save_me{SUFFIX_PROMPT}"
|
assert "disk unlink error" in str(exc_info.value)
|
||||||
assert saved_file.exists()
|
|
||||||
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_save_prompts_io_error_on_write(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
模拟 save_prompts 在写入新 prompt 文件时发生 IO 错误。
|
||||||
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
new_manager = PromptManager()
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
|
prompts_dir.mkdir(parents=True)
|
||||||
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
# Act
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
new_manager.load_prompts()
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
# Assert
|
|
||||||
loaded = new_manager.get_prompt("save_me")
|
|
||||||
assert loaded.template == "Template {x}"
|
|
||||||
assert "save_me" in new_manager._prompt_to_save
|
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
|
|
||||||
# Arrange
|
|
||||||
test_dir = tmp_path / "prompts_dir"
|
|
||||||
test_dir.mkdir()
|
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
|
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
p1 = Prompt(prompt_name="save_error", template="T")
|
p1 = Prompt(prompt_name="save_error", template="T")
|
||||||
manager.add_prompt(p1, need_save=True)
|
manager.add_prompt(p1, need_save=True)
|
||||||
|
|
||||||
class FakeFile:
|
class FakeFile:
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
raise OSError("disk error")
|
raise OSError("disk write error")
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
@ -554,15 +692,23 @@ def test_prompt_manager_save_prompts_io_error(tmp_path, monkeypatch):
|
||||||
manager.save_prompts()
|
manager.save_prompts()
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert "disk error" in str(exc_info.value)
|
assert "disk write error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
|
def test_prompt_manager_load_prompts_io_error_from_default_dir(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
模拟从 PROMPTS_DIR 读取 prompt 时发生 IO 错误。
|
||||||
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
test_dir = tmp_path / "prompts_dir"
|
prompts_dir = tmp_path / "prompts"
|
||||||
test_dir.mkdir()
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", test_dir, raising=False)
|
prompts_dir.mkdir(parents=True)
|
||||||
prompt_file = test_dir / f"bad{SUFFIX_PROMPT}"
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
|
prompt_file = prompts_dir / f"bad{SUFFIX_PROMPT}"
|
||||||
prompt_file.write_text("content", encoding="utf-8")
|
prompt_file.write_text("content", encoding="utf-8")
|
||||||
|
|
||||||
class FakeFile:
|
class FakeFile:
|
||||||
|
|
@ -572,8 +718,12 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
|
||||||
def __exit__(self, exc_type, exc, tb):
|
def __exit__(self, exc_type, exc, tb):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def fake_open(*_args, **_kwargs):
|
def fake_open(*args, **kwargs):
|
||||||
return FakeFile()
|
# 只对 default 目录下的文件触发错误,其余正常(如果有)
|
||||||
|
file_path = Path(args[0])
|
||||||
|
if file_path == prompt_file:
|
||||||
|
return FakeFile()
|
||||||
|
return open(*args, **kwargs)
|
||||||
|
|
||||||
monkeypatch.setattr("builtins.open", fake_open)
|
monkeypatch.setattr("builtins.open", fake_open)
|
||||||
manager = PromptManager()
|
manager = PromptManager()
|
||||||
|
|
@ -586,6 +736,151 @@ def test_prompt_manager_load_prompts_io_error(tmp_path, monkeypatch):
|
||||||
assert "read error" in str(exc_info.value)
|
assert "read error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_load_prompts_io_error_from_custom_dir(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
模拟从 CUSTOM_PROMPTS_DIR 读取 prompt 时发生 IO 错误。
|
||||||
|
包含两种路径:
|
||||||
|
1. default 与 custom 同名,load_prompts 会优先读取 custom;
|
||||||
|
2. 仅 custom 有文件,且 default 无同名文件。
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
|
prompts_dir.mkdir(parents=True)
|
||||||
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
|
# default 与 custom 同名的文件
|
||||||
|
same_name = f"same{SUFFIX_PROMPT}"
|
||||||
|
base_file = prompts_dir / same_name
|
||||||
|
base_file.write_text("base", encoding="utf-8")
|
||||||
|
custom_file_same = custom_dir / same_name
|
||||||
|
custom_file_same.write_text("custom", encoding="utf-8")
|
||||||
|
|
||||||
|
# 仅 custom 下存在的文件
|
||||||
|
only_custom_file = custom_dir / f"only_custom{SUFFIX_PROMPT}"
|
||||||
|
only_custom_file.write_text("only", encoding="utf-8")
|
||||||
|
|
||||||
|
class FakeFile:
|
||||||
|
def __enter__(self):
|
||||||
|
raise OSError("custom read error")
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def fake_open(*args, **kwargs):
|
||||||
|
file_path = Path(args[0])
|
||||||
|
# 对 custom 目录下的 prompt 文件统一触发错误
|
||||||
|
if file_path.parent == custom_dir:
|
||||||
|
return FakeFile()
|
||||||
|
return open(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr("builtins.open", fake_open)
|
||||||
|
manager = PromptManager()
|
||||||
|
|
||||||
|
# Act / Assert
|
||||||
|
with pytest.raises(OSError) as exc_info:
|
||||||
|
manager.load_prompts()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert "custom read error" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_load_prompts_custom_overrides_default(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
load_prompts 逻辑:
|
||||||
|
- 遍历 PROMPTS_DIR/*.prompt
|
||||||
|
- 如果 CUSTOM_PROMPTS_DIR 下存在同名文件,则优先使用自定义目录
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
|
prompts_dir.mkdir(parents=True)
|
||||||
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
|
# 默认目录 prompt
|
||||||
|
base_file = prompts_dir / f"testp{SUFFIX_PROMPT}"
|
||||||
|
base_file.write_text("BaseTemplate {x}", encoding="utf-8")
|
||||||
|
|
||||||
|
# 自定义目录同名 prompt,应当覆盖默认
|
||||||
|
custom_file = custom_dir / base_file.name
|
||||||
|
custom_file.write_text("CustomTemplate {x}", encoding="utf-8")
|
||||||
|
|
||||||
|
manager = PromptManager()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
manager.load_prompts()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
p = manager.get_prompt("testp")
|
||||||
|
assert p.template == "CustomTemplate {x}"
|
||||||
|
# 从自定义目录加载的 prompt 应标记为 need_save(加入 _prompt_to_save)
|
||||||
|
assert "testp" in manager._prompt_to_save
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_load_prompts_default_dir_not_mark_need_save(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
从 PROMPTS_DIR 加载、且没有同名自定义 prompt 时,need_save 应为 False(不进入 _prompt_to_save)。
|
||||||
|
"""
|
||||||
|
# Arrange
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
|
prompts_dir.mkdir(parents=True)
|
||||||
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
|
# 仅默认目录有 prompt,自定义目录中无同名文件
|
||||||
|
base_file = prompts_dir / f"only_default{SUFFIX_PROMPT}"
|
||||||
|
base_file.write_text("DefaultTemplate {x}", encoding="utf-8")
|
||||||
|
|
||||||
|
manager = PromptManager()
|
||||||
|
|
||||||
|
# Act
|
||||||
|
manager.load_prompts()
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
p = manager.get_prompt("only_default")
|
||||||
|
assert p.template == "DefaultTemplate {x}"
|
||||||
|
# 从默认目录加载的 prompt 不应标记为 need_save
|
||||||
|
assert "only_default" not in manager._prompt_to_save
|
||||||
|
|
||||||
|
|
||||||
|
def test_prompt_manager_save_prompts_use_custom_dir(tmp_path, monkeypatch):
|
||||||
|
"""
|
||||||
|
save_prompts 使用 CUSTOM_PROMPTS_DIR 进行保存。
|
||||||
|
"""
|
||||||
|
prompts_dir = tmp_path / "prompts"
|
||||||
|
custom_dir = tmp_path / "data" / "custom_prompts"
|
||||||
|
prompts_dir.mkdir(parents=True)
|
||||||
|
custom_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.PROMPTS_DIR", prompts_dir, raising=False)
|
||||||
|
monkeypatch.setattr("src.prompt.prompt_manager.CUSTOM_PROMPTS_DIR", custom_dir, raising=False)
|
||||||
|
|
||||||
|
manager = PromptManager()
|
||||||
|
p1 = Prompt(prompt_name="save_me", template="Template {x}")
|
||||||
|
p1.add_context("x", "X")
|
||||||
|
manager.add_prompt(p1, need_save=True)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
manager.save_prompts()
|
||||||
|
|
||||||
|
# Assert: 文件应保存在 custom_dir 中
|
||||||
|
saved_file = custom_dir / f"save_me{SUFFIX_PROMPT}"
|
||||||
|
assert saved_file.exists()
|
||||||
|
assert saved_file.read_text(encoding="utf-8") == "Template {x}"
|
||||||
|
|
||||||
|
|
||||||
|
# ========= 其它 =========
|
||||||
|
|
||||||
|
|
||||||
def test_prompt_manager_global_instance_access():
|
def test_prompt_manager_global_instance_access():
|
||||||
# Act
|
# Act
|
||||||
pm = prompt_manager
|
pm = prompt_manager
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,10 @@ _RIGHT_BRACE = "\ufdea"
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
PROJECT_ROOT = Path(__file__).parent.parent.parent.absolute().resolve()
|
||||||
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
PROMPTS_DIR = PROJECT_ROOT / "prompts"
|
||||||
|
DATA_DIR = PROJECT_ROOT / "data"
|
||||||
|
CUSTOM_PROMPTS_DIR = DATA_DIR / "custom_prompts"
|
||||||
PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
CUSTOM_PROMPTS_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
SUFFIX_PROMPT = ".prompt"
|
SUFFIX_PROMPT = ".prompt"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -54,7 +57,6 @@ class Prompt:
|
||||||
|
|
||||||
class PromptManager:
|
class PromptManager:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
PROMPTS_DIR.mkdir(parents=True, exist_ok=True) # 确保提示词目录存在
|
|
||||||
self.prompts: dict[str, Prompt] = {}
|
self.prompts: dict[str, Prompt] = {}
|
||||||
"""存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果"""
|
"""存储 Prompt 实例,禁止直接从外部访问,否则将引起不可知后果"""
|
||||||
self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {}
|
self._context_construct_functions: dict[str, tuple[Callable[[str], str | Coroutine[Any, Any, str]], str]] = {}
|
||||||
|
|
@ -72,6 +74,22 @@ class PromptManager:
|
||||||
if need_save:
|
if need_save:
|
||||||
self._prompt_to_save.add(prompt.prompt_name)
|
self._prompt_to_save.add(prompt.prompt_name)
|
||||||
|
|
||||||
|
def remove_prompt(self, prompt_name: str) -> None:
|
||||||
|
if prompt_name not in self.prompts:
|
||||||
|
raise KeyError(f"Prompt name '{prompt_name}' 不存在")
|
||||||
|
del self.prompts[prompt_name]
|
||||||
|
if prompt_name in self._prompt_to_save:
|
||||||
|
self._prompt_to_save.remove(prompt_name)
|
||||||
|
|
||||||
|
def replace_prompt(self, prompt: Prompt, need_save: bool = False) -> None:
|
||||||
|
if prompt.prompt_name not in self.prompts:
|
||||||
|
raise KeyError(f"Prompt name '{prompt.prompt_name}' 不存在,无法替换")
|
||||||
|
self.prompts[prompt.prompt_name] = prompt
|
||||||
|
if need_save:
|
||||||
|
self._prompt_to_save.add(prompt.prompt_name)
|
||||||
|
elif prompt.prompt_name in self._prompt_to_save:
|
||||||
|
self._prompt_to_save.remove(prompt.prompt_name)
|
||||||
|
|
||||||
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
def add_context_construct_function(self, name: str, func: Callable[[str], str | Coroutine[Any, Any, str]]) -> None:
|
||||||
if name in self._context_construct_functions or name in self.prompts:
|
if name in self._context_construct_functions or name in self.prompts:
|
||||||
raise KeyError(f"Construct function name '{name}' 已存在")
|
raise KeyError(f"Construct function name '{name}' 已存在")
|
||||||
|
|
@ -159,9 +177,16 @@ class PromptManager:
|
||||||
return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}")
|
return rendered_template.replace(_LEFT_BRACE, "{").replace(_RIGHT_BRACE, "}")
|
||||||
|
|
||||||
def save_prompts(self) -> None:
|
def save_prompts(self) -> None:
|
||||||
|
# 先清空自定义目录下的所有 Prompt 文件
|
||||||
|
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||||
|
try:
|
||||||
|
prompt_file.unlink()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||||
|
raise e
|
||||||
for prompt_name in self._prompt_to_save:
|
for prompt_name in self._prompt_to_save:
|
||||||
prompt = self.prompts[prompt_name]
|
prompt = self.prompts[prompt_name]
|
||||||
file_path = PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
file_path = CUSTOM_PROMPTS_DIR / f"{prompt_name}{SUFFIX_PROMPT}"
|
||||||
try:
|
try:
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
with open(file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(prompt.template)
|
f.write(prompt.template)
|
||||||
|
|
@ -171,12 +196,28 @@ class PromptManager:
|
||||||
|
|
||||||
def load_prompts(self) -> None:
|
def load_prompts(self) -> None:
|
||||||
for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
for prompt_file in PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||||
|
try:
|
||||||
|
prompt_to_load = prompt_file
|
||||||
|
need_save = False
|
||||||
|
if (CUSTOM_PROMPTS_DIR / prompt_file.name).exists():
|
||||||
|
# 优先加载自定义目录下的 Prompt 文件
|
||||||
|
prompt_to_load = CUSTOM_PROMPTS_DIR / prompt_file.name
|
||||||
|
need_save = True
|
||||||
|
with open(prompt_to_load, "r", encoding="utf-8") as f:
|
||||||
|
template = f.read()
|
||||||
|
self.add_prompt(Prompt(prompt_name=prompt_to_load.stem, template=template), need_save=need_save)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||||
|
raise e
|
||||||
|
for prompt_file in CUSTOM_PROMPTS_DIR.glob(f"*{SUFFIX_PROMPT}"):
|
||||||
|
if (PROMPTS_DIR / prompt_file.name).exists():
|
||||||
|
continue # 已经加载过了,跳过
|
||||||
try:
|
try:
|
||||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||||
template = f.read()
|
template = f.read()
|
||||||
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
self.add_prompt(Prompt(prompt_name=prompt_file.stem, template=template), need_save=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
logger.error(f"加载自定义 Prompt 文件 '{prompt_file}' 时出错,错误信息: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def _get_function_result(
|
async def _get_function_result(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue