diff --git a/.gitignore b/.gitignore
index 4db85eab..f51b8d6f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,6 +41,7 @@ config/bot_config.toml.bak
config/lpmm_config.toml
config/lpmm_config.toml.bak
template/compare/bot_config_template.toml
+template/compare/model_config_template.toml
(测试版)麦麦生成人格.bat
(临时版)麦麦开始学习.bat
src/plugins/utils/statistic.py
@@ -321,4 +322,5 @@ run_pet.bat
config.toml
-interested_rates.txt
\ No newline at end of file
+interested_rates.txt
+MaiBot.code-workspace
diff --git a/bot.py b/bot.py
index 72ea65d2..b8f154cd 100644
--- a/bot.py
+++ b/bot.py
@@ -74,36 +74,6 @@ def easter_egg():
print(rainbow_text)
-def scan_provider(env_config: dict):
- provider = {}
-
- # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
- # 避免 GPG_KEY 这样的变量干扰检查
- env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
-
- # 遍历 env_config 的所有键
- for key in env_config:
- # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
- if key.endswith("_BASE_URL") or key.endswith("_KEY"):
- # 提取 provider 名称
- provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
-
- # 初始化 provider 的字典(如果尚未初始化)
- if provider_name not in provider:
- provider[provider_name] = {"url": None, "key": None}
-
- # 根据键的类型填充 url 或 key
- if key.endswith("_BASE_URL"):
- provider[provider_name]["url"] = env_config[key]
- elif key.endswith("_KEY"):
- provider[provider_name]["key"] = env_config[key]
-
- # 检查每个 provider 是否同时存在 url 和 key
- for provider_name, config in provider.items():
- if config["url"] is None or config["key"] is None:
- logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
- raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
-
async def graceful_shutdown():
try:
@@ -229,9 +199,6 @@ def raw_main():
easter_egg()
- env_config = {key: os.getenv(key) for key in os.environ}
- scan_provider(env_config)
-
# 返回MainSystem实例
return MainSystem()
diff --git a/changelogs/changelog.md b/changelogs/changelog.md
index c835e684..35631077 100644
--- a/changelogs/changelog.md
+++ b/changelogs/changelog.md
@@ -1,5 +1,17 @@
# Changelog
+## [0.10.0] - 2025-7-1
+### 主要功能更改
+- 工具系统重构,现在合并到了插件系统中
+- 彻底重构了整个LLM Request了,现在支持模型轮询和更多灵活的参数
+ - 同时重构了整个模型配置系统,升级需要重新配置llm配置文件
+- 随着LLM Request的重构,插件系统彻底重构完成。插件系统进入稳定状态,仅增加新的API
+ - 具体相比于之前的更改可以查看[changes.md](./changes.md)
+
+### 细节优化
+- 修复了lint爆炸的问题,代码更加规范了
+- 修改了log的颜色,更加护眼
+
## [0.9.1] - 2025-7-26
### 主要修复和优化
diff --git a/changes.md b/changelogs/changes.md
similarity index 89%
rename from changes.md
rename to changelogs/changes.md
index b776991d..db41703c 100644
--- a/changes.md
+++ b/changelogs/changes.md
@@ -25,6 +25,7 @@
- 这意味着你终于可以动态控制是否继续后续消息的处理了。
8. 移除了dependency_manager,但是依然保留了`python_dependencies`属性,等待后续重构。
- 一并移除了文档有关manager的内容。
+9. 增加了工具的有关api
# 插件系统修改
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
@@ -57,30 +58,12 @@
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
- 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作
- 同样不保存到配置文件~
+16. 把`BaseTool`一并合并进入了插件系统
# 官方插件修改
1. `HelloWorld`插件现在有一个样例的`EventHandler`。
-2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。
-
-### TODO
-把这个看起来就很别扭的config获取方式改一下
-
-
-# 吐槽
-```python
-plugin_path = Path(plugin_file)
-if plugin_path.parent.name != "plugins":
- # 插件包格式:parent_dir.plugin
- module_name = f"plugins.{plugin_path.parent.name}.plugin"
-else:
- # 单文件格式:plugins.filename
- module_name = f"plugins.{plugin_path.stem}"
-```
-```python
-plugin_path = Path(plugin_file)
-module_name = ".".join(plugin_path.parent.parts)
-```
-这两个区别很大的。
+2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。(需要自行启用)
+3. `HelloWorld`插件现在有一个样例的`CompareNumbersTool`。
### 执笔BGM
塞壬唱片!
\ No newline at end of file
diff --git a/docs/image-1.png b/docs/image-1.png
new file mode 100644
index 00000000..c7a0adc8
Binary files /dev/null and b/docs/image-1.png differ
diff --git a/docs/image.png b/docs/image.png
new file mode 100644
index 00000000..63416251
Binary files /dev/null and b/docs/image.png differ
diff --git a/docs/model_configuration_guide.md b/docs/model_configuration_guide.md
new file mode 100644
index 00000000..6bbe05af
--- /dev/null
+++ b/docs/model_configuration_guide.md
@@ -0,0 +1,331 @@
+# 模型配置指南
+
+本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
+
+## 配置文件结构
+
+配置文件主要包含以下几个部分:
+- 版本信息
+- API服务提供商配置
+- 模型配置
+- 模型任务配置
+
+## 1. 版本信息
+
+```toml
+[inner]
+version = "1.1.1"
+```
+
+用于标识配置文件的版本,遵循语义化版本规则。
+
+## 2. API服务提供商配置
+
+### 2.1 基本配置
+
+使用 `[[api_providers]]` 数组配置多个API服务提供商:
+
+```toml
+[[api_providers]]
+name = "DeepSeek" # 服务商名称(自定义)
+base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
+api_key = "your-api-key-here" # API密钥
+client_type = "openai" # 客户端类型
+max_retry = 2 # 最大重试次数
+timeout = 30 # 超时时间(秒)
+retry_interval = 10 # 重试间隔(秒)
+```
+
+### 2.2 配置参数说明
+
+| 参数 | 必填 | 说明 | 默认值 |
+|------|------|------|--------|
+| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
+| `base_url` | ✅ | API服务的基础URL | - |
+| `api_key` | ✅ | API密钥,请替换为实际密钥 | - |
+| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` |
+| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
+| `timeout` | ❌ | API请求超时时间(秒) | 30 |
+| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
+
+### 2.3 支持的服务商示例
+
+#### DeepSeek
+```toml
+[[api_providers]]
+name = "DeepSeek"
+base_url = "https://api.deepseek.cn/v1"
+api_key = "your-deepseek-api-key"
+client_type = "openai"
+```
+
+#### SiliconFlow
+```toml
+[[api_providers]]
+name = "SiliconFlow"
+base_url = "https://api.siliconflow.cn/v1"
+api_key = "your-siliconflow-api-key"
+client_type = "openai"
+```
+
+#### Google Gemini
+```toml
+[[api_providers]]
+name = "Google"
+base_url = "https://api.google.com/v1"
+api_key = "your-google-api-key"
+client_type = "gemini" # 注意:Gemini需要使用特殊客户端
+```
+
+## 3. 模型配置
+
+### 3.1 基本模型配置
+
+使用 `[[models]]` 数组配置多个模型:
+
+```toml
+[[models]]
+model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
+name = "deepseek-v3" # 自定义模型名称
+api_provider = "DeepSeek" # 引用的API服务商名称
+price_in = 2.0 # 输入价格(元/M token)
+price_out = 8.0 # 输出价格(元/M token)
+```
+
+### 3.2 高级模型配置
+
+#### 强制流式输出
+对于不支持非流式输出的模型:
+```toml
+[[models]]
+model_identifier = "some-model"
+name = "custom-name"
+api_provider = "Provider"
+force_stream_mode = true # 启用强制流式输出
+```
+
+#### 额外参数配置`extra_params`
+```toml
+[[models]]
+model_identifier = "Qwen/Qwen3-8B"
+name = "qwen3-8b"
+api_provider = "SiliconFlow"
+[models.extra_params]
+enable_thinking = false # 禁用思考
+```
+这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置,**配置时应参考相应的API文档**。
+
+比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
+
+
+
+以豆包文档为另一个例子
+
+
+
+得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
+```toml
+[[models]]
+# 你的模型
+[models.extra_params]
+thinking = {type = "disabled"} # 禁用思考
+```
+请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。
+
+### 3.3 配置参数说明
+
+| 参数 | 必填 | 说明 |
+|------|------|------|
+| `model_identifier` | ✅ | API服务商提供的模型标识符 |
+| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
+| `api_provider` | ✅ | 对应的API服务商名称 |
+| `price_in` | ❌ | 输入价格(元/M token),用于成本统计 |
+| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 |
+| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
+| `extra_params` | ❌ | 额外的模型参数配置 |
+
+## 4. 模型任务配置
+
+### utils - 工具模型
+用于表情包模块、取名模块、关系模块等核心功能:
+```toml
+[model_task_config.utils]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2
+max_tokens = 800
+```
+
+### utils_small - 小型工具模型
+用于高频率调用的场景,建议使用速度快的小模型:
+```toml
+[model_task_config.utils_small]
+model_list = ["qwen3-8b"]
+temperature = 0.7
+max_tokens = 800
+```
+
+### replyer_1 - 主要回复模型
+首要回复模型,也用于表达器和表达方式学习:
+```toml
+[model_task_config.replyer_1]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2
+max_tokens = 800
+```
+
+### replyer_2 - 次要回复模型
+```toml
+[model_task_config.replyer_2]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.7
+max_tokens = 800
+```
+
+### planner - 决策模型
+负责决定MaiBot该做什么:
+```toml
+[model_task_config.planner]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.3
+max_tokens = 800
+```
+
+### emotion - 情绪模型
+负责MaiBot的情绪变化:
+```toml
+[model_task_config.emotion]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.3
+max_tokens = 800
+```
+
+### memory - 记忆模型
+```toml
+[model_task_config.memory]
+model_list = ["qwen3-30b"]
+temperature = 0.7
+max_tokens = 800
+```
+
+### vlm - 视觉语言模型
+用于图像识别:
+```toml
+[model_task_config.vlm]
+model_list = ["qwen2.5-vl-72b"]
+max_tokens = 800
+```
+
+### voice - 语音识别模型
+```toml
+[model_task_config.voice]
+model_list = ["sensevoice-small"]
+```
+
+### embedding - 嵌入模型
+```toml
+[model_task_config.embedding]
+model_list = ["bge-m3"]
+```
+
+### tool_use - 工具调用模型
+需要使用支持工具调用的模型:
+```toml
+[model_task_config.tool_use]
+model_list = ["qwen3-14b"]
+temperature = 0.7
+max_tokens = 800
+```
+
+### lpmm_entity_extract - 实体提取模型
+```toml
+[model_task_config.lpmm_entity_extract]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2
+max_tokens = 800
+```
+
+### lpmm_rdf_build - RDF构建模型
+```toml
+[model_task_config.lpmm_rdf_build]
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2
+max_tokens = 800
+```
+
+### lpmm_qa - 问答模型
+```toml
+[model_task_config.lpmm_qa]
+model_list = ["deepseek-r1-distill-qwen-32b"]
+temperature = 0.7
+max_tokens = 800
+```
+
+## 5. 配置建议
+
+### 5.1 Temperature 参数选择
+
+| 任务类型 | 推荐温度 | 说明 |
+|----------|----------|------|
+| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
+| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
+| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
+
+### 5.2 模型选择建议
+
+| 任务类型 | 推荐模型类型 | 示例 |
+|----------|--------------|------|
+| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
+| 高频率任务 | 小模型 | Qwen3-8B |
+| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
+| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
+
+### 5.3 成本优化
+
+1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
+2. **合理配置max_tokens**:根据实际需求设置,避免浪费
+3. **选择免费模型**:对于测试环境,优先使用price为0的模型
+
+## 6. 配置验证
+
+### 6.1 必要检查项
+
+1. ✅ API密钥是否正确配置
+2. ✅ 模型标识符是否与API服务商提供的一致
+3. ✅ 任务配置中引用的模型名称是否在models中定义
+4. ✅ 多模态任务是否配置了对应的专用模型
+
+### 6.2 测试配置
+
+建议在正式使用前:
+1. 使用少量测试数据验证配置
+2. 检查API调用是否正常
+3. 确认成本统计功能正常工作
+
+## 7. 故障排除
+
+### 7.1 常见问题
+
+**问题1**: API调用失败
+- 检查API密钥是否正确
+- 确认base_url是否可访问
+- 检查模型标识符是否正确
+
+**问题2**: 模型未找到
+- 确认模型名称在任务配置和模型定义中一致
+- 检查api_provider名称是否匹配
+
+**问题3**: 响应异常
+- 检查温度参数是否合理(0-1之间)
+- 确认max_tokens设置是否合适
+- 验证模型是否支持所需功能
+
+### 7.2 日志查看
+
+查看 `logs/` 目录下的日志文件,寻找相关错误信息。
+
+## 8. 更新和维护
+
+1. **定期更新**: 关注API服务商的模型更新,及时调整配置
+2. **性能监控**: 监控模型调用的成本和性能
+3. **备份配置**: 在修改前备份当前配置文件
+
diff --git a/docs/plugins/api/component-manage-api.md b/docs/plugins/api/component-manage-api.md
new file mode 100644
index 00000000..a857fb27
--- /dev/null
+++ b/docs/plugins/api/component-manage-api.md
@@ -0,0 +1,194 @@
+# 组件管理API
+
+组件管理API模块提供了对插件组件的查询和管理功能,使得插件能够获取和使用组件相关的信息。
+
+## 导入方式
+```python
+from src.plugin_system.apis import component_manage_api
+# 或者
+from src.plugin_system import component_manage_api
+```
+
+## 功能概述
+
+组件管理API主要提供以下功能:
+- **插件信息查询** - 获取所有插件或指定插件的信息。
+- **组件查询** - 按名称或类型查询组件信息。
+- **组件管理** - 启用或禁用组件,支持全局和局部操作。
+
+## 主要功能
+
+### 1. 获取所有插件信息
+```python
+def get_all_plugin_info() -> Dict[str, PluginInfo]:
+```
+获取所有插件的信息。
+
+**Returns:**
+- `Dict[str, PluginInfo]` - 包含所有插件信息的字典,键为插件名称,值为 `PluginInfo` 对象。
+
+### 2. 获取指定插件信息
+```python
+def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
+```
+获取指定插件的信息。
+
+**Args:**
+- `plugin_name` (str): 插件名称。
+
+**Returns:**
+- `Optional[PluginInfo]`: 插件信息对象,如果插件不存在则返回 `None`。
+
+### 3. 获取指定组件信息
+```python
+def get_component_info(component_name: str, component_type: ComponentType) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
+```
+获取指定组件的信息。
+
+**Args:**
+- `component_name` (str): 组件名称。
+- `component_type` (ComponentType): 组件类型。
+
+**Returns:**
+- `Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 组件信息对象,如果组件不存在则返回 `None`。
+
+### 4. 获取指定类型的所有组件信息
+```python
+def get_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
+```
+获取指定类型的所有组件信息。
+
+**Args:**
+- `component_type` (ComponentType): 组件类型。
+
+**Returns:**
+- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
+
+### 5. 获取指定类型的所有启用的组件信息
+```python
+def get_enabled_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
+```
+获取指定类型的所有启用的组件信息。
+
+**Args:**
+- `component_type` (ComponentType): 组件类型。
+
+**Returns:**
+- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
+
+### 6. 获取指定 Action 的注册信息
+```python
+def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
+```
+获取指定 Action 的注册信息。
+
+**Args:**
+- `action_name` (str): Action 名称。
+
+**Returns:**
+- `Optional[ActionInfo]` - Action 信息对象,如果 Action 不存在则返回 `None`。
+
+### 7. 获取指定 Command 的注册信息
+```python
+def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
+```
+获取指定 Command 的注册信息。
+
+**Args:**
+- `command_name` (str): Command 名称。
+
+**Returns:**
+- `Optional[CommandInfo]` - Command 信息对象,如果 Command 不存在则返回 `None`。
+
+### 8. 获取指定 Tool 的注册信息
+```python
+def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
+```
+获取指定 Tool 的注册信息。
+
+**Args:**
+- `tool_name` (str): Tool 名称。
+
+**Returns:**
+- `Optional[ToolInfo]` - Tool 信息对象,如果 Tool 不存在则返回 `None`。
+
+### 9. 获取指定 EventHandler 的注册信息
+```python
+def get_registered_event_handler_info(event_handler_name: str) -> Optional[EventHandlerInfo]:
+```
+获取指定 EventHandler 的注册信息。
+
+**Args:**
+- `event_handler_name` (str): EventHandler 名称。
+
+**Returns:**
+- `Optional[EventHandlerInfo]` - EventHandler 信息对象,如果 EventHandler 不存在则返回 `None`。
+
+### 10. 全局启用指定组件
+```python
+def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
+```
+全局启用指定组件。
+
+**Args:**
+- `component_name` (str): 组件名称。
+- `component_type` (ComponentType): 组件类型。
+
+**Returns:**
+- `bool` - 启用成功返回 `True`,否则返回 `False`。
+
+### 11. 全局禁用指定组件
+```python
+async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
+```
+全局禁用指定组件。
+
+**此函数是异步的,确保在异步环境中调用。**
+
+**Args:**
+- `component_name` (str): 组件名称。
+- `component_type` (ComponentType): 组件类型。
+
+**Returns:**
+- `bool` - 禁用成功返回 `True`,否则返回 `False`。
+
+### 12. 局部启用指定组件
+```python
+def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
+```
+局部启用指定组件。
+
+**Args:**
+- `component_name` (str): 组件名称。
+- `component_type` (ComponentType): 组件类型。
+- `stream_id` (str): 消息流 ID。
+
+**Returns:**
+- `bool` - 启用成功返回 `True`,否则返回 `False`。
+
+### 13. 局部禁用指定组件
+```python
+def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
+```
+局部禁用指定组件。
+
+**Args:**
+- `component_name` (str): 组件名称。
+- `component_type` (ComponentType): 组件类型。
+- `stream_id` (str): 消息流 ID。
+
+**Returns:**
+- `bool` - 禁用成功返回 `True`,否则返回 `False`。
+
+### 14. 获取指定消息流中禁用的组件列表
+```python
+def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
+```
+获取指定消息流中禁用的组件列表。
+
+**Args:**
+- `stream_id` (str): 消息流 ID。
+- `component_type` (ComponentType): 组件类型。
+
+**Returns:**
+- `list[str]` - 禁用的组件名称列表。
diff --git a/docs/plugins/api/config-api.md b/docs/plugins/api/config-api.md
index 2a5691fc..2ee1cdfc 100644
--- a/docs/plugins/api/config-api.md
+++ b/docs/plugins/api/config-api.md
@@ -1,6 +1,6 @@
# 配置API
-配置API模块提供了配置读取和用户信息获取等功能,让插件能够安全地访问全局配置和用户信息。
+配置API模块提供了配置读取功能,让插件能够安全地访问全局配置和插件配置。
## 导入方式
diff --git a/docs/plugins/api/database-api.md b/docs/plugins/api/database-api.md
index 174bef15..5b6b4468 100644
--- a/docs/plugins/api/database-api.md
+++ b/docs/plugins/api/database-api.md
@@ -6,72 +6,51 @@
```python
from src.plugin_system.apis import database_api
+# 或者
+from src.plugin_system import database_api
```
## 主要功能
-### 1. 通用数据库查询
-
-#### `db_query(model_class, query_type="get", filters=None, data=None, limit=None, order_by=None, single_result=False)`
-执行数据库查询操作的通用接口
-
-**参数:**
-- `model_class`:Peewee模型类,如ActionRecords、Messages等
-- `query_type`:查询类型,可选值: "get", "create", "update", "delete", "count"
-- `filters`:过滤条件字典,键为字段名,值为要匹配的值
-- `data`:用于创建或更新的数据字典
-- `limit`:限制结果数量
-- `order_by`:排序字段列表,使用字段名,前缀'-'表示降序
-- `single_result`:是否只返回单个结果
-
-**返回:**
-根据查询类型返回不同的结果:
-- "get":返回查询结果列表或单个结果
-- "create":返回创建的记录
-- "update":返回受影响的行数
-- "delete":返回受影响的行数
-- "count":返回记录数量
-
-### 2. 便捷查询函数
-
-#### `db_save(model_class, data, key_field=None, key_value=None)`
-保存数据到数据库(创建或更新)
-
-**参数:**
-- `model_class`:Peewee模型类
-- `data`:要保存的数据字典
-- `key_field`:用于查找现有记录的字段名
-- `key_value`:用于查找现有记录的字段值
-
-**返回:**
-- `Dict[str, Any]`:保存后的记录数据,失败时返回None
-
-#### `db_get(model_class, filters=None, order_by=None, limit=None)`
-简化的查询函数
-
-**参数:**
-- `model_class`:Peewee模型类
-- `filters`:过滤条件字典
-- `order_by`:排序字段
-- `limit`:限制结果数量
-
-**返回:**
-- `Union[List[Dict], Dict, None]`:查询结果
-
-### 3. 专用函数
-
-#### `store_action_info(...)`
-存储动作信息的专用函数
-
-## 使用示例
-
-### 1. 基本查询操作
+### 1. 通用数据库操作
```python
-from src.plugin_system.apis import database_api
-from src.common.database.database_model import Messages, ActionRecords
+async def db_query(
+ model_class: Type[Model],
+ data: Optional[Dict[str, Any]] = None,
+ query_type: Optional[str] = "get",
+ filters: Optional[Dict[str, Any]] = None,
+ limit: Optional[int] = None,
+ order_by: Optional[List[str]] = None,
+ single_result: Optional[bool] = False,
+) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
+```
+执行数据库查询操作的通用接口。
-# 查询最近10条消息
+**Args:**
+- `model_class`: Peewee模型类。
+ - Peewee模型类可以在`src.common.database.database_model`模块中找到,如`ActionRecords`、`Messages`等。
+- `data`: 用于创建或更新的数据
+- `query_type`: 查询类型
+ - 可选值: `get`, `create`, `update`, `delete`, `count`。
+- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
+- `limit`: 限制结果数量。
+- `order_by`: 排序字段列表,使用字段名,前缀'-'表示降序。
+ - 排序字段,前缀`-`表示降序,例如`-time`表示按时间字段(即`time`字段)降序
+- `single_result`: 是否只返回单个结果。
+
+**Returns:**
+- 根据查询类型返回不同的结果:
+ - `get`: 返回查询结果列表或单个结果。(如果 `single_result=True`)
+ - `create`: 返回创建的记录。
+ - `update`: 返回受影响的行数。
+ - `delete`: 返回受影响的行数。
+ - `count`: 返回记录数量。
+
+#### 示例
+
+1. 查询最近10条消息
+```python
messages = await database_api.db_query(
Messages,
query_type="get",
@@ -79,180 +58,159 @@ messages = await database_api.db_query(
limit=10,
order_by=["-time"]
)
-
-# 查询单条记录
-message = await database_api.db_query(
- Messages,
- query_type="get",
- filters={"message_id": "msg_123"},
- single_result=True
-)
```
-
-### 2. 创建记录
-
+2. 创建一条记录
```python
-# 创建新的动作记录
new_record = await database_api.db_query(
ActionRecords,
+ data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
- data={
- "action_id": "action_123",
- "time": time.time(),
- "action_name": "TestAction",
- "action_done": True
- }
)
-
-print(f"创建了记录: {new_record['id']}")
```
-
-### 3. 更新记录
-
+3. 更新记录
```python
-# 更新动作状态
updated_count = await database_api.db_query(
ActionRecords,
+ data={"action_done": True},
query_type="update",
- filters={"action_id": "action_123"},
- data={"action_done": True, "completion_time": time.time()}
+ filters={"action_id": "123"},
)
-
-print(f"更新了 {updated_count} 条记录")
```
-
-### 4. 删除记录
-
+4. 删除记录
```python
-# 删除过期记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
- filters={"time__lt": time.time() - 86400} # 删除24小时前的记录
+ filters={"action_id": "123"}
)
-
-print(f"删除了 {deleted_count} 条过期记录")
```
-
-### 5. 统计查询
-
+5. 计数
```python
-# 统计消息数量
-message_count = await database_api.db_query(
+count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
-
-print(f"该聊天有 {message_count} 条消息")
```
-### 6. 使用便捷函数
-
+### 2. 数据库保存
+```python
+async def db_save(
+ model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
+) -> Optional[Dict[str, Any]]:
+```
+保存数据到数据库(创建或更新)
+
+如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
+
+如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
+
+**Args:**
+- `model_class`: Peewee模型类。
+- `data`: 要保存的数据字典。
+- `key_field`: 用于查找现有记录的字段名,例如"action_id"。
+- `key_value`: 用于查找现有记录的字段值。
+
+**Returns:**
+- `Optional[Dict[str, Any]]`: 保存后的记录数据,失败时返回None。
+
+#### 示例
+创建或更新一条记录
```python
-# 使用db_save进行创建或更新
record = await database_api.db_save(
ActionRecords,
{
- "action_id": "action_123",
+ "action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
- key_value="action_123"
+ key_value="123"
)
+```
-# 使用db_get进行简单查询
-recent_messages = await database_api.db_get(
+### 3. 数据库获取
+```python
+async def db_get(
+ model_class: Type[Model],
+ filters: Optional[Dict[str, Any]] = None,
+ limit: Optional[int] = None,
+ order_by: Optional[str] = None,
+ single_result: Optional[bool] = False,
+) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
+```
+
+从数据库获取记录
+
+这是db_query方法的简化版本,专注于数据检索操作。
+
+**Args:**
+- `model_class`: Peewee模型类。
+- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
+- `limit`: 限制结果数量。
+- `order_by`: 排序字段,使用字段名,前缀'-'表示降序。
+- `single_result`: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
+
+**Returns:**
+- `Union[List[Dict], Dict, None]`: 查询结果列表或单个结果(如果`single_result=True`),失败时返回None。
+
+#### 示例
+1. 获取单个记录
+```python
+record = await database_api.db_get(
+ ActionRecords,
+ filters={"action_id": "123"},
+ limit=1
+)
+```
+2. 获取最近10条记录
+```python
+records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
+ limit=10,
order_by="-time",
- limit=5
)
```
-## 高级用法
-
-### 复杂查询示例
-
+### 4. 动作信息存储
```python
-# 查询特定用户在特定时间段的消息
-user_messages = await database_api.db_query(
- Messages,
- query_type="get",
- filters={
- "user_id": "123456",
- "time__gte": start_time, # 大于等于开始时间
- "time__lt": end_time # 小于结束时间
- },
- order_by=["-time"],
- limit=50
+async def store_action_info(
+ chat_stream=None,
+ action_build_into_prompt: bool = False,
+ action_prompt_display: str = "",
+ action_done: bool = True,
+ thinking_id: str = "",
+ action_data: Optional[dict] = None,
+ action_name: str = "",
+) -> Optional[Dict[str, Any]]:
+```
+存储动作信息到数据库,是一种针对 Action 的 `db_save()` 的封装函数。
+
+将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
+
+**Args:**
+- `chat_stream`: 聊天流对象,包含聊天ID等信息。
+- `action_build_into_prompt`: 是否将动作信息构建到提示中。
+- `action_prompt_display`: 动作提示的显示文本。
+- `action_done`: 动作是否完成。
+- `thinking_id`: 思考过程的ID。
+- `action_data`: 动作的数据字典。
+- `action_name`: 动作的名称。
+
+**Returns:**
+- `Optional[Dict[str, Any]]`: 存储后的记录数据,失败时返回None。
+
+#### 示例
+```python
+record = await database_api.store_action_info(
+ chat_stream=chat_stream,
+ action_build_into_prompt=True,
+ action_prompt_display="执行了回复动作",
+ action_done=True,
+ thinking_id="thinking_123",
+ action_data={"content": "Hello"},
+ action_name="reply_action"
)
-
-# 批量处理
-for message in user_messages:
- print(f"消息内容: {message['plain_text']}")
- print(f"发送时间: {message['time']}")
-```
-
-### 插件中的数据持久化
-
-```python
-from src.plugin_system.base import BasePlugin
-from src.plugin_system.apis import database_api
-
-class DataPlugin(BasePlugin):
- async def handle_action(self, action_data, chat_stream):
- # 保存插件数据
- plugin_data = {
- "plugin_name": self.plugin_name,
- "chat_id": chat_stream.stream_id,
- "data": json.dumps(action_data),
- "created_time": time.time()
- }
-
- # 使用自定义表模型(需要先定义)
- record = await database_api.db_save(
- PluginData, # 假设的插件数据模型
- plugin_data,
- key_field="plugin_name",
- key_value=self.plugin_name
- )
-
- return {"success": True, "record_id": record["id"]}
-```
-
-## 数据模型
-
-### 常用模型类
-系统提供了以下常用的数据模型:
-
-- `Messages`:消息记录
-- `ActionRecords`:动作记录
-- `UserInfo`:用户信息
-- `GroupInfo`:群组信息
-
-### 字段说明
-
-#### Messages模型主要字段
-- `message_id`:消息ID
-- `chat_id`:聊天ID
-- `user_id`:用户ID
-- `plain_text`:纯文本内容
-- `time`:时间戳
-
-#### ActionRecords模型主要字段
-- `action_id`:动作ID
-- `action_name`:动作名称
-- `action_done`:是否完成
-- `time`:创建时间
-
-## 注意事项
-
-1. **异步操作**:所有数据库API都是异步的,必须使用`await`
-2. **错误处理**:函数内置错误处理,失败时返回None或空列表
-3. **数据类型**:返回的都是字典格式的数据,不是模型对象
-4. **性能考虑**:使用`limit`参数避免查询大量数据
-5. **过滤条件**:支持简单的等值过滤,复杂查询需要使用原生Peewee语法
-6. **事务**:如需事务支持,建议直接使用Peewee的事务功能
\ No newline at end of file
+```
\ No newline at end of file
diff --git a/docs/plugins/api/emoji-api.md b/docs/plugins/api/emoji-api.md
index 6dd071b9..ce9dd0c8 100644
--- a/docs/plugins/api/emoji-api.md
+++ b/docs/plugins/api/emoji-api.md
@@ -6,11 +6,13 @@
```python
from src.plugin_system.apis import emoji_api
+# 或者
+from src.plugin_system import emoji_api
```
-## 🆕 **二步走识别优化**
+## 二步走识别优化
-从最新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
+从新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
### **收到表情包时的识别流程**
1. **第一步**:VLM视觉分析 - 生成详细描述
@@ -30,217 +32,84 @@ from src.plugin_system.apis import emoji_api
## 主要功能
### 1. 表情包获取
-
-#### `get_by_description(description: str) -> Optional[Tuple[str, str, str]]`
+```python
+async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
+```
根据场景描述选择表情包
-**参数:**
-- `description`:场景描述文本,例如"开心的大笑"、"轻微的讽刺"、"表示无奈和沮丧"等
+**Args:**
+- `description`:表情包的描述文本,例如"开心"、"难过"、"愤怒"等
-**返回:**
-- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 匹配的场景) 或 None
+**Returns:**
+- `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到匹配的表情包则返回None
-**示例:**
+#### 示例
```python
-emoji_result = await emoji_api.get_by_description("开心的大笑")
+emoji_result = await emoji_api.get_by_description("大笑")
if emoji_result:
emoji_base64, description, matched_scene = emoji_result
print(f"获取到表情包: {description}, 场景: {matched_scene}")
# 可以将emoji_base64用于发送表情包
```
-#### `get_random() -> Optional[Tuple[str, str, str]]`
-随机获取表情包
-
-**返回:**
-- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 随机场景) 或 None
-
-**示例:**
+### 2. 随机获取表情包
```python
-random_emoji = await emoji_api.get_random()
-if random_emoji:
- emoji_base64, description, scene = random_emoji
- print(f"随机表情包: {description}")
+async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
```
+随机获取指定数量的表情包
-#### `get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]`
-根据场景关键词获取表情包
+**Args:**
+- `count`:要获取的表情包数量,默认为1
-**参数:**
-- `emotion`:场景关键词,如"大笑"、"讽刺"、"无奈"等
+**Returns:**
+- `List[Tuple[str, str, str]]`:一个包含多个表情包的列表,每个元素是一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到或出错则返回空列表
-**返回:**
-- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 匹配的场景) 或 None
-
-**示例:**
+### 3. 根据情感获取表情包
```python
-emoji_result = await emoji_api.get_by_emotion("讽刺")
-if emoji_result:
- emoji_base64, description, scene = emoji_result
- # 发送讽刺表情包
+async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
```
+根据情感标签获取表情包
-### 2. 表情包信息查询
+**Args:**
+- `emotion`:情感标签,例如"开心"、"悲伤"、"愤怒"等
-#### `get_count() -> int`
-获取表情包数量
+**Returns:**
+- `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到则返回None
-**返回:**
-- `int`:当前可用的表情包数量
+### 4. 获取表情包数量
+```python
+def get_count() -> int:
+```
+获取当前可用表情包的数量
-#### `get_info() -> dict`
-获取表情包系统信息
+### 5. 获取表情包系统信息
+```python
+def get_info() -> Dict[str, Any]:
+```
+获取表情包系统的基本信息
-**返回:**
-- `dict`:包含表情包数量、最大数量等信息
+**Returns:**
+- `Dict[str, Any]`:包含表情包数量、描述等信息的字典,包含以下键:
+ - `current_count`:当前表情包数量
+ - `max_count`:最大表情包数量
+ - `available_emojis`:当前可用的表情包数量
-**返回字典包含:**
-- `current_count`:当前表情包数量
-- `max_count`:最大表情包数量
-- `available_emojis`:可用表情包数量
+### 6. 获取所有可用的情感标签
+```python
+def get_emotions() -> List[str]:
+```
+获取所有可用的情感标签 **(已经去重)**
-#### `get_emotions() -> list`
-获取所有可用的场景关键词
-
-**返回:**
-- `list`:所有表情包的场景关键词列表(去重)
-
-#### `get_descriptions() -> list`
+### 7. 获取所有表情包描述
+```python
+def get_descriptions() -> List[str]:
+```
获取所有表情包的描述列表
-**返回:**
-- `list`:所有表情包的描述文本列表
-
-## 使用示例
-
-### 1. 智能表情包选择
-
-```python
-from src.plugin_system.apis import emoji_api
-
-async def send_emotion_response(message_text: str, chat_stream):
- """根据消息内容智能选择表情包回复"""
-
- # 分析消息场景
- if "哈哈" in message_text or "好笑" in message_text:
- emoji_result = await emoji_api.get_by_description("开心的大笑")
- elif "无语" in message_text or "算了" in message_text:
- emoji_result = await emoji_api.get_by_description("表示无奈和沮丧")
- elif "呵呵" in message_text or "是吗" in message_text:
- emoji_result = await emoji_api.get_by_description("轻微的讽刺")
- elif "生气" in message_text or "愤怒" in message_text:
- emoji_result = await emoji_api.get_by_description("愤怒和不满")
- else:
- # 随机选择一个表情包
- emoji_result = await emoji_api.get_random()
-
- if emoji_result:
- emoji_base64, description, scene = emoji_result
- # 使用send_api发送表情包
- from src.plugin_system.apis import send_api
- success = await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
- return success
-
- return False
-```
-
-### 2. 表情包管理功能
-
-```python
-async def show_emoji_stats():
- """显示表情包统计信息"""
-
- # 获取基本信息
- count = emoji_api.get_count()
- info = emoji_api.get_info()
- scenes = emoji_api.get_emotions() # 实际返回的是场景关键词
-
- stats = f"""
-📊 表情包统计信息:
-- 总数量: {count}
-- 可用数量: {info['available_emojis']}
-- 最大容量: {info['max_count']}
-- 支持场景: {len(scenes)}种
-
-🎭 支持的场景关键词: {', '.join(scenes[:10])}{'...' if len(scenes) > 10 else ''}
- """
-
- return stats
-```
-
-### 3. 表情包测试功能
-
-```python
-async def test_emoji_system():
- """测试表情包系统的各种功能"""
-
- print("=== 表情包系统测试 ===")
-
- # 测试场景描述查找
- test_descriptions = ["开心的大笑", "轻微的讽刺", "表示无奈和沮丧", "愤怒和不满"]
- for desc in test_descriptions:
- result = await emoji_api.get_by_description(desc)
- if result:
- _, description, scene = result
- print(f"✅ 场景'{desc}' -> {description} ({scene})")
- else:
- print(f"❌ 场景'{desc}' -> 未找到")
-
- # 测试关键词查找
- scenes = emoji_api.get_emotions()
- if scenes:
- test_scene = scenes[0]
- result = await emoji_api.get_by_emotion(test_scene)
- if result:
- print(f"✅ 关键词'{test_scene}' -> 找到匹配表情包")
-
- # 测试随机获取
- random_result = await emoji_api.get_random()
- if random_result:
- print("✅ 随机获取 -> 成功")
-
- print(f"📊 系统信息: {emoji_api.get_info()}")
-```
-
-### 4. 在Action中使用表情包
-
-```python
-from src.plugin_system.base import BaseAction
-
-class EmojiAction(BaseAction):
- async def execute(self, action_data, chat_stream):
- # 从action_data获取场景描述或关键词
- scene_keyword = action_data.get("scene", "")
- scene_description = action_data.get("description", "")
-
- emoji_result = None
-
- # 优先使用具体的场景描述
- if scene_description:
- emoji_result = await emoji_api.get_by_description(scene_description)
- # 其次使用场景关键词
- elif scene_keyword:
- emoji_result = await emoji_api.get_by_emotion(scene_keyword)
- # 最后随机选择
- else:
- emoji_result = await emoji_api.get_random()
-
- if emoji_result:
- emoji_base64, description, scene = emoji_result
- return {
- "success": True,
- "emoji_base64": emoji_base64,
- "description": description,
- "scene": scene
- }
-
- return {"success": False, "message": "未找到合适的表情包"}
-```
-
## 场景描述说明
### 常用场景描述
-表情包系统支持多种具体的场景描述,常见的包括:
+表情包系统支持多种具体的场景描述,举例如下:
- **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈
- **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头
@@ -248,8 +117,8 @@ class EmojiAction(BaseAction):
- **惊讶类场景**:震惊的表情、意外的发现、困惑的思考
- **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子
-### 场景关键词示例
-系统支持的场景关键词包括:
+### 情感关键词示例
+系统支持的情感关键词举例如下:
- 大笑、微笑、兴奋、手舞足蹈
- 无奈、沮丧、讽刺、无语、摇头
- 愤怒、不满、生气、瞪视、抓狂
@@ -263,9 +132,9 @@ class EmojiAction(BaseAction):
## 注意事项
-1. **异步函数**:获取表情包的函数都是异步的,需要使用 `await`
+1. **异步函数**:部分函数是异步的,需要使用 `await`
2. **返回格式**:表情包以base64编码返回,可直接用于发送
-3. **错误处理**:所有函数都有错误处理,失败时返回None或默认值
+3. **错误处理**:所有函数都有错误处理,失败时返回None,空列表或默认值
4. **使用统计**:系统会记录表情包的使用次数
5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在
6. **编码格式**:返回的是base64编码的图片数据,可直接用于网络传输
diff --git a/docs/plugins/api/generator-api.md b/docs/plugins/api/generator-api.md
index 964fff84..afeb6eec 100644
--- a/docs/plugins/api/generator-api.md
+++ b/docs/plugins/api/generator-api.md
@@ -6,241 +6,151 @@
```python
from src.plugin_system.apis import generator_api
+# 或者
+from src.plugin_system import generator_api
```
## 主要功能
### 1. 回复器获取
-
-#### `get_replyer(chat_stream=None, platform=None, chat_id=None, is_group=True)`
+```python
+def get_replyer(
+ chat_stream: Optional[ChatStream] = None,
+ chat_id: Optional[str] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
+ request_type: str = "replyer",
+) -> Optional[DefaultReplyer]:
+```
获取回复器对象
-**参数:**
-- `chat_stream`:聊天流对象(优先)
-- `platform`:平台名称,如"qq"
-- `chat_id`:聊天ID(群ID或用户ID)
-- `is_group`:是否为群聊
+优先使用chat_stream,如果没有则使用chat_id直接查找。
-**返回:**
-- `DefaultReplyer`:回复器对象,如果获取失败则返回None
+使用 ReplyerManager 来管理实例,避免重复创建。
-**示例:**
+**Args:**
+- `chat_stream`: 聊天流对象
+- `chat_id`: 聊天ID(实际上就是`stream_id`)
+- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
+- `request_type`: 请求类型,用于记录LLM使用情况,可以不写
+
+**Returns:**
+- `DefaultReplyer`: 回复器对象,如果获取失败则返回None
+
+#### 示例
```python
# 使用聊天流获取回复器
replyer = generator_api.get_replyer(chat_stream=chat_stream)
-# 使用平台和ID获取回复器
-replyer = generator_api.get_replyer(
- platform="qq",
- chat_id="123456789",
- is_group=True
-)
+# 使用平台和ID获取回复器
+replyer = generator_api.get_replyer(chat_id="123456789")
```
### 2. 回复生成
-
-#### `generate_reply(chat_stream=None, action_data=None, platform=None, chat_id=None, is_group=True)`
+```python
+async def generate_reply(
+ chat_stream: Optional[ChatStream] = None,
+ chat_id: Optional[str] = None,
+ action_data: Optional[Dict[str, Any]] = None,
+ reply_to: str = "",
+ extra_info: str = "",
+ available_actions: Optional[Dict[str, ActionInfo]] = None,
+ enable_tool: bool = False,
+ enable_splitter: bool = True,
+ enable_chinese_typo: bool = True,
+ return_prompt: bool = False,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
+ request_type: str = "generator_api",
+) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
+```
生成回复
-**参数:**
-- `chat_stream`:聊天流对象(优先)
-- `action_data`:动作数据
-- `platform`:平台名称(备用)
-- `chat_id`:聊天ID(备用)
-- `is_group`:是否为群聊(备用)
+优先使用chat_stream,如果没有则使用chat_id直接查找。
-**返回:**
-- `Tuple[bool, List[Tuple[str, Any]]]`:(是否成功, 回复集合)
+**Args:**
+- `chat_stream`: 聊天流对象
+- `chat_id`: 聊天ID(实际上就是`stream_id`)
+- `action_data`: 动作数据(向下兼容,包含`reply_to`和`extra_info`)
+- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
+- `extra_info`: 附加信息
+- `available_actions`: 可用动作字典,格式为 `{"action_name": ActionInfo}`
+- `enable_tool`: 是否启用工具
+- `enable_splitter`: 是否启用分割器
+- `enable_chinese_typo`: 是否启用中文错别字
+- `return_prompt`: 是否返回提示词
+- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
+- `request_type`: 请求类型(可选,记录LLM使用)
+- `request_type`: 请求类型,用于记录LLM使用情况
-**示例:**
+**Returns:**
+- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
+
+#### 示例
```python
-success, reply_set = await generator_api.generate_reply(
+success, reply_set, prompt = await generator_api.generate_reply(
chat_stream=chat_stream,
- action_data={"message": "你好", "intent": "greeting"}
+ action_data=action_data,
+ reply_to="麦麦:你好",
+ available_actions=action_info,
+ enable_tool=True,
+ return_prompt=True
)
-
if success:
for reply_type, reply_content in reply_set:
print(f"回复类型: {reply_type}, 内容: {reply_content}")
+ if prompt:
+ print(f"使用的提示词: {prompt}")
```
-#### `rewrite_reply(chat_stream=None, reply_data=None, platform=None, chat_id=None, is_group=True)`
-重写回复
-
-**参数:**
-- `chat_stream`:聊天流对象(优先)
-- `reply_data`:回复数据
-- `platform`:平台名称(备用)
-- `chat_id`:聊天ID(备用)
-- `is_group`:是否为群聊(备用)
-
-**返回:**
-- `Tuple[bool, List[Tuple[str, Any]]]`:(是否成功, 回复集合)
-
-**示例:**
+### 3. 回复重写
```python
-success, reply_set = await generator_api.rewrite_reply(
+async def rewrite_reply(
+ chat_stream: Optional[ChatStream] = None,
+ reply_data: Optional[Dict[str, Any]] = None,
+ chat_id: Optional[str] = None,
+ enable_splitter: bool = True,
+ enable_chinese_typo: bool = True,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
+ raw_reply: str = "",
+ reason: str = "",
+ reply_to: str = "",
+ return_prompt: bool = False,
+) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
+```
+重写回复,使用新的内容替换旧的回复内容。
+
+优先使用chat_stream,如果没有则使用chat_id直接查找。
+
+**Args:**
+- `chat_stream`: 聊天流对象
+- `reply_data`: 回复数据,包含`raw_reply`, `reason`和`reply_to`,**(向下兼容备用,当其他参数缺失时从此获取)**
+- `chat_id`: 聊天ID(实际上就是`stream_id`)
+- `enable_splitter`: 是否启用分割器
+- `enable_chinese_typo`: 是否启用中文错别字
+- `model_set_with_weight`: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
+- `raw_reply`: 原始回复内容
+- `reason`: 重写原因
+- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
+
+**Returns:**
+- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
+
+#### 示例
+```python
+success, reply_set, prompt = await generator_api.rewrite_reply(
chat_stream=chat_stream,
- reply_data={"original_text": "原始回复", "style": "more_friendly"}
+ raw_reply="原始回复内容",
+ reason="重写原因",
+ reply_to="麦麦:你好",
+ return_prompt=True
)
+if success:
+ for reply_type, reply_content in reply_set:
+ print(f"回复类型: {reply_type}, 内容: {reply_content}")
+ if prompt:
+ print(f"使用的提示词: {prompt}")
```
-## 使用示例
-
-### 1. 基础回复生成
-
-```python
-from src.plugin_system.apis import generator_api
-
-async def generate_greeting_reply(chat_stream, user_name):
- """生成问候回复"""
-
- action_data = {
- "intent": "greeting",
- "user_name": user_name,
- "context": "morning_greeting"
- }
-
- success, reply_set = await generator_api.generate_reply(
- chat_stream=chat_stream,
- action_data=action_data
- )
-
- if success and reply_set:
- # 获取第一个回复
- reply_type, reply_content = reply_set[0]
- return reply_content
-
- return "你好!" # 默认回复
-```
-
-### 2. 在Action中使用回复生成器
-
-```python
-from src.plugin_system.base import BaseAction
-
-class ChatAction(BaseAction):
- async def execute(self, action_data, chat_stream):
- # 准备回复数据
- reply_context = {
- "message_type": "response",
- "user_input": action_data.get("user_message", ""),
- "intent": action_data.get("intent", ""),
- "entities": action_data.get("entities", {}),
- "context": self.get_conversation_context(chat_stream)
- }
-
- # 生成回复
- success, reply_set = await generator_api.generate_reply(
- chat_stream=chat_stream,
- action_data=reply_context
- )
-
- if success:
- return {
- "success": True,
- "replies": reply_set,
- "generated_count": len(reply_set)
- }
-
- return {
- "success": False,
- "error": "回复生成失败",
- "fallback_reply": "抱歉,我现在无法理解您的消息。"
- }
-```
-
-### 3. 多样化回复生成
-
-```python
-async def generate_diverse_replies(chat_stream, topic, count=3):
- """生成多个不同风格的回复"""
-
- styles = ["formal", "casual", "humorous"]
- all_replies = []
-
- for i, style in enumerate(styles[:count]):
- action_data = {
- "topic": topic,
- "style": style,
- "variation": i
- }
-
- success, reply_set = await generator_api.generate_reply(
- chat_stream=chat_stream,
- action_data=action_data
- )
-
- if success and reply_set:
- all_replies.extend(reply_set)
-
- return all_replies
-```
-
-### 4. 回复重写功能
-
-```python
-async def improve_reply(chat_stream, original_reply, improvement_type="more_friendly"):
- """改进原始回复"""
-
- reply_data = {
- "original_text": original_reply,
- "improvement_type": improvement_type,
- "target_audience": "young_users",
- "tone": "positive"
- }
-
- success, improved_replies = await generator_api.rewrite_reply(
- chat_stream=chat_stream,
- reply_data=reply_data
- )
-
- if success and improved_replies:
- # 返回改进后的第一个回复
- _, improved_content = improved_replies[0]
- return improved_content
-
- return original_reply # 如果改进失败,返回原始回复
-```
-
-### 5. 条件回复生成
-
-```python
-async def conditional_reply_generation(chat_stream, user_message, user_emotion):
- """根据用户情感生成条件回复"""
-
- # 根据情感调整回复策略
- if user_emotion == "sad":
- action_data = {
- "intent": "comfort",
- "tone": "empathetic",
- "style": "supportive"
- }
- elif user_emotion == "angry":
- action_data = {
- "intent": "calm",
- "tone": "peaceful",
- "style": "understanding"
- }
- else:
- action_data = {
- "intent": "respond",
- "tone": "neutral",
- "style": "helpful"
- }
-
- action_data["user_message"] = user_message
- action_data["user_emotion"] = user_emotion
-
- success, reply_set = await generator_api.generate_reply(
- chat_stream=chat_stream,
- action_data=action_data
- )
-
- return reply_set if success else []
-```
-
-## 回复集合格式
+## 回复集合`reply_set`格式
### 回复类型
生成的回复集合包含多种类型的回复:
@@ -260,82 +170,32 @@ reply_set = [
]
```
-## 高级用法
-
-### 1. 自定义回复器配置
-
+### 4. 自定义提示词回复
```python
-async def generate_with_custom_config(chat_stream, action_data):
- """使用自定义配置生成回复"""
-
- # 获取回复器
- replyer = generator_api.get_replyer(chat_stream=chat_stream)
-
- if replyer:
- # 可以访问回复器的内部方法
- success, reply_set = await replyer.generate_reply_with_context(
- reply_data=action_data,
- # 可以传递额外的配置参数
- )
- return success, reply_set
-
- return False, []
+async def generate_response_custom(
+ chat_stream: Optional[ChatStream] = None,
+ chat_id: Optional[str] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
+ prompt: str = "",
+) -> Optional[str]:
```
+生成自定义提示词回复
-### 2. 回复质量评估
+优先使用chat_stream,如果没有则使用chat_id直接查找。
-```python
-async def generate_and_evaluate_replies(chat_stream, action_data):
- """生成回复并评估质量"""
-
- success, reply_set = await generator_api.generate_reply(
- chat_stream=chat_stream,
- action_data=action_data
- )
-
- if success:
- evaluated_replies = []
- for reply_type, reply_content in reply_set:
- # 简单的质量评估
- quality_score = evaluate_reply_quality(reply_content)
- evaluated_replies.append({
- "type": reply_type,
- "content": reply_content,
- "quality": quality_score
- })
-
- # 按质量排序
- evaluated_replies.sort(key=lambda x: x["quality"], reverse=True)
- return evaluated_replies
-
- return []
+**Args:**
+- `chat_stream`: 聊天流对象
+- `chat_id`: 聊天ID(备用)
+- `model_set_with_weight`: 模型集合配置列表
+- `prompt`: 自定义提示词
-def evaluate_reply_quality(reply_content):
- """简单的回复质量评估"""
- if not reply_content:
- return 0
-
- score = 50 # 基础分
-
- # 长度适中加分
- if 5 <= len(reply_content) <= 100:
- score += 20
-
- # 包含积极词汇加分
- positive_words = ["好", "棒", "不错", "感谢", "开心"]
- for word in positive_words:
- if word in reply_content:
- score += 10
- break
-
- return min(score, 100)
-```
+**Returns:**
+- `Optional[str]`: 生成的自定义回复内容,如果生成失败则返回None
## 注意事项
-1. **异步操作**:所有生成函数都是异步的,必须使用`await`
-2. **错误处理**:函数内置错误处理,失败时返回False和空列表
-3. **聊天流依赖**:需要有效的聊天流对象才能正常工作
-4. **性能考虑**:回复生成可能需要一些时间,特别是使用LLM时
-5. **回复格式**:返回的回复集合是元组列表,包含类型和内容
-6. **上下文感知**:生成器会考虑聊天上下文和历史消息
\ No newline at end of file
+1. **异步操作**:部分函数是异步的,须使用`await`
+2. **聊天流依赖**:需要有效的聊天流对象才能正常工作
+3. **性能考虑**:回复生成可能需要一些时间,特别是使用LLM时
+4. **回复格式**:返回的回复集合是元组列表,包含类型和内容
+5. **上下文感知**:生成器会考虑聊天上下文和历史消息,除非你用的是自定义提示词。
\ No newline at end of file
diff --git a/docs/plugins/api/llm-api.md b/docs/plugins/api/llm-api.md
index e0879ddf..9a266933 100644
--- a/docs/plugins/api/llm-api.md
+++ b/docs/plugins/api/llm-api.md
@@ -6,239 +6,34 @@ LLM API模块提供与大语言模型交互的功能,让插件能够使用系
```python
from src.plugin_system.apis import llm_api
+# 或者
+from src.plugin_system import llm_api
```
## 主要功能
-### 1. 模型管理
-
-#### `get_available_models() -> Dict[str, Any]`
-获取所有可用的模型配置
-
-**返回:**
-- `Dict[str, Any]`:模型配置字典,key为模型名称,value为模型配置
-
-**示例:**
+### 1. 查询可用模型
```python
-models = llm_api.get_available_models()
-for model_name, model_config in models.items():
- print(f"模型: {model_name}")
- print(f"配置: {model_config}")
+def get_available_models() -> Dict[str, TaskConfig]:
```
+获取所有可用的模型配置。
-### 2. 内容生成
+**Return:**
+- `Dict[str, TaskConfig]`:模型配置字典,key为模型名称,value为模型配置对象。
-#### `generate_with_model(prompt, model_config, request_type="plugin.generate", **kwargs)`
-使用指定模型生成内容
-
-**参数:**
-- `prompt`:提示词
-- `model_config`:模型配置(从 get_available_models 获取)
-- `request_type`:请求类型标识
-- `**kwargs`:其他模型特定参数,如temperature、max_tokens等
-
-**返回:**
-- `Tuple[bool, str, str, str]`:(是否成功, 生成的内容, 推理过程, 模型名称)
-
-**示例:**
+### 2. 使用模型生成内容
```python
-models = llm_api.get_available_models()
-default_model = models.get("default")
-
-if default_model:
- success, response, reasoning, model_name = await llm_api.generate_with_model(
- prompt="请写一首关于春天的诗",
- model_config=default_model,
- temperature=0.7,
- max_tokens=200
- )
-
- if success:
- print(f"生成内容: {response}")
- print(f"使用模型: {model_name}")
+async def generate_with_model(
+ prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
+) -> Tuple[bool, str, str, str]:
```
+使用指定模型生成内容。
-## 使用示例
+**Args:**
+- `prompt`:提示词。
+- `model_config`:模型配置对象(从 `get_available_models` 获取)。
+- `request_type`:请求类型标识,默认为 `"plugin.generate"`。
+- `**kwargs`:其他模型特定参数,如 `temperature`、`max_tokens` 等。
-### 1. 基础文本生成
-
-```python
-from src.plugin_system.apis import llm_api
-
-async def generate_story(topic: str):
- """生成故事"""
- models = llm_api.get_available_models()
- model = models.get("default")
-
- if not model:
- return "未找到可用模型"
-
- prompt = f"请写一个关于{topic}的短故事,大约100字左右。"
-
- success, story, reasoning, model_name = await llm_api.generate_with_model(
- prompt=prompt,
- model_config=model,
- request_type="story.generate",
- temperature=0.8,
- max_tokens=150
- )
-
- return story if success else "故事生成失败"
-```
-
-### 2. 在Action中使用LLM
-
-```python
-from src.plugin_system.base import BaseAction
-
-class LLMAction(BaseAction):
- async def execute(self, action_data, chat_stream):
- # 获取用户输入
- user_input = action_data.get("user_message", "")
- intent = action_data.get("intent", "chat")
-
- # 获取模型配置
- models = llm_api.get_available_models()
- model = models.get("default")
-
- if not model:
- return {"success": False, "error": "未配置LLM模型"}
-
- # 构建提示词
- prompt = self.build_prompt(user_input, intent)
-
- # 生成回复
- success, response, reasoning, model_name = await llm_api.generate_with_model(
- prompt=prompt,
- model_config=model,
- request_type=f"plugin.{self.plugin_name}",
- temperature=0.7
- )
-
- if success:
- return {
- "success": True,
- "response": response,
- "model_used": model_name,
- "reasoning": reasoning
- }
-
- return {"success": False, "error": response}
-
- def build_prompt(self, user_input: str, intent: str) -> str:
- """构建提示词"""
- base_prompt = "你是一个友善的AI助手。"
-
- if intent == "question":
- return f"{base_prompt}\n\n用户问题:{user_input}\n\n请提供准确、有用的回答:"
- elif intent == "chat":
- return f"{base_prompt}\n\n用户说:{user_input}\n\n请进行自然的对话:"
- else:
- return f"{base_prompt}\n\n用户输入:{user_input}\n\n请回复:"
-```
-
-### 3. 多模型对比
-
-```python
-async def compare_models(prompt: str):
- """使用多个模型生成内容并对比"""
- models = llm_api.get_available_models()
- results = {}
-
- for model_name, model_config in models.items():
- success, response, reasoning, actual_model = await llm_api.generate_with_model(
- prompt=prompt,
- model_config=model_config,
- request_type="comparison.test"
- )
-
- results[model_name] = {
- "success": success,
- "response": response,
- "model": actual_model,
- "reasoning": reasoning
- }
-
- return results
-```
-
-### 4. 智能对话插件
-
-```python
-class ChatbotPlugin(BasePlugin):
- async def handle_action(self, action_data, chat_stream):
- user_message = action_data.get("message", "")
-
- # 获取历史对话上下文
- context = self.get_conversation_context(chat_stream)
-
- # 构建对话提示词
- prompt = self.build_conversation_prompt(user_message, context)
-
- # 获取模型配置
- models = llm_api.get_available_models()
- chat_model = models.get("chat", models.get("default"))
-
- if not chat_model:
- return {"success": False, "message": "聊天模型未配置"}
-
- # 生成回复
- success, response, reasoning, model_name = await llm_api.generate_with_model(
- prompt=prompt,
- model_config=chat_model,
- request_type="chat.conversation",
- temperature=0.8,
- max_tokens=500
- )
-
- if success:
- # 保存对话历史
- self.save_conversation(chat_stream, user_message, response)
-
- return {
- "success": True,
- "reply": response,
- "model": model_name
- }
-
- return {"success": False, "message": "回复生成失败"}
-
- def build_conversation_prompt(self, user_message: str, context: list) -> str:
- """构建对话提示词"""
- prompt = "你是一个有趣、友善的聊天机器人。请自然地回复用户的消息。\n\n"
-
- # 添加历史对话
- if context:
- prompt += "对话历史:\n"
- for msg in context[-5:]: # 只保留最近5条
- prompt += f"用户: {msg['user']}\n机器人: {msg['bot']}\n"
- prompt += "\n"
-
- prompt += f"用户: {user_message}\n机器人: "
- return prompt
-```
-
-## 模型配置说明
-
-### 常用模型类型
-- `default`:默认模型
-- `chat`:聊天专用模型
-- `creative`:创意生成模型
-- `code`:代码生成模型
-
-### 配置参数
-LLM模型支持的常用参数:
-- `temperature`:控制输出随机性(0.0-1.0)
-- `max_tokens`:最大生成长度
-- `top_p`:核采样参数
-- `frequency_penalty`:频率惩罚
-- `presence_penalty`:存在惩罚
-
-## 注意事项
-
-1. **异步操作**:LLM生成是异步的,必须使用`await`
-2. **错误处理**:生成失败时返回False和错误信息
-3. **配置依赖**:需要正确配置模型才能使用
-4. **请求类型**:建议为不同用途设置不同的request_type
-5. **性能考虑**:LLM调用可能较慢,考虑超时和缓存
-6. **成本控制**:注意控制max_tokens以控制成本
\ No newline at end of file
+**Return:**
+- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。
\ No newline at end of file
diff --git a/docs/plugins/api/logging-api.md b/docs/plugins/api/logging-api.md
new file mode 100644
index 00000000..5576bf5c
--- /dev/null
+++ b/docs/plugins/api/logging-api.md
@@ -0,0 +1,29 @@
+# Logging API
+
+Logging API模块提供了获取本体logger的功能,允许插件记录日志信息。
+
+## 导入方式
+
+```python
+from src.plugin_system.apis import get_logger
+# 或者
+from src.plugin_system import get_logger
+```
+
+## 主要功能
+### 1. 获取本体logger
+```python
+def get_logger(name: str) -> structlog.stdlib.BoundLogger:
+```
+获取本体logger实例。
+
+**Args:**
+- `name` (str): 日志记录器的名称。
+
+**Returns:**
+- 一个logger实例,有以下方法:
+ - `debug`
+ - `info`
+ - `warning`
+ - `error`
+ - `critical`
\ No newline at end of file
diff --git a/docs/plugins/api/message-api.md b/docs/plugins/api/message-api.md
index c95a9cc6..85d83a9b 100644
--- a/docs/plugins/api/message-api.md
+++ b/docs/plugins/api/message-api.md
@@ -1,11 +1,13 @@
# 消息API
-> 消息API提供了强大的消息查询、计数和格式化功能,让你轻松处理聊天消息数据。
+消息API提供了强大的消息查询、计数和格式化功能,让你轻松处理聊天消息数据。
## 导入方式
```python
from src.plugin_system.apis import message_api
+# 或者
+from src.plugin_system import message_api
```
## 功能概述
@@ -15,297 +17,356 @@ from src.plugin_system.apis import message_api
- **消息计数** - 统计新消息数量
- **消息格式化** - 将消息转换为可读格式
----
+## 主要功能
-## 消息查询API
+### 1. 按照事件查询消息
+```python
+def get_messages_by_time(
+ start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
+) -> List[Dict[str, Any]]:
+```
+获取指定时间范围内的消息。
-### 按时间查询消息
-
-#### `get_messages_by_time(start_time, end_time, limit=0, limit_mode="latest")`
-
-获取指定时间范围内的消息
-
-**参数:**
+**Args:**
- `start_time` (float): 开始时间戳
-- `end_time` (float): 结束时间戳
+- `end_time` (float): 结束时间戳
- `limit` (int): 限制返回消息数量,0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
-**返回:** `List[Dict[str, Any]]` - 消息列表
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
-**示例:**
+消息列表中包含的键与`Messages`类的属性一致。(位于`src.common.database.database_model`)
+
+### 2. 获取指定聊天中指定时间范围内的信息
```python
-import time
-
-# 获取最近24小时的消息
-now = time.time()
-yesterday = now - 24 * 3600
-messages = message_api.get_messages_by_time(yesterday, now, limit=50)
+def get_messages_by_time_in_chat(
+ chat_id: str,
+ start_time: float,
+ end_time: float,
+ limit: int = 0,
+ limit_mode: str = "latest",
+ filter_mai: bool = False,
+) -> List[Dict[str, Any]]:
```
+获取指定聊天中指定时间范围内的消息。
-### 按聊天查询消息
-
-#### `get_messages_by_time_in_chat(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
-
-获取指定聊天中指定时间范围内的消息
-
-**参数:**
-- `chat_id` (str): 聊天ID
-- 其他参数同上
-
-**示例:**
-```python
-# 获取某个群聊最近的100条消息
-messages = message_api.get_messages_by_time_in_chat(
- chat_id="123456789",
- start_time=yesterday,
- end_time=now,
- limit=100
-)
-```
-
-#### `get_messages_by_time_in_chat_inclusive(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
-
-获取指定聊天中指定时间范围内的消息(包含边界时间点)
-
-与 `get_messages_by_time_in_chat` 类似,但包含边界时间戳的消息。
-
-#### `get_recent_messages(chat_id, hours=24.0, limit=100, limit_mode="latest")`
-
-获取指定聊天中最近一段时间的消息(便捷方法)
-
-**参数:**
-- `chat_id` (str): 聊天ID
-- `hours` (float): 最近多少小时,默认24小时
-- `limit` (int): 限制返回消息数量,默认100条
-- `limit_mode` (str): 限制模式
-
-**示例:**
-```python
-# 获取最近6小时的消息
-recent_messages = message_api.get_recent_messages(
- chat_id="123456789",
- hours=6.0,
- limit=50
-)
-```
-
-### 按用户查询消息
-
-#### `get_messages_by_time_in_chat_for_users(chat_id, start_time, end_time, person_ids, limit=0, limit_mode="latest")`
-
-获取指定聊天中指定用户在指定时间范围内的消息
-
-**参数:**
+**Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
-- `person_ids` (list): 用户ID列表
-- `limit` (int): 限制返回消息数量
-- `limit_mode` (str): 限制模式
+- `limit` (int): 限制返回消息数量,0为不限制
+- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
-**示例:**
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
+
+
+### 3. 获取指定聊天中指定时间范围内的信息(包含边界)
```python
-# 获取特定用户的消息
-user_messages = message_api.get_messages_by_time_in_chat_for_users(
- chat_id="123456789",
- start_time=yesterday,
- end_time=now,
- person_ids=["user1", "user2"]
-)
+def get_messages_by_time_in_chat_inclusive(
+ chat_id: str,
+ start_time: float,
+ end_time: float,
+ limit: int = 0,
+ limit_mode: str = "latest",
+ filter_mai: bool = False,
+ filter_command: bool = False,
+) -> List[Dict[str, Any]]:
```
+获取指定聊天中指定时间范围内的消息(包含边界)。
-#### `get_messages_by_time_for_users(start_time, end_time, person_ids, limit=0, limit_mode="latest")`
+**Args:**
+- `chat_id` (str): 聊天ID
+- `start_time` (float): 开始时间戳(包含)
+- `end_time` (float): 结束时间戳(包含)
+- `limit` (int): 限制返回消息数量,0为不限制
+- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
+- `filter_command` (bool): 是否过滤命令消息,默认False
-获取指定用户在所有聊天中指定时间范围内的消息
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
-### 其他查询方法
-#### `get_random_chat_messages(start_time, end_time, limit=0, limit_mode="latest")`
+### 4. 获取指定聊天中指定用户在指定时间范围内的消息
+```python
+def get_messages_by_time_in_chat_for_users(
+ chat_id: str,
+ start_time: float,
+ end_time: float,
+ person_ids: List[str],
+ limit: int = 0,
+ limit_mode: str = "latest",
+) -> List[Dict[str, Any]]:
+```
+获取指定聊天中指定用户在指定时间范围内的消息。
-随机选择一个聊天,返回该聊天在指定时间范围内的消息
-
-#### `get_messages_before_time(timestamp, limit=0)`
-
-获取指定时间戳之前的消息
-
-#### `get_messages_before_time_in_chat(chat_id, timestamp, limit=0)`
-
-获取指定聊天中指定时间戳之前的消息
-
-#### `get_messages_before_time_for_users(timestamp, person_ids, limit=0)`
-
-获取指定用户在指定时间戳之前的消息
-
----
-
-## 消息计数API
-
-### `count_new_messages(chat_id, start_time=0.0, end_time=None)`
-
-计算指定聊天中从开始时间到结束时间的新消息数量
-
-**参数:**
+**Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳
-- `end_time` (float): 结束时间戳,如果为None则使用当前时间
+- `end_time` (float): 结束时间戳
+- `person_ids` (List[str]): 用户ID列表
+- `limit` (int): 限制返回消息数量,0为不限制
+- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
-**返回:** `int` - 新消息数量
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
-**示例:**
+
+### 5. 随机选择一个聊天,返回该聊天在指定时间范围内的消息
```python
-# 计算最近1小时的新消息数
-import time
-now = time.time()
-hour_ago = now - 3600
-new_count = message_api.count_new_messages("123456789", hour_ago, now)
-print(f"最近1小时有{new_count}条新消息")
+def get_random_chat_messages(
+ start_time: float,
+ end_time: float,
+ limit: int = 0,
+ limit_mode: str = "latest",
+ filter_mai: bool = False,
+) -> List[Dict[str, Any]]:
```
+随机选择一个聊天,返回该聊天在指定时间范围内的消息。
-### `count_new_messages_for_users(chat_id, start_time, end_time, person_ids)`
+**Args:**
+- `start_time` (float): 开始时间戳
+- `end_time` (float): 结束时间戳
+- `limit` (int): 限制返回消息数量,0为不限制
+- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
-计算指定聊天中指定用户从开始时间到结束时间的新消息数量
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
----
-## 消息格式化API
+### 6. 获取指定用户在所有聊天中指定时间范围内的消息
+```python
+def get_messages_by_time_for_users(
+ start_time: float,
+ end_time: float,
+ person_ids: List[str],
+ limit: int = 0,
+ limit_mode: str = "latest",
+) -> List[Dict[str, Any]]:
+```
+获取指定用户在所有聊天中指定时间范围内的消息。
-### `build_readable_messages_to_str(messages, **options)`
+**Args:**
+- `start_time` (float): 开始时间戳
+- `end_time` (float): 结束时间戳
+- `person_ids` (List[str]): 用户ID列表
+- `limit` (int): 限制返回消息数量,0为不限制
+- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
-将消息列表构建成可读的字符串
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
-**参数:**
+
+### 7. 获取指定时间戳之前的消息
+```python
+def get_messages_before_time(
+ timestamp: float,
+ limit: int = 0,
+ filter_mai: bool = False,
+) -> List[Dict[str, Any]]:
+```
+获取指定时间戳之前的消息。
+
+**Args:**
+- `timestamp` (float): 时间戳
+- `limit` (int): 限制返回消息数量,0为不限制
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
+
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
+
+
+### 8. 获取指定聊天中指定时间戳之前的消息
+```python
+def get_messages_before_time_in_chat(
+ chat_id: str,
+ timestamp: float,
+ limit: int = 0,
+ filter_mai: bool = False,
+) -> List[Dict[str, Any]]:
+```
+获取指定聊天中指定时间戳之前的消息。
+
+**Args:**
+- `chat_id` (str): 聊天ID
+- `timestamp` (float): 时间戳
+- `limit` (int): 限制返回消息数量,0为不限制
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
+
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
+
+
+### 9. 获取指定用户在指定时间戳之前的消息
+```python
+def get_messages_before_time_for_users(
+ timestamp: float,
+ person_ids: List[str],
+ limit: int = 0,
+) -> List[Dict[str, Any]]:
+```
+获取指定用户在指定时间戳之前的消息。
+
+**Args:**
+- `timestamp` (float): 时间戳
+- `person_ids` (List[str]): 用户ID列表
+- `limit` (int): 限制返回消息数量,0为不限制
+
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
+
+
+### 10. 获取指定聊天中最近一段时间的消息
+```python
+def get_recent_messages(
+ chat_id: str,
+ hours: float = 24.0,
+ limit: int = 100,
+ limit_mode: str = "latest",
+ filter_mai: bool = False,
+) -> List[Dict[str, Any]]:
+```
+获取指定聊天中最近一段时间的消息。
+
+**Args:**
+- `chat_id` (str): 聊天ID
+- `hours` (float): 最近多少小时,默认24小时
+- `limit` (int): 限制返回消息数量,默认100条
+- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
+- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False
+
+**Returns:**
+- `List[Dict[str, Any]]` - 消息列表
+
+
+### 11. 计算指定聊天中从开始时间到结束时间的新消息数量
+```python
+def count_new_messages(
+ chat_id: str,
+ start_time: float = 0.0,
+ end_time: Optional[float] = None,
+) -> int:
+```
+计算指定聊天中从开始时间到结束时间的新消息数量。
+
+**Args:**
+- `chat_id` (str): 聊天ID
+- `start_time` (float): 开始时间戳
+- `end_time` (Optional[float]): 结束时间戳,如果为None则使用当前时间
+
+**Returns:**
+- `int` - 新消息数量
+
+
+### 12. 计算指定聊天中指定用户从开始时间到结束时间的新消息数量
+```python
+def count_new_messages_for_users(
+ chat_id: str,
+ start_time: float,
+ end_time: float,
+ person_ids: List[str],
+) -> int:
+```
+计算指定聊天中指定用户从开始时间到结束时间的新消息数量。
+
+**Args:**
+- `chat_id` (str): 聊天ID
+- `start_time` (float): 开始时间戳
+- `end_time` (float): 结束时间戳
+- `person_ids` (List[str]): 用户ID列表
+
+**Returns:**
+- `int` - 新消息数量
+
+
+### 13. 将消息列表构建成可读的字符串
+```python
+def build_readable_messages_to_str(
+ messages: List[Dict[str, Any]],
+ replace_bot_name: bool = True,
+ merge_messages: bool = False,
+ timestamp_mode: str = "relative",
+ read_mark: float = 0.0,
+ truncate: bool = False,
+ show_actions: bool = False,
+) -> str:
+```
+将消息列表构建成可读的字符串。
+
+**Args:**
- `messages` (List[Dict[str, Any]]): 消息列表
-- `replace_bot_name` (bool): 是否将机器人的名称替换为"你",默认True
-- `merge_messages` (bool): 是否合并连续消息,默认False
-- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`,默认`"relative"`
-- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息,默认0.0
-- `truncate` (bool): 是否截断长消息,默认False
-- `show_actions` (bool): 是否显示动作记录,默认False
+- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
+- `merge_messages` (bool): 是否合并连续消息
+- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`
+- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息
+- `truncate` (bool): 是否截断长消息
+- `show_actions` (bool): 是否显示动作记录
-**返回:** `str` - 格式化后的可读字符串
+**Returns:**
+- `str` - 格式化后的可读字符串
-**示例:**
+
+### 14. 将消息列表构建成可读的字符串,并返回详细信息
```python
-# 获取消息并格式化为可读文本
-messages = message_api.get_recent_messages("123456789", hours=2)
-readable_text = message_api.build_readable_messages_to_str(
- messages,
- replace_bot_name=True,
- merge_messages=True,
- timestamp_mode="relative"
-)
-print(readable_text)
+async def build_readable_messages_with_details(
+ messages: List[Dict[str, Any]],
+ replace_bot_name: bool = True,
+ merge_messages: bool = False,
+ timestamp_mode: str = "relative",
+ truncate: bool = False,
+) -> Tuple[str, List[Tuple[float, str, str]]]:
```
+将消息列表构建成可读的字符串,并返回详细信息。
-### `build_readable_messages_with_details(messages, **options)` 异步
+**Args:**
+- `messages` (List[Dict[str, Any]]): 消息列表
+- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
+- `merge_messages` (bool): 是否合并连续消息
+- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`
+- `truncate` (bool): 是否截断长消息
-将消息列表构建成可读的字符串,并返回详细信息
+**Returns:**
+- `Tuple[str, List[Tuple[float, str, str]]]` - 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
-**参数:** 与 `build_readable_messages_to_str` 类似,但不包含 `read_mark` 和 `show_actions`
-**返回:** `Tuple[str, List[Tuple[float, str, str]]]` - 格式化字符串和详细信息元组列表(时间戳, 昵称, 内容)
-
-**示例:**
+### 15. 从消息列表中提取不重复的用户ID列表
```python
-# 异步获取详细格式化信息
-readable_text, details = await message_api.build_readable_messages_with_details(
- messages,
- timestamp_mode="absolute"
-)
-
-for timestamp, nickname, content in details:
- print(f"{timestamp}: {nickname} 说: {content}")
+async def get_person_ids_from_messages(
+ messages: List[Dict[str, Any]],
+) -> List[str]:
```
+从消息列表中提取不重复的用户ID列表。
-### `get_person_ids_from_messages(messages)` 异步
-
-从消息列表中提取不重复的用户ID列表
-
-**参数:**
+**Args:**
- `messages` (List[Dict[str, Any]]): 消息列表
-**返回:** `List[str]` - 用户ID列表
+**Returns:**
+- `List[str]` - 用户ID列表
-**示例:**
+
+### 16. 从消息列表中移除机器人的消息
```python
-# 获取参与对话的所有用户ID
-messages = message_api.get_recent_messages("123456789")
-person_ids = await message_api.get_person_ids_from_messages(messages)
-print(f"参与对话的用户: {person_ids}")
+def filter_mai_messages(
+ messages: List[Dict[str, Any]],
+) -> List[Dict[str, Any]]:
```
+从消息列表中移除机器人的消息。
----
+**Args:**
+- `messages` (List[Dict[str, Any]]): 消息列表,每个元素是消息字典
-## 完整使用示例
-
-### 场景1:统计活跃度
-
-```python
-import time
-from src.plugin_system.apis import message_api
-
-async def analyze_chat_activity(chat_id: str):
- """分析聊天活跃度"""
- now = time.time()
- day_ago = now - 24 * 3600
-
- # 获取最近24小时的消息
- messages = message_api.get_recent_messages(chat_id, hours=24)
-
- # 统计消息数量
- total_count = len(messages)
-
- # 获取参与用户
- person_ids = await message_api.get_person_ids_from_messages(messages)
-
- # 格式化消息内容
- readable_text = message_api.build_readable_messages_to_str(
- messages[-10:], # 最后10条消息
- merge_messages=True,
- timestamp_mode="relative"
- )
-
- return {
- "total_messages": total_count,
- "active_users": len(person_ids),
- "recent_chat": readable_text
- }
-```
-
-### 场景2:查看特定用户的历史消息
-
-```python
-def get_user_history(chat_id: str, user_id: str, days: int = 7):
- """获取用户最近N天的消息历史"""
- now = time.time()
- start_time = now - days * 24 * 3600
-
- # 获取特定用户的消息
- user_messages = message_api.get_messages_by_time_in_chat_for_users(
- chat_id=chat_id,
- start_time=start_time,
- end_time=now,
- person_ids=[user_id],
- limit=100
- )
-
- # 格式化为可读文本
- readable_history = message_api.build_readable_messages_to_str(
- user_messages,
- replace_bot_name=False,
- timestamp_mode="absolute"
- )
-
- return readable_history
-```
-
----
+**Returns:**
+- `List[Dict[str, Any]]` - 过滤后的消息列表
## 注意事项
1. **时间戳格式**:所有时间参数都使用Unix时间戳(float类型)
-2. **异步函数**:`build_readable_messages_with_details` 和 `get_person_ids_from_messages` 是异步函数,需要使用 `await`
+2. **异步函数**:部分函数是异步函数,需要使用 `await`
3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数
4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息
5. **用户ID**:`person_ids` 参数接受字符串列表,用于筛选特定用户的消息
\ No newline at end of file
diff --git a/docs/plugins/api/person-api.md b/docs/plugins/api/person-api.md
index 3e1bafaf..f97498dc 100644
--- a/docs/plugins/api/person-api.md
+++ b/docs/plugins/api/person-api.md
@@ -6,59 +6,65 @@
```python
from src.plugin_system.apis import person_api
+# 或者
+from src.plugin_system import person_api
```
## 主要功能
-### 1. Person ID管理
-
-#### `get_person_id(platform: str, user_id: int) -> str`
+### 1. Person ID 获取
+```python
+def get_person_id(platform: str, user_id: int) -> str:
+```
根据平台和用户ID获取person_id
-**参数:**
+**Args:**
- `platform`:平台名称,如 "qq", "telegram" 等
- `user_id`:用户ID
-**返回:**
+**Returns:**
- `str`:唯一的person_id(MD5哈希值)
-**示例:**
+#### 示例
```python
person_id = person_api.get_person_id("qq", 123456)
-print(f"Person ID: {person_id}")
```
### 2. 用户信息查询
+```python
+async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
+```
+查询单个用户信息字段值
-#### `get_person_value(person_id: str, field_name: str, default: Any = None) -> Any`
-根据person_id和字段名获取某个值
-
-**参数:**
+**Args:**
- `person_id`:用户的唯一标识ID
-- `field_name`:要获取的字段名,如 "nickname", "impression" 等
-- `default`:当字段不存在或获取失败时返回的默认值
+- `field_name`:要获取的字段名
+- `default`:字段值不存在时的默认值
-**返回:**
+**Returns:**
- `Any`:字段值或默认值
-**示例:**
+#### 示例
```python
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression")
```
-#### `get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict`
+### 3. 批量用户信息查询
+```python
+async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
+```
批量获取用户信息字段值
-**参数:**
+**Args:**
- `person_id`:用户的唯一标识ID
- `field_names`:要获取的字段名列表
- `default_dict`:默认值字典,键为字段名,值为默认值
-**返回:**
+**Returns:**
- `dict`:字段名到值的映射字典
-**示例:**
+#### 示例
```python
values = await person_api.get_person_values(
person_id,
@@ -67,204 +73,31 @@ values = await person_api.get_person_values(
)
```
-### 3. 用户状态查询
-
-#### `is_person_known(platform: str, user_id: int) -> bool`
+### 4. 判断用户是否已知
+```python
+async def is_person_known(platform: str, user_id: int) -> bool:
+```
判断是否认识某个用户
-**参数:**
+**Args:**
- `platform`:平台名称
- `user_id`:用户ID
-**返回:**
+**Returns:**
- `bool`:是否认识该用户
-**示例:**
+### 5. 根据用户名获取Person ID
```python
-known = await person_api.is_person_known("qq", 123456)
-if known:
- print("这个用户我认识")
+def get_person_id_by_name(person_name: str) -> str:
```
-
-### 4. 用户名查询
-
-#### `get_person_id_by_name(person_name: str) -> str`
根据用户名获取person_id
-**参数:**
+**Args:**
- `person_name`:用户名
-**返回:**
+**Returns:**
- `str`:person_id,如果未找到返回空字符串
-**示例:**
-```python
-person_id = person_api.get_person_id_by_name("张三")
-if person_id:
- print(f"找到用户: {person_id}")
-```
-
-## 使用示例
-
-### 1. 基础用户信息获取
-
-```python
-from src.plugin_system.apis import person_api
-
-async def get_user_info(platform: str, user_id: int):
- """获取用户基本信息"""
-
- # 获取person_id
- person_id = person_api.get_person_id(platform, user_id)
-
- # 获取用户信息
- user_info = await person_api.get_person_values(
- person_id,
- ["nickname", "impression", "know_times", "last_seen"],
- {
- "nickname": "未知用户",
- "impression": "",
- "know_times": 0,
- "last_seen": 0
- }
- )
-
- return {
- "person_id": person_id,
- "nickname": user_info["nickname"],
- "impression": user_info["impression"],
- "know_times": user_info["know_times"],
- "last_seen": user_info["last_seen"]
- }
-```
-
-### 2. 在Action中使用用户信息
-
-```python
-from src.plugin_system.base import BaseAction
-
-class PersonalizedAction(BaseAction):
- async def execute(self, action_data, chat_stream):
- # 获取发送者信息
- user_id = chat_stream.user_info.user_id
- platform = chat_stream.platform
-
- # 获取person_id
- person_id = person_api.get_person_id(platform, user_id)
-
- # 获取用户昵称和印象
- nickname = await person_api.get_person_value(person_id, "nickname", "朋友")
- impression = await person_api.get_person_value(person_id, "impression", "")
-
- # 根据用户信息个性化回复
- if impression:
- response = f"你好 {nickname}!根据我对你的了解:{impression}"
- else:
- response = f"你好 {nickname}!很高兴见到你。"
-
- return {
- "success": True,
- "response": response,
- "user_info": {
- "nickname": nickname,
- "impression": impression
- }
- }
-```
-
-### 3. 用户识别和欢迎
-
-```python
-async def welcome_user(chat_stream):
- """欢迎用户,区分新老用户"""
-
- user_id = chat_stream.user_info.user_id
- platform = chat_stream.platform
-
- # 检查是否认识这个用户
- is_known = await person_api.is_person_known(platform, user_id)
-
- if is_known:
- # 老用户,获取详细信息
- person_id = person_api.get_person_id(platform, user_id)
- nickname = await person_api.get_person_value(person_id, "nickname", "老朋友")
- know_times = await person_api.get_person_value(person_id, "know_times", 0)
-
- welcome_msg = f"欢迎回来,{nickname}!我们已经聊过 {know_times} 次了。"
- else:
- # 新用户
- welcome_msg = "你好!很高兴认识你,我是MaiBot。"
-
- return welcome_msg
-```
-
-### 4. 用户搜索功能
-
-```python
-async def find_user_by_name(name: str):
- """根据名字查找用户"""
-
- person_id = person_api.get_person_id_by_name(name)
-
- if not person_id:
- return {"found": False, "message": f"未找到名为 '{name}' 的用户"}
-
- # 获取用户详细信息
- user_info = await person_api.get_person_values(
- person_id,
- ["nickname", "platform", "user_id", "impression", "know_times"],
- {}
- )
-
- return {
- "found": True,
- "person_id": person_id,
- "info": user_info
- }
-```
-
-### 5. 用户印象分析
-
-```python
-async def analyze_user_relationship(chat_stream):
- """分析用户关系"""
-
- user_id = chat_stream.user_info.user_id
- platform = chat_stream.platform
- person_id = person_api.get_person_id(platform, user_id)
-
- # 获取关系相关信息
- relationship_info = await person_api.get_person_values(
- person_id,
- ["nickname", "impression", "know_times", "relationship_level", "last_interaction"],
- {
- "nickname": "未知",
- "impression": "",
- "know_times": 0,
- "relationship_level": "stranger",
- "last_interaction": 0
- }
- )
-
- # 分析关系程度
- know_times = relationship_info["know_times"]
- if know_times == 0:
- relationship = "陌生人"
- elif know_times < 5:
- relationship = "新朋友"
- elif know_times < 20:
- relationship = "熟人"
- else:
- relationship = "老朋友"
-
- return {
- "nickname": relationship_info["nickname"],
- "relationship": relationship,
- "impression": relationship_info["impression"],
- "interaction_count": know_times
- }
-```
-
## 常用字段说明
### 基础信息字段
@@ -274,69 +107,13 @@ async def analyze_user_relationship(chat_stream):
### 关系信息字段
- `impression`:对用户的印象
-- `know_times`:交互次数
-- `relationship_level`:关系等级
-- `last_seen`:最后见面时间
-- `last_interaction`:最后交互时间
+- `points`: 用户特征点
-### 个性化字段
-- `preferences`:用户偏好
-- `interests`:兴趣爱好
-- `mood_history`:情绪历史
-- `topic_interests`:话题兴趣
-
-## 最佳实践
-
-### 1. 错误处理
-```python
-async def safe_get_user_info(person_id: str, field: str):
- """安全获取用户信息"""
- try:
- value = await person_api.get_person_value(person_id, field)
- return value if value is not None else "未设置"
- except Exception as e:
- logger.error(f"获取用户信息失败: {e}")
- return "获取失败"
-```
-
-### 2. 批量操作
-```python
-async def get_complete_user_profile(person_id: str):
- """获取完整用户档案"""
-
- # 一次性获取所有需要的字段
- fields = [
- "nickname", "impression", "know_times",
- "preferences", "interests", "relationship_level"
- ]
-
- defaults = {
- "nickname": "用户",
- "impression": "",
- "know_times": 0,
- "preferences": "{}",
- "interests": "[]",
- "relationship_level": "stranger"
- }
-
- profile = await person_api.get_person_values(person_id, fields, defaults)
-
- # 处理JSON字段
- try:
- profile["preferences"] = json.loads(profile["preferences"])
- profile["interests"] = json.loads(profile["interests"])
- except:
- profile["preferences"] = {}
- profile["interests"] = []
-
- return profile
-```
+其他字段可以参考`PersonInfo`类的属性(位于`src.common.database.database_model`)
## 注意事项
-1. **异步操作**:大部分查询函数都是异步的,需要使用`await`
-2. **错误处理**:所有函数都有错误处理,失败时记录日志并返回默认值
-3. **数据类型**:返回的数据可能是字符串、数字或JSON,需要适当处理
-4. **性能考虑**:批量查询优于单个查询
-5. **隐私保护**:确保用户信息的使用符合隐私政策
-6. **数据一致性**:person_id是用户的唯一标识,应妥善保存和使用
\ No newline at end of file
+1. **异步操作**:部分查询函数都是异步的,需要使用`await`
+2. **性能考虑**:批量查询优于单个查询
+3. **隐私保护**:确保用户信息的使用符合隐私政策
+4. **数据一致性**:person_id是用户的唯一标识,应妥善保存和使用
\ No newline at end of file
diff --git a/docs/plugins/api/plugin-manage-api.md b/docs/plugins/api/plugin-manage-api.md
new file mode 100644
index 00000000..688ea9ef
--- /dev/null
+++ b/docs/plugins/api/plugin-manage-api.md
@@ -0,0 +1,105 @@
+# 插件管理API
+
+插件管理API模块提供了对插件的加载、卸载、重新加载以及目录管理功能。
+
+## 导入方式
+```python
+from src.plugin_system.apis import plugin_manage_api
+# 或者
+from src.plugin_system import plugin_manage_api
+```
+
+## 功能概述
+
+插件管理API主要提供以下功能:
+- **插件查询** - 列出当前加载的插件或已注册的插件。
+- **插件管理** - 加载、卸载、重新加载插件。
+- **插件目录管理** - 添加插件目录并重新扫描。
+
+## 主要功能
+
+### 1. 列出当前加载的插件
+```python
+def list_loaded_plugins() -> List[str]:
+```
+列出所有当前加载的插件。
+
+**Returns:**
+- `List[str]` - 当前加载的插件名称列表。
+
+### 2. 列出所有已注册的插件
+```python
+def list_registered_plugins() -> List[str]:
+```
+列出所有已注册的插件。
+
+**Returns:**
+- `List[str]` - 已注册的插件名称列表。
+
+### 3. 获取插件路径
+```python
+def get_plugin_path(plugin_name: str) -> str:
+```
+获取指定插件的路径。
+
+**Args:**
+- `plugin_name` (str): 要查询的插件名称。
+**Returns:**
+- `str` - 插件的路径,如果插件不存在则 raise ValueError。
+
+### 4. 卸载指定的插件
+```python
+async def remove_plugin(plugin_name: str) -> bool:
+```
+卸载指定的插件。
+
+**Args:**
+- `plugin_name` (str): 要卸载的插件名称。
+
+**Returns:**
+- `bool` - 卸载是否成功。
+
+### 5. 重新加载指定的插件
+```python
+async def reload_plugin(plugin_name: str) -> bool:
+```
+重新加载指定的插件。
+
+**Args:**
+- `plugin_name` (str): 要重新加载的插件名称。
+
+**Returns:**
+- `bool` - 重新加载是否成功。
+
+### 6. 加载指定的插件
+```python
+def load_plugin(plugin_name: str) -> Tuple[bool, int]:
+```
+加载指定的插件。
+
+**Args:**
+- `plugin_name` (str): 要加载的插件名称。
+
+**Returns:**
+- `Tuple[bool, int]` - 加载是否成功,成功或失败的个数。
+
+### 7. 添加插件目录
+```python
+def add_plugin_directory(plugin_directory: str) -> bool:
+```
+添加插件目录。
+
+**Args:**
+- `plugin_directory` (str): 要添加的插件目录路径。
+
+**Returns:**
+- `bool` - 添加是否成功。
+
+### 8. 重新扫描插件目录
+```python
+def rescan_plugin_directory() -> Tuple[int, int]:
+```
+重新扫描插件目录,加载新插件。
+
+**Returns:**
+- `Tuple[int, int]` - 成功加载的插件数量和失败的插件数量。
\ No newline at end of file
diff --git a/docs/plugins/api/send-api.md b/docs/plugins/api/send-api.md
index 79335c61..8b3c607f 100644
--- a/docs/plugins/api/send-api.md
+++ b/docs/plugins/api/send-api.md
@@ -6,86 +6,108 @@
```python
from src.plugin_system.apis import send_api
+# 或者
+from src.plugin_system import send_api
```
## 主要功能
-### 1. 文本消息发送
+### 1. 发送文本消息
+```python
+async def text_to_stream(
+ text: str,
+ stream_id: str,
+ typing: bool = False,
+ reply_to: str = "",
+ storage_message: bool = True,
+) -> bool:
+```
+发送文本消息到指定的流
-#### `text_to_group(text, group_id, platform="qq", typing=False, reply_to="", storage_message=True)`
-向群聊发送文本消息
+**Args:**
+- `text` (str): 要发送的文本内容
+- `stream_id` (str): 聊天流ID
+- `typing` (bool): 是否显示正在输入
+- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
+- `storage_message` (bool): 是否存储消息到数据库
-**参数:**
-- `text`:要发送的文本内容
-- `group_id`:群聊ID
-- `platform`:平台,默认为"qq"
-- `typing`:是否显示正在输入
-- `reply_to`:回复消息的格式,如"发送者:消息内容"
-- `storage_message`:是否存储到数据库
+**Returns:**
+- `bool` - 是否发送成功
-**返回:**
-- `bool`:是否发送成功
+### 2. 发送表情包
+```python
+async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
+```
+向指定流发送表情包。
-#### `text_to_user(text, user_id, platform="qq", typing=False, reply_to="", storage_message=True)`
-向用户发送私聊文本消息
+**Args:**
+- `emoji_base64` (str): 表情包的base64编码
+- `stream_id` (str): 聊天流ID
+- `storage_message` (bool): 是否存储消息到数据库
-**参数与返回值同上**
+**Returns:**
+- `bool` - 是否发送成功
-### 2. 表情包发送
+### 3. 发送图片
+```python
+async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
+```
+向指定流发送图片。
-#### `emoji_to_group(emoji_base64, group_id, platform="qq", storage_message=True)`
-向群聊发送表情包
+**Args:**
+- `image_base64` (str): 图片的base64编码
+- `stream_id` (str): 聊天流ID
+- `storage_message` (bool): 是否存储消息到数据库
-**参数:**
-- `emoji_base64`:表情包的base64编码
-- `group_id`:群聊ID
-- `platform`:平台,默认为"qq"
-- `storage_message`:是否存储到数据库
+**Returns:**
+- `bool` - 是否发送成功
-#### `emoji_to_user(emoji_base64, user_id, platform="qq", storage_message=True)`
-向用户发送表情包
+### 4. 发送命令
+```python
+async def command_to_stream(command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "") -> bool:
+```
+向指定流发送命令。
-### 3. 图片发送
+**Args:**
+- `command` (Union[str, dict]): 命令内容
+- `stream_id` (str): 聊天流ID
+- `storage_message` (bool): 是否存储消息到数据库
+- `display_message` (str): 显示消息
-#### `image_to_group(image_base64, group_id, platform="qq", storage_message=True)`
-向群聊发送图片
+**Returns:**
+- `bool` - 是否发送成功
-#### `image_to_user(image_base64, user_id, platform="qq", storage_message=True)`
-向用户发送图片
+### 5. 发送自定义类型消息
+```python
+async def custom_to_stream(
+ message_type: str,
+ content: str,
+ stream_id: str,
+ display_message: str = "",
+ typing: bool = False,
+ reply_to: str = "",
+ storage_message: bool = True,
+ show_log: bool = True,
+) -> bool:
+```
+向指定流发送自定义类型消息。
-### 4. 命令发送
+**Args:**
+- `message_type` (str): 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
+- `content` (str): 消息内容(通常是base64编码或文本)
+- `stream_id` (str): 聊天流ID
+- `display_message` (str): 显示消息
+- `typing` (bool): 是否显示正在输入
+- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
+- `storage_message` (bool): 是否存储消息到数据库
+- `show_log` (bool): 是否显示日志
-#### `command_to_group(command, group_id, platform="qq", storage_message=True)`
-向群聊发送命令
-
-#### `command_to_user(command, user_id, platform="qq", storage_message=True)`
-向用户发送命令
-
-### 5. 自定义消息发送
-
-#### `custom_to_group(message_type, content, group_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
-向群聊发送自定义类型消息
-
-#### `custom_to_user(message_type, content, user_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
-向用户发送自定义类型消息
-
-#### `custom_message(message_type, content, target_id, is_group=True, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
-通用的自定义消息发送
-
-**参数:**
-- `message_type`:消息类型,如"text"、"image"、"emoji"等
-- `content`:消息内容
-- `target_id`:目标ID(群ID或用户ID)
-- `is_group`:是否为群聊
-- `platform`:平台
-- `display_message`:显示消息
-- `typing`:是否显示正在输入
-- `reply_to`:回复消息
-- `storage_message`:是否存储
+**Returns:**
+- `bool` - 是否发送成功
## 使用示例
-### 1. 基础文本发送
+### 1. 基础文本发送,并回复消息
```python
from src.plugin_system.apis import send_api
@@ -93,57 +115,23 @@ from src.plugin_system.apis import send_api
async def send_hello(chat_stream):
"""发送问候消息"""
- if chat_stream.group_info:
- # 群聊
- success = await send_api.text_to_group(
- text="大家好!",
- group_id=chat_stream.group_info.group_id,
- typing=True
- )
- else:
- # 私聊
- success = await send_api.text_to_user(
- text="你好!",
- user_id=chat_stream.user_info.user_id,
- typing=True
- )
+ success = await send_api.text_to_stream(
+ text="Hello, world!",
+ stream_id=chat_stream.stream_id,
+ typing=True,
+ reply_to="User:How are you?",
+ storage_message=True
+ )
return success
```
-### 2. 回复特定消息
-
-```python
-async def reply_to_message(chat_stream, reply_text, original_sender, original_message):
- """回复特定消息"""
-
- # 构建回复格式
- reply_to = f"{original_sender}:{original_message}"
-
- if chat_stream.group_info:
- success = await send_api.text_to_group(
- text=reply_text,
- group_id=chat_stream.group_info.group_id,
- reply_to=reply_to
- )
- else:
- success = await send_api.text_to_user(
- text=reply_text,
- user_id=chat_stream.user_info.user_id,
- reply_to=reply_to
- )
-
- return success
-```
-
-### 3. 发送表情包
+### 2. 发送表情包
```python
+from src.plugin_system.apis import emoji_api
async def send_emoji_reaction(chat_stream, emotion):
"""根据情感发送表情包"""
-
- from src.plugin_system.apis import emoji_api
-
# 获取表情包
emoji_result = await emoji_api.get_by_emotion(emotion)
if not emoji_result:
@@ -152,107 +140,10 @@ async def send_emoji_reaction(chat_stream, emotion):
emoji_base64, description, matched_emotion = emoji_result
# 发送表情包
- if chat_stream.group_info:
- success = await send_api.emoji_to_group(
- emoji_base64=emoji_base64,
- group_id=chat_stream.group_info.group_id
- )
- else:
- success = await send_api.emoji_to_user(
- emoji_base64=emoji_base64,
- user_id=chat_stream.user_info.user_id
- )
-
- return success
-```
-
-### 4. 在Action中发送消息
-
-```python
-from src.plugin_system.base import BaseAction
-
-class MessageAction(BaseAction):
- async def execute(self, action_data, chat_stream):
- message_type = action_data.get("type", "text")
- content = action_data.get("content", "")
-
- if message_type == "text":
- success = await self.send_text(chat_stream, content)
- elif message_type == "emoji":
- success = await self.send_emoji(chat_stream, content)
- elif message_type == "image":
- success = await self.send_image(chat_stream, content)
- else:
- success = False
-
- return {"success": success}
-
- async def send_text(self, chat_stream, text):
- if chat_stream.group_info:
- return await send_api.text_to_group(text, chat_stream.group_info.group_id)
- else:
- return await send_api.text_to_user(text, chat_stream.user_info.user_id)
-
- async def send_emoji(self, chat_stream, emoji_base64):
- if chat_stream.group_info:
- return await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
- else:
- return await send_api.emoji_to_user(emoji_base64, chat_stream.user_info.user_id)
-
- async def send_image(self, chat_stream, image_base64):
- if chat_stream.group_info:
- return await send_api.image_to_group(image_base64, chat_stream.group_info.group_id)
- else:
- return await send_api.image_to_user(image_base64, chat_stream.user_info.user_id)
-```
-
-### 5. 批量发送消息
-
-```python
-async def broadcast_message(message: str, target_groups: list):
- """向多个群组广播消息"""
-
- results = {}
-
- for group_id in target_groups:
- try:
- success = await send_api.text_to_group(
- text=message,
- group_id=group_id,
- typing=True
- )
- results[group_id] = success
- except Exception as e:
- results[group_id] = False
- print(f"发送到群 {group_id} 失败: {e}")
-
- return results
-```
-
-### 6. 智能消息发送
-
-```python
-async def smart_send(chat_stream, message_data):
- """智能发送不同类型的消息"""
-
- message_type = message_data.get("type", "text")
- content = message_data.get("content", "")
- options = message_data.get("options", {})
-
- # 根据聊天流类型选择发送方法
- target_id = (chat_stream.group_info.group_id if chat_stream.group_info
- else chat_stream.user_info.user_id)
- is_group = chat_stream.group_info is not None
-
- # 使用通用发送方法
- success = await send_api.custom_message(
- message_type=message_type,
- content=content,
- target_id=target_id,
- is_group=is_group,
- typing=options.get("typing", False),
- reply_to=options.get("reply_to", ""),
- display_message=options.get("display_message", "")
+ success = await send_api.emoji_to_stream(
+ emoji_base64=emoji_base64,
+ stream_id=chat_stream.stream_id,
+ storage_message=False # 不存储到数据库
)
return success
@@ -273,90 +164,6 @@ async def smart_send(chat_stream, message_data):
系统会自动查找匹配的原始消息并进行回复。
-## 高级用法
-
-### 1. 消息发送队列
-
-```python
-import asyncio
-
-class MessageQueue:
- def __init__(self):
- self.queue = asyncio.Queue()
- self.running = False
-
- async def add_message(self, chat_stream, message_type, content, options=None):
- """添加消息到队列"""
- message_item = {
- "chat_stream": chat_stream,
- "type": message_type,
- "content": content,
- "options": options or {}
- }
- await self.queue.put(message_item)
-
- async def process_queue(self):
- """处理消息队列"""
- self.running = True
-
- while self.running:
- try:
- message_item = await asyncio.wait_for(self.queue.get(), timeout=1.0)
-
- # 发送消息
- success = await smart_send(
- message_item["chat_stream"],
- {
- "type": message_item["type"],
- "content": message_item["content"],
- "options": message_item["options"]
- }
- )
-
- # 标记任务完成
- self.queue.task_done()
-
- # 发送间隔
- await asyncio.sleep(0.5)
-
- except asyncio.TimeoutError:
- continue
- except Exception as e:
- print(f"处理消息队列出错: {e}")
-```
-
-### 2. 消息模板系统
-
-```python
-class MessageTemplate:
- def __init__(self):
- self.templates = {
- "welcome": "欢迎 {nickname} 加入群聊!",
- "goodbye": "{nickname} 离开了群聊。",
- "notification": "🔔 通知:{message}",
- "error": "❌ 错误:{error_message}",
- "success": "✅ 成功:{message}"
- }
-
- def format_message(self, template_name: str, **kwargs) -> str:
- """格式化消息模板"""
- template = self.templates.get(template_name, "{message}")
- return template.format(**kwargs)
-
- async def send_template(self, chat_stream, template_name: str, **kwargs):
- """发送模板消息"""
- message = self.format_message(template_name, **kwargs)
-
- if chat_stream.group_info:
- return await send_api.text_to_group(message, chat_stream.group_info.group_id)
- else:
- return await send_api.text_to_user(message, chat_stream.user_info.user_id)
-
-# 使用示例
-template_system = MessageTemplate()
-await template_system.send_template(chat_stream, "welcome", nickname="张三")
-```
-
## 注意事项
1. **异步操作**:所有发送函数都是异步的,必须使用`await`
diff --git a/docs/plugins/api/tool-api.md b/docs/plugins/api/tool-api.md
new file mode 100644
index 00000000..d86734fc
--- /dev/null
+++ b/docs/plugins/api/tool-api.md
@@ -0,0 +1,55 @@
+# 工具API
+
+工具API模块提供了获取和管理工具实例的功能,让插件能够访问系统中注册的工具。
+
+## 导入方式
+
+```python
+from src.plugin_system.apis import tool_api
+# 或者
+from src.plugin_system import tool_api
+```
+
+## 主要功能
+
+### 1. 获取工具实例
+
+```python
+def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
+```
+
+获取指定名称的工具实例。
+
+**Args**:
+- `tool_name`: 工具名称字符串
+
+**Returns**:
+- `Optional[BaseTool]`: 工具实例,如果工具不存在则返回 None
+
+### 2. 获取LLM可用的工具定义
+
+```python
+def get_llm_available_tool_definitions():
+```
+
+获取所有LLM可用的工具定义列表。
+
+**Returns**:
+- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组
+ - 其具体定义请参照[tool-components.md](../tool-components.md)中的工具定义格式。
+#### 示例:
+
+```python
+# 获取所有LLM可用的工具定义
+tools = tool_api.get_llm_available_tool_definitions()
+for tool_name, tool_definition in tools:
+ print(f"工具: {tool_name}")
+ print(f"定义: {tool_definition}")
+```
+
+## 注意事项
+
+1. **工具存在性检查**:使用前请检查工具实例是否为 None
+2. **权限控制**:某些工具可能有使用权限限制
+3. **异步调用**:大多数工具方法是异步的,需要使用 await
+4. **错误处理**:调用工具时请做好异常处理
diff --git a/docs/plugins/api/utils-api.md b/docs/plugins/api/utils-api.md
deleted file mode 100644
index bbab092e..00000000
--- a/docs/plugins/api/utils-api.md
+++ /dev/null
@@ -1,435 +0,0 @@
-# 工具API
-
-工具API模块提供了各种辅助功能,包括文件操作、时间处理、唯一ID生成等常用工具函数。
-
-## 导入方式
-
-```python
-from src.plugin_system.apis import utils_api
-```
-
-## 主要功能
-
-### 1. 文件操作
-
-#### `get_plugin_path(caller_frame=None) -> str`
-获取调用者插件的路径
-
-**参数:**
-- `caller_frame`:调用者的栈帧,默认为None(自动获取)
-
-**返回:**
-- `str`:插件目录的绝对路径
-
-**示例:**
-```python
-plugin_path = utils_api.get_plugin_path()
-print(f"插件路径: {plugin_path}")
-```
-
-#### `read_json_file(file_path: str, default: Any = None) -> Any`
-读取JSON文件
-
-**参数:**
-- `file_path`:文件路径,可以是相对于插件目录的路径
-- `default`:如果文件不存在或读取失败时返回的默认值
-
-**返回:**
-- `Any`:JSON数据或默认值
-
-**示例:**
-```python
-# 读取插件配置文件
-config = utils_api.read_json_file("config.json", {})
-settings = utils_api.read_json_file("data/settings.json", {"enabled": True})
-```
-
-#### `write_json_file(file_path: str, data: Any, indent: int = 2) -> bool`
-写入JSON文件
-
-**参数:**
-- `file_path`:文件路径,可以是相对于插件目录的路径
-- `data`:要写入的数据
-- `indent`:JSON缩进
-
-**返回:**
-- `bool`:是否写入成功
-
-**示例:**
-```python
-data = {"name": "test", "value": 123}
-success = utils_api.write_json_file("output.json", data)
-```
-
-### 2. 时间相关
-
-#### `get_timestamp() -> int`
-获取当前时间戳
-
-**返回:**
-- `int`:当前时间戳(秒)
-
-#### `format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str`
-格式化时间
-
-**参数:**
-- `timestamp`:时间戳,如果为None则使用当前时间
-- `format_str`:时间格式字符串
-
-**返回:**
-- `str`:格式化后的时间字符串
-
-#### `parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int`
-解析时间字符串为时间戳
-
-**参数:**
-- `time_str`:时间字符串
-- `format_str`:时间格式字符串
-
-**返回:**
-- `int`:时间戳(秒)
-
-### 3. 其他工具
-
-#### `generate_unique_id() -> str`
-生成唯一ID
-
-**返回:**
-- `str`:唯一ID
-
-## 使用示例
-
-### 1. 插件数据管理
-
-```python
-from src.plugin_system.apis import utils_api
-
-class DataPlugin(BasePlugin):
- def __init__(self):
- self.plugin_path = utils_api.get_plugin_path()
- self.data_file = "plugin_data.json"
- self.load_data()
-
- def load_data(self):
- """加载插件数据"""
- default_data = {
- "users": {},
- "settings": {"enabled": True},
- "stats": {"message_count": 0}
- }
- self.data = utils_api.read_json_file(self.data_file, default_data)
-
- def save_data(self):
- """保存插件数据"""
- return utils_api.write_json_file(self.data_file, self.data)
-
- async def handle_action(self, action_data, chat_stream):
- # 更新统计信息
- self.data["stats"]["message_count"] += 1
- self.data["stats"]["last_update"] = utils_api.get_timestamp()
-
- # 保存数据
- if self.save_data():
- return {"success": True, "message": "数据已保存"}
- else:
- return {"success": False, "message": "数据保存失败"}
-```
-
-### 2. 日志记录系统
-
-```python
-class PluginLogger:
- def __init__(self, plugin_name: str):
- self.plugin_name = plugin_name
- self.log_file = f"{plugin_name}_log.json"
- self.logs = utils_api.read_json_file(self.log_file, [])
-
- def log_event(self, event_type: str, message: str, data: dict = None):
- """记录事件"""
- log_entry = {
- "id": utils_api.generate_unique_id(),
- "timestamp": utils_api.get_timestamp(),
- "formatted_time": utils_api.format_time(),
- "event_type": event_type,
- "message": message,
- "data": data or {}
- }
-
- self.logs.append(log_entry)
-
- # 保持最新的100条记录
- if len(self.logs) > 100:
- self.logs = self.logs[-100:]
-
- # 保存到文件
- utils_api.write_json_file(self.log_file, self.logs)
-
- def get_logs_by_type(self, event_type: str) -> list:
- """获取指定类型的日志"""
- return [log for log in self.logs if log["event_type"] == event_type]
-
- def get_recent_logs(self, count: int = 10) -> list:
- """获取最近的日志"""
- return self.logs[-count:]
-
-# 使用示例
-logger = PluginLogger("my_plugin")
-logger.log_event("user_action", "用户发送了消息", {"user_id": "123", "message": "hello"})
-```
-
-### 3. 配置管理系统
-
-```python
-class ConfigManager:
- def __init__(self, config_file: str = "plugin_config.json"):
- self.config_file = config_file
- self.default_config = {
- "enabled": True,
- "debug": False,
- "max_users": 100,
- "response_delay": 1.0,
- "features": {
- "auto_reply": True,
- "logging": True
- }
- }
- self.config = self.load_config()
-
- def load_config(self) -> dict:
- """加载配置"""
- return utils_api.read_json_file(self.config_file, self.default_config)
-
- def save_config(self) -> bool:
- """保存配置"""
- return utils_api.write_json_file(self.config_file, self.config, indent=4)
-
- def get(self, key: str, default=None):
- """获取配置值,支持嵌套访问"""
- keys = key.split('.')
- value = self.config
-
- for k in keys:
- if isinstance(value, dict) and k in value:
- value = value[k]
- else:
- return default
-
- return value
-
- def set(self, key: str, value):
- """设置配置值,支持嵌套设置"""
- keys = key.split('.')
- config = self.config
-
- for k in keys[:-1]:
- if k not in config:
- config[k] = {}
- config = config[k]
-
- config[keys[-1]] = value
-
- def update_config(self, updates: dict):
- """批量更新配置"""
- def deep_update(base, updates):
- for key, value in updates.items():
- if isinstance(value, dict) and key in base and isinstance(base[key], dict):
- deep_update(base[key], value)
- else:
- base[key] = value
-
- deep_update(self.config, updates)
-
-# 使用示例
-config = ConfigManager()
-print(f"调试模式: {config.get('debug', False)}")
-print(f"自动回复: {config.get('features.auto_reply', True)}")
-
-config.set('features.new_feature', True)
-config.save_config()
-```
-
-### 4. 缓存系统
-
-```python
-class PluginCache:
- def __init__(self, cache_file: str = "plugin_cache.json", ttl: int = 3600):
- self.cache_file = cache_file
- self.ttl = ttl # 缓存过期时间(秒)
- self.cache = self.load_cache()
-
- def load_cache(self) -> dict:
- """加载缓存"""
- return utils_api.read_json_file(self.cache_file, {})
-
- def save_cache(self):
- """保存缓存"""
- return utils_api.write_json_file(self.cache_file, self.cache)
-
- def get(self, key: str):
- """获取缓存值"""
- if key not in self.cache:
- return None
-
- item = self.cache[key]
- current_time = utils_api.get_timestamp()
-
- # 检查是否过期
- if current_time - item["timestamp"] > self.ttl:
- del self.cache[key]
- return None
-
- return item["value"]
-
- def set(self, key: str, value):
- """设置缓存值"""
- self.cache[key] = {
- "value": value,
- "timestamp": utils_api.get_timestamp()
- }
- self.save_cache()
-
- def clear_expired(self):
- """清理过期缓存"""
- current_time = utils_api.get_timestamp()
- expired_keys = []
-
- for key, item in self.cache.items():
- if current_time - item["timestamp"] > self.ttl:
- expired_keys.append(key)
-
- for key in expired_keys:
- del self.cache[key]
-
- if expired_keys:
- self.save_cache()
-
- return len(expired_keys)
-
-# 使用示例
-cache = PluginCache(ttl=1800) # 30分钟过期
-cache.set("user_data_123", {"name": "张三", "score": 100})
-user_data = cache.get("user_data_123")
-```
-
-### 5. 时间处理工具
-
-```python
-class TimeHelper:
- @staticmethod
- def get_time_info():
- """获取当前时间的详细信息"""
- timestamp = utils_api.get_timestamp()
- return {
- "timestamp": timestamp,
- "datetime": utils_api.format_time(timestamp),
- "date": utils_api.format_time(timestamp, "%Y-%m-%d"),
- "time": utils_api.format_time(timestamp, "%H:%M:%S"),
- "year": utils_api.format_time(timestamp, "%Y"),
- "month": utils_api.format_time(timestamp, "%m"),
- "day": utils_api.format_time(timestamp, "%d"),
- "weekday": utils_api.format_time(timestamp, "%A")
- }
-
- @staticmethod
- def time_ago(timestamp: int) -> str:
- """计算时间差"""
- current = utils_api.get_timestamp()
- diff = current - timestamp
-
- if diff < 60:
- return f"{diff}秒前"
- elif diff < 3600:
- return f"{diff // 60}分钟前"
- elif diff < 86400:
- return f"{diff // 3600}小时前"
- else:
- return f"{diff // 86400}天前"
-
- @staticmethod
- def parse_duration(duration_str: str) -> int:
- """解析时间段字符串,返回秒数"""
- import re
-
- pattern = r'(\d+)([smhd])'
- matches = re.findall(pattern, duration_str.lower())
-
- total_seconds = 0
- for value, unit in matches:
- value = int(value)
- if unit == 's':
- total_seconds += value
- elif unit == 'm':
- total_seconds += value * 60
- elif unit == 'h':
- total_seconds += value * 3600
- elif unit == 'd':
- total_seconds += value * 86400
-
- return total_seconds
-
-# 使用示例
-time_info = TimeHelper.get_time_info()
-print(f"当前时间: {time_info['datetime']}")
-
-last_seen = 1699000000
-print(f"最后见面: {TimeHelper.time_ago(last_seen)}")
-
-duration = TimeHelper.parse_duration("1h30m") # 1小时30分钟 = 5400秒
-```
-
-## 最佳实践
-
-### 1. 错误处理
-```python
-def safe_file_operation(file_path: str, data: dict):
- """安全的文件操作"""
- try:
- success = utils_api.write_json_file(file_path, data)
- if not success:
- logger.warning(f"文件写入失败: {file_path}")
- return success
- except Exception as e:
- logger.error(f"文件操作出错: {e}")
- return False
-```
-
-### 2. 路径处理
-```python
-import os
-
-def get_data_path(filename: str) -> str:
- """获取数据文件的完整路径"""
- plugin_path = utils_api.get_plugin_path()
- data_dir = os.path.join(plugin_path, "data")
-
- # 确保数据目录存在
- os.makedirs(data_dir, exist_ok=True)
-
- return os.path.join(data_dir, filename)
-```
-
-### 3. 定期清理
-```python
-async def cleanup_old_files():
- """清理旧文件"""
- plugin_path = utils_api.get_plugin_path()
- current_time = utils_api.get_timestamp()
-
- for filename in os.listdir(plugin_path):
- if filename.endswith('.tmp'):
- file_path = os.path.join(plugin_path, filename)
- file_time = os.path.getmtime(file_path)
-
- # 删除超过24小时的临时文件
- if current_time - file_time > 86400:
- os.remove(file_path)
-```
-
-## 注意事项
-
-1. **相对路径**:文件路径支持相对于插件目录的路径
-2. **自动创建目录**:写入文件时会自动创建必要的目录
-3. **错误处理**:所有函数都有错误处理,失败时返回默认值
-4. **编码格式**:文件读写使用UTF-8编码
-5. **时间格式**:时间戳使用秒为单位
-6. **JSON格式**:JSON文件使用可读性好的缩进格式
\ No newline at end of file
diff --git a/docs/plugins/index.md b/docs/plugins/index.md
index af8fad85..2454c98a 100644
--- a/docs/plugins/index.md
+++ b/docs/plugins/index.md
@@ -10,6 +10,7 @@
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
+- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构
@@ -43,24 +44,24 @@ Command vs Action 选择指南
- [LLM API](api/llm-api.md) - 大语言模型交互接口,可以使用内置LLM生成内容
- [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器
-### 表情包api
+### 表情包API
- [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口
-### 关系系统api
+### 关系系统API
- [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口
### 数据与配置API
- [🗄️ 数据库API](api/database-api.md) - 数据库操作接口
- [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口
+### 插件和组件管理API
+- [🔌 插件API](api/plugin-manage-api.md) - 插件加载和管理接口
+- [🧩 组件API](api/component-manage-api.md) - 组件注册和管理接口
+
+### 日志API
+- [📜 日志API](api/logging-api.md) - logger实例获取接口
### 工具API
-- [工具API](api/utils-api.md) - 文件操作、时间处理等工具函数
-
-
-## 实验性
-
-这些功能将在未来重构或移除
-- [🔧 工具系统详解](tool-system.md) - 工具系统的使用和开发
+- [🔧 工具API](api/tool-api.md) - tool获取接口
diff --git a/docs/plugins/tool-system.md b/docs/plugins/tool-components.md
similarity index 58%
rename from docs/plugins/tool-system.md
rename to docs/plugins/tool-components.md
index baa43528..059656aa 100644
--- a/docs/plugins/tool-system.md
+++ b/docs/plugins/tool-components.md
@@ -1,10 +1,10 @@
-# 🔧 工具系统详解
+# 🔧 工具组件详解
-## 📖 什么是工具系统
+## 📖 什么是工具
-工具系统是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
+工具是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
-### 🎯 工具系统的特点
+### 🎯 工具的特点
- 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力
- 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据
@@ -20,14 +20,11 @@
| **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 |
| **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 |
-## 🏗️ 工具基本结构
-
-### 必要组件
+## 🏗️ Tool组件的基本结构
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
-
```python
-from src.tools.tool_can_use.base_tool import BaseTool, register_tool
+from src.plugin_system import BaseTool, ToolParamType
class MyTool(BaseTool):
# 工具名称,必须唯一
@@ -36,21 +33,29 @@ class MyTool(BaseTool):
# 工具描述,告诉LLM这个工具的用途
description = "这个工具用于获取特定类型的信息"
- # 参数定义,遵循JSONSchema格式
- parameters = {
- "type": "object",
- "properties": {
- "query": {
- "type": "string",
- "description": "查询参数"
- },
- "limit": {
- "type": "integer",
- "description": "结果数量限制"
- }
- },
- "required": ["query"]
- }
+ # 参数定义,仅定义参数
+ # 比如想要定义一个类似下面的openai格式的参数表,则可以这么定义:
+ # {
+ # "type": "object",
+ # "properties": {
+ # "query": {
+ # "type": "string",
+ # "description": "查询参数"
+ # },
+ # "limit": {
+ # "type": "integer",
+ # "description": "结果数量限制"
+ # "enum": [10, 20, 50] # 可选值
+ # }
+ # },
+ # "required": ["query"]
+ # }
+ parameters = [
+ ("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数
+ ("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数
+ ]
+
+ available_for_llm = True # 是否对LLM可用
async def execute(self, function_args: Dict[str, Any]):
"""执行工具逻辑"""
@@ -69,7 +74,12 @@ class MyTool(BaseTool):
|-----|------|------|
| `name` | str | 工具的唯一标识名称 |
| `description` | str | 工具功能描述,帮助LLM理解用途 |
-| `parameters` | dict | JSONSchema格式的参数定义 |
+| `parameters` | list[tuple] | 参数定义 |
+
+其构造而成的工具定义为:
+```python
+{"name": cls.name, "description": cls.description, "parameters": cls.parameters}
+```
### 方法说明
@@ -77,15 +87,6 @@ class MyTool(BaseTool):
|-----|------|--------|------|
| `execute` | `function_args` | `dict` | 执行工具核心逻辑 |
-## 🔄 自动注册机制
-
-工具系统采用自动发现和注册机制:
-
-1. **文件扫描**:系统自动遍历 `tool_can_use` 目录中的所有Python文件
-2. **类识别**:寻找继承自 `BaseTool` 的工具类
-3. **自动注册**:只需要实现对应的类并把文件放在正确文件夹中就可自动注册
-4. **即用即加载**:工具在需要时被实例化和调用
-
---
## 🎨 完整工具示例
@@ -93,7 +94,7 @@ class MyTool(BaseTool):
完成一个天气查询工具
```python
-from src.tools.tool_can_use.base_tool import BaseTool, register_tool
+from src.plugin_system import BaseTool
import aiohttp
import json
@@ -102,23 +103,13 @@ class WeatherTool(BaseTool):
name = "weather_query"
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
+ available_for_llm = True # 允许LLM调用此工具
+ parameters = [
+ ("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None),
+ ("country", ToolParamType.STRING, "国家代码,如:CN、US,可选参数", False, None)
+ ]
- parameters = {
- "type": "object",
- "properties": {
- "city": {
- "type": "string",
- "description": "要查询天气的城市名称,如:北京、上海、纽约"
- },
- "country": {
- "type": "string",
- "description": "国家代码,如:CN、US,可选参数"
- }
- },
- "required": ["city"]
- }
-
- async def execute(self, function_args, message_txt=""):
+ async def execute(self, function_args: dict):
"""执行天气查询"""
try:
city = function_args.get("city")
@@ -177,55 +168,12 @@ class WeatherTool(BaseTool):
---
-## 📊 工具开发步骤
-
-### 1. 创建工具文件
-
-在 `src/tools/tool_can_use/` 目录下创建新的Python文件:
-
-```bash
-# 例如创建 my_new_tool.py
-touch src/tools/tool_can_use/my_new_tool.py
-```
-
-### 2. 实现工具类
-
-```python
-from src.tools.tool_can_use.base_tool import BaseTool, register_tool
-
-class MyNewTool(BaseTool):
- name = "my_new_tool"
- description = "新工具的功能描述"
-
- parameters = {
- "type": "object",
- "properties": {
- # 定义参数
- },
- "required": []
- }
-
- async def execute(self, function_args, message_txt=""):
- # 实现工具逻辑
- return {
- "name": self.name,
- "content": "执行结果"
- }
-```
-
-### 3. 系统集成
-
-工具创建完成后,系统会自动发现和注册,无需额外配置。
-
----
-
## 🚨 注意事项和限制
### 当前限制
-1. **独立开发**:需要单独编写,暂未完全融入插件系统
-2. **适用范围**:主要适用于信息获取场景
-3. **配置要求**:必须开启工具处理器
+1. **适用范围**:主要适用于信息获取场景
+2. **配置要求**:必须开启工具处理器
### 开发建议
@@ -238,66 +186,49 @@ class MyNewTool(BaseTool):
## 🎯 最佳实践
### 1. 工具命名规范
-
+#### ✅ 好的命名
```python
-# ✅ 好的命名
name = "weather_query" # 清晰表达功能
name = "knowledge_search" # 描述性强
name = "stock_price_check" # 功能明确
-
-# ❌ 避免的命名
+```
+#### ❌ 避免的命名
+```python
name = "tool1" # 无意义
name = "wq" # 过于简短
name = "weather_and_news" # 功能过于复杂
```
### 2. 描述规范
-
+#### ✅ 良好的描述
```python
-# ✅ 好的描述
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况"
-
-# ❌ 避免的描述
+```
+#### ❌ 避免的描述
+```python
description = "天气" # 过于简单
description = "获取信息" # 不够具体
```
### 3. 参数设计
+#### ✅ 合理的参数设计
```python
-# ✅ 合理的参数设计
-parameters = {
- "type": "object",
- "properties": {
- "city": {
- "type": "string",
- "description": "城市名称,如:北京、上海"
- },
- "unit": {
- "type": "string",
- "description": "温度单位:celsius(摄氏度) 或 fahrenheit(华氏度)",
- "enum": ["celsius", "fahrenheit"]
- }
- },
- "required": ["city"]
-}
-
-# ❌ 避免的参数设计
-parameters = {
- "type": "object",
- "properties": {
- "data": {
- "type": "string",
- "description": "数据" # 描述不清晰
- }
- }
-}
+parameters = [
+ ("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None),
+ ("unit", ToolParamType.STRING, "温度单位:celsius 或 fahrenheit", False, ["celsius", "fahrenheit"])
+]
+```
+#### ❌ 避免的参数设计
+```python
+parameters = [
+ ("data", "string", "数据", True) # 参数过于模糊
+]
```
### 4. 结果格式化
-
+#### ✅ 良好的结果格式
```python
-# ✅ 良好的结果格式
def _format_result(self, data):
return f"""
🔍 查询结果
@@ -307,12 +238,9 @@ def _format_result(self, data):
📝 说明: {data['description']}
━━━━━━━━━━━━
""".strip()
-
-# ❌ 避免的结果格式
+```
+#### ❌ 避免的结果格式
+```python
def _format_result(self, data):
return str(data) # 直接返回原始数据
```
-
----
-
-🎉 **工具系统为麦麦提供了强大的信息获取能力!合理使用工具可以让麦麦变得更加智能和博学。**
\ No newline at end of file
diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py
index 8ede9616..f9855481 100644
--- a/plugins/hello_world_plugin/plugin.py
+++ b/plugins/hello_world_plugin/plugin.py
@@ -1,18 +1,55 @@
-from typing import List, Tuple, Type
+from typing import List, Tuple, Type, Any
from src.plugin_system import (
BasePlugin,
register_plugin,
BaseAction,
BaseCommand,
+ BaseTool,
ComponentInfo,
ActionActivationType,
ConfigField,
BaseEventHandler,
EventType,
MaiMessages,
+ ToolParamType
)
+class CompareNumbersTool(BaseTool):
+ """比较两个数大小的工具"""
+
+ name = "compare_numbers"
+ description = "使用工具 比较两个数的大小,返回较大的数"
+ parameters = [
+ ("num1", ToolParamType.FLOAT, "第一个数字", True, None),
+ ("num2", ToolParamType.FLOAT, "第二个数字", True, None),
+ ]
+
+ async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
+ """执行比较两个数的大小
+
+ Args:
+ function_args: 工具参数
+
+ Returns:
+ dict: 工具执行结果
+ """
+ num1: int | float = function_args.get("num1") # type: ignore
+ num2: int | float = function_args.get("num2") # type: ignore
+
+ try:
+ if num1 > num2:
+ result = f"{num1} 大于 {num2}"
+ elif num1 < num2:
+ result = f"{num1} 小于 {num2}"
+ else:
+ result = f"{num1} 等于 {num2}"
+
+ return {"name": self.name, "content": result}
+ except Exception as e:
+ return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
+
+
# ===== Action组件 =====
class HelloAction(BaseAction):
"""问候Action - 简单的问候动作"""
@@ -132,7 +169,9 @@ class HelloWorldPlugin(BasePlugin):
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
},
"greeting": {
- "message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"),
+ "message": ConfigField(
+ type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息"
+ ),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
},
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
@@ -142,6 +181,7 @@ class HelloWorldPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [
(HelloAction.get_action_info(), HelloAction),
+ (CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage),
diff --git a/requirements.txt b/requirements.txt
index a09637a9..999bd5fd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -15,6 +15,7 @@ matplotlib
networkx
numpy
openai
+google-genai
pandas
peewee
pyarrow
diff --git a/scripts/import_openie.py b/scripts/import_openie.py
index 63a4d985..1177650d 100644
--- a/scripts/import_openie.py
+++ b/scripts/import_openie.py
@@ -24,46 +24,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入")
-ENV_FILE = os.path.join(ROOT_PATH, ".env")
-
-if os.path.exists(".env"):
- load_dotenv(".env", override=True)
- print("成功加载环境变量配置")
-else:
- print("未找到.env文件,请确保程序所需的环境变量被正确设置")
- raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
-
-env_mask = {key: os.getenv(key) for key in os.environ}
-def scan_provider(env_config: dict):
- provider = {}
-
- # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
- # 避免 GPG_KEY 这样的变量干扰检查
- env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
-
- # 遍历 env_config 的所有键
- for key in env_config:
- # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
- if key.endswith("_BASE_URL") or key.endswith("_KEY"):
- # 提取 provider 名称
- provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
-
- # 初始化 provider 的字典(如果尚未初始化)
- if provider_name not in provider:
- provider[provider_name] = {"url": None, "key": None}
-
- # 根据键的类型填充 url 或 key
- if key.endswith("_BASE_URL"):
- provider[provider_name]["url"] = env_config[key]
- elif key.endswith("_KEY"):
- provider[provider_name]["key"] = env_config[key]
-
- # 检查每个 provider 是否同时存在 url 和 key
- for provider_name, config in provider.items():
- if config["url"] is None or config["key"] is None:
- logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
- raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
-
def ensure_openie_dir():
"""确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR):
@@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
def main(): # sourcery skip: dict-comprehension
# 新增确认提示
- env_config = {key: os.getenv(key) for key in os.environ}
- scan_provider(env_config)
print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py
index c36a7789..47ad55a8 100644
--- a/scripts/info_extraction.py
+++ b/scripts/info_extraction.py
@@ -25,9 +25,8 @@ from rich.progress import (
TextColumn,
)
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
-from dotenv import load_dotenv
logger = get_logger("LPMM知识库-信息提取")
@@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
-ENV_FILE = os.path.join(ROOT_PATH, ".env")
-
-if os.path.exists(".env"):
- load_dotenv(".env", override=True)
- print("成功加载环境变量配置")
-else:
- print("未找到.env文件,请确保程序所需的环境变量被正确设置")
- raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
-
-env_mask = {key: os.getenv(key) for key in os.environ}
-def scan_provider(env_config: dict):
- provider = {}
-
- # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
- # 避免 GPG_KEY 这样的变量干扰检查
- env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
-
- # 遍历 env_config 的所有键
- for key in env_config:
- # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
- if key.endswith("_BASE_URL") or key.endswith("_KEY"):
- # 提取 provider 名称
- provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
-
- # 初始化 provider 的字典(如果尚未初始化)
- if provider_name not in provider:
- provider[provider_name] = {"url": None, "key": None}
-
- # 根据键的类型填充 url 或 key
- if key.endswith("_BASE_URL"):
- provider[provider_name]["url"] = env_config[key]
- elif key.endswith("_KEY"):
- provider[provider_name]["key"] = env_config[key]
-
- # 检查每个 provider 是否同时存在 url 和 key
- for provider_name, config in provider.items():
- if config["url"] is None or config["key"] is None:
- logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
- raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
@@ -96,11 +56,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
- model=global_config.model.lpmm_entity_extract,
+ model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
- model=global_config.model.lpmm_rdf_build,
+ model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build"
)
def process_single_text(pg_hash, raw_data):
@@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器
signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在
- env_config = {key: os.getenv(key) for key in os.environ}
- scan_provider(env_config)
# 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py
index efa8f69b..d51fa96b 100644
--- a/src/chat/chat_loop/heartFC_chat.py
+++ b/src/chat/chat_loop/heartFC_chat.py
@@ -414,7 +414,7 @@ class HeartFChatting:
else:
logger.warning(f"{self.log_prefix} 预生成的回复任务未生成有效内容")
- action_message: Dict[str, Any] = message_data or target_message # type: ignore
+ action_message = message_data or target_message
if action_type == "reply":
# 等待回复生成完毕
if self.loop_mode == ChatMode.NORMAL:
diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py
index 918b8396..6d50d890 100644
--- a/src/chat/emoji_system/emoji_manager.py
+++ b/src/chat/emoji_system/emoji_manager.py
@@ -8,15 +8,15 @@ import traceback
import io
import re
import binascii
+
from typing import Optional, Tuple, List, Any
from PIL import Image
from rich.traceback import install
-
from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest
@@ -379,9 +379,9 @@ class EmojiManager:
self._scan_task = None
- self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji")
+ self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
self.llm_emotion_judge = LLMRequest(
- model=global_config.model.utils, max_tokens=600, request_type="emoji"
+ model_set=model_config.model_task_config.utils, request_type="emoji"
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
self.emoji_num = 0
@@ -492,6 +492,7 @@ class EmojiManager:
return None
def _levenshtein_distance(self, s1: str, s2: str) -> int:
+ # sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
"""计算两个字符串的编辑距离
Args:
@@ -629,11 +630,11 @@ class EmojiManager:
if success:
# 注册成功则跳出循环
break
- else:
- # 注册失败则删除对应文件
- file_path = os.path.join(EMOJI_DIR, filename)
- os.remove(file_path)
- logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
+
+ # 注册失败则删除对应文件
+ file_path = os.path.join(EMOJI_DIR, filename)
+ os.remove(file_path)
+ logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
@@ -694,6 +695,7 @@ class EmojiManager:
return []
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
+ # sourcery skip: use-next
"""从内存中的 emoji_objects 列表获取表情包
参数:
@@ -709,10 +711,10 @@ class EmojiManager:
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述
-
+
Args:
emoji_hash: 表情包的哈希值
-
+
Returns:
Optional[str]: 表情包描述,如果未找到则返回None
"""
@@ -722,7 +724,7 @@ class EmojiManager:
if emoji and emoji.description:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
return emoji.description
-
+
# 如果内存中没有,从数据库查找
self._ensure_db()
try:
@@ -732,9 +734,9 @@ class EmojiManager:
return emoji_record.description
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")
-
+
return None
-
+
except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
return None
@@ -779,6 +781,7 @@ class EmojiManager:
return False
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
+ # sourcery skip: use-getitem-for-re-match-groups
"""替换一个表情包
Args:
@@ -820,7 +823,7 @@ class EmojiManager:
)
# 调用大模型进行决策
- decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8)
+ decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600)
logger.info(f"[决策] 结果: {decision}")
# 解析决策结果
@@ -828,9 +831,7 @@ class EmojiManager:
logger.info("[决策] 不删除任何表情包")
return False
- # 尝试从决策中提取表情包编号
- match = re.search(r"删除编号(\d+)", decision)
- if match:
+ if match := re.search(r"删除编号(\d+)", decision):
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
# 检查索引是否有效
@@ -889,6 +890,7 @@ class EmojiManager:
existing_description = None
try:
from src.common.database.database_model import Images
+
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
if existing_image and existing_image.description:
existing_description = existing_image.description
@@ -902,15 +904,21 @@ class EmojiManager:
logger.info("[优化] 复用已有的详细描述,跳过VLM调用")
else:
logger.info("[VLM分析] 生成新的详细描述")
- if image_format == "gif" or image_format == "GIF":
+ if image_format in ["gif", "GIF"]:
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64:
raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
- description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg")
+ description, _ = await self.vlm.generate_response_for_image(
+ prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
+ )
else:
- prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
- description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
+ prompt = (
+ "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
+ )
+ description, _ = await self.vlm.generate_response_for_image(
+ prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
+ )
# 审核表情包
if global_config.emoji.content_filtration:
@@ -922,7 +930,9 @@ class EmojiManager:
4. 不要出现5个以上文字
请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容
'''
- content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
+ content, _ = await self.vlm.generate_response_for_image(
+ prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
+ )
if content == "否":
return "", []
@@ -933,7 +943,9 @@ class EmojiManager:
你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析
请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔
"""
- emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7)
+ emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
+ emotion_prompt, temperature=0.7, max_tokens=600
+ )
# 处理情感列表
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py
index 1870c470..a9808503 100644
--- a/src/chat/express/expression_learner.py
+++ b/src/chat/express/expression_learner.py
@@ -7,12 +7,12 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
+from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import model_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
-from src.common.database.database_model import Expression
MAX_EXPRESSION_COUNT = 300
@@ -80,11 +80,8 @@ def init_prompt() -> None:
class ExpressionLearner:
def __init__(self) -> None:
- # TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest(
- model=global_config.model.replyer_1,
- temperature=0.3,
- request_type="expressor.learner",
+ model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner"
)
self.llm_model = None
self._ensure_expression_directories()
@@ -101,7 +98,7 @@ class ExpressionLearner:
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
-
+
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
@@ -116,7 +113,7 @@ class ExpressionLearner:
"""
base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done")
-
+
# 确保基础目录存在
try:
os.makedirs(base_dir, exist_ok=True)
@@ -124,28 +121,28 @@ class ExpressionLearner:
except Exception as e:
logger.error(f"创建表达方式目录失败: {e}")
return
-
+
if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移,无需重复迁移。")
return
-
+
logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0
-
+
for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}")
continue
-
+
try:
chat_ids = os.listdir(type_dir)
logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}")
continue
-
+
for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file):
@@ -153,24 +150,24 @@ class ExpressionLearner:
try:
with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f)
-
+
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue
-
+
for expr in expressions:
if not isinstance(expr, dict):
continue
-
+
situation = expr.get("situation")
style_val = expr.get("style")
count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time())
-
+
if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue
-
+
# 查重:同chat_id+type+situation+style
from src.common.database.database_model import Expression
@@ -201,7 +198,7 @@ class ExpressionLearner:
logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
-
+
# 标记迁移完成
try:
# 确保done.done文件的父目录存在
@@ -209,7 +206,7 @@ class ExpressionLearner:
if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
-
+
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
@@ -229,13 +226,13 @@ class ExpressionLearner:
# 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null())
updated_count = 0
-
+
for expr in old_expressions:
# 使用last_active_time作为create_date
expr.create_date = expr.last_active_time
expr.save()
updated_count += 1
-
+
if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e:
@@ -287,25 +284,29 @@ class ExpressionLearner:
获取指定chat_id的表达方式创建信息,按创建日期排序
"""
try:
- expressions = (Expression.select()
- .where(Expression.chat_id == chat_id)
- .order_by(Expression.create_date.desc())
- .limit(limit))
-
+ expressions = (
+ Expression.select()
+ .where(Expression.chat_id == chat_id)
+ .order_by(Expression.create_date.desc())
+ .limit(limit)
+ )
+
result = []
for expr in expressions:
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
- result.append({
- "situation": expr.situation,
- "style": expr.style,
- "type": expr.type,
- "count": expr.count,
- "create_date": create_date,
- "create_date_formatted": format_create_date(create_date),
- "last_active_time": expr.last_active_time,
- "last_active_formatted": format_create_date(expr.last_active_time),
- })
-
+ result.append(
+ {
+ "situation": expr.situation,
+ "style": expr.style,
+ "type": expr.type,
+ "count": expr.count,
+ "create_date": create_date,
+ "create_date_formatted": format_create_date(create_date),
+ "last_active_time": expr.last_active_time,
+ "last_active_formatted": format_create_date(expr.last_active_time),
+ }
+ )
+
return result
except Exception as e:
logger.error(f"获取表达方式创建信息失败: {e}")
@@ -355,19 +356,19 @@ class ExpressionLearner:
try:
# 获取所有表达方式
all_expressions = Expression.select()
-
+
updated_count = 0
deleted_count = 0
-
+
for expr in all_expressions:
# 计算时间差
last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
-
+
# 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value)
-
+
if new_count <= 0.01:
# 如果count太小,删除这个表达方式
expr.delete_instance()
@@ -377,10 +378,10 @@ class ExpressionLearner:
expr.count = new_count
expr.save()
updated_count += 1
-
+
if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
-
+
except Exception as e:
logger.error(f"数据库全局衰减失败: {e}")
@@ -527,7 +528,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的prompt: {prompt}")
try:
- response, _ = await self.express_learn_model.generate_response_async(prompt)
+ response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习{type_str}失败: {e}")
return None
diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py
index 910b43c2..111225c8 100644
--- a/src/chat/express/expression_selector.py
+++ b/src/chat/express/expression_selector.py
@@ -1,16 +1,17 @@
import json
import time
import random
+import hashlib
from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.common.logger import get_logger
+from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner
-from src.common.database.database_model import Expression
logger = get_logger("expression_selector")
@@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
class ExpressionSelector:
def __init__(self):
self.expression_learner = get_expression_learner()
- # TODO: API-Adapter修改标记
self.llm_model = LLMRequest(
- model=global_config.model.utils_small,
- request_type="expression.selector",
+ model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
)
@staticmethod
@@ -92,7 +91,6 @@ class ExpressionSelector:
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
- import hashlib
if is_group:
components = [platform, str(id_str)]
else:
@@ -108,8 +106,7 @@ class ExpressionSelector:
for group in groups:
group_chat_ids = []
for stream_config_str in group:
- chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str)
- if chat_id_candidate:
+ if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids:
return group_chat_ids
@@ -118,9 +115,10 @@ class ExpressionSelector:
def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
+ # sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
-
+
# 优化:一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
@@ -128,7 +126,7 @@ class ExpressionSelector:
grammar_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
)
-
+
style_exprs = [
{
"situation": expr.situation,
@@ -138,9 +136,10 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
- } for expr in style_query
+ }
+ for expr in style_query
]
-
+
grammar_exprs = [
{
"situation": expr.situation,
@@ -150,9 +149,10 @@ class ExpressionSelector:
"source_id": expr.chat_id,
"type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
- } for expr in grammar_query
+ }
+ for expr in grammar_query
]
-
+
style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage)
# 按权重抽样(使用count作为权重)
@@ -174,22 +174,22 @@ class ExpressionSelector:
return
updates_by_key = {}
for expr in expressions_to_update:
- source_id = expr.get("source_id")
- expr_type = expr.get("type", "style")
- situation = expr.get("situation")
- style = expr.get("style")
+ source_id: str = expr.get("source_id") # type: ignore
+ expr_type: str = expr.get("type", "style")
+ situation: str = expr.get("situation") # type: ignore
+ style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue
key = (source_id, expr_type, situation, style)
if key not in updates_by_key:
updates_by_key[key] = expr
- for (chat_id, expr_type, situation, style), _expr in updates_by_key.items():
+ for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where(
- (Expression.chat_id == chat_id) &
- (Expression.type == expr_type) &
- (Expression.situation == situation) &
- (Expression.style == style)
+ (Expression.chat_id == chat_id)
+ & (Expression.type == expr_type)
+ & (Expression.situation == situation)
+ & (Expression.style == style)
)
if query.exists():
expr_obj = query.get()
@@ -264,7 +264,7 @@ class ExpressionSelector:
# 4. 调用LLM
try:
- content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt)
+ content, _ = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix} LLM返回结果: {content}")
diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py
index d732683a..d0f6e774 100644
--- a/src/chat/knowledge/embedding_store.py
+++ b/src/chat/knowledge/embedding_store.py
@@ -3,6 +3,7 @@ import json
import os
import math
import asyncio
+from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple
import numpy as np
@@ -11,8 +12,6 @@ import pandas as pd
# import tqdm
import faiss
-# from .llm_client import LLMClient
-# from .lpmmconfig import global_config
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
@@ -26,12 +25,20 @@ from rich.progress import (
SpinnerColumn,
TextColumn,
)
-from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding
from src.config.config import global_config
install(extra_lines=3)
+
+# 多线程embedding配置常量
+DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
+DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
+MIN_CHUNK_SIZE = 1 # 最小分块大小
+MAX_CHUNK_SIZE = 50 # 最大分块大小
+MIN_WORKERS = 1 # 最小线程数
+MAX_WORKERS = 20 # 最大线程数
+
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
@@ -87,13 +94,23 @@ class EmbeddingStoreItem:
class EmbeddingStore:
- def __init__(self, namespace: str, dir_path: str):
+ def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
self.namespace = namespace
self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
+ # 多线程配置参数验证和设置
+ self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
+ self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
+
+ # 如果配置值被调整,记录日志
+ if self.max_workers != max_workers:
+ logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
+ if self.chunk_size != chunk_size:
+ logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
+
self.store = {}
self.faiss_index = None
@@ -125,16 +142,134 @@ class EmbeddingStore:
return []
return result
+ def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
+ """使用多线程批量获取嵌入向量
+
+ Args:
+ strs: 要获取嵌入的字符串列表
+ chunk_size: 每个线程处理的数据块大小
+ max_workers: 最大线程数
+ progress_callback: 进度回调函数,接收一个参数表示完成的数量
+
+ Returns:
+ 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
+ """
+ if not strs:
+ return []
+
+ # 分块
+ chunks = []
+ for i in range(0, len(strs), chunk_size):
+ chunk = strs[i:i + chunk_size]
+ chunks.append((i, chunk)) # 保存起始索引以维持顺序
+
+ # 结果存储,使用字典按索引存储以保证顺序
+ results = {}
+
+ def process_chunk(chunk_data):
+ """处理单个数据块的函数"""
+ start_idx, chunk_strs = chunk_data
+ chunk_results = []
+
+ # 为每个线程创建独立的LLMRequest实例
+ from src.llm_models.utils_model import LLMRequest
+ from src.config.config import model_config
+
+ try:
+ # 创建线程专用的LLM实例
+ llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
+
+ for i, s in enumerate(chunk_strs):
+ try:
+ # 直接使用异步函数
+ embedding = asyncio.run(llm.get_embedding(s))
+ if embedding and len(embedding) > 0:
+ chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
+ else:
+ logger.error(f"获取嵌入失败: {s}")
+ chunk_results.append((start_idx + i, s, []))
+
+ # 每完成一个嵌入立即更新进度
+ if progress_callback:
+ progress_callback(1)
+
+ except Exception as e:
+ logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
+ chunk_results.append((start_idx + i, s, []))
+
+ # 即使失败也要更新进度
+ if progress_callback:
+ progress_callback(1)
+
+ except Exception as e:
+ logger.error(f"创建LLM实例失败: {e}")
+ # 如果创建LLM实例失败,返回空结果
+ for i, s in enumerate(chunk_strs):
+ chunk_results.append((start_idx + i, s, []))
+ # 即使失败也要更新进度
+ if progress_callback:
+ progress_callback(1)
+
+ return chunk_results
+
+ # 使用线程池处理
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ # 提交所有任务
+ future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
+
+ # 收集结果(进度已在process_chunk中实时更新)
+ for future in as_completed(future_to_chunk):
+ try:
+ chunk_results = future.result()
+ for idx, s, embedding in chunk_results:
+ results[idx] = (s, embedding)
+ except Exception as e:
+ chunk = future_to_chunk[future]
+ logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
+ # 为失败的块添加空结果
+ start_idx, chunk_strs = chunk
+ for i, s in enumerate(chunk_strs):
+ results[start_idx + i] = (s, [])
+
+ # 按原始顺序返回结果
+ ordered_results = []
+ for i in range(len(strs)):
+ if i in results:
+ ordered_results.append(results[i])
+ else:
+ # 防止遗漏
+ ordered_results.append((strs[i], []))
+
+ return ordered_results
+
def get_test_file_path(self):
return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self):
- """保存测试字符串的嵌入到本地"""
+ """保存测试字符串的嵌入到本地(使用多线程优化)"""
+ logger.info("开始保存测试字符串的嵌入向量...")
+
+ # 使用多线程批量获取测试字符串的嵌入
+ embedding_results = self._get_embeddings_batch_threaded(
+ EMBEDDING_TEST_STRINGS,
+ chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
+ max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
+ )
+
+ # 构建测试向量字典
test_vectors = {}
- for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
- test_vectors[str(idx)] = self._get_embedding(s)
+ for idx, (s, embedding) in enumerate(embedding_results):
+ if embedding:
+ test_vectors[str(idx)] = embedding
+ else:
+ logger.error(f"获取测试字符串嵌入失败: {s}")
+ # 使用原始单线程方法作为后备
+ test_vectors[str(idx)] = self._get_embedding(s)
+
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
+
+ logger.info("测试字符串嵌入向量保存完成")
def load_embedding_test_vectors(self):
"""加载本地保存的测试字符串嵌入"""
@@ -145,29 +280,64 @@ class EmbeddingStore:
return json.load(f)
def check_embedding_model_consistency(self):
- """校验当前模型与本地嵌入模型是否一致"""
+ """校验当前模型与本地嵌入模型是否一致(使用多线程优化)"""
local_vectors = self.load_embedding_test_vectors()
if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors()
return True
- for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
- local_emb = local_vectors.get(str(idx))
- if local_emb is None:
+
+ # 检查本地向量完整性
+ for idx in range(len(EMBEDDING_TEST_STRINGS)):
+ if local_vectors.get(str(idx)) is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors()
return True
- new_emb = self._get_embedding(s)
+
+ logger.info("开始检验嵌入模型一致性...")
+
+ # 使用多线程批量获取当前模型的嵌入
+ embedding_results = self._get_embeddings_batch_threaded(
+ EMBEDDING_TEST_STRINGS,
+ chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
+ max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
+ )
+
+ # 检查一致性
+ for idx, (s, new_emb) in enumerate(embedding_results):
+ local_emb = local_vectors.get(str(idx))
+ if not new_emb:
+ logger.error(f"获取测试字符串嵌入失败: {s}")
+ return False
+
sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD:
- logger.error("嵌入模型一致性校验失败")
+ logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
return False
+
logger.info("嵌入模型一致性校验通过。")
return True
def batch_insert_strs(self, strs: List[str], times: int) -> None:
- """向库中存入字符串"""
+ """向库中存入字符串(使用多线程优化)"""
+ if not strs:
+ return
+
total = len(strs)
+
+ # 过滤已存在的字符串
+ new_strs = []
+ for s in strs:
+ item_hash = self.namespace + "-" + get_sha256(s)
+ if item_hash not in self.store:
+ new_strs.append(s)
+
+ if not new_strs:
+ logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
+ return
+
+ logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
+
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
@@ -181,19 +351,38 @@ class EmbeddingStore:
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
- for s in strs:
- # 计算hash去重
- item_hash = self.namespace + "-" + get_sha256(s)
- if item_hash in self.store:
- progress.update(task, advance=1)
- continue
-
- # 获取embedding
- embedding = self._get_embedding(s)
-
- # 存入
- self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
- progress.update(task, advance=1)
+
+ # 首先更新已存在项的进度
+ already_processed = total - len(new_strs)
+ if already_processed > 0:
+ progress.update(task, advance=already_processed)
+
+ if new_strs:
+ # 使用实例配置的参数,智能调整分块和线程数
+ optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
+ optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
+
+ logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
+
+ # 定义进度更新回调函数
+ def update_progress(count):
+ progress.update(task, advance=count)
+
+ # 批量获取嵌入,并实时更新进度
+ embedding_results = self._get_embeddings_batch_threaded(
+ new_strs,
+ chunk_size=optimal_chunk_size,
+ max_workers=optimal_max_workers,
+ progress_callback=update_progress
+ )
+
+ # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
+ for s, embedding in embedding_results:
+ item_hash = self.namespace + "-" + get_sha256(s)
+ if embedding: # 只有成功获取到嵌入才存入
+ self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
+ else:
+ logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
def save_to_file(self) -> None:
"""保存到文件"""
@@ -316,31 +505,37 @@ class EmbeddingStore:
class EmbeddingManager:
- def __init__(self):
+ def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
+ """
+ 初始化EmbeddingManager
+
+ Args:
+ max_workers: 最大线程数
+ chunk_size: 每个线程处理的数据块大小
+ """
self.paragraphs_embedding_store = EmbeddingStore(
- local_storage["pg_namespace"], # type: ignore
+ "paragraph", # type: ignore
EMBEDDING_DATA_DIR_STR,
+ max_workers=max_workers,
+ chunk_size=chunk_size,
)
self.entities_embedding_store = EmbeddingStore(
- local_storage["pg_namespace"], # type: ignore
+ "entity", # type: ignore
EMBEDDING_DATA_DIR_STR,
+ max_workers=max_workers,
+ chunk_size=chunk_size,
)
self.relation_embedding_store = EmbeddingStore(
- local_storage["pg_namespace"], # type: ignore
+ "relation", # type: ignore
EMBEDDING_DATA_DIR_STR,
+ max_workers=max_workers,
+ chunk_size=chunk_size,
)
self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self):
"""对所有嵌入库做模型一致性校验"""
- for store in [
- self.paragraphs_embedding_store,
- self.entities_embedding_store,
- self.relation_embedding_store,
- ]:
- if not store.check_embedding_model_consistency():
- return False
- return True
+ return self.paragraphs_embedding_store.check_embedding_model_consistency()
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py
index 16d4e080..340a678d 100644
--- a/src/chat/knowledge/ie_process.py
+++ b/src/chat/knowledge/ie_process.py
@@ -8,12 +8,15 @@ from . import prompt_template
from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json
+
+
def _extract_json_from_text(text: str):
+ # sourcery skip: assign-if-exp, extract-method
"""从文本中提取JSON数据的高容错方法"""
if text is None:
logger.error("输入文本为None")
return []
-
+
try:
fixed_json = repair_json(text)
if isinstance(fixed_json, str):
@@ -24,7 +27,7 @@ def _extract_json_from_text(text: str):
# 如果是列表,直接返回
if isinstance(parsed_json, list):
return parsed_json
-
+
# 如果是字典且只有一个项目,可能包装了列表
if isinstance(parsed_json, dict):
# 如果字典只有一个键,并且值是列表,返回那个列表
@@ -33,7 +36,7 @@ def _extract_json_from_text(text: str):
if isinstance(value, list):
return value
return parsed_json
-
+
# 其他情况,尝试转换为列表
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
return []
@@ -42,44 +45,40 @@ def _extract_json_from_text(text: str):
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
return []
+
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
+ # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
-
+
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
- future = asyncio.run_coroutine_threadsafe(
- llm_req.generate_response_async(entity_extract_context), loop
- )
- response, (reasoning_content, model_name) = future.result()
+ future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
+ response, _ = future.result()
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
- response, (reasoning_content, model_name) = asyncio.run(
- llm_req.generate_response_async(entity_extract_context)
- )
+ response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
# 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}")
-
+
entity_extract_result = _extract_json_from_text(response)
-
+
# 检查返回的是否为有效的实体列表
if not isinstance(entity_extract_result, list):
- # 如果不是列表,可能是字典格式,尝试从中提取列表
- if isinstance(entity_extract_result, dict):
- # 尝试常见的键名
- for key in ['entities', 'result', 'data', 'items']:
- if key in entity_extract_result and isinstance(entity_extract_result[key], list):
- entity_extract_result = entity_extract_result[key]
- break
- else:
- # 如果找不到合适的列表,抛出异常
- raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
+ if not isinstance(entity_extract_result, dict):
+ raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
+
+ # 尝试常见的键名
+ for key in ["entities", "result", "data", "items"]:
+ if key in entity_extract_result and isinstance(entity_extract_result[key], list):
+ entity_extract_result = entity_extract_result[key]
+ break
else:
- raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
-
+ # 如果找不到合适的列表,抛出异常
+ raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 过滤无效实体
entity_extract_result = [
entity
@@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
]
- if len(entity_extract_result) == 0:
- raise Exception("实体提取结果为空")
+ if not entity_extract_result:
+ raise ValueError("实体提取结果为空")
return entity_extract_result
@@ -98,45 +97,44 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
-
+
# 使用 asyncio.run 来运行异步方法
try:
# 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop()
- future = asyncio.run_coroutine_threadsafe(
- llm_req.generate_response_async(rdf_extract_context), loop
- )
- response, (reasoning_content, model_name) = future.result()
+ future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
+ response, _ = future.result()
except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run
- response, (reasoning_content, model_name) = asyncio.run(
- llm_req.generate_response_async(rdf_extract_context)
- )
+ response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
# 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}")
-
+
rdf_triple_result = _extract_json_from_text(response)
-
+
# 检查返回的是否为有效的三元组列表
if not isinstance(rdf_triple_result, list):
- # 如果不是列表,可能是字典格式,尝试从中提取列表
- if isinstance(rdf_triple_result, dict):
- # 尝试常见的键名
- for key in ['triples', 'result', 'data', 'items']:
- if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
- rdf_triple_result = rdf_triple_result[key]
- break
- else:
- # 如果找不到合适的列表,抛出异常
- raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
+ if not isinstance(rdf_triple_result, dict):
+ raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
+
+ # 尝试常见的键名
+ for key in ["triples", "result", "data", "items"]:
+ if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
+ rdf_triple_result = rdf_triple_result[key]
+ break
else:
- raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
-
+ # 如果找不到合适的列表,抛出异常
+ raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
# 验证三元组格式
for triple in rdf_triple_result:
- if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
- raise Exception("RDF提取结果格式错误")
+ if (
+ not isinstance(triple, list)
+ or len(triple) != 3
+ or (triple[0] is None or triple[1] is None or triple[2] is None)
+ or "" in triple
+ ):
+ raise ValueError("RDF提取结果格式错误")
return rdf_triple_result
diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py
index 083a741d..da082e39 100644
--- a/src/chat/knowledge/kg_manager.py
+++ b/src/chat/knowledge/kg_manager.py
@@ -20,8 +20,7 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
-from .lpmmconfig import global_config
-from src.manager.local_store_manager import local_storage
+from src.config.config import global_config
from .global_logger import logger
@@ -30,19 +29,9 @@ def _get_kg_dir():
"""
安全地获取KG数据目录路径
"""
- root_path: str = local_storage["root_path"]
- if root_path is None:
- # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
- current_dir = os.path.dirname(os.path.abspath(__file__))
- root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
- logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
-
- # 获取RAG数据目录
- rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
- if rag_data_dir is None:
- kg_dir = os.path.join(root_path, "data/rag")
- else:
- kg_dir = os.path.join(root_path, rag_data_dir)
+ current_dir = os.path.dirname(os.path.abspath(__file__))
+ root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
+ kg_dir = os.path.join(root_path, "data/rag")
return str(kg_dir).replace("\\", "/")
@@ -65,9 +54,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str()
- self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
- self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
- self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
+ self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml"
+ self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet"
+ self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json"
def save_to_file(self):
"""将KG数据保存到文件"""
@@ -122,8 +111,8 @@ class KGManager:
# 避免自连接
continue
# 一个triple就是一条边(同时构建双向联系)
- hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
- hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
+ hash_key1 = "entity" + "-" + get_sha256(triple[0])
+ hash_key2 = "entity" + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1)
@@ -141,8 +130,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系"""
for idx in triple_list_data:
for triple in triple_list_data[idx]:
- ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
- pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
+ ent_hash_key = "entity" + "-" + get_sha256(triple[0])
+ pg_hash_key = "paragraph" + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod
@@ -157,12 +146,12 @@ class KGManager:
ent_hash_list = set()
for triple_list in triple_list_data.values():
for triple in triple_list:
- ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
- ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
+ ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
+ ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list)
synonym_hash_set = set()
- synonym_result = dict()
+ synonym_result = {}
# rich 进度条
total = len(ent_hash_list)
@@ -190,14 +179,14 @@ class KGManager:
assert isinstance(ent, EmbeddingStoreItem)
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
- ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
+ ent.embedding, global_config.lpmm_knowledge.rag_synonym_search_top_k
)
res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash:
# 避免自连接
continue
- if similarity < global_config["rag"]["params"]["synonym_threshold"]:
+ if similarity < global_config.lpmm_knowledge.rag_synonym_threshold:
# 相似度阈值
continue
node_to_node[(res_ent_hash, ent_hash)] = similarity
@@ -263,7 +252,7 @@ class KGManager:
for src_tgt in node_to_node.keys():
for node_hash in src_tgt:
if node_hash not in existed_nodes:
- if node_hash.startswith(local_storage["ent_namespace"]):
+ if node_hash.startswith("entity"):
# 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None:
@@ -275,7 +264,7 @@ class KGManager:
node_item["type"] = "ent"
node_item["create_time"] = now_time
self.graph.update_node(node_item)
- elif node_hash.startswith(local_storage["pg_namespace"]):
+ elif node_hash.startswith("paragraph"):
# 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None:
@@ -359,7 +348,7 @@ class KGManager:
# 关系三元组
triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]:
- ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
+ ent_hash = "entity" + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = []
@@ -380,7 +369,7 @@ class KGManager:
for ent_hash in ent_weights.keys():
ent_weights[ent_hash] = 1.0
else:
- down_edge = global_config["qa"]["params"]["paragraph_node_weight"]
+ down_edge = global_config.lpmm_knowledge.qa_paragraph_node_weight
# 缩放取值区间至[down_edge, 1]
for ent_hash, score in ent_weights.items():
# 缩放相似度
@@ -389,7 +378,7 @@ class KGManager:
) + down_edge
# 取平均相似度的top_k实体
- top_k = global_config["qa"]["params"]["ent_filter_top_k"]
+ top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k
if len(ent_mean_scores) > top_k:
# 从大到小排序,取后len - k个
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
@@ -418,7 +407,7 @@ class KGManager:
for pg_hash, score in pg_sim_scores.items():
pg_weights[pg_hash] = (
- score * global_config["qa"]["params"]["paragraph_node_weight"]
+ score * global_config.lpmm_knowledge.qa_paragraph_node_weight
) # 文段权重 = 归一化相似度 * 文段节点权重参数
del pg_sim_scores
@@ -431,7 +420,7 @@ class KGManager:
self.graph,
personalization=ppr_node_weights,
max_iter=100,
- alpha=global_config["qa"]["params"]["ppr_damping"],
+ alpha=global_config.lpmm_knowledge.qa_ppr_damping,
)
# 获取最终结果
@@ -439,7 +428,7 @@ class KGManager:
passage_node_res = [
(node_key, score)
for node_key, score in ppr_res.items()
- if node_key.startswith(local_storage["pg_namespace"])
+ if node_key.startswith("paragraph")
]
del ppr_res
diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py
index 1e87d382..13629f18 100644
--- a/src/chat/knowledge/knowledge_lib.py
+++ b/src/chat/knowledge/knowledge_lib.py
@@ -1,12 +1,8 @@
-from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager
-from src.chat.knowledge.llm_client import LLMClient
-from src.chat.knowledge.mem_active_manager import MemoryActiveManager
from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
-from src.config.config import global_config as bot_global_config
-from src.manager.local_store_manager import local_storage
+from src.config.config import global_config
import os
INVALID_ENTITY = [
@@ -21,9 +17,6 @@ INVALID_ENTITY = [
"她们",
"它们",
]
-PG_NAMESPACE = "paragraph"
-ENT_NAMESPACE = "entity"
-REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
@@ -34,67 +27,13 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..",
DATA_PATH = os.path.join(ROOT_PATH, "data")
-def _initialize_knowledge_local_storage():
- """
- 初始化知识库相关的本地存储配置
- 使用字典批量设置,避免重复的if判断
- """
- # 定义所有需要初始化的配置项
- default_configs = {
- # 路径配置
- "root_path": ROOT_PATH,
- "data_path": f"{ROOT_PATH}/data",
- # 实体和命名空间配置
- "lpmm_invalid_entity": INVALID_ENTITY,
- "pg_namespace": PG_NAMESPACE,
- "ent_namespace": ENT_NAMESPACE,
- "rel_namespace": REL_NAMESPACE,
- # RAG相关命名空间配置
- "rag_graph_namespace": RAG_GRAPH_NAMESPACE,
- "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
- "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
- }
-
- # 日志级别映射:重要配置用info,其他用debug
- important_configs = {"root_path", "data_path"}
-
- # 批量设置配置项
- initialized_count = 0
- for key, default_value in default_configs.items():
- if local_storage[key] is None:
- local_storage[key] = default_value
-
- # 根据重要性选择日志级别
- if key in important_configs:
- logger.info(f"设置{key}: {default_value}")
- else:
- logger.debug(f"设置{key}: {default_value}")
-
- initialized_count += 1
-
- if initialized_count > 0:
- logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
- else:
- logger.debug("知识库本地存储配置已存在,跳过初始化")
-
-
-# 初始化本地存储路径
-# sourcery skip: dict-comprehension
-_initialize_knowledge_local_storage()
-
qa_manager = None
inspire_manager = None
# 检查LPMM知识库是否启用
-if bot_global_config.lpmm_knowledge.enable:
+if global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端")
- llm_client_list = {}
- for key in global_config["llm_providers"]:
- llm_client_list[key] = LLMClient(
- global_config["llm_providers"][key]["base_url"], # type: ignore
- global_config["llm_providers"][key]["api_key"], # type: ignore
- )
# 初始化Embedding库
embed_manager = EmbeddingManager()
@@ -120,7 +59,7 @@ if bot_global_config.lpmm_knowledge.enable:
# 数据比对:Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes:
- key = f"{PG_NAMESPACE}-{pg_hash}"
+ key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
@@ -130,11 +69,11 @@ if bot_global_config.lpmm_knowledge.enable:
kg_manager,
)
- # 记忆激活(用于记忆库)
- inspire_manager = MemoryActiveManager(
- embed_manager,
- llm_client_list[global_config["embedding"]["provider"]],
- )
+ # # 记忆激活(用于记忆库)
+ # inspire_manager = MemoryActiveManager(
+ # embed_manager,
+ # llm_client_list[global_config["embedding"]["provider"]],
+ # )
else:
logger.info("LPMM知识库已禁用,跳过初始化")
# 创建空的占位符对象,避免导入错误
diff --git a/src/chat/knowledge/llm_client.py b/src/chat/knowledge/llm_client.py
deleted file mode 100644
index 52d0dca0..00000000
--- a/src/chat/knowledge/llm_client.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from openai import OpenAI
-
-
-class LLMMessage:
- def __init__(self, role, content):
- self.role = role
- self.content = content
-
- def to_dict(self):
- return {"role": self.role, "content": self.content}
-
-
-class LLMClient:
- """LLM客户端,对应一个API服务商"""
-
- def __init__(self, url, api_key):
- self.client = OpenAI(
- base_url=url,
- api_key=api_key,
- )
-
- def send_chat_request(self, model, messages):
- """发送对话请求,等待返回结果"""
- response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
- if hasattr(response.choices[0].message, "reasoning_content"):
- # 有单独的推理内容块
- reasoning_content = response.choices[0].message.reasoning_content
- content = response.choices[0].message.content
- else:
- # 无单独的推理内容块
- response = response.choices[0].message.content.split("")[-1].split("")
- # 如果有推理内容,则分割推理内容和内容
- if len(response) == 2:
- reasoning_content = response[0]
- content = response[1]
- else:
- reasoning_content = None
- content = response[0]
-
- return reasoning_content, content
-
- def send_embedding_request(self, model, text):
- """发送嵌入请求,等待返回结果"""
- text = text.replace("\n", " ")
- return self.client.embeddings.create(input=[text], model=model).data[0].embedding
diff --git a/src/chat/knowledge/lpmmconfig.py b/src/chat/knowledge/lpmmconfig.py
deleted file mode 100644
index 49f77725..00000000
--- a/src/chat/knowledge/lpmmconfig.py
+++ /dev/null
@@ -1,137 +0,0 @@
-import os
-import toml
-import sys
-
-# import argparse
-from .global_logger import logger
-
-PG_NAMESPACE = "paragraph"
-ENT_NAMESPACE = "entity"
-REL_NAMESPACE = "relation"
-
-RAG_GRAPH_NAMESPACE = "rag-graph"
-RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
-RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
-
-# 无效实体
-INVALID_ENTITY = [
- "",
- "你",
- "他",
- "她",
- "它",
- "我们",
- "你们",
- "他们",
- "她们",
- "它们",
-]
-
-
-def _load_config(config, config_file_path):
- """读取TOML格式的配置文件"""
- if not os.path.exists(config_file_path):
- return
- with open(config_file_path, "r", encoding="utf-8") as f:
- file_config = toml.load(f)
-
- # Check if all top-level keys from default config exist in the file config
- for key in config.keys():
- if key not in file_config:
- logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。")
- logger.critical("请通过template/lpmm_config_template.toml文件进行更新")
- sys.exit(1)
-
- if "llm_providers" in file_config:
- for provider in file_config["llm_providers"]:
- if provider["name"] not in config["llm_providers"]:
- config["llm_providers"][provider["name"]] = {}
- config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
- config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
-
- if "entity_extract" in file_config:
- config["entity_extract"] = file_config["entity_extract"]
-
- if "rdf_build" in file_config:
- config["rdf_build"] = file_config["rdf_build"]
-
- if "embedding" in file_config:
- config["embedding"] = file_config["embedding"]
-
- if "rag" in file_config:
- config["rag"] = file_config["rag"]
-
- if "qa" in file_config:
- config["qa"] = file_config["qa"]
-
- if "persistence" in file_config:
- config["persistence"] = file_config["persistence"]
- # print(config)
- logger.info(f"从文件中读取配置: {config_file_path}")
-
-
-global_config = dict(
- {
- "lpmm": {
- "version": "0.1.0",
- },
- "llm_providers": {
- "localhost": {
- "base_url": "https://api.siliconflow.cn/v1",
- "api_key": "sk-ospynxadyorf",
- }
- },
- "entity_extract": {
- "llm": {
- "provider": "localhost",
- "model": "Pro/deepseek-ai/DeepSeek-V3",
- }
- },
- "rdf_build": {
- "llm": {
- "provider": "localhost",
- "model": "Pro/deepseek-ai/DeepSeek-V3",
- }
- },
- "embedding": {
- "provider": "localhost",
- "model": "Pro/BAAI/bge-m3",
- "dimension": 1024,
- },
- "rag": {
- "params": {
- "synonym_search_top_k": 10,
- "synonym_threshold": 0.75,
- }
- },
- "qa": {
- "params": {
- "relation_search_top_k": 10,
- "relation_threshold": 0.75,
- "paragraph_search_top_k": 10,
- "paragraph_node_weight": 0.05,
- "ent_filter_top_k": 10,
- "ppr_damping": 0.8,
- "res_top_k": 10,
- },
- "llm": {
- "provider": "localhost",
- "model": "qa",
- },
- },
- "persistence": {
- "data_root_path": "data",
- "raw_data_path": "data/raw.json",
- "openie_data_path": "data/openie.json",
- "embedding_data_dir": "data/embedding",
- "rag_data_dir": "data/rag",
- },
- "info_extraction": {
- "workers": 10,
- },
- }
-)
-
-ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
-config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
-_load_config(global_config, config_path)
diff --git a/src/chat/knowledge/mem_active_manager.py b/src/chat/knowledge/mem_active_manager.py
index 3998c066..a55b929f 100644
--- a/src/chat/knowledge/mem_active_manager.py
+++ b/src/chat/knowledge/mem_active_manager.py
@@ -1,3 +1,4 @@
+raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py
index c83683b7..5354447a 100644
--- a/src/chat/knowledge/qa_manager.py
+++ b/src/chat/knowledge/qa_manager.py
@@ -2,16 +2,14 @@ import time
from typing import Tuple, List, Dict, Optional
from .global_logger import logger
-
-# from . import prompt_template
from .embedding_store import EmbeddingManager
-# from .llm_client import LLMClient
from .kg_manager import KGManager
+
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
-from src.config.config import global_config
+from src.config.config import global_config, model_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -21,17 +19,12 @@ class QAManager:
self,
embed_manager: EmbeddingManager,
kg_manager: KGManager,
-
):
self.embed_manager = embed_manager
self.kg_manager = kg_manager
- # TODO: API-Adapter修改标记
- self.qa_model = LLMRequest(
- model=global_config.model.lpmm_qa,
- request_type="lpmm.qa"
- )
+ self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
- async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
+ async def process_query(self, question: str) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
"""处理查询"""
# 生成问题的Embedding
@@ -49,66 +42,70 @@ class QAManager:
question_embedding,
global_config.lpmm_knowledge.qa_relation_search_top_k,
)
- if relation_search_res is not None:
- # 过滤阈值
- # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
- relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
- if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
- # 未找到相关关系
- logger.debug("未找到相关关系,跳过关系检索")
- relation_search_res = []
+ if relation_search_res is None:
+ return None
+ # 过滤阈值
+ # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
+ relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
+ if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
+ # 未找到相关关系
+ logger.debug("未找到相关关系,跳过关系检索")
+ relation_search_res = []
- part_end_time = time.perf_counter()
- logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
+ part_end_time = time.perf_counter()
+ logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
- for res in relation_search_res:
- rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
- print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
+ for res in relation_search_res:
+ rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
+ print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
- # TODO: 使用LLM过滤三元组结果
- # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
- # part_start_time = time.time()
+ # TODO: 使用LLM过滤三元组结果
+ # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s")
+ # part_start_time = time.time()
- # 根据问题Embedding查询Paragraph Embedding库
+ # 根据问题Embedding查询Paragraph Embedding库
+ part_start_time = time.perf_counter()
+ paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
+ question_embedding,
+ global_config.lpmm_knowledge.qa_paragraph_search_top_k,
+ )
+ part_end_time = time.perf_counter()
+ logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
+
+ if len(relation_search_res) != 0:
+ logger.info("找到相关关系,将使用RAG进行检索")
+ # 使用KG检索
part_start_time = time.perf_counter()
- paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
- question_embedding,
- global_config.lpmm_knowledge.qa_paragraph_search_top_k,
+ result, ppr_node_weights = self.kg_manager.kg_search(
+ relation_search_res, paragraph_search_res, self.embed_manager
)
part_end_time = time.perf_counter()
- logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
-
- if len(relation_search_res) != 0:
- logger.info("找到相关关系,将使用RAG进行检索")
- # 使用KG检索
- part_start_time = time.perf_counter()
- result, ppr_node_weights = self.kg_manager.kg_search(
- relation_search_res, paragraph_search_res, self.embed_manager
- )
- part_end_time = time.perf_counter()
- logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
- else:
- logger.info("未找到相关关系,将使用文段检索结果")
- result = paragraph_search_res
- ppr_node_weights = None
-
- # 过滤阈值
- result = dyn_select_top_k(result, 0.5, 1.0)
-
- for res in result:
- raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
- print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
-
- return result, ppr_node_weights
+ logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s")
else:
- return None
+ logger.info("未找到相关关系,将使用文段检索结果")
+ result = paragraph_search_res
+ ppr_node_weights = None
- async def get_knowledge(self, question: str) -> str:
+ # 过滤阈值
+ result = dyn_select_top_k(result, 0.5, 1.0)
+
+ for res in result:
+ raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
+ print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
+
+ return result, ppr_node_weights
+
+ async def get_knowledge(self, question: str) -> Optional[str]:
"""获取知识"""
# 处理查询
processed_result = await self.process_query(question)
if processed_result is not None:
query_res = processed_result[0]
+ # 检查查询结果是否为空
+ if not query_res:
+ logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
+ return None
+
knowledge = [
(
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
diff --git a/src/chat/knowledge/raw_processing.py b/src/chat/knowledge/raw_processing.py
deleted file mode 100644
index 98b1f168..00000000
--- a/src/chat/knowledge/raw_processing.py
+++ /dev/null
@@ -1,48 +0,0 @@
-import json
-import os
-
-from .global_logger import logger
-from .lpmmconfig import global_config
-from src.chat.knowledge.utils.hash import get_sha256
-
-
-def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
- """加载原始数据文件
-
- 读取原始数据文件,将原始数据加载到内存中
-
- Args:
- path: 可选,指定要读取的json文件绝对路径
-
- Returns:
- - raw_data: 原始数据列表
- - sha256_list: 原始数据的SHA256集合
- """
- # 读取指定路径或默认路径的json文件
- json_path = path if path else global_config["persistence"]["raw_data_path"]
- if os.path.exists(json_path):
- with open(json_path, "r", encoding="utf-8") as f:
- import_json = json.loads(f.read())
- else:
- raise Exception(f"原始数据文件读取失败: {json_path}")
- """
- import_json 内容示例:
- import_json = ["The capital of China is Beijing. The capital of France is Paris.",]
- """
- raw_data = []
- sha256_list = []
- sha256_set = set()
- for item in import_json:
- if not isinstance(item, str):
- logger.warning("数据类型错误:{}".format(item))
- continue
- pg_hash = get_sha256(item)
- if pg_hash in sha256_set:
- logger.warning("重复数据:{}".format(item))
- continue
- sha256_set.add(pg_hash)
- sha256_list.append(pg_hash)
- raw_data.append(item)
- logger.info("共读取到{}条数据".format(len(raw_data)))
-
- return sha256_list, raw_data
diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py
index eb40ef3a..5304934f 100644
--- a/src/chat/knowledge/utils/dyn_topk.py
+++ b/src/chat/knowledge/utils/dyn_topk.py
@@ -5,6 +5,10 @@ def dyn_select_top_k(
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
) -> List[Tuple[Any, float, float]]:
"""动态TopK选择"""
+ # 检查输入列表是否为空
+ if not score:
+ return []
+
# 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True)
diff --git a/src/chat/knowledge/utils/visualize_graph.py b/src/chat/knowledge/utils/visualize_graph.py
deleted file mode 100644
index 7ca9b7e6..00000000
--- a/src/chat/knowledge/utils/visualize_graph.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import networkx as nx
-from matplotlib import pyplot as plt
-
-
-def draw_graph_and_show(graph):
- """绘制图并显示,画布大小1280*1280"""
- fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
- nx.draw_networkx(
- graph,
- node_size=100,
- width=0.5,
- with_labels=True,
- labels=nx.get_node_attributes(graph, "content"),
- font_family="Sarasa Mono SC",
- font_size=8,
- )
- fig.show()
diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py
index 26660e5c..fe3c2562 100644
--- a/src/chat/memory_system/Hippocampus.py
+++ b/src/chat/memory_system/Hippocampus.py
@@ -5,25 +5,27 @@ import random
import time
import re
import json
-from itertools import combinations
-
import jieba
import networkx as nx
import numpy as np
+
+from itertools import combinations
+from typing import List, Tuple, Coroutine, Any, Set
from collections import Counter
-from ...llm_models.utils_model import LLMRequest
+from rich.traceback import install
+
+from src.llm_models.utils_model import LLMRequest
+from src.config.config import global_config, model_config
+from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from src.common.logger import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
-from ..utils.chat_message_builder import (
+from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
) # 导入 build_readable_messages
-from ..utils.utils import translate_timestamp_to_human_readable
-from rich.traceback import install
+from src.chat.utils.utils import translate_timestamp_to_human_readable
-from ...config.config import global_config
-from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
install(extra_lines=3)
@@ -198,8 +200,7 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db()
- # TODO: API-Adapter修改标记
- self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder")
+ self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder")
def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表"""
@@ -339,9 +340,7 @@ class Hippocampus:
else:
topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
- topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async(
- self.find_topic_llm(text, topic_num)
- )
+ topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
# 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response)
@@ -353,12 +352,11 @@ class Hippocampus:
for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",")
if keyword.strip()
]
-
+
if keywords:
logger.info(f"提取关键词: {keywords}")
-
- return keywords
-
+
+ return keywords
async def get_memory_from_text(
self,
@@ -1245,7 +1243,7 @@ class ParahippocampalGyrus:
# 2. 使用LLM提取关键主题
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
- topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async(
+ topics_response, _ = await self.hippocampus.model_summary.generate_response_async(
self.hippocampus.find_topic_llm(input_text, topic_num)
)
@@ -1269,7 +1267,7 @@ class ParahippocampalGyrus:
logger.debug(f"过滤后话题: {filtered_topics}")
# 4. 创建所有话题的摘要生成任务
- tasks = []
+ tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List | None]]]]] = []
for topic in filtered_topics:
# 调用修改后的 topic_what,不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
@@ -1281,7 +1279,7 @@ class ParahippocampalGyrus:
continue
# 等待所有任务完成
- compressed_memory = set()
+ compressed_memory: Set[Tuple[str, str]] = set()
similar_topics_dict = {}
for topic, task in tasks:
diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py
index f7e54f8e..a702a87e 100644
--- a/src/chat/memory_system/instant_memory.py
+++ b/src/chat/memory_system/instant_memory.py
@@ -3,13 +3,16 @@ import time
import re
import json
import ast
-from json_repair import repair_json
-from src.llm_models.utils_model import LLMRequest
-from src.common.logger import get_logger
import traceback
-from src.config.config import global_config
+from json_repair import repair_json
+from datetime import datetime, timedelta
+
+from src.llm_models.utils_model import LLMRequest
+from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入
+from src.config.config import model_config
+
logger = get_logger(__name__)
@@ -35,8 +38,7 @@ class InstantMemory:
self.chat_id = chat_id
self.last_view_time = time.time()
self.summary_model = LLMRequest(
- model=global_config.model.memory,
- temperature=0.5,
+ model_set=model_config.model_task_config.memory,
request_type="memory.summary",
)
@@ -48,14 +50,11 @@ class InstantMemory:
"""
try:
- response, _ = await self.summary_model.generate_response_async(prompt)
+ response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
- if "1" in response:
- return True
- else:
- return False
+ return "1" in response
except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False
@@ -71,9 +70,9 @@ class InstantMemory:
}}
"""
try:
- response, _ = await self.summary_model.generate_response_async(prompt)
- print(prompt)
- print(response)
+ response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
+ # print(prompt)
+ # print(response)
if not response:
return None
try:
@@ -142,7 +141,7 @@ class InstantMemory:
请只输出json格式,不要输出其他多余内容
"""
try:
- response, _ = await self.summary_model.generate_response_async(prompt)
+ response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
if not response:
@@ -177,7 +176,7 @@ class InstantMemory:
for mem in query:
# 对每条记忆
- mem_keywords = mem.keywords or []
+ mem_keywords = mem.keywords or ""
parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list):
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
@@ -201,6 +200,7 @@ class InstantMemory:
return None
def _parse_time_range(self, time_str):
+ # sourcery skip: extract-duplicate-method, use-contextlib-suppress
"""
支持解析如下格式:
- 具体日期时间:YYYY-MM-DD HH:MM:SS
@@ -208,8 +208,6 @@ class InstantMemory:
- 相对时间:今天,昨天,前天,N天前,N个月前
- 空字符串:返回(None, None)
"""
- from datetime import datetime, timedelta
-
now = datetime.now()
if not time_str:
return 0, now
@@ -239,14 +237,12 @@ class InstantMemory:
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
- m = re.match(r"(\d+)天前", time_str)
- if m:
+ if m := re.match(r"(\d+)天前", time_str):
days = int(m.group(1))
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1)
return start, end
- m = re.match(r"(\d+)个月前", time_str)
- if m:
+ if m := re.match(r"(\d+)个月前", time_str):
months = int(m.group(1))
# 近似每月30天
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)
diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py
index 715d9c06..d3cbb5d7 100644
--- a/src/chat/memory_system/memory_activator.py
+++ b/src/chat/memory_system/memory_activator.py
@@ -1,13 +1,15 @@
-from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
-from src.common.logger import get_logger
-from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
-from datetime import datetime
-from src.chat.memory_system.Hippocampus import hippocampus_manager
-from typing import List, Dict
import difflib
import json
+
from json_repair import repair_json
+from typing import List, Dict
+from datetime import datetime
+
+from src.llm_models.utils_model import LLMRequest
+from src.config.config import global_config, model_config
+from src.common.logger import get_logger
+from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
+from src.chat.memory_system.Hippocampus import hippocampus_manager
logger = get_logger("memory_activator")
@@ -61,11 +63,8 @@ def init_prompt():
class MemoryActivator:
def __init__(self):
- # TODO: API-Adapter修改标记
-
self.key_words_model = LLMRequest(
- model=global_config.model.utils_small,
- temperature=0.5,
+ model_set=model_config.model_task_config.utils_small,
request_type="memory.activator",
)
@@ -92,7 +91,9 @@ class MemoryActivator:
# logger.debug(f"prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt)
+ response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
+ prompt, temperature=0.5
+ )
keywords = list(get_keywords_from_json(response))
diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py
index 7a18dcf0..58dd6d68 100644
--- a/src/chat/message_receive/message.py
+++ b/src/chat/message_receive/message.py
@@ -203,7 +203,7 @@ class MessageRecvS4U(MessageRecv):
self.is_superchat = False
self.gift_info = None
self.gift_name = None
- self.gift_count = None
+ self.gift_count: Optional[str] = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
@@ -444,7 +444,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
- reply_to: str = None, # type: ignore
+ reply_to: Optional[str] = None,
):
# 调用父类初始化
super().__init__(
diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py
index 21d47c75..267b7a8f 100644
--- a/src/chat/planner_actions/action_manager.py
+++ b/src/chat/planner_actions/action_manager.py
@@ -1,9 +1,10 @@
from typing import Dict, Optional, Type
-from src.plugin_system.base.base_action import BaseAction
+
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo
+from src.plugin_system.base.base_action import BaseAction
logger = get_logger("action_manager")
diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py
index da11c54f..dfa4c79c 100644
--- a/src/chat/planner_actions/action_modifier.py
+++ b/src/chat/planner_actions/action_modifier.py
@@ -5,7 +5,7 @@ import time
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.planner_actions.action_manager import ActionManager
@@ -36,10 +36,7 @@ class ActionModifier:
self.action_manager = action_manager
# 用于LLM判定的小模型
- self.llm_judge = LLMRequest(
- model=global_config.model.utils_small,
- request_type="action.judge",
- )
+ self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
# 缓存相关属性
self._llm_judge_cache = {} # 缓存LLM判定结果
@@ -438,4 +435,4 @@ class ActionModifier:
return True
else:
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
- return False
\ No newline at end of file
+ return False
diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py
index 0b26a97d..85dd5e63 100644
--- a/src/chat/planner_actions/planner.py
+++ b/src/chat/planner_actions/planner.py
@@ -7,7 +7,7 @@ from datetime import datetime
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
@@ -36,8 +36,6 @@ def init_prompt():
{chat_context_description},以下是具体的聊天内容
{chat_content_block}
-
-
{moderation_prompt}
现在请你根据{by_what}选择合适的action和触发action的消息:
@@ -73,10 +71,7 @@ class ActionPlanner:
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.action_manager = action_manager
# LLM规划器配置
- self.planner_llm = LLMRequest(
- model=global_config.model.planner,
- request_type="planner", # 用于动作规划
- )
+ self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划
self.last_obs_time_mark = 0.0
@@ -140,7 +135,7 @@ class ActionPlanner:
# --- 调用 LLM (普通文本生成) ---
llm_content = None
try:
- llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt)
+ llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py
index cab6a2b4..c2b6e1cb 100644
--- a/src/chat/replyer/default_generator.py
+++ b/src/chat/replyer/default_generator.py
@@ -8,7 +8,8 @@ from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import global_config, model_config
+from src.config.api_ada_configs import TaskConfig
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
@@ -23,14 +24,13 @@ from src.chat.utils.chat_message_builder import (
replace_user_references_sync,
)
from src.chat.express.expression_selector import expression_selector
-from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.instant_memory import InstantMemory
from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager
-from src.tools.tool_executor import ToolExecutor
from src.plugin_system.base.component_types import ActionInfo
+from src.plugin_system.apis import llm_api
logger = get_logger("replyer")
@@ -40,7 +40,7 @@ def init_prompt():
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("在群里聊天", "chat_target_group2")
Prompt("和{sender_name}聊天", "chat_target_private2")
-
+
Prompt(
"""
{expression_habits_block}
@@ -102,36 +102,57 @@ def init_prompt():
"s4u_style_prompt",
)
+ Prompt(
+ """
+你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。
+群里正在进行的聊天内容:
+{chat_history}
+
+现在,{sender}发送了内容:{target_message},你想要回复ta。
+请仔细分析聊天内容,考虑以下几点:
+1. 内容中是否包含需要查询信息的问题
+2. 是否有明确的知识获取指令
+
+If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
+""",
+ name="lpmm_get_knowledge_prompt",
+ )
+
class DefaultReplyer:
def __init__(
self,
chat_stream: ChatStream,
- model_configs: Optional[List[Dict[str, Any]]] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "focus.replyer",
):
self.request_type = request_type
- if model_configs:
- self.express_model_configs = model_configs
+ if model_set_with_weight:
+ # self.express_model_configs = model_configs
+ self.model_set: List[Tuple[TaskConfig, float]] = model_set_with_weight
else:
# 当未提供配置时,使用默认配置并赋予默认权重
- model_config_1 = global_config.model.replyer_1.copy()
- model_config_2 = global_config.model.replyer_2.copy()
+ # model_config_1 = global_config.model.replyer_1.copy()
+ # model_config_2 = global_config.model.replyer_2.copy()
prob_first = global_config.chat.replyer_random_probability
- model_config_1["weight"] = prob_first
- model_config_2["weight"] = 1.0 - prob_first
+ # model_config_1["weight"] = prob_first
+ # model_config_2["weight"] = 1.0 - prob_first
- self.express_model_configs = [model_config_1, model_config_2]
+ # self.express_model_configs = [model_config_1, model_config_2]
+ self.model_set = [
+ (model_config.model_task_config.replyer_1, prob_first),
+ (model_config.model_task_config.replyer_2, 1.0 - prob_first),
+ ]
- if not self.express_model_configs:
- logger.warning("未找到有效的模型配置,回复生成可能会失败。")
- # 提供一个最终的回退,以防止在空列表上调用 random.choice
- fallback_config = global_config.model.replyer_1.copy()
- fallback_config.setdefault("weight", 1.0)
- self.express_model_configs = [fallback_config]
+ # if not self.express_model_configs:
+ # logger.warning("未找到有效的模型配置,回复生成可能会失败。")
+ # # 提供一个最终的回退,以防止在空列表上调用 random.choice
+ # fallback_config = global_config.model.replyer_1.copy()
+ # fallback_config.setdefault("weight", 1.0)
+ # self.express_model_configs = [fallback_config]
self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
@@ -139,13 +160,16 @@ class DefaultReplyer:
self.heart_fc_sender = HeartFCSender()
self.memory_activator = MemoryActivator()
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
+
+ from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖
+
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
- def _select_weighted_model_config(self) -> Dict[str, Any]:
+ def _select_weighted_models_config(self) -> Tuple[TaskConfig, float]:
"""使用加权随机选择来挑选一个模型配置"""
- configs = self.express_model_configs
+ configs = self.model_set
# 提取权重,如果模型配置中没有'weight'键,则默认为1.0
- weights = [config.get("weight", 1.0) for config in configs]
+ weights = [weight for _, weight in configs]
return random.choices(population=configs, weights=weights, k=1)[0]
@@ -155,18 +179,16 @@ class DefaultReplyer:
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True,
- enable_timeout: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]:
"""
回复器 (Replier): 负责生成回复文本的核心逻辑。
-
+
Args:
reply_to: 回复对象,格式为 "发送者:消息内容"
extra_info: 额外信息,用于补充上下文
available_actions: 可用的动作信息字典
enable_tool: 是否启用工具调用
- enable_timeout: 是否启用超时处理
-
+
Returns:
Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt)
"""
@@ -177,13 +199,12 @@ class DefaultReplyer:
# 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context(
- reply_to = reply_to,
+ reply_to=reply_to,
extra_info=extra_info,
available_actions=available_actions,
- enable_timeout=enable_timeout,
enable_tool=enable_tool,
)
-
+
if not prompt:
logger.warning("构建prompt失败,跳过回复生成")
return False, None, None
@@ -194,26 +215,8 @@ class DefaultReplyer:
model_name = "unknown_model"
try:
- with Timer("LLM生成", {}): # 内部计时器,可选保留
- # 加权随机选择一个模型配置
- selected_model_config = self._select_weighted_model_config()
- logger.info(
- f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
- )
-
- express_model = LLMRequest(
- model=selected_model_config,
- request_type=self.request_type,
- )
-
- if global_config.debug.show_prompt:
- logger.info(f"\n{prompt}\n")
- else:
- logger.debug(f"\n{prompt}\n")
-
- content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
-
- logger.debug(f"replyer生成内容: {content}")
+ content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
+ logger.debug(f"replyer生成内容: {content}")
except Exception as llm_e:
# 精简报错信息
@@ -232,22 +235,21 @@ class DefaultReplyer:
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
- ) -> Tuple[bool, Optional[str]]:
+ return_prompt: bool = False,
+ ) -> Tuple[bool, Optional[str], Optional[str]]:
"""
表达器 (Expressor): 负责重写和优化回复文本。
-
+
Args:
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象,格式为 "发送者:消息内容"
relation_info: 关系信息
-
+
Returns:
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
"""
try:
-
-
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_rewrite_context(
raw_reply=raw_reply,
@@ -260,36 +262,23 @@ class DefaultReplyer:
model_name = "unknown_model"
if not prompt:
logger.error("Prompt 构建失败,无法生成回复。")
- return False, None
+ return False, None, None
try:
- with Timer("LLM生成", {}): # 内部计时器,可选保留
- # 加权随机选择一个模型配置
- selected_model_config = self._select_weighted_model_config()
- logger.info(
- f"使用模型重写回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
- )
-
- express_model = LLMRequest(
- model=selected_model_config,
- request_type=self.request_type,
- )
-
- content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
-
- logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
+ content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
+ logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
except Exception as llm_e:
# 精简报错信息
logger.error(f"LLM 生成失败: {llm_e}")
- return False, None # LLM 调用失败则无法生成回复
+ return False, None, prompt if return_prompt else None # LLM 调用失败则无法生成回复
- return True, content
+ return True, content, prompt if return_prompt else None
except Exception as e:
logger.error(f"回复生成意外失败: {e}")
traceback.print_exc()
- return False, None
+ return False, None, prompt if return_prompt else None
async def build_relation_info(self, reply_to: str = ""):
if not global_config.relationship.enable_relationship:
@@ -313,11 +302,11 @@ class DefaultReplyer:
async def build_expression_habits(self, chat_history: str, target: str) -> str:
"""构建表达习惯块
-
+
Args:
chat_history: 聊天历史记录
target: 目标消息内容
-
+
Returns:
str: 表达习惯信息字符串
"""
@@ -366,17 +355,15 @@ class DefaultReplyer:
if style_habits_str.strip() and grammar_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:"
- expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}"
-
- return expression_habits_block
+ return f"{expression_habits_title}\n{expression_habits_block}"
async def build_memory_block(self, chat_history: str, target: str) -> str:
"""构建记忆块
-
+
Args:
chat_history: 聊天历史记录
target: 目标消息内容
-
+
Returns:
str: 记忆信息字符串
"""
@@ -441,7 +428,7 @@ class DefaultReplyer:
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
- result_type = tool_result.get("type", "info")
+ result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n"
@@ -459,10 +446,10 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息
-
+
Args:
target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容"
-
+
Returns:
Tuple[str, str]: (发送者名称, 消息内容)
"""
@@ -481,10 +468,10 @@ class DefaultReplyer:
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示
-
+
Args:
target: 目标消息内容
-
+
Returns:
str: 关键词反应提示字符串
"""
@@ -523,11 +510,11 @@ class DefaultReplyer:
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
"""计时并运行异步任务的辅助函数
-
+
Args:
coroutine: 要执行的协程
name: 任务名称
-
+
Returns:
Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时)
"""
@@ -537,7 +524,9 @@ class DefaultReplyer:
duration = end_time - start_time
return name, result, duration
- def build_s4u_chat_history_prompts(self, message_list_before_now: List[Dict[str, Any]], target_user_id: str) -> Tuple[str, str]:
+ def build_s4u_chat_history_prompts(
+ self, message_list_before_now: List[Dict[str, Any]], target_user_id: str
+ ) -> Tuple[str, str]:
"""
构建 s4u 风格的分离对话 prompt
@@ -612,7 +601,7 @@ class DefaultReplyer:
chat_info: str,
) -> Any:
"""构建 mai_think 上下文信息
-
+
Args:
chat_id: 聊天ID
memory_block: 记忆块内容
@@ -625,7 +614,7 @@ class DefaultReplyer:
sender: 发送者名称
target: 目标消息内容
chat_info: 聊天信息
-
+
Returns:
Any: mai_think 实例
"""
@@ -647,19 +636,17 @@ class DefaultReplyer:
reply_to: str,
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
- enable_timeout: bool = False,
enable_tool: bool = True,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
"""
构建回复器上下文
Args:
- reply_data: 回复数据
- replay_data 包含以下字段:
- structured_info: 结构化信息,一般是工具调用获得的信息
- reply_to: 回复对象
- extra_info/extra_info_block: 额外信息
+ reply_to: 回复对象,格式为 "发送者:消息内容"
+ extra_info: 额外信息,用于补充上下文
available_actions: 可用动作
+ enable_timeout: 是否启用超时处理
+ enable_tool: 是否启用工具调用
Returns:
str: 构建好的上下文
@@ -727,7 +714,7 @@ class DefaultReplyer:
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
),
- self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"),
+ self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, reply_to), "prompt_info"),
)
# 任务名称中英文映射
@@ -877,7 +864,7 @@ class DefaultReplyer:
raw_reply: str,
reason: str,
reply_to: str,
- ) -> str:
+ ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
@@ -1011,6 +998,81 @@ class DefaultReplyer:
display_message=display_message,
)
+ async def llm_generate_content(self, prompt: str):
+ with Timer("LLM生成", {}): # 内部计时器,可选保留
+ # 加权随机选择一个模型配置
+ selected_model_config, weight = self._select_weighted_models_config()
+ logger.info(f"使用模型集生成回复: {selected_model_config} (选中概率: {weight})")
+
+ express_model = LLMRequest(model_set=selected_model_config, request_type=self.request_type)
+
+ if global_config.debug.show_prompt:
+ logger.info(f"\n{prompt}\n")
+ else:
+ logger.debug(f"\n{prompt}\n")
+
+ content, (reasoning_content, model_name, tool_calls) = await express_model.generate_response_async(prompt)
+
+ logger.debug(f"replyer生成内容: {content}")
+ return content, reasoning_content, model_name, tool_calls
+
+ async def get_prompt_info(self, message: str, reply_to: str):
+ related_info = ""
+ start_time = time.time()
+ from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
+ if not reply_to:
+ logger.debug("没有回复对象,跳过获取知识库内容")
+ return ""
+ sender, content = self._parse_reply_target(reply_to)
+ if not content:
+ logger.debug("回复对象内容为空,跳过获取知识库内容")
+ return ""
+ logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
+ # 从LPMM知识库获取知识
+ try:
+ # 检查LPMM知识库是否启用
+ if not global_config.lpmm_knowledge.enable:
+ logger.debug("LPMM知识库未启用,跳过获取知识库内容")
+ return ""
+ time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
+
+ bot_name = global_config.bot.nickname
+
+ prompt = await global_prompt_manager.format_prompt(
+ "lpmm_get_knowledge_prompt",
+ bot_name=bot_name,
+ time_now=time_now,
+ chat_history=message,
+ sender=sender,
+ target_message=content,
+ )
+ _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
+ prompt,
+ model_config=model_config.model_task_config.tool_use,
+ tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
+ )
+ if tool_calls:
+ result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
+ end_time = time.time()
+ if not result or not result.get("content"):
+ logger.debug("从LPMM知识库获取知识失败,返回空知识...")
+ return ""
+ found_knowledge_from_lpmm = result.get("content", "")
+ logger.debug(
+ f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
+ )
+ related_info += found_knowledge_from_lpmm
+ logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
+ logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
+
+ return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
+ else:
+ logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
+ return ""
+ except Exception as e:
+ logger.error(f"获取知识库内容时发生异常: {str(e)}")
+ return ""
+
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
@@ -1046,38 +1108,4 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
return selected
-async def get_prompt_info(message: str, threshold: float):
- related_info = ""
- start_time = time.time()
-
- logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
- # 从LPMM知识库获取知识
- try:
- # 检查LPMM知识库是否启用
- if qa_manager is None:
- logger.debug("LPMM知识库已禁用,跳过知识获取")
- return ""
-
- found_knowledge_from_lpmm = await qa_manager.get_knowledge(message)
-
- end_time = time.time()
- if found_knowledge_from_lpmm is not None:
- logger.debug(
- f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
- )
- related_info += found_knowledge_from_lpmm
- logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒")
- logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
-
- # 格式化知识信息
- formatted_prompt_info = f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
- return formatted_prompt_info
- else:
- logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...")
- return ""
- except Exception as e:
- logger.error(f"获取知识库内容时发生异常: {str(e)}")
- return ""
-
-
init_prompt()
diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py
index 3f1c731b..bb3a313b 100644
--- a/src/chat/replyer/replyer_manager.py
+++ b/src/chat/replyer/replyer_manager.py
@@ -1,6 +1,7 @@
-from typing import Dict, Any, Optional, List
+from typing import Dict, Optional, List, Tuple
from src.common.logger import get_logger
+from src.config.api_ada_configs import TaskConfig
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer
@@ -15,7 +16,7 @@ class ReplyerManager:
self,
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
- model_configs: Optional[List[Dict[str, Any]]] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""
@@ -49,7 +50,7 @@ class ReplyerManager:
# model_configs 只在此时(初始化时)生效
replyer = DefaultReplyer(
chat_stream=target_stream,
- model_configs=model_configs, # 可以是None,此时使用默认模型
+ model_set_with_weight=model_set_with_weight, # 可以是None,此时使用默认模型
request_type=request_type,
)
self._repliers[stream_id] = replyer
diff --git a/src/chat/utils/json_utils.py b/src/chat/utils/json_utils.py
deleted file mode 100644
index 892deac4..00000000
--- a/src/chat/utils/json_utils.py
+++ /dev/null
@@ -1,223 +0,0 @@
-import ast
-import json
-import logging
-
-from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
-
-# 定义类型变量用于泛型类型提示
-T = TypeVar("T")
-
-# 获取logger
-logger = logging.getLogger("json_utils")
-
-
-def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
- """
- 安全地解析JSON字符串,出错时返回默认值
- 现在尝试处理单引号和标准JSON
-
- 参数:
- json_str: 要解析的JSON字符串
- default_value: 解析失败时返回的默认值
-
- 返回:
- 解析后的Python对象,或在解析失败时返回default_value
- """
- if not json_str or not isinstance(json_str, str):
- logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
- return default_value
-
- try:
- # 尝试标准的 JSON 解析
- return json.loads(json_str)
- except json.JSONDecodeError:
- # 如果标准解析失败,尝试用 ast.literal_eval 解析
- try:
- # logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...")
- result = ast.literal_eval(json_str)
- if isinstance(result, dict):
- return result
- logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
- return default_value
- except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
- logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
- return default_value
- except Exception as e:
- logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
- return default_value
- except Exception as e:
- logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
- return default_value
-
-
-def extract_tool_call_arguments(
- tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
-) -> Dict[str, Any]:
- """
- 从LLM工具调用对象中提取参数
-
- 参数:
- tool_call: 工具调用对象字典
- default_value: 解析失败时返回的默认值
-
- 返回:
- 解析后的参数字典,或在解析失败时返回default_value
- """
- default_result = default_value or {}
-
- if not tool_call or not isinstance(tool_call, dict):
- logger.error(f"无效的工具调用对象: {tool_call}")
- return default_result
-
- try:
- # 提取function参数
- function_data = tool_call.get("function", {})
- if not function_data or not isinstance(function_data, dict):
- logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
- return default_result
-
- if arguments_str := function_data.get("arguments", "{}"):
- # 解析JSON
- return safe_json_loads(arguments_str, default_result)
- else:
- return default_result
-
- except Exception as e:
- logger.error(f"提取工具调用参数时出错: {e}")
- return default_result
-
-
-def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
- """
- 安全地将Python对象序列化为JSON字符串
-
- 参数:
- obj: 要序列化的Python对象
- default_value: 序列化失败时返回的默认值
- ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符)
- pretty: 是否美化输出JSON
-
- 返回:
- 序列化后的JSON字符串,或在序列化失败时返回default_value
- """
- try:
- indent = 2 if pretty else None
- return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
- except TypeError as e:
- logger.error(f"JSON序列化失败(类型错误): {e}")
- return default_value
- except Exception as e:
- logger.error(f"JSON序列化过程中发生意外错误: {e}")
- return default_value
-
-
-def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
- """
- 标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式
-
- 参数:
- response: 原始LLM响应
- log_prefix: 日志前缀
-
- 返回:
- 元组 (成功标志, 标准化后的响应列表, 错误消息)
- """
-
- logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
-
- # 检查是否为None
- if response is None:
- return False, [], "LLM响应为None"
-
- # 记录原始类型
- logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
-
- # 将元组转换为列表
- if isinstance(response, tuple):
- logger.debug(f"{log_prefix}将元组响应转换为列表")
- response = list(response)
-
- # 确保是列表类型
- if not isinstance(response, list):
- return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
-
- # 处理工具调用部分(如果存在)
- if len(response) == 3:
- content, reasoning, tool_calls = response
-
- # 将工具调用部分转换为列表(如果是元组)
- if isinstance(tool_calls, tuple):
- logger.debug(f"{log_prefix}将工具调用元组转换为列表")
- tool_calls = list(tool_calls)
- response[2] = tool_calls
-
- return True, response, ""
-
-
-def process_llm_tool_calls(
- tool_calls: List[Dict[str, Any]], log_prefix: str = ""
-) -> Tuple[bool, List[Dict[str, Any]], str]:
- """
- 处理并验证LLM响应中的工具调用列表
-
- 参数:
- tool_calls: 从LLM响应中直接获取的工具调用列表
- log_prefix: 日志前缀
-
- 返回:
- 元组 (成功标志, 验证后的工具调用列表, 错误消息)
- """
-
- # 如果列表为空,表示没有工具调用,这不是错误
- if not tool_calls:
- return True, [], "工具调用列表为空"
-
- # 验证每个工具调用的格式
- valid_tool_calls = []
- for i, tool_call in enumerate(tool_calls):
- if not isinstance(tool_call, dict):
- logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
- continue
-
- # 检查基本结构
- if tool_call.get("type") != "function":
- logger.warning(
- f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
- )
- continue
-
- if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
- logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
- continue
-
- func_details = tool_call["function"]
- if "name" not in func_details or not isinstance(func_details.get("name"), str):
- logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
- continue
-
- # 验证参数 'arguments'
- args_value = func_details.get("arguments")
-
- # 1. 检查 arguments 是否存在且是字符串
- if args_value is None or not isinstance(args_value, str):
- logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
- continue
-
- # 2. 尝试安全地解析 arguments 字符串
- parsed_args = safe_json_loads(args_value, None)
-
- # 3. 检查解析结果是否为字典
- if parsed_args is None or not isinstance(parsed_args, dict):
- logger.warning(
- f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
- f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
- )
- continue
-
- # 如果检查通过,将原始的 tool_call 加入有效列表
- valid_tool_calls.append(tool_call)
-
- if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
- return False, [], "所有工具调用格式均无效"
-
- return True, valid_tool_calls, ""
diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py
index 3ee4ae7b..0b9ec779 100644
--- a/src/chat/utils/utils.py
+++ b/src/chat/utils/utils.py
@@ -11,7 +11,7 @@ from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest
@@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
return is_mentioned, reply_probability
-async def get_embedding(text, request_type="embedding"):
+async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量"""
- # TODO: API-Adapter修改标记
- llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
- # return llm.get_embedding_sync(text)
+ llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
try:
- embedding = await llm.get_embedding(text)
+ embedding, _ = await llm.get_embedding(text)
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
embedding = None
diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py
index 7f14aa6d..fcf1c717 100644
--- a/src/chat/utils/utils_image.py
+++ b/src/chat/utils/utils_image.py
@@ -14,7 +14,7 @@ from rich.traceback import install
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
install(extra_lines=3)
@@ -37,7 +37,7 @@ class ImageManager:
self._ensure_image_dir()
self._initialized = True
- self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image")
+ self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try:
db.connect(reuse_if_open=True)
@@ -107,6 +107,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
+
emoji_manager = get_emoji_manager()
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
if cached_emoji_description:
@@ -116,13 +117,12 @@ class ImageManager:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述
- cached_description = self._get_description_from_db(image_hash, "emoji")
- if cached_description:
+ if cached_description := self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[表情包:{cached_description}]"
# === 二步走识别流程 ===
-
+
# 第一步:VLM视觉分析 - 生成详细描述
if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64)
@@ -130,10 +130,16 @@ class ImageManager:
logger.warning("GIF转换失败,无法获取描述")
return "[表情包(GIF处理失败)]"
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
- detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg")
+ detailed_description, _ = await self.vlm.generate_response_for_image(
+ vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
+ )
else:
- vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
- detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format)
+ vlm_prompt = (
+ "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
+ )
+ detailed_description, _ = await self.vlm.generate_response_for_image(
+ vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
+ )
if detailed_description is None:
logger.warning("VLM未能生成表情包详细描述")
@@ -150,31 +156,32 @@ class ImageManager:
3. 输出简短精准,不要解释
4. 如果有多个词用逗号分隔
"""
-
+
# 使用较低温度确保输出稳定
- emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji")
- emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt)
+ emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
+ emotion_result, _ = await emotion_llm.generate_response_async(
+ emotion_prompt, temperature=0.3, max_tokens=50
+ )
if emotion_result is None:
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
# 降级处理:从详细描述中提取关键词
import jieba
+
words = list(jieba.cut(detailed_description))
emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
# 处理情感结果,取前1-2个最重要的标签
emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()]
final_emotion = emotions[0] if emotions else "表情"
-
+
# 如果有第二个情感且不重复,也包含进来
if len(emotions) > 1 and emotions[1] != emotions[0]:
final_emotion = f"{emotions[0]},{emotions[1]}"
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
- # 再次检查缓存,防止并发写入时重复生成
- cached_description = self._get_description_from_db(image_hash, "emoji")
- if cached_description:
+ if cached_description := self._get_description_from_db(image_hash, "emoji"):
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]"
@@ -242,9 +249,7 @@ class ImageManager:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]"
- # 查询ImageDescriptions表的缓存描述
- cached_description = self._get_description_from_db(image_hash, "image")
- if cached_description:
+ if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[图片:{cached_description}]"
@@ -252,7 +257,9 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
- description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
+ description, _ = await self.vlm.generate_response_for_image(
+ prompt, image_base64, image_format, temperature=0.4, max_tokens=300
+ )
if description is None:
logger.warning("AI未能生成图片描述")
@@ -445,10 +452,7 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
- # 检查图片是否已存在
- existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
-
- if existing_image:
+ if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
@@ -524,9 +528,7 @@ class ImageManager:
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
- (Images.emoji_hash == image_hash) &
- (Images.description.is_null(False)) &
- (Images.description != "")
+ (Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
@@ -538,8 +540,7 @@ class ImageManager:
return
# 检查ImageDescriptions表的缓存描述
- cached_description = self._get_description_from_db(image_hash, "image")
- if cached_description:
+ if cached_description := self._get_description_from_db(image_hash, "image"):
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
@@ -554,15 +555,15 @@ class ImageManager:
# 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
- description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format)
+ description, _ = await self.vlm.generate_response_for_image(
+ prompt, image_base64, image_format, temperature=0.4, max_tokens=300
+ )
if description is None:
logger.warning("VLM未能生成图片描述")
description = "无法生成描述"
- # 再次检查缓存,防止并发写入时重复生成
- cached_description = self._get_description_from_db(image_hash, "image")
- if cached_description:
+ if cached_description := self._get_description_from_db(image_hash, "image"):
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description
@@ -606,7 +607,7 @@ def image_path_to_base64(image_path: str) -> str:
raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f:
- image_data = f.read()
- if not image_data:
+ if image_data := f.read():
+ return base64.b64encode(image_data).decode("utf-8")
+ else:
raise IOError(f"读取图片文件失败: {image_path}")
- return base64.b64encode(image_data).decode("utf-8")
diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py
index cf71dc56..49ec1079 100644
--- a/src/chat/utils/utils_voice.py
+++ b/src/chat/utils/utils_voice.py
@@ -1,35 +1,29 @@
-import base64
-
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from rich.traceback import install
+
install(extra_lines=3)
logger = get_logger("chat_voice")
+
async def get_voice_text(voice_base64: str) -> str:
- """获取音频文件描述"""
+ """获取音频文件转录文本"""
if not global_config.voice.enable_asr:
logger.warning("语音识别未启用,无法处理语音消息")
return "[语音]"
try:
- # 解码base64音频数据
- # 确保base64字符串只包含ASCII字符
- if isinstance(voice_base64, str):
- voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
- voice_bytes = base64.b64decode(voice_base64)
- _llm = LLMRequest(model=global_config.model.voice, request_type="voice")
- text = await _llm.generate_response_for_voice(voice_bytes)
+ _llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
+ text = await _llm.generate_response_for_voice(voice_base64)
if text is None:
logger.warning("未能生成语音文本")
return "[语音(文本生成失败)]"
-
+
logger.debug(f"描述是{text}")
return f"[语音:{text}]"
except Exception as e:
logger.error(f"语音转文字失败: {str(e)}")
return "[语音]"
-
diff --git a/src/chat/willing/mode_mxp.py b/src/chat/willing/mode_mxp.py
index 5a13a628..a249cb6f 100644
--- a/src/chat/willing/mode_mxp.py
+++ b/src/chat/willing/mode_mxp.py
@@ -19,13 +19,13 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟(@梦溪畔)
"""
-from .willing_manager import BaseWillingManager
from typing import Dict
import asyncio
import time
import math
from src.chat.message_receive.chat_stream import ChatStream
+from .willing_manager import BaseWillingManager
class MxpWillingManager(BaseWillingManager):
diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py
index 1d0b8a39..d2b3acce 100644
--- a/src/common/database/database_model.py
+++ b/src/common/database/database_model.py
@@ -281,20 +281,6 @@ class Memory(BaseModel):
table_name = "memory"
-class Knowledges(BaseModel):
- """
- 用于存储知识库条目的模型。
- """
-
- content = TextField() # 知识内容的文本
- embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
- # 可以添加其他元数据字段,如 source, create_time 等
-
- class Meta:
- # database = db # 继承自 BaseModel
- table_name = "knowledges"
-
-
class Expression(BaseModel):
"""
用于存储表达风格的模型。
@@ -382,7 +368,6 @@ def create_tables():
ImageDescriptions,
OnlineTime,
PersonInfo,
- Knowledges,
Expression,
ThinkingLog,
GraphNodes, # 添加图节点表
@@ -408,7 +393,6 @@ def initialize_database():
ImageDescriptions,
OnlineTime,
PersonInfo,
- Knowledges,
Expression,
Memory,
ThinkingLog,
diff --git a/src/common/logger.py b/src/common/logger.py
index 78446dec..e27fcb4e 100644
--- a/src/common/logger.py
+++ b/src/common/logger.py
@@ -334,7 +334,7 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m",
- "memory": "\033[34m",
+ "memory": "\033[38;5;117m", # 天蓝色
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
"action_manager": "\033[38;5;208m", # 橙色,不与replyer重复
# 关系系统
@@ -352,7 +352,7 @@ MODULE_COLORS = {
"expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块
"replyer": "\033[38;5;166m", # 橙色
- "memory_activator": "\033[34m", # 绿色
+ "memory_activator": "\033[38;5;117m", # 天蓝色
# 插件系统
"plugins": "\033[31m", # 红色
"plugin_api": "\033[33m", # 黄色
@@ -451,7 +451,7 @@ class ModuleColoredConsoleRenderer:
# 日志级别颜色
self._level_colors = {
"debug": "\033[38;5;208m", # 橙色
- "info": "\033[34m", # 蓝色
+ "info": "\033[38;5;117m", # 天蓝色
"success": "\033[32m", # 绿色
"warning": "\033[33m", # 黄色
"error": "\033[31m", # 红色
diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py
new file mode 100644
index 00000000..5f3398e0
--- /dev/null
+++ b/src/config/api_ada_configs.py
@@ -0,0 +1,142 @@
+from dataclasses import dataclass, field
+
+from .config_base import ConfigBase
+
+
+@dataclass
+class APIProvider(ConfigBase):
+ """API提供商配置类"""
+
+ name: str
+ """API提供商名称"""
+
+ base_url: str
+ """API基础URL"""
+
+ api_key: str = field(default_factory=str, repr=False)
+ """API密钥列表"""
+
+ client_type: str = field(default="openai")
+ """客户端类型(如openai/google等,默认为openai)"""
+
+ max_retry: int = 2
+ """最大重试次数(单个模型API调用失败,最多重试的次数)"""
+
+ timeout: int = 10
+ """API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)"""
+
+ retry_interval: int = 10
+ """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)"""
+
+ def get_api_key(self) -> str:
+ return self.api_key
+
+ def __post_init__(self):
+ """确保api_key在repr中不被显示"""
+ if not self.api_key:
+ raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。")
+ if not self.base_url:
+ raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。")
+ if not self.name:
+ raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。")
+
+
+@dataclass
+class ModelInfo(ConfigBase):
+ """单个模型信息配置类"""
+
+ model_identifier: str
+ """模型标识符(用于URL调用)"""
+
+ name: str
+ """模型名称(用于模块调用)"""
+
+ api_provider: str
+ """API提供商(如OpenAI、Azure等)"""
+
+ price_in: float = field(default=0.0)
+ """每M token输入价格"""
+
+ price_out: float = field(default=0.0)
+ """每M token输出价格"""
+
+ force_stream_mode: bool = field(default=False)
+ """是否强制使用流式输出模式"""
+
+ extra_params: dict = field(default_factory=dict)
+ """额外参数(用于API调用时的额外配置)"""
+
+ def __post_init__(self):
+ if not self.model_identifier:
+ raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
+ if not self.name:
+ raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
+ if not self.api_provider:
+ raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。")
+
+
+@dataclass
+class TaskConfig(ConfigBase):
+ """任务配置类"""
+
+ model_list: list[str] = field(default_factory=list)
+ """任务使用的模型列表"""
+
+ max_tokens: int = 1024
+ """任务最大输出token数"""
+
+ temperature: float = 0.3
+ """模型温度"""
+
+
+@dataclass
+class ModelTaskConfig(ConfigBase):
+ """模型配置类"""
+
+ utils: TaskConfig
+ """组件模型配置"""
+
+ utils_small: TaskConfig
+ """组件小模型配置"""
+
+ replyer_1: TaskConfig
+ """normal_chat首要回复模型模型配置"""
+
+ replyer_2: TaskConfig
+ """normal_chat次要回复模型配置"""
+
+ memory: TaskConfig
+ """记忆模型配置"""
+
+ emotion: TaskConfig
+ """情绪模型配置"""
+
+ vlm: TaskConfig
+ """视觉语言模型配置"""
+
+ voice: TaskConfig
+ """语音识别模型配置"""
+
+ tool_use: TaskConfig
+ """专注工具使用模型配置"""
+
+ planner: TaskConfig
+ """规划模型配置"""
+
+ embedding: TaskConfig
+ """嵌入模型配置"""
+
+ lpmm_entity_extract: TaskConfig
+ """LPMM实体提取模型配置"""
+
+ lpmm_rdf_build: TaskConfig
+ """LPMM RDF构建模型配置"""
+
+ lpmm_qa: TaskConfig
+ """LPMM问答模型配置"""
+
+ def get_task(self, task_name: str) -> TaskConfig:
+ """获取指定任务的配置"""
+ if hasattr(self, task_name):
+ return getattr(self, task_name)
+ raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
diff --git a/src/config/auto_update.py b/src/config/auto_update.py
deleted file mode 100644
index e6471e80..00000000
--- a/src/config/auto_update.py
+++ /dev/null
@@ -1,162 +0,0 @@
-import shutil
-import tomlkit
-from tomlkit.items import Table, KeyType
-from pathlib import Path
-from datetime import datetime
-
-
-def get_key_comment(toml_table, key):
- # 获取key的注释(如果有)
- if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
- return toml_table.trivia.comment
- if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
- item = toml_table.value.get(key)
- if item is not None and hasattr(item, "trivia"):
- return item.trivia.comment
- if hasattr(toml_table, "keys"):
- for k in toml_table.keys():
- if isinstance(k, KeyType) and k.key == key:
- return k.trivia.comment
- return None
-
-
-def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
- # 递归比较两个dict,找出新增和删减项,收集注释
- if path is None:
- path = []
- if logs is None:
- logs = []
- if new_comments is None:
- new_comments = {}
- if old_comments is None:
- old_comments = {}
- # 新增项
- for key in new:
- if key == "version":
- continue
- if key not in old:
- comment = get_key_comment(new, key)
- logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
- elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
- compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
- # 删减项
- for key in old:
- if key == "version":
- continue
- if key not in new:
- comment = get_key_comment(old, key)
- logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
- return logs
-
-
-def update_config():
- print("开始更新配置文件...")
- # 获取根目录路径
- root_dir = Path(__file__).parent.parent.parent.parent
- template_dir = root_dir / "template"
- config_dir = root_dir / "config"
- old_config_dir = config_dir / "old"
-
- # 创建old目录(如果不存在)
- old_config_dir.mkdir(exist_ok=True)
-
- # 定义文件路径
- template_path = template_dir / "bot_config_template.toml"
- old_config_path = config_dir / "bot_config.toml"
- new_config_path = config_dir / "bot_config.toml"
-
- # 读取旧配置文件
- old_config = {}
- if old_config_path.exists():
- print(f"发现旧配置文件: {old_config_path}")
- with open(old_config_path, "r", encoding="utf-8") as f:
- old_config = tomlkit.load(f)
-
- # 生成带时间戳的新文件名
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
-
- # 移动旧配置文件到old目录
- shutil.move(old_config_path, old_backup_path)
- print(f"已备份旧配置文件到: {old_backup_path}")
-
- # 复制模板文件到配置目录
- print(f"从模板文件创建新配置: {template_path}")
- shutil.copy2(template_path, new_config_path)
-
- # 读取新配置文件
- with open(new_config_path, "r", encoding="utf-8") as f:
- new_config = tomlkit.load(f)
-
- # 检查version是否相同
- if old_config and "inner" in old_config and "inner" in new_config:
- old_version = old_config["inner"].get("version") # type: ignore
- new_version = new_config["inner"].get("version") # type: ignore
- if old_version and new_version and old_version == new_version:
- print(f"检测到版本号相同 (v{old_version}),跳过更新")
- # 如果version相同,恢复旧配置文件并返回
- shutil.move(old_backup_path, old_config_path) # type: ignore
- return
- else:
- print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
-
- # 输出新增和删减项及注释
- if old_config:
- print("配置项变动如下:")
- logs = compare_dicts(new_config, old_config)
- if logs:
- for log in logs:
- print(log)
- else:
- print("无新增或删减项")
-
- # 递归更新配置
- def update_dict(target, source):
- for key, value in source.items():
- # 跳过version字段的更新
- if key == "version":
- continue
- if key in target:
- if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
- update_dict(target[key], value)
- else:
- try:
- # 对数组类型进行特殊处理
- if isinstance(value, list):
- # 如果是空数组,确保它保持为空数组
- if not value:
- target[key] = tomlkit.array()
- else:
- # 特殊处理正则表达式数组和包含正则表达式的结构
- if key == "ban_msgs_regex":
- # 直接使用原始值,不进行额外处理
- target[key] = value
- elif key == "regex_rules":
- # 对于regex_rules,需要特殊处理其中的regex字段
- target[key] = value
- else:
- # 检查是否包含正则表达式相关的字典项
- contains_regex = False
- if value and isinstance(value[0], dict) and "regex" in value[0]:
- contains_regex = True
-
- target[key] = value if contains_regex else tomlkit.array(str(value))
- else:
- # 其他类型使用item方法创建新值
- target[key] = tomlkit.item(value)
- except (TypeError, ValueError):
- # 如果转换失败,直接赋值
- target[key] = value
-
- # 将旧配置的值更新到新配置中
- print("开始合并新旧配置...")
- update_dict(new_config, old_config)
-
- # 保存更新后的配置(保留注释和格式)
- with open(new_config_path, "w", encoding="utf-8") as f:
- f.write(tomlkit.dumps(new_config))
- print("配置文件更新完成")
-
-
-if __name__ == "__main__":
- update_config()
diff --git a/src/config/config.py b/src/config/config.py
index 805a17d4..368adaa5 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -1,12 +1,14 @@
import os
import tomlkit
import shutil
+import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass
from rich.traceback import install
+from typing import List, Optional
from src.common.logger import get_logger
from src.config.config_base import ConfigBase
@@ -25,7 +27,6 @@ from src.config.official_configs import (
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
- ModelConfig,
MessageReceiveConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
@@ -36,6 +37,13 @@ from src.config.official_configs import (
CustomPromptConfig,
)
+from .api_ada_configs import (
+ ModelTaskConfig,
+ ModelInfo,
+ APIProvider,
+)
+
+
install(extra_lines=3)
@@ -49,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
-MMC_VERSION = "0.9.1"
+MMC_VERSION = "0.10.0-snapshot.4"
def get_key_comment(toml_table, key):
@@ -79,7 +87,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in old:
comment = get_key_comment(new, key)
- logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
+ logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项
@@ -88,7 +96,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue
if key not in new:
comment = get_key_comment(old, key)
- logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}")
+ logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}")
return logs
@@ -123,67 +131,110 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
- else:
- # 只要值发生变化就记录
- if new[key] != old[key]:
- logs.append(
- f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
- )
- changes.append((path + [str(key)], old[key], new[key]))
+ elif new[key] != old[key]:
+ logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
+ changes.append((path + [str(key)], old[key], new[key]))
return logs, changes
-def update_config():
+def _get_version_from_toml(toml_path) -> Optional[str]:
+ """从TOML文件中获取版本号"""
+ if not os.path.exists(toml_path):
+ return None
+ with open(toml_path, "r", encoding="utf-8") as f:
+ doc = tomlkit.load(f)
+ if "inner" in doc and "version" in doc["inner"]: # type: ignore
+ return doc["inner"]["version"] # type: ignore
+ return None
+
+
+def _version_tuple(v):
+ """将版本字符串转换为元组以便比较"""
+ if v is None:
+ return (0,)
+ return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
+
+
+def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
+ """
+ 将source字典的值更新到target字典中(如果target中存在相同的键)
+ """
+ for key, value in source.items():
+ # 跳过version字段的更新
+ if key == "version":
+ continue
+ if key in target:
+ target_value = target[key]
+ if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
+ _update_dict(target_value, value)
+ else:
+ try:
+ # 对数组类型进行特殊处理
+ if isinstance(value, list):
+ # 如果是空数组,确保它保持为空数组
+ target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
+ else:
+ # 其他类型使用item方法创建新值
+ target[key] = tomlkit.item(value)
+ except (TypeError, ValueError):
+ # 如果转换失败,直接赋值
+ target[key] = value
+
+
+def _update_config_generic(config_name: str, template_name: str):
+ """
+ 通用的配置文件更新函数
+
+ Args:
+ config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config'
+ template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template'
+ """
# 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径
- template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml")
- old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
- new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
- compare_path = os.path.join(compare_dir, "bot_config_template.toml")
+ template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
+ old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
+ new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
+ compare_path = os.path.join(compare_dir, f"{template_name}.toml")
# 创建compare目录(如果不存在)
os.makedirs(compare_dir, exist_ok=True)
- # 处理compare下的模板文件
- def get_version_from_toml(toml_path):
- if not os.path.exists(toml_path):
- return None
- with open(toml_path, "r", encoding="utf-8") as f:
- doc = tomlkit.load(f)
- if "inner" in doc and "version" in doc["inner"]: # type: ignore
- return doc["inner"]["version"] # type: ignore
- return None
+ template_version = _get_version_from_toml(template_path)
+ compare_version = _get_version_from_toml(compare_path)
- template_version = get_version_from_toml(template_path)
- compare_version = get_version_from_toml(compare_path)
+ # 检查配置文件是否存在
+ if not os.path.exists(old_config_path):
+ logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置")
+ os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
+ shutil.copy2(template_path, old_config_path) # 复制模板文件
+ logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
+ # 新创建配置文件,退出
+ sys.exit(0)
- def version_tuple(v):
- if v is None:
- return (0,)
- return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
+ compare_config = None
+ new_config = None
+ old_config = None
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f)
- else:
- compare_config = None
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
- if compare_config is not None:
+ if compare_config:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config)
if logs:
- logger.info("检测到模板默认值变动如下:")
+ logger.info(f"检测到{config_name}模板默认值变动如下:")
for log in logs:
logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值
@@ -192,33 +243,20 @@ def update_config():
if old_value == old_default:
set_value_by_path(old_config, path, new_default)
logger.info(
- f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
+ f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
)
else:
- logger.info("未检测到模板默认值变动")
- # 保存旧配置的变更(后续合并逻辑会用到 old_config)
- else:
- old_config = None
+ logger.info(f"未检测到{config_name}模板默认值变动")
# 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path)
- logger.info(f"已将模板文件复制到: {compare_path}")
+ logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
+ elif _version_tuple(template_version) > _version_tuple(compare_version):
+ shutil.copy2(template_path, compare_path)
+ logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}")
else:
- if version_tuple(template_version) > version_tuple(compare_version):
- shutil.copy2(template_path, compare_path)
- logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}")
- else:
- logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}")
-
- # 检查配置文件是否存在
- if not os.path.exists(old_config_path):
- logger.info("配置文件不存在,从模板创建新配置")
- os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
- shutil.copy2(template_path, old_config_path) # 复制模板文件
- logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
- # 如果是新创建的配置文件,直接返回
- quit()
+ logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
if old_config is None:
@@ -226,79 +264,60 @@ def update_config():
old_config = tomlkit.load(f)
# new_config 已经读取
- # 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
-
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
- logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新")
+ logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
return
else:
logger.info(
- f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
+ f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
)
else:
- logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新")
+ logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录(如果不存在)
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
- old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml")
+ old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
- logger.info(f"已备份旧配置文件到: {old_backup_path}")
+ logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path)
- logger.info(f"已创建新配置文件: {new_config_path}")
+ logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
# 输出新增和删减项及注释
if old_config:
- logger.info("配置项变动如下:\n----------------------------------------")
- logs = compare_dicts(new_config, old_config)
- if logs:
+ logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
+ if logs := compare_dicts(new_config, old_config):
for log in logs:
logger.info(log)
else:
logger.info("无新增或删减项")
- def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
- """
- 将source字典的值更新到target字典中(如果target中存在相同的键)
- """
- for key, value in source.items():
- # 跳过version字段的更新
- if key == "version":
- continue
- if key in target:
- target_value = target[key]
- if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
- update_dict(target_value, value)
- else:
- try:
- # 对数组类型进行特殊处理
- if isinstance(value, list):
- # 如果是空数组,确保它保持为空数组
- target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
- else:
- # 其他类型使用item方法创建新值
- target[key] = tomlkit.item(value)
- except (TypeError, ValueError):
- # 如果转换失败,直接赋值
- target[key] = value
-
# 将旧配置的值更新到新配置中
- logger.info("开始合并新旧配置...")
- update_dict(new_config, old_config)
+ logger.info(f"开始合并{config_name}新旧配置...")
+ _update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
- logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
- quit()
+ logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
+
+
+def update_config():
+ """更新bot_config.toml配置文件"""
+ _update_config_generic("bot_config", "bot_config_template")
+
+
+def update_model_config():
+ """更新model_config.toml配置文件"""
+ _update_config_generic("model_config", "model_config_template")
@dataclass
@@ -323,7 +342,6 @@ class Config(ConfigBase):
response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig
experimental: ExperimentalConfig
- model: ModelConfig
maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig
@@ -331,11 +349,69 @@ class Config(ConfigBase):
custom_prompt: CustomPromptConfig
voice: VoiceConfig
+
+@dataclass
+class APIAdapterConfig(ConfigBase):
+ """API Adapter配置类"""
+
+ models: List[ModelInfo]
+ """模型列表"""
+
+ model_task_config: ModelTaskConfig
+ """模型任务配置"""
+
+ api_providers: List[APIProvider] = field(default_factory=list)
+ """API提供商列表"""
+
+ def __post_init__(self):
+ if not self.models:
+ raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
+ if not self.api_providers:
+ raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
+
+ # 检查API提供商名称是否重复
+ provider_names = [provider.name for provider in self.api_providers]
+ if len(provider_names) != len(set(provider_names)):
+ raise ValueError("API提供商名称存在重复,请检查配置文件。")
+
+ # 检查模型名称是否重复
+ model_names = [model.name for model in self.models]
+ if len(model_names) != len(set(model_names)):
+ raise ValueError("模型名称存在重复,请检查配置文件。")
+
+ self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
+ self.models_dict = {model.name: model for model in self.models}
+
+ for model in self.models:
+ if not model.model_identifier:
+ raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
+ if not model.api_provider or model.api_provider not in self.api_providers_dict:
+ raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
+
+ def get_model_info(self, model_name: str) -> ModelInfo:
+ """根据模型名称获取模型信息"""
+ if not model_name:
+ raise ValueError("模型名称不能为空")
+ if model_name not in self.models_dict:
+ raise KeyError(f"模型 '{model_name}' 不存在")
+ return self.models_dict[model_name]
+
+ def get_provider(self, provider_name: str) -> APIProvider:
+ """根据提供商名称获取API提供商信息"""
+ if not provider_name:
+ raise ValueError("API提供商名称不能为空")
+ if provider_name not in self.api_providers_dict:
+ raise KeyError(f"API提供商 '{provider_name}' 不存在")
+ return self.api_providers_dict[provider_name]
+
+
def load_config(config_path: str) -> Config:
"""
加载配置文件
- :param config_path: 配置文件路径
- :return: Config对象
+ Args:
+ config_path: 配置文件路径
+ Returns:
+ Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
@@ -349,18 +425,32 @@ def load_config(config_path: str) -> Config:
raise e
-def get_config_dir() -> str:
+def api_ada_load_config(config_path: str) -> APIAdapterConfig:
"""
- 获取配置目录
- :return: 配置目录路径
+ 加载API适配器配置文件
+ Args:
+ config_path: 配置文件路径
+ Returns:
+ APIAdapterConfig对象
"""
- return CONFIG_DIR
+ # 读取配置文件
+ with open(config_path, "r", encoding="utf-8") as f:
+ config_data = tomlkit.load(f)
+
+ # 创建APIAdapterConfig对象
+ try:
+ return APIAdapterConfig.from_dict(config_data)
+ except Exception as e:
+ logger.critical("API适配器配置文件解析失败")
+ raise e
# 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config()
+update_model_config()
logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
+model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!")
diff --git a/src/config/official_configs.py b/src/config/official_configs.py
index 2c9f847c..8f34a184 100644
--- a/src/config/official_configs.py
+++ b/src/config/official_configs.py
@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass, field
-from typing import Any, Literal, Optional
+from typing import Literal, Optional
from src.config.config_base import ConfigBase
@@ -598,51 +598,3 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""
-
-@dataclass
-class ModelConfig(ConfigBase):
- """模型配置类"""
-
- model_max_output_length: int = 800 # 最大回复长度
-
- utils: dict[str, Any] = field(default_factory=lambda: {})
- """组件模型配置"""
-
- utils_small: dict[str, Any] = field(default_factory=lambda: {})
- """组件小模型配置"""
-
- replyer_1: dict[str, Any] = field(default_factory=lambda: {})
- """normal_chat首要回复模型模型配置"""
-
- replyer_2: dict[str, Any] = field(default_factory=lambda: {})
- """normal_chat次要回复模型配置"""
-
- memory: dict[str, Any] = field(default_factory=lambda: {})
- """记忆模型配置"""
-
- emotion: dict[str, Any] = field(default_factory=lambda: {})
- """情绪模型配置"""
-
- vlm: dict[str, Any] = field(default_factory=lambda: {})
- """视觉语言模型配置"""
-
- voice: dict[str, Any] = field(default_factory=lambda: {})
- """语音识别模型配置"""
-
- tool_use: dict[str, Any] = field(default_factory=lambda: {})
- """专注工具使用模型配置"""
-
- planner: dict[str, Any] = field(default_factory=lambda: {})
- """规划模型配置"""
-
- embedding: dict[str, Any] = field(default_factory=lambda: {})
- """嵌入模型配置"""
-
- lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
- """LPMM实体提取模型配置"""
-
- lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
- """LPMM RDF构建模型配置"""
-
- lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
- """LPMM问答模型配置"""
diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py
index 4c8fcac5..c2655fba 100644
--- a/src/individuality/individuality.py
+++ b/src/individuality/individuality.py
@@ -4,7 +4,7 @@ import hashlib
import time
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager
from rich.traceback import install
@@ -23,10 +23,7 @@ class Individuality:
self.meta_info_file_path = "data/personality/meta.json"
self.personality_data_file_path = "data/personality/personality_data.json"
- self.model = LLMRequest(
- model=global_config.model.utils,
- request_type="individuality.compress",
- )
+ self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress")
async def initialize(self) -> None:
"""初始化个体特征"""
@@ -35,7 +32,6 @@ class Individuality:
personality_side = global_config.personality.personality_side
identity = global_config.personality.identity
-
person_info_manager = get_person_info_manager()
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
self.name = bot_nickname
@@ -85,16 +81,16 @@ class Individuality:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else:
bot_nickname = ""
-
+
# 从文件获取 short_impression
personality, identity = self._get_personality_from_file()
-
+
# 确保short_impression是列表格式且有足够的元素
if not personality or not identity:
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
personality = "友好活泼"
identity = "人类"
-
+
prompt_personality = f"{personality}\n{identity}"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
@@ -215,7 +211,7 @@ class Individuality:
def _get_personality_from_file(self) -> tuple[str, str]:
"""从文件获取personality数据
-
+
Returns:
tuple: (personality, identity)
"""
@@ -226,7 +222,7 @@ class Individuality:
def _save_personality_to_file(self, personality: str, identity: str):
"""保存personality数据到文件
-
+
Args:
personality: 压缩后的人格描述
identity: 压缩后的身份描述
@@ -235,7 +231,7 @@ class Individuality:
"personality": personality,
"identity": identity,
"bot_nickname": self.name,
- "last_updated": int(time.time())
+ "last_updated": int(time.time()),
}
self._save_personality_data(personality_data)
@@ -269,7 +265,7 @@ class Individuality:
2. 尽量简洁,不超过30字
3. 直接输出压缩后的内容,不要解释"""
- response, (_, _) = await self.model.generate_response_async(
+ response, _ = await self.model.generate_response_async(
prompt=prompt,
)
@@ -281,7 +277,7 @@ class Individuality:
# 压缩失败时使用原始内容
if personality_side:
personality_parts.append(personality_side)
-
+
if personality_parts:
personality_result = "。".join(personality_parts)
else:
@@ -308,7 +304,7 @@ class Individuality:
2. 尽量简洁,不超过30字
3. 直接输出压缩后的内容,不要解释"""
- response, (_, _) = await self.model.generate_response_async(
+ response, _ = await self.model.generate_response_async(
prompt=prompt,
)
diff --git a/src/llm_models/LICENSE b/src/llm_models/LICENSE
new file mode 100644
index 00000000..8b3236ed
--- /dev/null
+++ b/src/llm_models/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2025 Mai.To.The.Gate
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/src/llm_models/__init__.py b/src/llm_models/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/src/llm_models/exceptions.py b/src/llm_models/exceptions.py
new file mode 100644
index 00000000..5b04f58c
--- /dev/null
+++ b/src/llm_models/exceptions.py
@@ -0,0 +1,98 @@
+from typing import Any
+
+
+# 常见Error Code Mapping (以OpenAI API为例)
+error_code_mapping = {
+ 400: "参数不正确",
+ 401: "API-Key错误,认证失败,请检查/config/model_list.toml中的配置是否正确",
+ 402: "账号余额不足",
+ 403: "模型拒绝访问,可能需要实名或余额不足",
+ 404: "Not Found",
+ 413: "请求体过大,请尝试压缩图片或减少输入内容",
+ 429: "请求过于频繁,请稍后再试",
+ 500: "服务器内部故障",
+ 503: "服务器负载过高",
+}
+
+
+class NetworkConnectionError(Exception):
+ """连接异常,常见于网络问题或服务器不可用"""
+
+ def __init__(self):
+ super().__init__()
+
+ def __str__(self):
+ return "连接异常,请检查网络连接状态或URL是否正确"
+
+
+class ReqAbortException(Exception):
+ """请求异常退出,常见于请求被中断或取消"""
+
+ def __init__(self, message: str | None = None):
+ super().__init__(message)
+ self.message = message
+
+ def __str__(self):
+ return self.message or "请求因未知原因异常终止"
+
+
+class RespNotOkException(Exception):
+ """请求响应异常,见于请求未能成功响应(非 '200 OK')"""
+
+ def __init__(self, status_code: int, message: str | None = None):
+ super().__init__(message)
+ self.status_code = status_code
+ self.message = message
+
+ def __str__(self):
+ if self.status_code in error_code_mapping:
+ return error_code_mapping[self.status_code]
+ elif self.message:
+ return self.message
+ else:
+ return f"未知的异常响应代码:{self.status_code}"
+
+
+class RespParseException(Exception):
+ """响应解析错误,常见于响应格式不正确或解析方法不匹配"""
+
+ def __init__(self, ext_info: Any, message: str | None = None):
+ super().__init__(message)
+ self.ext_info = ext_info
+ self.message = message
+
+ def __str__(self):
+ return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
+
+
+class PayLoadTooLargeError(Exception):
+ """自定义异常类,用于处理请求体过大错误"""
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+ def __str__(self):
+ return "请求体过大,请尝试压缩图片或减少输入内容。"
+
+
+class RequestAbortException(Exception):
+ """自定义异常类,用于处理请求中断异常"""
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+ def __str__(self):
+ return self.message
+
+
+class PermissionDeniedException(Exception):
+ """自定义异常类,用于处理访问拒绝的异常"""
+
+ def __init__(self, message: str):
+ super().__init__(message)
+ self.message = message
+
+ def __str__(self):
+ return self.message
diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py
new file mode 100644
index 00000000..80f7e115
--- /dev/null
+++ b/src/llm_models/model_client/__init__.py
@@ -0,0 +1,8 @@
+from src.config.config import model_config
+
+used_client_types = {provider.client_type for provider in model_config.api_providers}
+
+if "openai" in used_client_types:
+ from . import openai_client # noqa: F401
+if "gemini" in used_client_types:
+ from . import gemini_client # noqa: F401
diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py
new file mode 100644
index 00000000..8e8affba
--- /dev/null
+++ b/src/llm_models/model_client/base_client.py
@@ -0,0 +1,172 @@
+import asyncio
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+from typing import Callable, Any, Optional
+
+from src.config.api_ada_configs import ModelInfo, APIProvider
+from ..payload_content.message import Message
+from ..payload_content.resp_format import RespFormat
+from ..payload_content.tool_option import ToolOption, ToolCall
+
+
+@dataclass
+class UsageRecord:
+ """
+ 使用记录类
+ """
+
+ model_name: str
+ """模型名称"""
+
+ provider_name: str
+ """提供商名称"""
+
+ prompt_tokens: int
+ """提示token数"""
+
+ completion_tokens: int
+ """完成token数"""
+
+ total_tokens: int
+ """总token数"""
+
+
+@dataclass
+class APIResponse:
+ """
+ API响应类
+ """
+
+ content: str | None = None
+ """响应内容"""
+
+ reasoning_content: str | None = None
+ """推理内容"""
+
+ tool_calls: list[ToolCall] | None = None
+ """工具调用 [(工具名称, 工具参数), ...]"""
+
+ embedding: list[float] | None = None
+ """嵌入向量"""
+
+ usage: UsageRecord | None = None
+ """使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
+
+ raw_data: Any = None
+ """响应原始数据"""
+
+
+class BaseClient(ABC):
+ """
+ 基础客户端
+ """
+
+ api_provider: APIProvider
+
+ def __init__(self, api_provider: APIProvider):
+ self.api_provider = api_provider
+
+ @abstractmethod
+ async def get_response(
+ self,
+ model_info: ModelInfo,
+ message_list: list[Message],
+ tool_options: list[ToolOption] | None = None,
+ max_tokens: int = 1024,
+ temperature: float = 0.7,
+ response_format: RespFormat | None = None,
+ stream_response_handler: Optional[
+ Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
+ ] = None,
+ async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
+ interrupt_flag: asyncio.Event | None = None,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取对话响应
+ :param model_info: 模型信息
+ :param message_list: 对话体
+ :param tool_options: 工具选项(可选,默认为None)
+ :param max_tokens: 最大token数(可选,默认为1024)
+ :param temperature: 温度(可选,默认为0.7)
+ :param response_format: 响应格式(可选,默认为 NotGiven )
+ :param stream_response_handler: 流式响应处理函数(可选)
+ :param async_response_parser: 响应解析函数(可选)
+ :param interrupt_flag: 中断信号量(可选,默认为None)
+ :return: (响应文本, 推理文本, 工具调用, 其他数据)
+ """
+ raise NotImplementedError("'get_response' method should be overridden in subclasses")
+
+ @abstractmethod
+ async def get_embedding(
+ self,
+ model_info: ModelInfo,
+ embedding_input: str,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取文本嵌入
+ :param model_info: 模型信息
+ :param embedding_input: 嵌入输入文本
+ :return: 嵌入响应
+ """
+ raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
+
+ @abstractmethod
+ async def get_audio_transcriptions(
+ self,
+ model_info: ModelInfo,
+ audio_base64: str,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取音频转录
+ :param model_info: 模型信息
+ :param audio_base64: base64编码的音频数据
+ :extra_params: 附加的请求参数
+ :return: 音频转录响应
+ """
+ raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
+
+ @abstractmethod
+ def get_support_image_formats(self) -> list[str]:
+ """
+ 获取支持的图片格式
+ :return: 支持的图片格式列表
+ """
+ raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
+
+
+class ClientRegistry:
+ def __init__(self) -> None:
+ self.client_registry: dict[str, type[BaseClient]] = {}
+
+ def register_client_class(self, client_type: str):
+ """
+ 注册API客户端类
+ Args:
+ client_class: API客户端类
+ """
+
+ def decorator(cls: type[BaseClient]) -> type[BaseClient]:
+ if not issubclass(cls, BaseClient):
+ raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
+ self.client_registry[client_type] = cls
+ return cls
+
+ return decorator
+
+ def get_client_class(self, client_type: str) -> type[BaseClient]:
+ """
+ 获取注册的API客户端类
+ Args:
+ client_type: 客户端类型
+ Returns:
+ type[BaseClient]: 注册的API客户端类
+ """
+ if client_type not in self.client_registry:
+ raise KeyError(f"'{client_type}' 类型的 Client 未注册")
+ return self.client_registry[client_type]
+
+
+client_registry = ClientRegistry()
diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py
new file mode 100644
index 00000000..e4127029
--- /dev/null
+++ b/src/llm_models/model_client/gemini_client.py
@@ -0,0 +1,496 @@
+import asyncio
+import io
+import base64
+from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
+
+from google import genai
+from google.genai.types import (
+ Content,
+ Part,
+ FunctionDeclaration,
+ GenerateContentResponse,
+ ContentListUnion,
+ ContentUnion,
+ ThinkingConfig,
+ Tool,
+ GenerateContentConfig,
+ EmbedContentResponse,
+ EmbedContentConfig,
+)
+from google.genai.errors import (
+ ClientError,
+ ServerError,
+ UnknownFunctionCallArgumentError,
+ UnsupportedFunctionError,
+ FunctionInvocationError,
+)
+
+from src.config.api_ada_configs import ModelInfo, APIProvider
+from src.common.logger import get_logger
+
+from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
+from ..exceptions import (
+ RespParseException,
+ NetworkConnectionError,
+ RespNotOkException,
+ ReqAbortException,
+)
+from ..payload_content.message import Message, RoleType
+from ..payload_content.resp_format import RespFormat, RespFormatType
+from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
+
+logger = get_logger("Gemini客户端")
+
+
+def _convert_messages(
+ messages: list[Message],
+) -> tuple[ContentListUnion, list[str] | None]:
+ """
+ 转换消息格式 - 将消息转换为Gemini API所需的格式
+ :param messages: 消息列表
+ :return: 转换后的消息列表(和可能存在的system消息)
+ """
+
+ def _convert_message_item(message: Message) -> Content:
+ """
+ 转换单个消息格式,除了system和tool类型的消息
+ :param message: 消息对象
+ :return: 转换后的消息字典
+ """
+
+ # 将openai格式的角色重命名为gemini格式的角色
+ if message.role == RoleType.Assistant:
+ role = "model"
+ elif message.role == RoleType.User:
+ role = "user"
+
+ # 添加Content
+ if isinstance(message.content, str):
+ content = [Part.from_text(text=message.content)]
+ elif isinstance(message.content, list):
+ content: List[Part] = []
+ for item in message.content:
+ if isinstance(item, tuple):
+ content.append(
+ Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
+ )
+ elif isinstance(item, str):
+ content.append(Part.from_text(text=item))
+ else:
+ raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
+
+ return Content(role=role, parts=content)
+
+ temp_list: list[ContentUnion] = []
+ system_instructions: list[str] = []
+ for message in messages:
+ if message.role == RoleType.System:
+ if isinstance(message.content, str):
+ system_instructions.append(message.content)
+ else:
+ raise ValueError("你tm怎么往system里面塞图片base64?")
+ elif message.role == RoleType.Tool:
+ if not message.tool_call_id:
+ raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
+ else:
+ temp_list.append(_convert_message_item(message))
+ if system_instructions:
+ # 如果有system消息,就把它加上去
+ ret: tuple = (temp_list, system_instructions)
+ else:
+ # 如果没有system消息,就直接返回
+ ret: tuple = (temp_list, None)
+
+ return ret
+
+
+def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]:
+ """
+ 转换工具选项格式 - 将工具选项转换为Gemini API所需的格式
+ :param tool_options: 工具选项列表
+ :return: 转换后的工具对象列表
+ """
+
+ def _convert_tool_param(tool_option_param: ToolParam) -> dict:
+ """
+ 转换单个工具参数格式
+ :param tool_option_param: 工具参数对象
+ :return: 转换后的工具参数字典
+ """
+ return_dict: dict[str, Any] = {
+ "type": tool_option_param.param_type.value,
+ "description": tool_option_param.description,
+ }
+ if tool_option_param.enum_values:
+ return_dict["enum"] = tool_option_param.enum_values
+ return return_dict
+
+ def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration:
+ """
+ 转换单个工具项格式
+ :param tool_option: 工具选项对象
+ :return: 转换后的Gemini工具选项对象
+ """
+ ret: dict[str, Any] = {
+ "name": tool_option.name,
+ "description": tool_option.description,
+ }
+ if tool_option.params:
+ ret["parameters"] = {
+ "type": "object",
+ "properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
+ "required": [param.name for param in tool_option.params if param.required],
+ }
+ ret1 = FunctionDeclaration(**ret)
+ return ret1
+
+ return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
+
+
+def _process_delta(
+ delta: GenerateContentResponse,
+ fc_delta_buffer: io.StringIO,
+ tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
+):
+ if not hasattr(delta, "candidates") or not delta.candidates:
+ raise RespParseException(delta, "响应解析失败,缺失candidates字段")
+
+ if delta.text:
+ fc_delta_buffer.write(delta.text)
+
+ if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的
+ for call in delta.function_calls:
+ try:
+ if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
+ raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
+ if not call.id or not call.name:
+ raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段")
+ tool_calls_buffer.append(
+ (
+ call.id,
+ call.name,
+ call.args or {}, # 如果args是None,则转换为一个空字典
+ )
+ )
+ except Exception as e:
+ raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e
+
+
+def _build_stream_api_resp(
+ _fc_delta_buffer: io.StringIO,
+ _tool_calls_buffer: list[tuple[str, str, dict]],
+) -> APIResponse:
+ # sourcery skip: simplify-len-comparison, use-assigned-variable
+ resp = APIResponse()
+
+ if _fc_delta_buffer.tell() > 0:
+ # 如果正式内容缓冲区不为空,则将其写入APIResponse对象
+ resp.content = _fc_delta_buffer.getvalue()
+ _fc_delta_buffer.close()
+ if len(_tool_calls_buffer) > 0:
+ # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表
+ resp.tool_calls = []
+ for call_id, function_name, arguments_buffer in _tool_calls_buffer:
+ if arguments_buffer is not None:
+ arguments = arguments_buffer
+ if not isinstance(arguments, dict):
+ raise RespParseException(
+ None,
+ f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}",
+ )
+ else:
+ arguments = None
+
+ resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
+
+ return resp
+
+
+async def _default_stream_response_handler(
+ resp_stream: AsyncIterator[GenerateContentResponse],
+ interrupt_flag: asyncio.Event | None,
+) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
+ """
+ 流式响应处理函数 - 处理Gemini API的流式响应
+ :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西
+ :return: APIResponse对象
+ """
+ _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
+ _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
+ _usage_record = None # 使用情况记录
+
+ def _insure_buffer_closed():
+ if _fc_delta_buffer and not _fc_delta_buffer.closed:
+ _fc_delta_buffer.close()
+
+ async for chunk in resp_stream:
+ # 检查是否有中断量
+ if interrupt_flag and interrupt_flag.is_set():
+ # 如果中断量被设置,则抛出ReqAbortException
+ raise ReqAbortException("请求被外部信号中断")
+
+ _process_delta(
+ chunk,
+ _fc_delta_buffer,
+ _tool_calls_buffer,
+ )
+
+ if chunk.usage_metadata:
+ # 如果有使用情况,则将其存储在APIResponse对象中
+ _usage_record = (
+ chunk.usage_metadata.prompt_token_count or 0,
+ (chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
+ chunk.usage_metadata.total_token_count or 0,
+ )
+ try:
+ return _build_stream_api_resp(
+ _fc_delta_buffer,
+ _tool_calls_buffer,
+ ), _usage_record
+ except Exception:
+ # 确保缓冲区被关闭
+ _insure_buffer_closed()
+ raise
+
+
+def _default_normal_response_parser(
+ resp: GenerateContentResponse,
+) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
+ """
+ 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
+ :param resp: 响应对象
+ :return: APIResponse对象
+ """
+ api_response = APIResponse()
+
+ if not hasattr(resp, "candidates") or not resp.candidates:
+ raise RespParseException(resp, "响应解析失败,缺失candidates字段")
+ try:
+ if resp.candidates[0].content and resp.candidates[0].content.parts:
+ for part in resp.candidates[0].content.parts:
+ if not part.text:
+ continue
+ if part.thought:
+ api_response.reasoning_content = (
+ api_response.reasoning_content + part.text if api_response.reasoning_content else part.text
+ )
+ except Exception as e:
+ logger.warning(f"解析思考内容时发生错误: {e},跳过解析")
+
+ if resp.text:
+ api_response.content = resp.text
+
+ if resp.function_calls:
+ api_response.tool_calls = []
+ for call in resp.function_calls:
+ try:
+ if not isinstance(call.args, dict):
+ raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
+ if not call.name:
+ raise RespParseException(resp, "响应解析失败,工具调用缺失name字段")
+ api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {}))
+ except Exception as e:
+ raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
+
+ if resp.usage_metadata:
+ _usage_record = (
+ resp.usage_metadata.prompt_token_count or 0,
+ (resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
+ resp.usage_metadata.total_token_count or 0,
+ )
+ else:
+ _usage_record = None
+
+ api_response.raw_data = resp
+
+ return api_response, _usage_record
+
+
+@client_registry.register_client_class("gemini")
+class GeminiClient(BaseClient):
+ client: genai.Client
+
+ def __init__(self, api_provider: APIProvider):
+ super().__init__(api_provider)
+ self.client = genai.Client(
+ api_key=api_provider.api_key,
+ ) # 这里和openai不一样,gemini会自己决定自己是否需要retry
+
+ async def get_response(
+ self,
+ model_info: ModelInfo,
+ message_list: list[Message],
+ tool_options: list[ToolOption] | None = None,
+ max_tokens: int = 1024,
+ temperature: float = 0.7,
+ response_format: RespFormat | None = None,
+ stream_response_handler: Optional[
+ Callable[
+ [AsyncIterator[GenerateContentResponse], asyncio.Event | None],
+ Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
+ ]
+ ] = None,
+ async_response_parser: Optional[
+ Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
+ ] = None,
+ interrupt_flag: asyncio.Event | None = None,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取对话响应
+ Args:
+ model_info: 模型信息
+ message_list: 对话体
+ tool_options: 工具选项(可选,默认为None)
+ max_tokens: 最大token数(可选,默认为1024)
+ temperature: 温度(可选,默认为0.7)
+ response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入)
+ stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler)
+ async_response_parser: 响应解析函数(可选,默认为default_response_parser)
+ interrupt_flag: 中断信号量(可选,默认为None)
+ Returns:
+ APIResponse对象,包含响应内容、推理内容、工具调用等信息
+ """
+ if stream_response_handler is None:
+ stream_response_handler = _default_stream_response_handler
+
+ if async_response_parser is None:
+ async_response_parser = _default_normal_response_parser
+
+ # 将messages构造为Gemini API所需的格式
+ messages = _convert_messages(message_list)
+ # 将tool_options转换为Gemini API所需的格式
+ tools = _convert_tool_options(tool_options) if tool_options else None
+ # 将response_format转换为Gemini API所需的格式
+ generation_config_dict = {
+ "max_output_tokens": max_tokens,
+ "temperature": temperature,
+ "response_modalities": ["TEXT"],
+ "thinking_config": ThinkingConfig(
+ include_thoughts=True,
+ thinking_budget=(
+ extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None
+ ),
+ ),
+ }
+ if tools:
+ generation_config_dict["tools"] = Tool(function_declarations=tools)
+ if messages[1]:
+ # 如果有system消息,则将其添加到配置中
+ generation_config_dict["system_instructions"] = messages[1]
+ if response_format and response_format.format_type == RespFormatType.TEXT:
+ generation_config_dict["response_mime_type"] = "text/plain"
+ elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
+ generation_config_dict["response_mime_type"] = "application/json"
+ generation_config_dict["response_schema"] = response_format.to_dict()
+
+ generation_config = GenerateContentConfig(**generation_config_dict)
+
+ try:
+ if model_info.force_stream_mode:
+ req_task = asyncio.create_task(
+ self.client.aio.models.generate_content_stream(
+ model=model_info.model_identifier,
+ contents=messages[0],
+ config=generation_config,
+ )
+ )
+ while not req_task.done():
+ if interrupt_flag and interrupt_flag.is_set():
+ # 如果中断量存在且被设置,则取消任务并抛出异常
+ req_task.cancel()
+ raise ReqAbortException("请求被外部信号中断")
+ await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
+ resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
+ else:
+ req_task = asyncio.create_task(
+ self.client.aio.models.generate_content(
+ model=model_info.model_identifier,
+ contents=messages[0],
+ config=generation_config,
+ )
+ )
+ while not req_task.done():
+ if interrupt_flag and interrupt_flag.is_set():
+ # 如果中断量存在且被设置,则取消任务并抛出异常
+ req_task.cancel()
+ raise ReqAbortException("请求被外部信号中断")
+ await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
+
+ resp, usage_record = async_response_parser(req_task.result())
+ except (ClientError, ServerError) as e:
+ # 重封装ClientError和ServerError为RespNotOkException
+ raise RespNotOkException(e.code, e.message) from None
+ except (
+ UnknownFunctionCallArgumentError,
+ UnsupportedFunctionError,
+ FunctionInvocationError,
+ ) as e:
+ raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
+ except Exception as e:
+ raise NetworkConnectionError() from e
+
+ if usage_record:
+ resp.usage = UsageRecord(
+ model_name=model_info.name,
+ provider_name=model_info.api_provider,
+ prompt_tokens=usage_record[0],
+ completion_tokens=usage_record[1],
+ total_tokens=usage_record[2],
+ )
+
+ return resp
+
+ async def get_embedding(
+ self,
+ model_info: ModelInfo,
+ embedding_input: str,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取文本嵌入
+ :param model_info: 模型信息
+ :param embedding_input: 嵌入输入文本
+ :return: 嵌入响应
+ """
+ try:
+ raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
+ model=model_info.model_identifier,
+ contents=embedding_input,
+ config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
+ )
+ except (ClientError, ServerError) as e:
+ # 重封装ClientError和ServerError为RespNotOkException
+ raise RespNotOkException(e.code) from None
+ except Exception as e:
+ raise NetworkConnectionError() from e
+
+ response = APIResponse()
+
+ # 解析嵌入响应和使用情况
+ if hasattr(raw_response, "embeddings") and raw_response.embeddings:
+ response.embedding = raw_response.embeddings[0].values
+ else:
+ raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段")
+
+ response.usage = UsageRecord(
+ model_name=model_info.name,
+ provider_name=model_info.api_provider,
+ prompt_tokens=len(embedding_input),
+ completion_tokens=0,
+ total_tokens=len(embedding_input),
+ )
+
+ return response
+
+ def get_audio_transcriptions(
+ self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
+ ) -> APIResponse:
+ raise NotImplementedError("尚未实现音频转录功能")
+
+ def get_support_image_formats(self) -> list[str]:
+ """
+ 获取支持的图片格式
+ :return: 支持的图片格式列表
+ """
+ return ["png", "jpg", "jpeg", "webp", "heic", "heif"]
diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py
new file mode 100644
index 00000000..ad9cbf17
--- /dev/null
+++ b/src/llm_models/model_client/openai_client.py
@@ -0,0 +1,580 @@
+import asyncio
+import io
+import json
+import re
+import base64
+from collections.abc import Iterable
+from typing import Callable, Any, Coroutine, Optional
+from json_repair import repair_json
+
+from openai import (
+ AsyncOpenAI,
+ APIConnectionError,
+ APIStatusError,
+ NOT_GIVEN,
+ AsyncStream,
+)
+from openai.types.chat import (
+ ChatCompletion,
+ ChatCompletionChunk,
+ ChatCompletionMessageParam,
+ ChatCompletionToolParam,
+)
+from openai.types.chat.chat_completion_chunk import ChoiceDelta
+
+from src.config.api_ada_configs import ModelInfo, APIProvider
+from src.common.logger import get_logger
+from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
+from ..exceptions import (
+ RespParseException,
+ NetworkConnectionError,
+ RespNotOkException,
+ ReqAbortException,
+)
+from ..payload_content.message import Message, RoleType
+from ..payload_content.resp_format import RespFormat
+from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
+
+logger = get_logger("OpenAI客户端")
+
+
+def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
+ """
+ 转换消息格式 - 将消息转换为OpenAI API所需的格式
+ :param messages: 消息列表
+ :return: 转换后的消息列表
+ """
+
+ def _convert_message_item(message: Message) -> ChatCompletionMessageParam:
+ """
+ 转换单个消息格式
+ :param message: 消息对象
+ :return: 转换后的消息字典
+ """
+
+ # 添加Content
+ content: str | list[dict[str, Any]]
+ if isinstance(message.content, str):
+ content = message.content
+ elif isinstance(message.content, list):
+ content = []
+ for item in message.content:
+ if isinstance(item, tuple):
+ content.append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
+ }
+ )
+ elif isinstance(item, str):
+ content.append({"type": "text", "text": item})
+ else:
+ raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
+
+ ret = {
+ "role": message.role.value,
+ "content": content,
+ }
+
+ # 添加工具调用ID
+ if message.role == RoleType.Tool:
+ if not message.tool_call_id:
+ raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象")
+ ret["tool_call_id"] = message.tool_call_id
+
+ return ret # type: ignore
+
+ return [_convert_message_item(message) for message in messages]
+
+
+def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
+ """
+ 转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式
+ :param tool_options: 工具选项列表
+ :return: 转换后的工具选项列表
+ """
+
+ def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]:
+ """
+ 转换单个工具参数格式
+ :param tool_option_param: 工具参数对象
+ :return: 转换后的工具参数字典
+ """
+ return_dict: dict[str, Any] = {
+ "type": tool_option_param.param_type.value,
+ "description": tool_option_param.description,
+ }
+ if tool_option_param.enum_values:
+ return_dict["enum"] = tool_option_param.enum_values
+ return return_dict
+
+ def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
+ """
+ 转换单个工具项格式
+ :param tool_option: 工具选项对象
+ :return: 转换后的工具选项字典
+ """
+ ret: dict[str, Any] = {
+ "name": tool_option.name,
+ "description": tool_option.description,
+ }
+ if tool_option.params:
+ ret["parameters"] = {
+ "type": "object",
+ "properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
+ "required": [param.name for param in tool_option.params if param.required],
+ }
+ return ret
+
+ return [
+ {
+ "type": "function",
+ "function": _convert_tool_option_item(tool_option),
+ }
+ for tool_option in tool_options
+ ]
+
+
+def _process_delta(
+ delta: ChoiceDelta,
+ has_rc_attr_flag: bool,
+ in_rc_flag: bool,
+ rc_delta_buffer: io.StringIO,
+ fc_delta_buffer: io.StringIO,
+ tool_calls_buffer: list[tuple[str, str, io.StringIO]],
+) -> bool:
+ # 接收content
+ if has_rc_attr_flag:
+ # 有独立的推理内容块,则无需考虑content内容的判读
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
+ # 如果有推理内容,则将其写入推理内容缓冲区
+ assert isinstance(delta.reasoning_content, str) # type: ignore
+ rc_delta_buffer.write(delta.reasoning_content) # type: ignore
+ elif delta.content:
+ # 如果有正式内容,则将其写入正式内容缓冲区
+ fc_delta_buffer.write(delta.content)
+ elif hasattr(delta, "content") and delta.content is not None:
+ # 没有独立的推理内容块,但有正式内容
+ if in_rc_flag:
+ # 当前在推理内容块中
+ if delta.content == "":
+ # 如果当前内容是,则将其视为推理内容的结束标记,退出推理内容块
+ in_rc_flag = False
+ else:
+ # 其他情况视为推理内容,加入推理内容缓冲区
+ rc_delta_buffer.write(delta.content)
+ elif delta.content == "" and not fc_delta_buffer.getvalue():
+ # 如果当前内容是,且正式内容缓冲区为空,说明为输出的首个token
+ # 则将其视为推理内容的开始标记,进入推理内容块
+ in_rc_flag = True
+ else:
+ # 其他情况视为正式内容,加入正式内容缓冲区
+ fc_delta_buffer.write(delta.content)
+ # 接收tool_calls
+ if hasattr(delta, "tool_calls") and delta.tool_calls:
+ tool_call_delta = delta.tool_calls[0]
+
+ if tool_call_delta.index >= len(tool_calls_buffer):
+ # 调用索引号大于等于缓冲区长度,说明是新的工具调用
+ if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
+ tool_calls_buffer.append(
+ (
+ tool_call_delta.id,
+ tool_call_delta.function.name,
+ io.StringIO(),
+ )
+ )
+ else:
+ logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。")
+
+ if tool_call_delta.function and tool_call_delta.function.arguments:
+ # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
+ tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
+
+ return in_rc_flag
+
+
+def _build_stream_api_resp(
+ _fc_delta_buffer: io.StringIO,
+ _rc_delta_buffer: io.StringIO,
+ _tool_calls_buffer: list[tuple[str, str, io.StringIO]],
+) -> APIResponse:
+ resp = APIResponse()
+
+ if _rc_delta_buffer.tell() > 0:
+ # 如果推理内容缓冲区不为空,则将其写入APIResponse对象
+ resp.reasoning_content = _rc_delta_buffer.getvalue()
+ _rc_delta_buffer.close()
+ if _fc_delta_buffer.tell() > 0:
+ # 如果正式内容缓冲区不为空,则将其写入APIResponse对象
+ resp.content = _fc_delta_buffer.getvalue()
+ _fc_delta_buffer.close()
+ if _tool_calls_buffer:
+ # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表
+ resp.tool_calls = []
+ for call_id, function_name, arguments_buffer in _tool_calls_buffer:
+ if arguments_buffer.tell() > 0:
+ # 如果参数串缓冲区不为空,则解析为JSON对象
+ raw_arg_data = arguments_buffer.getvalue()
+ arguments_buffer.close()
+ try:
+ arguments = json.loads(repair_json(raw_arg_data))
+ if not isinstance(arguments, dict):
+ raise RespParseException(
+ None,
+ f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}",
+ )
+ except json.JSONDecodeError as e:
+ raise RespParseException(
+ None,
+ f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}",
+ ) from e
+ else:
+ arguments_buffer.close()
+ arguments = None
+
+ resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
+
+ return resp
+
+
+async def _default_stream_response_handler(
+ resp_stream: AsyncStream[ChatCompletionChunk],
+ interrupt_flag: asyncio.Event | None,
+) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
+ """
+ 流式响应处理函数 - 处理OpenAI API的流式响应
+ :param resp_stream: 流式响应对象
+ :return: APIResponse对象
+ """
+
+ _has_rc_attr_flag = False # 标记是否有独立的推理内容块
+ _in_rc_flag = False # 标记是否在推理内容块中
+ _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
+ _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
+ _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
+ _usage_record = None # 使用情况记录
+
+ def _insure_buffer_closed():
+ # 确保缓冲区被关闭
+ if _rc_delta_buffer and not _rc_delta_buffer.closed:
+ _rc_delta_buffer.close()
+ if _fc_delta_buffer and not _fc_delta_buffer.closed:
+ _fc_delta_buffer.close()
+ for _, _, buffer in _tool_calls_buffer:
+ if buffer and not buffer.closed:
+ buffer.close()
+
+ async for event in resp_stream:
+ if interrupt_flag and interrupt_flag.is_set():
+ # 如果中断量被设置,则抛出ReqAbortException
+ _insure_buffer_closed()
+ raise ReqAbortException("请求被外部信号中断")
+
+ delta = event.choices[0].delta # 获取当前块的delta内容
+
+ if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
+ # 标记:有独立的推理内容块
+ _has_rc_attr_flag = True
+
+ _in_rc_flag = _process_delta(
+ delta,
+ _has_rc_attr_flag,
+ _in_rc_flag,
+ _rc_delta_buffer,
+ _fc_delta_buffer,
+ _tool_calls_buffer,
+ )
+
+ if event.usage:
+ # 如果有使用情况,则将其存储在APIResponse对象中
+ _usage_record = (
+ event.usage.prompt_tokens or 0,
+ event.usage.completion_tokens or 0,
+ event.usage.total_tokens or 0,
+ )
+
+ try:
+ return _build_stream_api_resp(
+ _fc_delta_buffer,
+ _rc_delta_buffer,
+ _tool_calls_buffer,
+ ), _usage_record
+ except Exception:
+ # 确保缓冲区被关闭
+ _insure_buffer_closed()
+ raise
+
+
+pattern = re.compile(
+ r"(?P.*?)(?P.*)|(?P.*)|(?P.+)",
+ re.DOTALL,
+)
+"""用于解析推理内容的正则表达式"""
+
+
+def _default_normal_response_parser(
+ resp: ChatCompletion,
+) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
+ """
+ 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
+ :param resp: 响应对象
+ :return: APIResponse对象
+ """
+ api_response = APIResponse()
+
+ if not hasattr(resp, "choices") or len(resp.choices) == 0:
+ raise RespParseException(resp, "响应解析失败,缺失choices字段")
+ message_part = resp.choices[0].message
+
+ if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
+ # 有有效的推理字段
+ api_response.content = message_part.content
+ api_response.reasoning_content = message_part.reasoning_content # type: ignore
+ elif message_part.content:
+ # 提取推理和内容
+ match = pattern.match(message_part.content)
+ if not match:
+ raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容")
+ if match.group("think") is not None:
+ result = match.group("think").strip(), match.group("content").strip()
+ elif match.group("think_unclosed") is not None:
+ result = match.group("think_unclosed").strip(), None
+ else:
+ result = None, match.group("content_only").strip()
+ api_response.reasoning_content, api_response.content = result
+
+ # 提取工具调用
+ if message_part.tool_calls:
+ api_response.tool_calls = []
+ for call in message_part.tool_calls:
+ try:
+ arguments = json.loads(repair_json(call.function.arguments))
+ if not isinstance(arguments, dict):
+ raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
+ api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
+ except json.JSONDecodeError as e:
+ raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
+
+ # 提取Usage信息
+ if resp.usage:
+ _usage_record = (
+ resp.usage.prompt_tokens or 0,
+ resp.usage.completion_tokens or 0,
+ resp.usage.total_tokens or 0,
+ )
+ else:
+ _usage_record = None
+
+ # 将原始响应存储在原始数据中
+ api_response.raw_data = resp
+
+ return api_response, _usage_record
+
+
+@client_registry.register_client_class("openai")
+class OpenaiClient(BaseClient):
+ def __init__(self, api_provider: APIProvider):
+ super().__init__(api_provider)
+ self.client: AsyncOpenAI = AsyncOpenAI(
+ base_url=api_provider.base_url,
+ api_key=api_provider.api_key,
+ max_retries=0,
+ )
+
+ async def get_response(
+ self,
+ model_info: ModelInfo,
+ message_list: list[Message],
+ tool_options: list[ToolOption] | None = None,
+ max_tokens: int = 1024,
+ temperature: float = 0.7,
+ response_format: RespFormat | None = None,
+ stream_response_handler: Optional[
+ Callable[
+ [AsyncStream[ChatCompletionChunk], asyncio.Event | None],
+ Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
+ ]
+ ] = None,
+ async_response_parser: Optional[
+ Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
+ ] = None,
+ interrupt_flag: asyncio.Event | None = None,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取对话响应
+ Args:
+ model_info: 模型信息
+ message_list: 对话体
+ tool_options: 工具选项(可选,默认为None)
+ max_tokens: 最大token数(可选,默认为1024)
+ temperature: 温度(可选,默认为0.7)
+ response_format: 响应格式(可选,默认为 NotGiven )
+ stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler)
+ async_response_parser: 响应解析函数(可选,默认为default_response_parser)
+ interrupt_flag: 中断信号量(可选,默认为None)
+ Returns:
+ (响应文本, 推理文本, 工具调用, 其他数据)
+ """
+ if stream_response_handler is None:
+ stream_response_handler = _default_stream_response_handler
+
+ if async_response_parser is None:
+ async_response_parser = _default_normal_response_parser
+
+ # 将messages构造为OpenAI API所需的格式
+ messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
+ # 将tool_options转换为OpenAI API所需的格式
+ tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
+
+ try:
+ if model_info.force_stream_mode:
+ req_task = asyncio.create_task(
+ self.client.chat.completions.create(
+ model=model_info.model_identifier,
+ messages=messages,
+ tools=tools,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ stream=True,
+ response_format=NOT_GIVEN,
+ extra_body=extra_params,
+ )
+ )
+ while not req_task.done():
+ if interrupt_flag and interrupt_flag.is_set():
+ # 如果中断量存在且被设置,则取消任务并抛出异常
+ req_task.cancel()
+ raise ReqAbortException("请求被外部信号中断")
+ await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
+
+ resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
+ else:
+ # 发送请求并获取响应
+ req_task = asyncio.create_task(
+ self.client.chat.completions.create(
+ model=model_info.model_identifier,
+ messages=messages,
+ tools=tools,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ stream=False,
+ response_format=NOT_GIVEN,
+ extra_body=extra_params,
+ )
+ )
+ while not req_task.done():
+ if interrupt_flag and interrupt_flag.is_set():
+ # 如果中断量存在且被设置,则取消任务并抛出异常
+ req_task.cancel()
+ raise ReqAbortException("请求被外部信号中断")
+ await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
+
+ resp, usage_record = async_response_parser(req_task.result())
+ except APIConnectionError as e:
+ # 重封装APIConnectionError为NetworkConnectionError
+ raise NetworkConnectionError() from e
+ except APIStatusError as e:
+ # 重封装APIError为RespNotOkException
+ raise RespNotOkException(e.status_code, e.message) from e
+
+ if usage_record:
+ resp.usage = UsageRecord(
+ model_name=model_info.name,
+ provider_name=model_info.api_provider,
+ prompt_tokens=usage_record[0],
+ completion_tokens=usage_record[1],
+ total_tokens=usage_record[2],
+ )
+
+ return resp
+
+ async def get_embedding(
+ self,
+ model_info: ModelInfo,
+ embedding_input: str,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取文本嵌入
+ :param model_info: 模型信息
+ :param embedding_input: 嵌入输入文本
+ :return: 嵌入响应
+ """
+ try:
+ raw_response = await self.client.embeddings.create(
+ model=model_info.model_identifier,
+ input=embedding_input,
+ extra_body=extra_params,
+ )
+ except APIConnectionError as e:
+ raise NetworkConnectionError() from e
+ except APIStatusError as e:
+ # 重封装APIError为RespNotOkException
+ raise RespNotOkException(e.status_code) from e
+
+ response = APIResponse()
+
+ # 解析嵌入响应
+ if len(raw_response.data) > 0:
+ response.embedding = raw_response.data[0].embedding
+ else:
+ raise RespParseException(
+ raw_response,
+ "响应解析失败,缺失嵌入数据。",
+ )
+
+ # 解析使用情况
+ if hasattr(raw_response, "usage"):
+ response.usage = UsageRecord(
+ model_name=model_info.name,
+ provider_name=model_info.api_provider,
+ prompt_tokens=raw_response.usage.prompt_tokens or 0,
+ completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
+ total_tokens=raw_response.usage.total_tokens or 0,
+ )
+
+ return response
+
+ async def get_audio_transcriptions(
+ self,
+ model_info: ModelInfo,
+ audio_base64: str,
+ extra_params: dict[str, Any] | None = None,
+ ) -> APIResponse:
+ """
+ 获取音频转录
+ :param model_info: 模型信息
+ :param audio_base64: base64编码的音频数据
+ :extra_params: 附加的请求参数
+ :return: 音频转录响应
+ """
+ try:
+ raw_response = await self.client.audio.transcriptions.create(
+ model=model_info.model_identifier,
+ file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
+ extra_body=extra_params,
+ )
+ except APIConnectionError as e:
+ raise NetworkConnectionError() from e
+ except APIStatusError as e:
+ # 重封装APIError为RespNotOkException
+ raise RespNotOkException(e.status_code) from e
+ response = APIResponse()
+ # 解析转录响应
+ if hasattr(raw_response, "text"):
+ response.content = raw_response.text
+ else:
+ raise RespParseException(
+ raw_response,
+ "响应解析失败,缺失转录文本。",
+ )
+ return response
+
+ def get_support_image_formats(self) -> list[str]:
+ """
+ 获取支持的图片格式
+ :return: 支持的图片格式列表
+ """
+ return ["jpg", "jpeg", "png", "webp", "gif"]
diff --git a/src/llm_models/payload_content/__init__.py b/src/llm_models/payload_content/__init__.py
new file mode 100644
index 00000000..33e43c5e
--- /dev/null
+++ b/src/llm_models/payload_content/__init__.py
@@ -0,0 +1,3 @@
+from .tool_option import ToolCall
+
+__all__ = ["ToolCall"]
\ No newline at end of file
diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py
new file mode 100644
index 00000000..f70c3ded
--- /dev/null
+++ b/src/llm_models/payload_content/message.py
@@ -0,0 +1,107 @@
+from enum import Enum
+
+
+# 设计这系列类的目的是为未来可能的扩展做准备
+
+
+class RoleType(Enum):
+ System = "system"
+ User = "user"
+ Assistant = "assistant"
+ Tool = "tool"
+
+
+SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
+
+
+class Message:
+ def __init__(
+ self,
+ role: RoleType,
+ content: str | list[tuple[str, str] | str],
+ tool_call_id: str | None = None,
+ ):
+ """
+ 初始化消息对象
+ (不应直接修改Message类,而应使用MessageBuilder类来构建对象)
+ """
+ self.role: RoleType = role
+ self.content: str | list[tuple[str, str] | str] = content
+ self.tool_call_id: str | None = tool_call_id
+
+
+class MessageBuilder:
+ def __init__(self):
+ self.__role: RoleType = RoleType.User
+ self.__content: list[tuple[str, str] | str] = []
+ self.__tool_call_id: str | None = None
+
+ def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
+ """
+ 设置角色(默认为User)
+ :param role: 角色
+ :return: MessageBuilder对象
+ """
+ self.__role = role
+ return self
+
+ def add_text_content(self, text: str) -> "MessageBuilder":
+ """
+ 添加文本内容
+ :param text: 文本内容
+ :return: MessageBuilder对象
+ """
+ self.__content.append(text)
+ return self
+
+ def add_image_content(
+ self,
+ image_format: str,
+ image_base64: str,
+ support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
+ ) -> "MessageBuilder":
+ """
+ 添加图片内容
+ :param image_format: 图片格式
+ :param image_base64: 图片的base64编码
+ :return: MessageBuilder对象
+ """
+ if image_format.lower() not in support_formats:
+ raise ValueError("不受支持的图片格式")
+ if not image_base64:
+ raise ValueError("图片的base64编码不能为空")
+ self.__content.append((image_format, image_base64))
+ return self
+
+ def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
+ """
+ 添加工具调用指令(调用时请确保已设置为Tool角色)
+ :param tool_call_id: 工具调用指令的id
+ :return: MessageBuilder对象
+ """
+ if self.__role != RoleType.Tool:
+ raise ValueError("仅当角色为Tool时才能添加工具调用ID")
+ if not tool_call_id:
+ raise ValueError("工具调用ID不能为空")
+ self.__tool_call_id = tool_call_id
+ return self
+
+ def build(self) -> Message:
+ """
+ 构建消息对象
+ :return: Message对象
+ """
+ if len(self.__content) == 0:
+ raise ValueError("内容不能为空")
+ if self.__role == RoleType.Tool and self.__tool_call_id is None:
+ raise ValueError("Tool角色的工具调用ID不能为空")
+
+ return Message(
+ role=self.__role,
+ content=(
+ self.__content[0]
+ if (len(self.__content) == 1 and isinstance(self.__content[0], str))
+ else self.__content
+ ),
+ tool_call_id=self.__tool_call_id,
+ )
diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py
new file mode 100644
index 00000000..ab2e2edf
--- /dev/null
+++ b/src/llm_models/payload_content/resp_format.py
@@ -0,0 +1,223 @@
+from enum import Enum
+from typing import Optional, Any
+
+from pydantic import BaseModel
+from typing_extensions import TypedDict, Required
+
+
+class RespFormatType(Enum):
+ TEXT = "text" # 文本
+ JSON_OBJ = "json_object" # JSON
+ JSON_SCHEMA = "json_schema" # JSON Schema
+
+
+class JsonSchema(TypedDict, total=False):
+ name: Required[str]
+ """
+ The name of the response format.
+
+ Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
+ of 64.
+ """
+
+ description: Optional[str]
+ """
+ A description of what the response format is for, used by the model to determine
+ how to respond in the format.
+ """
+
+ schema: dict[str, object]
+ """
+ The schema for the response format, described as a JSON Schema object. Learn how
+ to build JSON schemas [here](https://json-schema.org/).
+ """
+
+ strict: Optional[bool]
+ """
+ Whether to enable strict schema adherence when generating the output. If set to
+ true, the model will always follow the exact schema defined in the `schema`
+ field. Only a subset of JSON Schema is supported when `strict` is `true`. To
+ learn more, read the
+ [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
+ """
+
+
+def _json_schema_type_check(instance) -> str | None:
+ if "name" not in instance:
+ return "schema必须包含'name'字段"
+ elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
+ return "schema的'name'字段必须是非空字符串"
+ if "description" in instance and (
+ not isinstance(instance["description"], str)
+ or instance["description"].strip() == ""
+ ):
+ return "schema的'description'字段只能填入非空字符串"
+ if "schema" not in instance:
+ return "schema必须包含'schema'字段"
+ elif not isinstance(instance["schema"], dict):
+ return "schema的'schema'字段必须是字典,详见https://json-schema.org/"
+ if "strict" in instance and not isinstance(instance["strict"], bool):
+ return "schema的'strict'字段只能填入布尔值"
+
+ return None
+
+
+def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
+ """
+ 递归移除JSON Schema中的title字段
+ """
+ if isinstance(schema, list):
+ # 如果当前Schema是列表,则对所有dict/list子元素递归调用
+ for idx, item in enumerate(schema):
+ if isinstance(item, (dict, list)):
+ schema[idx] = _remove_title(item)
+ elif isinstance(schema, dict):
+ # 是字典,移除title字段,并对所有dict/list子元素递归调用
+ if "title" in schema:
+ del schema["title"]
+ for key, value in schema.items():
+ if isinstance(value, (dict, list)):
+ schema[key] = _remove_title(value)
+
+ return schema
+
+
+def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
+ """
+ 链接JSON Schema中的definitions字段
+ """
+
+ def link_definitions_recursive(
+ path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
+ ) -> dict[str, Any]:
+ """
+ 递归链接JSON Schema中的definitions字段
+ :param path: 当前路径
+ :param sub_schema: 子Schema
+ :param defs: Schema定义集
+ :return:
+ """
+ if isinstance(sub_schema, list):
+ # 如果当前Schema是列表,则遍历每个元素
+ for i in range(len(sub_schema)):
+ if isinstance(sub_schema[i], dict):
+ sub_schema[i] = link_definitions_recursive(
+ f"{path}/{str(i)}", sub_schema[i], defs
+ )
+ else:
+ # 否则为字典
+ if "$defs" in sub_schema:
+ # 如果当前Schema有$def字段,则将其添加到defs中
+ key_prefix = f"{path}/$defs/"
+ for key, value in sub_schema["$defs"].items():
+ def_key = key_prefix + key
+ if def_key not in defs:
+ defs[def_key] = value
+ del sub_schema["$defs"]
+ if "$ref" in sub_schema:
+ # 如果当前Schema有$ref字段,则将其替换为defs中的定义
+ def_key = sub_schema["$ref"]
+ if def_key in defs:
+ sub_schema = defs[def_key]
+ else:
+ raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
+ # 遍历键值对
+ for key, value in sub_schema.items():
+ if isinstance(value, (dict, list)):
+ # 如果当前值是字典或列表,则递归调用
+ sub_schema[key] = link_definitions_recursive(
+ f"{path}/{key}", value, defs
+ )
+
+ return sub_schema
+
+ return link_definitions_recursive("#", schema, {})
+
+
+def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
+ """
+ 递归移除JSON Schema中的$defs字段
+ """
+ if isinstance(schema, list):
+ # 如果当前Schema是列表,则对所有dict/list子元素递归调用
+ for idx, item in enumerate(schema):
+ if isinstance(item, (dict, list)):
+ schema[idx] = _remove_title(item)
+ elif isinstance(schema, dict):
+ # 是字典,移除title字段,并对所有dict/list子元素递归调用
+ if "$defs" in schema:
+ del schema["$defs"]
+ for key, value in schema.items():
+ if isinstance(value, (dict, list)):
+ schema[key] = _remove_title(value)
+
+ return schema
+
+
+class RespFormat:
+ """
+ 响应格式
+ """
+
+ @staticmethod
+ def _generate_schema_from_model(schema):
+ json_schema = {
+ "name": schema.__name__,
+ "schema": _remove_defs(
+ _link_definitions(_remove_title(schema.model_json_schema()))
+ ),
+ "strict": False,
+ }
+ if schema.__doc__:
+ json_schema["description"] = schema.__doc__
+ return json_schema
+
+ def __init__(
+ self,
+ format_type: RespFormatType = RespFormatType.TEXT,
+ schema: type | JsonSchema | None = None,
+ ):
+ """
+ 响应格式
+ :param format_type: 响应格式类型(默认为文本)
+ :param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效)
+ """
+ self.format_type: RespFormatType = format_type
+
+ if format_type == RespFormatType.JSON_SCHEMA:
+ if schema is None:
+ raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空")
+ if isinstance(schema, dict):
+ if check_msg := _json_schema_type_check(schema):
+ raise ValueError(f"schema格式不正确,{check_msg}")
+
+ self.schema = schema
+ elif issubclass(schema, BaseModel):
+ try:
+ json_schema = self._generate_schema_from_model(schema)
+
+ self.schema = json_schema
+ except Exception as e:
+ raise ValueError(
+ f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n"
+ f"{schema.__name__}:\n"
+ ) from e
+ else:
+ raise ValueError("schema必须是BaseModel的子类或JsonSchema")
+ else:
+ self.schema = None
+
+ def to_dict(self):
+ """
+ 将响应格式转换为字典
+ :return: 字典
+ """
+ if self.schema:
+ return {
+ "format_type": self.format_type.value,
+ "schema": self.schema,
+ }
+ else:
+ return {
+ "format_type": self.format_type.value,
+ }
diff --git a/src/llm_models/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py
new file mode 100644
index 00000000..9fedbc86
--- /dev/null
+++ b/src/llm_models/payload_content/tool_option.py
@@ -0,0 +1,163 @@
+from enum import Enum
+
+
+class ToolParamType(Enum):
+ """
+ 工具调用参数类型
+ """
+
+ STRING = "string" # 字符串
+ INTEGER = "integer" # 整型
+ FLOAT = "float" # 浮点型
+ BOOLEAN = "bool" # 布尔型
+
+
+class ToolParam:
+ """
+ 工具调用参数
+ """
+
+ def __init__(
+ self,
+ name: str,
+ param_type: ToolParamType,
+ description: str,
+ required: bool,
+ enum_values: list[str] | None = None,
+ ):
+ """
+ 初始化工具调用参数
+ (不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象)
+ :param name: 参数名称
+ :param param_type: 参数类型
+ :param description: 参数描述
+ :param required: 是否必填
+ """
+ self.name: str = name
+ self.param_type: ToolParamType = param_type
+ self.description: str = description
+ self.required: bool = required
+ self.enum_values: list[str] | None = enum_values
+
+
+class ToolOption:
+ """
+ 工具调用项
+ """
+
+ def __init__(
+ self,
+ name: str,
+ description: str,
+ params: list[ToolParam] | None = None,
+ ):
+ """
+ 初始化工具调用项
+ (不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象)
+ :param name: 工具名称
+ :param description: 工具描述
+ :param params: 工具参数列表
+ """
+ self.name: str = name
+ self.description: str = description
+ self.params: list[ToolParam] | None = params
+
+
+class ToolOptionBuilder:
+ """
+ 工具调用项构建器
+ """
+
+ def __init__(self):
+ self.__name: str = ""
+ self.__description: str = ""
+ self.__params: list[ToolParam] = []
+
+ def set_name(self, name: str) -> "ToolOptionBuilder":
+ """
+ 设置工具名称
+ :param name: 工具名称
+ :return: ToolBuilder实例
+ """
+ if not name:
+ raise ValueError("工具名称不能为空")
+ self.__name = name
+ return self
+
+ def set_description(self, description: str) -> "ToolOptionBuilder":
+ """
+ 设置工具描述
+ :param description: 工具描述
+ :return: ToolBuilder实例
+ """
+ if not description:
+ raise ValueError("工具描述不能为空")
+ self.__description = description
+ return self
+
+ def add_param(
+ self,
+ name: str,
+ param_type: ToolParamType,
+ description: str,
+ required: bool = False,
+ enum_values: list[str] | None = None,
+ ) -> "ToolOptionBuilder":
+ """
+ 添加工具参数
+ :param name: 参数名称
+ :param param_type: 参数类型
+ :param description: 参数描述
+ :param required: 是否必填(默认为False)
+ :return: ToolBuilder实例
+ """
+ if not name or not description:
+ raise ValueError("参数名称/描述不能为空")
+
+ self.__params.append(
+ ToolParam(
+ name=name,
+ param_type=param_type,
+ description=description,
+ required=required,
+ enum_values=enum_values,
+ )
+ )
+
+ return self
+
+ def build(self):
+ """
+ 构建工具调用项
+ :return: 工具调用项
+ """
+ if self.__name == "" or self.__description == "":
+ raise ValueError("工具名称/描述不能为空")
+
+ return ToolOption(
+ name=self.__name,
+ description=self.__description,
+ params=None if len(self.__params) == 0 else self.__params,
+ )
+
+
+class ToolCall:
+ """
+ 来自模型反馈的工具调用
+ """
+
+ def __init__(
+ self,
+ call_id: str,
+ func_name: str,
+ args: dict | None = None,
+ ):
+ """
+ 初始化工具调用
+ :param call_id: 工具调用ID
+ :param func_name: 要调用的函数名称
+ :param args: 工具调用参数
+ """
+ self.call_id: str = call_id
+ self.func_name: str = func_name
+ self.args: dict | None = args
diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py
new file mode 100644
index 00000000..52a6120c
--- /dev/null
+++ b/src/llm_models/utils.py
@@ -0,0 +1,186 @@
+import base64
+import io
+
+from PIL import Image
+from datetime import datetime
+
+from src.common.logger import get_logger
+from src.common.database.database import db # 确保 db 被导入用于 create_tables
+from src.common.database.database_model import LLMUsage
+from src.config.api_ada_configs import ModelInfo
+from .payload_content.message import Message, MessageBuilder
+from .model_client.base_client import UsageRecord
+
+logger = get_logger("消息压缩工具")
+
+
+def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
+ """
+ 压缩消息列表中的图片
+ :param messages: 消息列表
+ :param img_target_size: 图片目标大小,默认1MB
+ :return: 压缩后的消息列表
+ """
+
+ def reformat_static_image(image_data: bytes) -> bytes:
+ """
+ 将静态图片转换为JPEG格式
+ :param image_data: 图片数据
+ :return: 转换后的图片数据
+ """
+ try:
+ image = Image.open(image_data)
+
+ if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
+ # 静态图像,转换为JPEG格式
+ reformated_image_data = io.BytesIO()
+ image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
+ image_data = reformated_image_data.getvalue()
+
+ return image_data
+ except Exception as e:
+ logger.error(f"图片转换格式失败: {str(e)}")
+ return image_data
+
+ def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
+ """
+ 缩放图片
+ :param image_data: 图片数据
+ :param scale: 缩放比例
+ :return: 缩放后的图片数据
+ """
+ try:
+ image = Image.open(image_data)
+
+ # 原始尺寸
+ original_size = (image.width, image.height)
+
+ # 计算新的尺寸
+ new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
+
+ output_buffer = io.BytesIO()
+
+ if getattr(image, "is_animated", False):
+ # 动态图片,处理所有帧
+ frames = []
+ new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
+ for frame_idx in range(getattr(image, "n_frames", 1)):
+ image.seek(frame_idx)
+ new_frame = image.copy()
+ new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS)
+ frames.append(new_frame)
+
+ # 保存到缓冲区
+ frames[0].save(
+ output_buffer,
+ format="GIF",
+ save_all=True,
+ append_images=frames[1:],
+ optimize=True,
+ duration=image.info.get("duration", 100),
+ loop=image.info.get("loop", 0),
+ )
+ else:
+ # 静态图片,直接缩放保存
+ resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
+ resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
+
+ return output_buffer.getvalue(), original_size, new_size
+
+ except Exception as e:
+ logger.error(f"图片缩放失败: {str(e)}")
+ import traceback
+
+ logger.error(traceback.format_exc())
+ return image_data, None, None
+
+ def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
+ original_b64_data_size = len(base64_data) # 计算原始数据大小
+
+ image_data = base64.b64decode(base64_data)
+
+ # 先尝试转换格式为JPEG
+ image_data = reformat_static_image(image_data)
+ base64_data = base64.b64encode(image_data).decode("utf-8")
+ if len(base64_data) <= target_size:
+ # 如果转换后小于目标大小,直接返回
+ logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB")
+ return base64_data
+
+ # 如果转换后仍然大于目标大小,进行尺寸压缩
+ scale = min(1.0, target_size / len(base64_data))
+ image_data, original_size, new_size = rescale_image(image_data, scale)
+ base64_data = base64.b64encode(image_data).decode("utf-8")
+
+ if original_size and new_size:
+ logger.info(
+ f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n"
+ f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB"
+ )
+
+ return base64_data
+
+ compressed_messages = []
+ for message in messages:
+ if isinstance(message.content, list):
+ # 检查content,如有图片则压缩
+ message_builder = MessageBuilder()
+ for content_item in message.content:
+ if isinstance(content_item, tuple):
+ # 图片,进行压缩
+ message_builder.add_image_content(
+ content_item[0],
+ compress_base64_image(content_item[1], target_size=img_target_size),
+ )
+ else:
+ message_builder.add_text_content(content_item)
+ compressed_messages.append(message_builder.build())
+ else:
+ compressed_messages.append(message)
+
+ return compressed_messages
+
+
+class LLMUsageRecorder:
+ """
+ LLM使用情况记录器
+ """
+
+ def __init__(self):
+ try:
+ # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
+ db.create_tables([LLMUsage], safe=True)
+ # logger.debug("LLMUsage 表已初始化/确保存在。")
+ except Exception as e:
+ logger.error(f"创建 LLMUsage 表失败: {str(e)}")
+
+ def record_usage_to_database(
+ self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
+ ):
+ input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
+ output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
+ total_cost = round(input_cost + output_cost, 6)
+ try:
+ # 使用 Peewee 模型创建记录
+ LLMUsage.create(
+ model_name=model_info.model_identifier,
+ user_id=user_id,
+ request_type=request_type,
+ endpoint=endpoint,
+ prompt_tokens=model_usage.prompt_tokens or 0,
+ completion_tokens=model_usage.completion_tokens or 0,
+ total_tokens=model_usage.total_tokens or 0,
+ cost=total_cost or 0.0,
+ status="success",
+ timestamp=datetime.now(), # Peewee 会处理 DateTimeField
+ )
+ logger.debug(
+ f"Token使用情况 - 模型: {model_usage.model_name}, "
+ f"用户: {user_id}, 类型: {request_type}, "
+ f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
+ f"总计: {model_usage.total_tokens}"
+ )
+ except Exception as e:
+ logger.error(f"记录token使用情况失败: {str(e)}")
+
+llm_usage_recorder = LLMUsageRecorder()
\ No newline at end of file
diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py
index 9aca329e..b6764064 100644
--- a/src/llm_models/utils_model.py
+++ b/src/llm_models/utils_model.py
@@ -1,65 +1,29 @@
-import asyncio
-import json
import re
-from datetime import datetime
-from typing import Tuple, Union, Dict, Any, Callable
-import aiohttp
-from aiohttp.client import ClientResponse
-from src.common.logger import get_logger
-import base64
-from PIL import Image
-import io
-import os
-import copy # 添加copy模块用于深拷贝
-from src.common.database.database import db # 确保 db 被导入用于 create_tables
-from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型
-from src.config.config import global_config
-from src.common.tcp_connector import get_tcp_connector
+import copy
+import asyncio
+
+from enum import Enum
from rich.traceback import install
+from typing import Tuple, List, Dict, Optional, Callable, Any
+
+from src.common.logger import get_logger
+from src.config.config import model_config
+from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig
+from .payload_content.message import MessageBuilder, Message
+from .payload_content.resp_format import RespFormat
+from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType
+from .model_client.base_client import BaseClient, APIResponse, client_registry
+from .utils import compress_messages, llm_usage_recorder
+from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException
install(extra_lines=3)
logger = get_logger("model_utils")
-
-class PayLoadTooLargeError(Exception):
- """自定义异常类,用于处理请求体过大错误"""
-
- def __init__(self, message: str):
- super().__init__(message)
- self.message = message
-
- def __str__(self):
- return "请求体过大,请尝试压缩图片或减少输入内容。"
-
-
-class RequestAbortException(Exception):
- """自定义异常类,用于处理请求中断异常"""
-
- def __init__(self, message: str, response: ClientResponse):
- super().__init__(message)
- self.message = message
- self.response = response
-
- def __str__(self):
- return self.message
-
-
-class PermissionDeniedException(Exception):
- """自定义异常类,用于处理访问拒绝的异常"""
-
- def __init__(self, message: str):
- super().__init__(message)
- self.message = message
-
- def __str__(self):
- return self.message
-
-
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
- 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~",
+ 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
@@ -69,1013 +33,474 @@ error_code_mapping = {
}
-async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]):
- """安全地记录请求体,用于调试日志,不会修改原始payload对象"""
- # 创建payload的深拷贝,避免修改原始对象
- safe_payload = copy.deepcopy(payload)
-
- image_base64: str = request_content.get("image_base64")
- image_format: str = request_content.get("image_format")
- if (
- image_base64
- and safe_payload
- and isinstance(safe_payload, dict)
- and "messages" in safe_payload
- and len(safe_payload["messages"]) > 0
- ):
- if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]:
- content = safe_payload["messages"][0]["content"]
- if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]:
- # 只修改拷贝的对象,用于安全的日志记录
- safe_payload["messages"][0]["content"][1]["image_url"]["url"] = (
- f"data:image/{image_format.lower() if image_format else 'jpeg'};base64,"
- f"{image_base64[:10]}...{image_base64[-10:]}"
- )
- return safe_payload
+class RequestType(Enum):
+ """请求类型枚举"""
+
+ RESPONSE = "response"
+ EMBEDDING = "embedding"
+ AUDIO = "audio"
class LLMRequest:
- # 定义需要转换的模型列表,作为类变量避免重复
- MODELS_NEEDING_TRANSFORMATION = [
- "o1",
- "o1-2024-12-17",
- "o1-mini",
- "o1-mini-2024-09-12",
- "o1-preview",
- "o1-preview-2024-09-12",
- "o1-pro",
- "o1-pro-2025-03-19",
- "o3",
- "o3-2025-04-16",
- "o3-mini",
- "o3-mini-2025-01-31",
- "o4-mini",
- "o4-mini-2025-04-16",
- ]
+ """LLM请求类"""
- def __init__(self, model: dict, **kwargs):
- # 将大写的配置键转换为小写并从config中获取实际值
- logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}")
- logger.debug(f"🔍 [模型初始化] 模型配置: {model}")
- logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}")
-
- try:
- # print(f"model['provider']: {model['provider']}")
- self.api_key = os.environ[f"{model['provider']}_KEY"]
- self.base_url = os.environ[f"{model['provider']}_BASE_URL"]
- logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL")
- except AttributeError as e:
- logger.error(f"原始 model dict 信息:{model}")
- logger.error(f"配置错误:找不到对应的配置项 - {str(e)}")
- raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
- except KeyError:
- logger.warning(
- f"找不到{model['provider']}_KEY或{model['provider']}_BASE_URL环境变量,请检查配置文件或环境变量设置。"
- )
- self.model_name: str = model["name"]
- self.params = kwargs
+ def __init__(self, model_set: TaskConfig, request_type: str = "") -> None:
+ self.task_name = request_type
+ self.model_for_task = model_set
+ self.request_type = request_type
+ self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list}
+ """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整"""
- # 记录配置文件中声明了哪些参数(不管值是什么)
- self.has_enable_thinking = "enable_thinking" in model
- self.has_thinking_budget = "thinking_budget" in model
-
- self.enable_thinking = model.get("enable_thinking", False)
- self.temp = model.get("temp", 0.7)
- self.thinking_budget = model.get("thinking_budget", 4096)
- self.stream = model.get("stream", False)
- self.pri_in = model.get("pri_in", 0)
- self.pri_out = model.get("pri_out", 0)
- self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length)
- # print(f"max_tokens: {self.max_tokens}")
-
- logger.debug("🔍 [模型初始化] 模型参数设置完成:")
- logger.debug(f" - model_name: {self.model_name}")
- logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}")
- logger.debug(f" - enable_thinking: {self.enable_thinking}")
- logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}")
- logger.debug(f" - thinking_budget: {self.thinking_budget}")
- logger.debug(f" - temp: {self.temp}")
- logger.debug(f" - stream: {self.stream}")
- logger.debug(f" - max_tokens: {self.max_tokens}")
- logger.debug(f" - base_url: {self.base_url}")
+ self.pri_in = 0
+ self.pri_out = 0
- # 获取数据库实例
- self._init_database()
-
- # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
- self.request_type = kwargs.pop("request_type", "default")
- logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}")
-
- @staticmethod
- def _init_database():
- """初始化数据库集合"""
- try:
- # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
- db.create_tables([LLMUsage], safe=True)
- # logger.debug("LLMUsage 表已初始化/确保存在。")
- except Exception as e:
- logger.error(f"创建 LLMUsage 表失败: {str(e)}")
-
- def _record_usage(
+ async def generate_response_for_image(
self,
- prompt_tokens: int,
- completion_tokens: int,
- total_tokens: int,
- user_id: str = "system",
- request_type: str = None,
- endpoint: str = "/chat/completions",
- ):
- """记录模型使用情况到数据库
- Args:
- prompt_tokens: 输入token数
- completion_tokens: 输出token数
- total_tokens: 总token数
- user_id: 用户ID,默认为system
- request_type: 请求类型
- endpoint: API端点
+ prompt: str,
+ image_base64: str,
+ image_format: str,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
"""
- # 如果 request_type 为 None,则使用实例变量中的值
- if request_type is None:
- request_type = self.request_type
-
- try:
- # 使用 Peewee 模型创建记录
- LLMUsage.create(
- model_name=self.model_name,
- user_id=user_id,
- request_type=request_type,
- endpoint=endpoint,
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- cost=self._calculate_cost(prompt_tokens, completion_tokens),
- status="success",
- timestamp=datetime.now(), # Peewee 会处理 DateTimeField
- )
- logger.debug(
- f"Token使用情况 - 模型: {self.model_name}, "
- f"用户: {user_id}, 类型: {request_type}, "
- f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
- f"总计: {total_tokens}"
- )
- except Exception as e:
- logger.error(f"记录token使用情况失败: {str(e)}")
-
- def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float:
- """计算API调用成本
- 使用模型的pri_in和pri_out价格计算输入和输出的成本
-
+ 为图像生成响应
Args:
- prompt_tokens: 输入token数量
- completion_tokens: 输出token数量
-
+ prompt (str): 提示词
+ image_base64 (str): 图像的Base64编码字符串
+ image_format (str): 图像格式(如 'png', 'jpeg' 等)
Returns:
- float: 总成本(元)
+ (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
"""
- # 使用模型的pri_in和pri_out计算成本
- input_cost = (prompt_tokens / 1000000) * self.pri_in
- output_cost = (completion_tokens / 1000000) * self.pri_out
- return round(input_cost + output_cost, 6)
+ # 模型选择
+ model_info, api_provider, client = self._select_model()
- async def _prepare_request(
- self,
- endpoint: str,
- prompt: str = None,
- image_base64: str = None,
- image_format: str = None,
- file_bytes: bytes = None,
- file_format: str = None,
- payload: dict = None,
- retry_policy: dict = None,
- ) -> Dict[str, Any]:
- """配置请求参数
+ # 请求体构建
+ message_builder = MessageBuilder()
+ message_builder.add_text_content(prompt)
+ message_builder.add_image_content(
+ image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats()
+ )
+ messages = [message_builder.build()]
+
+ # 请求并处理返回值
+ response = await self._execute_request(
+ api_provider=api_provider,
+ client=client,
+ request_type=RequestType.RESPONSE,
+ model_info=model_info,
+ message_list=messages,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ )
+ content = response.content or ""
+ reasoning_content = response.reasoning_content or ""
+ tool_calls = response.tool_calls
+ # 从内容中提取标签的推理内容(向后兼容)
+ if not reasoning_content and content:
+ content, extracted_reasoning = self._extract_reasoning(content)
+ reasoning_content = extracted_reasoning
+ if usage := response.usage:
+ llm_usage_recorder.record_usage_to_database(
+ model_info=model_info,
+ model_usage=usage,
+ user_id="system",
+ request_type=self.request_type,
+ endpoint="/chat/completions",
+ )
+ return content, (reasoning_content, model_info.name, tool_calls)
+
+ async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]:
+ """
+ 为语音生成响应
Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- file_bytes: 文件的二进制数据
- file_format: 文件格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- request_type: 请求类型
+ voice_base64 (str): 语音的Base64编码字符串
+ Returns:
+ (Optional[str]): 生成的文本描述或None
"""
+ # 模型选择
+ model_info, api_provider, client = self._select_model()
- # 合并重试策略
- default_retry = {
- "max_retries": 3,
- "base_wait": 10,
- "retry_codes": [429, 413, 500, 503],
- "abort_codes": [400, 401, 402, 403],
- }
- policy = {**default_retry, **(retry_policy or {})}
+ # 请求并处理返回值
+ response = await self._execute_request(
+ api_provider=api_provider,
+ client=client,
+ request_type=RequestType.AUDIO,
+ model_info=model_info,
+ audio_base64=voice_base64,
+ )
+ return response.content or None
- api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
+ async def generate_response_async(
+ self,
+ prompt: str,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]:
+ """
+ 异步生成响应
+ Args:
+ prompt (str): 提示词
+ temperature (float, optional): 温度参数
+ max_tokens (int, optional): 最大token数
+ Returns:
+ (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表
+ """
+ # 请求体构建
+ message_builder = MessageBuilder()
+ message_builder.add_text_content(prompt)
+ messages = [message_builder.build()]
+ tool_built = self._build_tool_options(tools)
+ # 模型选择
+ model_info, api_provider, client = self._select_model()
- stream_mode = self.stream
+ # 请求并处理返回值
+ response = await self._execute_request(
+ api_provider=api_provider,
+ client=client,
+ request_type=RequestType.RESPONSE,
+ model_info=model_info,
+ message_list=messages,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ tool_options=tool_built,
+ )
+ content = response.content
+ reasoning_content = response.reasoning_content or ""
+ tool_calls = response.tool_calls
+ # 从内容中提取标签的推理内容(向后兼容)
+ if not reasoning_content and content:
+ content, extracted_reasoning = self._extract_reasoning(content)
+ reasoning_content = extracted_reasoning
+ if usage := response.usage:
+ llm_usage_recorder.record_usage_to_database(
+ model_info=model_info,
+ model_usage=usage,
+ user_id="system",
+ request_type=self.request_type,
+ endpoint="/chat/completions",
+ )
+ if not content:
+ logger.warning("生成的响应为空")
+ content = "生成的响应为空,请检查模型配置或输入内容是否正确"
- # 构建请求体
- if image_base64:
- payload = await self._build_payload(prompt, image_base64, image_format)
- elif file_bytes:
- payload = await self._build_formdata_payload(file_bytes, file_format)
- elif payload is None:
- payload = await self._build_payload(prompt)
+ return content, (reasoning_content, model_info.name, tool_calls)
- if not file_bytes:
- if stream_mode:
- payload["stream"] = stream_mode
+ async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
+ """获取嵌入向量
+ Args:
+ embedding_input (str): 获取嵌入的目标
+ Returns:
+ (Tuple[List[float], str]): (嵌入向量,使用的模型名称)
+ """
+ # 无需构建消息体,直接使用输入文本
+ model_info, api_provider, client = self._select_model()
- if self.temp != 0.7:
- payload["temperature"] = self.temp
+ # 请求并处理返回值
+ response = await self._execute_request(
+ api_provider=api_provider,
+ client=client,
+ request_type=RequestType.EMBEDDING,
+ model_info=model_info,
+ embedding_input=embedding_input,
+ )
- # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
- if self.has_enable_thinking:
- payload["enable_thinking"] = self.enable_thinking
+ embedding = response.embedding
- # 添加thinking_budget参数(只有配置文件中声明了才添加)
- if self.has_thinking_budget:
- payload["thinking_budget"] = self.thinking_budget
+ if usage := response.usage:
+ llm_usage_recorder.record_usage_to_database(
+ model_info=model_info,
+ model_usage=usage,
+ user_id="system",
+ request_type=self.request_type,
+ endpoint="/embeddings",
+ )
- if self.max_tokens:
- payload["max_tokens"] = self.max_tokens
+ if not embedding:
+ raise RuntimeError("获取embedding失败")
- # if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- # payload["max_tokens"] = global_config.model.model_max_output_length
- # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
- payload["max_completion_tokens"] = payload.pop("max_tokens")
+ return embedding, model_info.name
- return {
- "policy": policy,
- "payload": payload,
- "api_url": api_url,
- "stream_mode": stream_mode,
- "image_base64": image_base64, # 保留必要的exception处理所需的原始数据
- "image_format": image_format,
- "file_bytes": file_bytes,
- "file_format": file_format,
- "prompt": prompt,
- }
+ def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
+ """
+ 根据总tokens和惩罚值选择的模型
+ """
+ least_used_model_name = min(
+ self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300
+ )
+ model_info = model_config.get_model_info(least_used_model_name)
+ api_provider = model_config.get_provider(model_info.api_provider)
+ client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider))
+ logger.debug(f"选择请求模型: {model_info.name}")
+ return model_info, api_provider, client
async def _execute_request(
self,
- endpoint: str,
- prompt: str = None,
- image_base64: str = None,
- image_format: str = None,
- file_bytes: bytes = None,
- file_format: str = None,
- payload: dict = None,
- retry_policy: dict = None,
- response_handler: Callable = None,
- user_id: str = "system",
- request_type: str = None,
- ):
- """统一请求执行入口
- Args:
- endpoint: API端点路径 (如 "chat/completions")
- prompt: prompt文本
- image_base64: 图片的base64编码
- image_format: 图片格式
- file_bytes: 文件的二进制数据
- file_format: 文件格式
- payload: 请求体数据
- retry_policy: 自定义重试策略
- response_handler: 自定义响应处理器
- user_id: 用户ID
- request_type: 请求类型
+ api_provider: APIProvider,
+ client: BaseClient,
+ request_type: RequestType,
+ model_info: ModelInfo,
+ message_list: List[Message] | None = None,
+ tool_options: list[ToolOption] | None = None,
+ response_format: RespFormat | None = None,
+ stream_response_handler: Optional[Callable] = None,
+ async_response_parser: Optional[Callable] = None,
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+ embedding_input: str = "",
+ audio_base64: str = "",
+ ) -> APIResponse:
"""
- # 获取请求配置
- request_content = await self._prepare_request(
- endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy
- )
- if request_type is None:
- request_type = self.request_type
- for retry in range(request_content["policy"]["max_retries"]):
+ 实际执行请求的方法
+
+ 包含了重试和异常处理逻辑
+ """
+ retry_remain = api_provider.max_retry
+ compressed_messages: Optional[List[Message]] = None
+ while retry_remain > 0:
try:
- # 使用上下文管理器处理会话
- if file_bytes:
- headers = await self._build_headers(is_formdata=True)
- else:
- headers = await self._build_headers(is_formdata=False)
- # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
- if request_content["stream_mode"]:
- headers["Accept"] = "text/event-stream"
-
- # 添加请求发送前的调试信息
- logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求")
- logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}")
- logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}")
-
- if not file_bytes:
- # 安全地记录请求体(隐藏敏感信息)
- safe_payload = await _safely_record(request_content, request_content["payload"])
- logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
- else:
- logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}")
-
- async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session:
- post_kwargs = {"headers": headers}
- # form-data数据上传方式不同
- if file_bytes:
- post_kwargs["data"] = request_content["payload"]
- else:
- post_kwargs["json"] = request_content["payload"]
-
- async with session.post(request_content["api_url"], **post_kwargs) as response:
- handled_result = await self._handle_response(
- response, request_content, retry, response_handler, user_id, request_type, endpoint
- )
- return handled_result
-
- except Exception as e:
- handled_payload, count_delta = await self._handle_exception(e, retry, request_content)
- retry += count_delta # 降级不计入重试次数
- if handled_payload:
- # 如果降级成功,重新构建请求体
- request_content["payload"] = handled_payload
- continue
-
- logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败")
- raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败")
-
- async def _handle_response(
- self,
- response: ClientResponse,
- request_content: Dict[str, Any],
- retry_count: int,
- response_handler: Callable,
- user_id,
- request_type,
- endpoint,
- ):
- policy = request_content["policy"]
- stream_mode = request_content["stream_mode"]
- if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]:
- await self._handle_error_response(response, retry_count, policy)
- return None
-
- response.raise_for_status()
- result = {}
- if stream_mode:
- # 将流式输出转化为非流式输出
- result = await self._handle_stream_output(response)
- else:
- result = await response.json()
- return (
- response_handler(result)
- if response_handler
- else self._default_response_handler(result, user_id, request_type, endpoint)
- )
-
- async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]:
- flag_delta_content_finished = False
- accumulated_content = ""
- usage = None # 初始化usage变量,避免未定义错误
- reasoning_content = ""
- content = ""
- tool_calls = None # 初始化工具调用变量
-
- async for line_bytes in response.content:
- try:
- line = line_bytes.decode("utf-8").strip()
- if not line:
- continue
- if line.startswith("data:"):
- data_str = line[5:].strip()
- if data_str == "[DONE]":
- break
- try:
- chunk = json.loads(data_str)
- if flag_delta_content_finished:
- chunk_usage = chunk.get("usage", None)
- if chunk_usage:
- usage = chunk_usage # 获取token用量
- else:
- delta = chunk["choices"][0]["delta"]
- delta_content = delta.get("content")
- if delta_content is None:
- delta_content = ""
- accumulated_content += delta_content
-
- # 提取工具调用信息
- if "tool_calls" in delta:
- if tool_calls is None:
- tool_calls = delta["tool_calls"]
- else:
- # 合并工具调用信息
- tool_calls.extend(delta["tool_calls"])
-
- # 检测流式输出文本是否结束
- finish_reason = chunk["choices"][0].get("finish_reason")
- if delta.get("reasoning_content", None):
- reasoning_content += delta["reasoning_content"]
- if finish_reason == "stop" or finish_reason == "tool_calls":
- chunk_usage = chunk.get("usage", None)
- if chunk_usage:
- usage = chunk_usage
- break
- # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk
- flag_delta_content_finished = True
- except Exception as e:
- logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}")
- except Exception as e:
- if isinstance(e, GeneratorExit):
- log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..."
- else:
- log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}"
- logger.warning(log_content)
- # 确保资源被正确清理
- try:
- await response.release()
- except Exception as cleanup_error:
- logger.error(f"清理资源时发生错误: {cleanup_error}")
- # 返回已经累积的内容
- content = accumulated_content
- if not content:
- content = accumulated_content
- think_match = re.search(r"(.*?)", content, re.DOTALL)
- if think_match:
- reasoning_content = think_match.group(1).strip()
- content = re.sub(r".*?", "", content, flags=re.DOTALL).strip()
-
- # 构建消息对象
- message = {
- "content": content,
- "reasoning_content": reasoning_content,
- }
-
- # 如果有工具调用,添加到消息中
- if tool_calls:
- message["tool_calls"] = tool_calls
-
- result = {
- "choices": [{"message": message}],
- "usage": usage,
- }
- return result
-
- async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]):
- if response.status in policy["retry_codes"]:
- wait_time = policy["base_wait"] * (2**retry_count)
- logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试")
- if response.status == 413:
- logger.warning("请求体过大,尝试压缩...")
- raise PayLoadTooLargeError("请求体过大")
- elif response.status in [500, 503]:
- logger.error(
- f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
- )
- raise RuntimeError("服务器负载过高,模型回复失败QAQ")
- else:
- logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...")
- raise RuntimeError("请求限制(429)")
- elif response.status in policy["abort_codes"]:
- # 特别处理400错误,添加详细调试信息
- if response.status == 400:
- logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断")
- logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}")
- logger.error(f"🔍 [调试信息] API地址: {self.base_url}")
- logger.error("🔍 [调试信息] 模型配置参数:")
- logger.error(f" - enable_thinking: {self.enable_thinking}")
- logger.error(f" - temp: {self.temp}")
- logger.error(f" - thinking_budget: {self.thinking_budget}")
- logger.error(f" - stream: {self.stream}")
- logger.error(f" - max_tokens: {self.max_tokens}")
- logger.error(f" - pri_in: {self.pri_in}")
- logger.error(f" - pri_out: {self.pri_out}")
- logger.error(f"🔍 [调试信息] 原始params: {self.params}")
-
- # 尝试获取服务器返回的详细错误信息
- try:
- error_text = await response.text()
- logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}")
-
- try:
- error_json = json.loads(error_text)
- logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}")
- except json.JSONDecodeError:
- logger.error("🔍 [调试信息] 错误响应不是有效的JSON格式")
- except Exception as e:
- logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}")
-
- raise RequestAbortException("参数错误,请检查调试信息", response)
- elif response.status != 403:
- raise RequestAbortException("请求出现错误,中断处理", response)
- else:
- raise PermissionDeniedException("模型禁止访问")
-
- async def _handle_exception(
- self, exception, retry_count: int, request_content: Dict[str, Any]
- ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]:
- policy = request_content["policy"]
- payload = request_content["payload"]
- wait_time = policy["base_wait"] * (2**retry_count)
- keep_request = False
- if retry_count < policy["max_retries"] - 1:
- keep_request = True
- if isinstance(exception, RequestAbortException):
- response = exception.response
- logger.error(
- f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}"
- )
-
- # 如果是400错误,额外输出请求体信息用于调试
- if response.status == 400:
- logger.error("🔍 [异常调试] 400错误 - 请求体调试信息:")
- try:
- safe_payload = await _safely_record(request_content, payload)
- logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}")
- except Exception as debug_error:
- logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}")
- logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}")
- if isinstance(payload, dict):
- logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}")
-
- # print(request_content)
- # print(response)
- # 尝试获取并记录服务器返回的详细错误信息
- try:
- error_json = await response.json()
- if error_json and isinstance(error_json, list) and len(error_json) > 0:
- # 处理多个错误的情况
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj: dict = error_item["error"]
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(
- f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- # 处理单个错误对象的情况
- error_obj = error_json.get("error", {})
- error_code = error_obj.get("code")
- error_message = error_obj.get("message")
- error_status = error_obj.get("status")
- logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}")
- else:
- # 记录原始错误响应内容
- logger.error(f"服务器错误响应: {error_json}")
- except Exception as e:
- logger.warning(f"无法解析服务器错误响应: {str(e)}")
- raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
-
- elif isinstance(exception, PermissionDeniedException):
- # 只针对硅基流动的V3和R1进行降级处理
- if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/":
- old_model_name = self.model_name
- self.model_name = self.model_name[4:] # 移除"Pro/"前缀
- logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}")
-
- # 对全局配置进行更新
- if global_config.model.replyer_2.get("name") == old_model_name:
- global_config.model.replyer_2["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}")
- if global_config.model.replyer_1.get("name") == old_model_name:
- global_config.model.replyer_1["name"] = self.model_name
- logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}")
-
- if payload and "model" in payload:
- payload["model"] = self.model_name
-
- await asyncio.sleep(wait_time)
- return payload, -1
- raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}")
-
- elif isinstance(exception, PayLoadTooLargeError):
- if keep_request:
- image_base64 = request_content["image_base64"]
- compressed_image_base64 = compress_base64_image_by_scale(image_base64)
- new_payload = await self._build_payload(
- request_content["prompt"], compressed_image_base64, request_content["image_format"]
- )
- return new_payload, 0
- else:
- return None, 0
-
- elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError):
- if keep_request:
- logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}")
- await asyncio.sleep(wait_time)
- return None, 0
- else:
- logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}")
- raise RuntimeError(f"网络请求失败: {str(exception)}")
-
- elif isinstance(exception, aiohttp.ClientResponseError):
- # 处理aiohttp抛出的,除了policy中的status的响应错误
- if keep_request:
- logger.error(
- f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}"
- )
- try:
- error_text = await exception.response.text()
- error_json = json.loads(error_text)
- if isinstance(error_json, list) and len(error_json) > 0:
- # 处理多个错误的情况
- for error_item in error_json:
- if "error" in error_item and isinstance(error_item["error"], dict):
- error_obj = error_item["error"]
- logger.error(
- f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
- f"状态={error_obj.get('status')}, "
- f"消息={error_obj.get('message')}"
- )
- elif isinstance(error_json, dict) and "error" in error_json:
- error_obj = error_json.get("error", {})
- logger.error(
- f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, "
- f"状态={error_obj.get('status')}, "
- f"消息={error_obj.get('message')}"
- )
- else:
- logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}")
- except (json.JSONDecodeError, TypeError) as json_err:
- logger.warning(
- f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}"
+ if request_type == RequestType.RESPONSE:
+ assert message_list is not None, "message_list cannot be None for response requests"
+ return await client.get_response(
+ model_info=model_info,
+ message_list=(compressed_messages or message_list),
+ tool_options=tool_options,
+ max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
+ temperature=self.model_for_task.temperature if temperature is None else temperature,
+ response_format=response_format,
+ stream_response_handler=stream_response_handler,
+ async_response_parser=async_response_parser,
+ extra_params=model_info.extra_params,
)
- except Exception as parse_err:
- logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}")
+ elif request_type == RequestType.EMBEDDING:
+ assert embedding_input, "embedding_input cannot be empty for embedding requests"
+ return await client.get_embedding(
+ model_info=model_info,
+ embedding_input=embedding_input,
+ extra_params=model_info.extra_params,
+ )
+ elif request_type == RequestType.AUDIO:
+ assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
+ return await client.get_audio_transcriptions(
+ model_info=model_info,
+ audio_base64=audio_base64,
+ extra_params=model_info.extra_params,
+ )
+ except Exception as e:
+ logger.debug(f"请求失败: {str(e)}")
+ # 处理异常
+ total_tokens, penalty = self.model_usage[model_info.name]
+ self.model_usage[model_info.name] = (total_tokens, penalty + 1)
- await asyncio.sleep(wait_time)
- return None, 0
- else:
- logger.critical(
- f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}"
- )
- # 安全地检查和记录请求详情
- handled_payload = await _safely_record(request_content, payload)
- logger.critical(
- f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
- )
- raise RuntimeError(
- f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}"
+ wait_interval, compressed_messages = self._default_exception_handler(
+ e,
+ self.task_name,
+ model_name=model_info.name,
+ remain_try=retry_remain,
+ retry_interval=api_provider.retry_interval,
+ messages=(message_list, compressed_messages is not None) if message_list else None,
)
- else:
- if keep_request:
- logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}")
- await asyncio.sleep(wait_time)
- return None, 0
- else:
- logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}")
- # 安全地检查和记录请求详情
- handled_payload = await _safely_record(request_content, payload)
- logger.critical(
- f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}"
- )
- raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}")
+ if wait_interval == -1:
+ retry_remain = 0 # 不再重试
+ elif wait_interval > 0:
+ logger.info(f"等待 {wait_interval} 秒后重试...")
+ await asyncio.sleep(wait_interval)
+ finally:
+ # 放在finally防止死循环
+ retry_remain -= 1
+ logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次")
+ raise RuntimeError("请求失败,已达到最大重试次数")
- async def _transform_parameters(self, params: dict) -> dict:
+ def _default_exception_handler(
+ self,
+ e: Exception,
+ task_name: str,
+ model_name: str,
+ remain_try: int,
+ retry_interval: int = 10,
+ messages: Tuple[List[Message], bool] | None = None,
+ ) -> Tuple[int, List[Message] | None]:
"""
- 根据模型名称转换参数:
- - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数,
- 并将 'max_tokens' 重命名为 'max_completion_tokens'
+ 默认异常处理函数
+ Args:
+ e (Exception): 异常对象
+ task_name (str): 任务名称
+ model_name (str): 模型名称
+ remain_try (int): 剩余尝试次数
+ retry_interval (int): 重试间隔
+ messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
+ Returns:
+ (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
"""
- # 复制一份参数,避免直接修改原始数据
- new_params = dict(params)
-
- logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换")
- logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}")
- logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}")
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION:
- logger.debug("🔍 [参数转换] 检测到CoT模型,开始参数转换")
- # 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度
- if "temperature" in new_params and new_params["temperature"] == 0.7:
- removed_temp = new_params.pop("temperature")
- logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}")
- # 如果存在 'max_tokens',则重命名为 'max_completion_tokens'
- if "max_tokens" in new_params:
- old_value = new_params["max_tokens"]
- new_params["max_completion_tokens"] = new_params.pop("max_tokens")
- logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})")
+ if isinstance(e, NetworkConnectionError): # 网络连接错误
+ return self._check_retry(
+ remain_try,
+ retry_interval,
+ can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试",
+ cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确",
+ )
+ elif isinstance(e, ReqAbortException):
+ logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}")
+ return -1, None # 不再重试请求该模型
+ elif isinstance(e, RespNotOkException):
+ return self._handle_resp_not_ok(
+ e,
+ task_name,
+ model_name,
+ remain_try,
+ retry_interval,
+ messages,
+ )
+ elif isinstance(e, RespParseException):
+ # 响应解析错误
+ logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}")
+ logger.debug(f"附加内容: {str(e.ext_info)}")
+ return -1, None # 不再重试请求该模型
else:
- logger.debug("🔍 [参数转换] 非CoT模型,无需参数转换")
-
- logger.debug(f"🔍 [参数转换] 转换前参数: {params}")
- logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}")
- return new_params
+ logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}")
+ return -1, None # 不再重试请求该模型
- async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData:
- """构建form-data请求体"""
- # 目前只适配了音频文件
- # 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑
- data = aiohttp.FormData()
- content_type_list = {
- "wav": "audio/wav",
- "mp3": "audio/mpeg",
- "ogg": "audio/ogg",
- "flac": "audio/flac",
- "aac": "audio/aac",
- }
+ def _check_retry(
+ self,
+ remain_try: int,
+ retry_interval: int,
+ can_retry_msg: str,
+ cannot_retry_msg: str,
+ can_retry_callable: Callable | None = None,
+ **kwargs,
+ ) -> Tuple[int, List[Message] | None]:
+ """辅助函数:检查是否可以重试
+ Args:
+ remain_try (int): 剩余尝试次数
+ retry_interval (int): 重试间隔
+ can_retry_msg (str): 可以重试时的提示信息
+ cannot_retry_msg (str): 不可以重试时的提示信息
+ can_retry_callable (Callable | None): 可以重试时调用的函数(如果有)
+ **kwargs: 其他参数
- content_type = content_type_list.get(file_format)
- if not content_type:
- logger.warning(f"暂不支持的文件类型: {file_format}")
-
- data.add_field(
- "file",
- io.BytesIO(file_bytes),
- filename=f"file.{file_format}",
- content_type=f"{content_type}", # 根据实际文件类型设置
- )
- data.add_field("model", self.model_name)
- return data
-
- async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict:
- """构建请求体"""
- # 复制一份参数,避免直接修改 self.params
- logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体")
- logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}")
-
- params_copy = await self._transform_parameters(self.params)
- logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}")
-
- if image_base64:
- messages = [
- {
- "role": "user",
- "content": [
- {"type": "text", "text": prompt},
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"},
- },
- ],
- }
- ]
- else:
- messages = [{"role": "user", "content": prompt}]
-
- payload = {
- "model": self.model_name,
- "messages": messages,
- **params_copy,
- }
-
- logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}")
-
- # 添加temp参数(如果不是默认值0.7)
- if self.temp != 0.7:
- payload["temperature"] = self.temp
- logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}")
-
- # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false)
- if self.has_enable_thinking:
- payload["enable_thinking"] = self.enable_thinking
- logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}")
-
- # 添加thinking_budget参数(只有配置文件中声明了才添加)
- if self.has_thinking_budget:
- payload["thinking_budget"] = self.thinking_budget
- logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}")
-
- if self.max_tokens:
- payload["max_tokens"] = self.max_tokens
- logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}")
-
- # if "max_tokens" not in payload and "max_completion_tokens" not in payload:
- # payload["max_tokens"] = global_config.model.model_max_output_length
- # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
- if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload:
- old_value = payload["max_tokens"]
- payload["max_completion_tokens"] = payload.pop("max_tokens")
- logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})")
-
- logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}")
- return payload
-
- def _default_response_handler(
- self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
- ) -> Tuple:
- """默认响应解析"""
- if "choices" in result and result["choices"]:
- message = result["choices"][0]["message"]
- content = message.get("content", "")
- content, reasoning = self._extract_reasoning(content)
- reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = message.get("reasoning_content", "")
- if not reasoning_content:
- reasoning_content = reasoning
-
- # 提取工具调用信息
- tool_calls = message.get("tool_calls", None)
-
- # 记录token使用情况
- usage = result.get("usage", {})
- if usage:
- prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
- total_tokens = usage.get("total_tokens", 0)
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id=user_id,
- request_type=request_type if request_type is not None else self.request_type,
- endpoint=endpoint,
- )
-
- # 只有当tool_calls存在且不为空时才返回
- if tool_calls:
- logger.debug(f"检测到工具调用: {tool_calls}")
- return content, reasoning_content, tool_calls
+ Returns:
+ (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
+ """
+ if remain_try > 0:
+ # 还有重试机会
+ logger.warning(f"{can_retry_msg}")
+ if can_retry_callable is not None:
+ return retry_interval, can_retry_callable(**kwargs)
else:
- return content, reasoning_content
- elif "text" in result and result["text"]:
- return result["text"]
- return "没有返回结果", ""
+ return retry_interval, None
+ else:
+ # 达到最大重试次数
+ logger.warning(f"{cannot_retry_msg}")
+ return -1, None # 不再重试请求该模型
+
+ def _handle_resp_not_ok(
+ self,
+ e: RespNotOkException,
+ task_name: str,
+ model_name: str,
+ remain_try: int,
+ retry_interval: int = 10,
+ messages: tuple[list[Message], bool] | None = None,
+ ):
+ """
+ 处理响应错误异常
+ Args:
+ e (RespNotOkException): 响应错误异常对象
+ task_name (str): 任务名称
+ model_name (str): 模型名称
+ remain_try (int): 剩余尝试次数
+ retry_interval (int): 重试间隔
+ messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过)
+ Returns:
+ (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息))
+ """
+ # 响应错误
+ if e.status_code in [400, 401, 402, 403, 404]:
+ # 客户端错误
+ logger.warning(
+ f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
+ )
+ return -1, None # 不再重试请求该模型
+ elif e.status_code == 413:
+ if messages and not messages[1]:
+ # 消息列表不为空且未压缩,尝试压缩消息
+ return self._check_retry(
+ remain_try,
+ 0,
+ can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试",
+ cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求",
+ can_retry_callable=compress_messages,
+ messages=messages[0],
+ )
+ # 没有消息可压缩
+ logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。")
+ return -1, None
+ elif e.status_code == 429:
+ # 请求过于频繁
+ return self._check_retry(
+ remain_try,
+ retry_interval,
+ can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试",
+ cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求",
+ )
+ elif e.status_code >= 500:
+ # 服务器错误
+ return self._check_retry(
+ remain_try,
+ retry_interval,
+ can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试",
+ cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试",
+ )
+ else:
+ # 未知错误
+ logger.warning(
+ f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}"
+ )
+ return -1, None
+
+ def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]:
+ # sourcery skip: extract-method
+ """构建工具选项列表"""
+ if not tools:
+ return None
+ tool_options: List[ToolOption] = []
+ for tool in tools:
+ tool_legal = True
+ tool_options_builder = ToolOptionBuilder()
+ tool_options_builder.set_name(tool.get("name", ""))
+ tool_options_builder.set_description(tool.get("description", ""))
+ parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", [])
+ for param in parameters:
+ try:
+ assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组"
+ assert isinstance(param[0], str), "参数名称必须是字符串"
+ assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举"
+ assert isinstance(param[2], str), "参数描述必须是字符串"
+ assert isinstance(param[3], bool), "参数是否必填必须是布尔值"
+ assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None"
+ tool_options_builder.add_param(
+ name=param[0],
+ param_type=param[1],
+ description=param[2],
+ required=param[3],
+ enum_values=param[4],
+ )
+ except AssertionError as ae:
+ tool_legal = False
+ logger.error(f"{param[0]} 参数定义错误: {str(ae)}")
+ except Exception as e:
+ tool_legal = False
+ logger.error(f"构建工具参数失败: {str(e)}")
+ if tool_legal:
+ tool_options.append(tool_options_builder.build())
+ return tool_options or None
@staticmethod
def _extract_reasoning(content: str) -> Tuple[str, str]:
- """CoT思维链提取"""
+ """CoT思维链提取,向后兼容"""
match = re.search(r"(?:)?(.*?)", content, re.DOTALL)
content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip()
- if match:
- reasoning = match.group(1).strip()
- else:
- reasoning = ""
+ reasoning = match[1].strip() if match else ""
return content, reasoning
-
- async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict:
- """构建请求头"""
- if no_key:
- if is_formdata:
- return {"Authorization": "Bearer **********"}
- return {"Authorization": "Bearer **********", "Content-Type": "application/json"}
- else:
- if is_formdata:
- return {"Authorization": f"Bearer {self.api_key}"}
- return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
- # 防止小朋友们截图自己的key
-
- async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple:
- """根据输入的提示和图片生成模型的异步响应"""
-
- response = await self._execute_request(
- endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format
- )
- # 根据返回值的长度决定怎么处理
- if len(response) == 3:
- content, reasoning_content, tool_calls = response
- return content, reasoning_content, tool_calls
- else:
- content, reasoning_content = response
- return content, reasoning_content
-
- async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple:
- """根据输入的语音文件生成模型的异步响应"""
- response = await self._execute_request(
- endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav"
- )
- return response
-
- async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]:
- """异步方式根据输入的提示生成模型的响应"""
- # 构建请求体,不硬编码max_tokens
- data = {
- "model": self.model_name,
- "messages": [{"role": "user", "content": prompt}],
- **self.params,
- **kwargs,
- }
-
- response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
- # 原样返回响应,不做处理
-
- if len(response) == 3:
- content, reasoning_content, tool_calls = response
- return content, (reasoning_content, self.model_name, tool_calls)
- else:
- content, reasoning_content = response
- return content, (reasoning_content, self.model_name)
-
- async def get_embedding(self, text: str) -> Union[list, None]:
- """异步方法:获取文本的embedding向量
-
- Args:
- text: 需要获取embedding的文本
-
- Returns:
- list: embedding向量,如果失败则返回None
- """
-
- if len(text) < 1:
- logger.debug("该消息没有长度,不再发送获取embedding向量的请求")
- return None
-
- def embedding_handler(result):
- """处理响应"""
- if "data" in result and len(result["data"]) > 0:
- # 提取 token 使用信息
- usage = result.get("usage", {})
- if usage:
- prompt_tokens = usage.get("prompt_tokens", 0)
- completion_tokens = usage.get("completion_tokens", 0)
- total_tokens = usage.get("total_tokens", 0)
- # 记录 token 使用情况
- self._record_usage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- user_id="system", # 可以根据需要修改 user_id
- # request_type="embedding", # 请求类型为 embedding
- request_type=self.request_type, # 请求类型为 text
- endpoint="/embeddings", # API 端点
- )
- return result["data"][0].get("embedding", None)
- return result["data"][0].get("embedding", None)
- return None
-
- embedding = await self._execute_request(
- endpoint="/embeddings",
- prompt=text,
- payload={"model": self.model_name, "input": text, "encoding_format": "float"},
- retry_policy={"max_retries": 2, "base_wait": 6},
- response_handler=embedding_handler,
- )
- return embedding
-
-
-def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
- """压缩base64格式的图片到指定大小
- Args:
- base64_data: base64编码的图片数据
- target_size: 目标文件大小(字节),默认0.8MB
- Returns:
- str: 压缩后的base64图片数据
- """
- try:
- # 将base64转换为字节数据
- # 确保base64字符串只包含ASCII字符
- if isinstance(base64_data, str):
- base64_data = base64_data.encode("ascii", errors="ignore").decode("ascii")
- image_data = base64.b64decode(base64_data)
-
- # 如果已经小于目标大小,直接返回原图
- if len(image_data) <= 2 * 1024 * 1024:
- return base64_data
-
- # 将字节数据转换为图片对象
- img = Image.open(io.BytesIO(image_data))
-
- # 获取原始尺寸
- original_width, original_height = img.size
-
- # 计算缩放比例
- scale = min(1.0, (target_size / len(image_data)) ** 0.5)
-
- # 计算新的尺寸
- new_width = int(original_width * scale)
- new_height = int(original_height * scale)
-
- # 创建内存缓冲区
- output_buffer = io.BytesIO()
-
- # 如果是GIF,处理所有帧
- if getattr(img, "is_animated", False):
- frames = []
- for frame_idx in range(img.n_frames):
- img.seek(frame_idx)
- new_frame = img.copy()
- new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折
- frames.append(new_frame)
-
- # 保存到缓冲区
- frames[0].save(
- output_buffer,
- format="GIF",
- save_all=True,
- append_images=frames[1:],
- optimize=True,
- duration=img.info.get("duration", 100),
- loop=img.info.get("loop", 0),
- )
- else:
- # 处理静态图片
- resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
-
- # 保存到缓冲区,保持原始格式
- if img.format == "PNG" and img.mode in ("RGBA", "LA"):
- resized_img.save(output_buffer, format="PNG", optimize=True)
- else:
- resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True)
-
- # 获取压缩后的数据并转换为base64
- compressed_data = output_buffer.getvalue()
- logger.info(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}")
- logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB")
-
- return base64.b64encode(compressed_data).decode("utf-8")
-
- except Exception as e:
- logger.error(f"压缩图片失败: {str(e)}")
- import traceback
-
- logger.error(traceback.format_exc())
- return base64_data
diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py
index 867ba8be..5a1f5808 100644
--- a/src/mais4u/mai_think.py
+++ b/src/mais4u/mai_think.py
@@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager
import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.logger import get_logger
+
logger = get_logger(__name__)
+
def init_prompt():
Prompt(
"""
@@ -32,10 +34,8 @@ def init_prompt():
)
-
-
class MaiThinking:
- def __init__(self,chat_id):
+ def __init__(self, chat_id):
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform
@@ -44,11 +44,11 @@ class MaiThinking:
self.is_group = True
else:
self.is_group = False
-
+
self.s4u_message_processor = S4UMessageProcessor()
-
+
self.mind = ""
-
+
self.memory_block = ""
self.relation_info_block = ""
self.time_block = ""
@@ -59,17 +59,13 @@ class MaiThinking:
self.identity = ""
self.sender = ""
self.target = ""
-
- self.thinking_model = LLMRequest(
- model=global_config.model.replyer_1,
- request_type="thinking",
- )
+
+ self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer_1, request_type="thinking")
async def do_think_before_response(self):
pass
- async def do_think_after_response(self,reponse:str):
-
+ async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt",
mind=self.mind,
@@ -85,47 +81,44 @@ class MaiThinking:
sender=self.sender,
target=self.target,
)
-
+
result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result
-
+
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
-
-
+
msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind)
-
-
+
async def do_think_when_receive_message(self):
pass
-
- async def build_internal_message_recv(self,message_text:str):
-
+
+ async def build_internal_message_recv(self, message_text: str):
msg_id = f"internal_{time.time()}"
-
+
message_dict = {
"message_info": {
"message_id": msg_id,
"time": time.time(),
"user_info": {
- "user_id": "internal", # 内部用户ID
- "user_nickname": "内心", # 内部昵称
- "platform": self.platform, # 平台标记为 internal
+ "user_id": "internal", # 内部用户ID
+ "user_nickname": "内心", # 内部昵称
+ "platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充
},
- "platform": self.platform, # 平台
+ "platform": self.platform, # 平台
# 其他 message_info 字段按需补充
},
"message_segment": {
- "type": "text", # 消息类型
- "data": message_text, # 消息内容
+ "type": "text", # 消息类型
+ "data": message_text, # 消息内容
# 其他 segment 字段按需补充
},
- "raw_message": message_text, # 原始消息内容
- "processed_plain_text": message_text, # 处理后的纯文本
+ "raw_message": message_text, # 原始消息内容
+ "processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False,
"has_emoji": False,
@@ -139,45 +132,36 @@ class MaiThinking:
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0,
}
-
+
if self.is_group:
message_dict["message_info"]["group_info"] = {
"platform": self.platform,
"group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name,
}
-
+
msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True
-
+
return msg_recv
-
-
-
+
class MaiThinkingManager:
def __init__(self):
self.mai_think_list = []
-
- def get_mai_think(self,chat_id):
+
+ def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id:
return mai_think
mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think)
return mai_think
-
+
+
mai_thinking_manager = MaiThinkingManager()
-
+
init_prompt()
-
-
-
-
-
-
-
-
diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py
index e7380822..8e05a025 100644
--- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py
+++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py
@@ -1,14 +1,16 @@
import json
import time
+
+from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
-from json_repair import repair_json
+
from src.mais4u.s4u_config import s4u_config
logger = get_logger("action")
@@ -32,7 +34,7 @@ BODY_CODE = {
"帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210",
- "平静,双手后放":"平静,双手后放",
+ "平静,双手后放": "平静,双手后放",
"思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般",
@@ -94,19 +96,15 @@ class ChatAction:
self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion)
- print(global_config.model.emotion)
-
- self.action_model = LLMRequest(
- model=global_config.model.emotion,
- temperature=0.7,
- request_type="motion",
- )
+ print(model_config.model_task_config.emotion)
- self.last_change_time = 0
+ self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
+
+ self.last_change_time: float = 0
async def send_action_update(self):
"""发送动作更新到前端"""
-
+
body_code = BODY_CODE.get(self.body_action, "")
await send_api.custom_to_stream(
message_type="body_action",
@@ -115,13 +113,11 @@ class ChatAction:
storage_message=False,
show_log=True,
)
-
-
async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0
- message_time = message.message_info.time
+ message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
@@ -147,13 +143,13 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
-
+
try:
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions)
-
+
prompt = await global_prompt_manager.format_prompt(
"change_action_prompt",
chat_talking_prompt=chat_talking_prompt,
@@ -163,19 +159,18 @@ class ChatAction:
)
logger.info(f"prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
+ response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
+ prompt=prompt, temperature=0.7
+ )
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
- action_data = json.loads(repair_json(response))
-
- if action_data:
+ if action_data := json.loads(repair_json(response)):
# 记录原动作,切换后进入冷却
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
- if new_body_action != prev_body_action:
- if prev_body_action:
- self.body_action_cooldown[prev_body_action] = 3
+ if new_body_action != prev_body_action and prev_body_action:
+ self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新
@@ -213,7 +208,6 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
try:
-
# 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
@@ -228,17 +222,17 @@ class ChatAction:
)
logger.info(f"prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt)
+ response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
+ prompt=prompt, temperature=0.7
+ )
logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}")
- action_data = json.loads(repair_json(response))
- if action_data:
+ if action_data := json.loads(repair_json(response)):
prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action)
- if new_body_action != prev_body_action:
- if prev_body_action:
- self.body_action_cooldown[prev_body_action] = 6
+ if new_body_action != prev_body_action and prev_body_action:
+ self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action
# 发送动作更新
await self.send_action_update()
@@ -306,9 +300,6 @@ class ActionManager:
return new_action_state
-
-
-
init_prompt()
action_manager = ActionManager()
diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py
index e447ae19..78df5e98 100644
--- a/src/mais4u/mais4u_chat/s4u_chat.py
+++ b/src/mais4u/mais4u_chat/s4u_chat.py
@@ -137,7 +137,7 @@ class MessageSenderContainer:
await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e:
- logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True)
+ logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved.
diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py
index c936cea1..11d8c7ca 100644
--- a/src/mais4u/mais4u_chat/s4u_mood_manager.py
+++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py
@@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api
@@ -114,18 +114,12 @@ class ChatMood:
self.regression_count: int = 0
- self.mood_model = LLMRequest(
- model=global_config.model.emotion,
- temperature=0.7,
- request_type="mood_text",
- )
+ self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
self.mood_model_numerical = LLMRequest(
- model=global_config.model.emotion,
- temperature=0.4,
- request_type="mood_numerical",
+ model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
)
- self.last_change_time = 0
+ self.last_change_time: float = 0
# 发送初始情绪状态到ws端
asyncio.create_task(self.send_emotion_update(self.mood_values))
@@ -164,7 +158,7 @@ class ChatMood:
async def update_mood_by_message(self, message: MessageRecv):
self.regression_count = 0
- message_time = message.message_info.time
+ message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_change_time,
@@ -199,7 +193,9 @@ class ChatMood:
mood_state=self.mood_state,
)
logger.debug(f"text mood prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
+ response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
+ prompt=prompt, temperature=0.7
+ )
logger.info(f"text mood response: {response}")
logger.debug(f"text mood reasoning_content: {reasoning_content}")
return response
@@ -216,8 +212,8 @@ class ChatMood:
fear=self.mood_values["fear"],
)
logger.debug(f"numerical mood prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
- prompt=prompt
+ response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
+ prompt=prompt, temperature=0.4
)
logger.info(f"numerical mood response: {response}")
logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
@@ -276,7 +272,9 @@ class ChatMood:
mood_state=self.mood_state,
)
logger.debug(f"text regress prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
+ response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
+ prompt=prompt, temperature=0.7
+ )
logger.info(f"text regress response: {response}")
logger.debug(f"text regress reasoning_content: {reasoning_content}")
return response
@@ -293,8 +291,9 @@ class ChatMood:
fear=self.mood_values["fear"],
)
logger.debug(f"numerical regress prompt: {prompt}")
- response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async(
- prompt=prompt
+ response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
+ prompt=prompt,
+ temperature=0.4,
)
logger.info(f"numerical regress response: {response}")
logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
@@ -447,6 +446,7 @@ class MoodManager:
# 发送初始情绪状态到ws端
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
+
if ENABLE_S4U:
init_prompt()
mood_manager = MoodManager()
diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py
index d748c25e..72324d74 100644
--- a/src/mais4u/mais4u_chat/s4u_prompt.py
+++ b/src/mais4u/mais4u_chat/s4u_prompt.py
@@ -150,19 +150,18 @@ class PromptBuilder:
relation_prompt = ""
if global_config.relationship.enable_relationship and who_chat_in_group:
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
-
+
# 将 (platform, user_id, nickname) 转换为 person_id
person_ids = []
for person in who_chat_in_group:
person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_ids.append(person_id)
-
+
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = await asyncio.gather(
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
)
- relation_info = "".join(relation_info_list)
- if relation_info:
+ if relation_info := "".join(relation_info_list):
relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info
)
@@ -186,9 +185,9 @@ class PromptBuilder:
timestamp=time.time(),
limit=300,
)
-
- talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id)
+
+ talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list = []
background_dialogue_list = []
@@ -258,19 +257,19 @@ class PromptBuilder:
all_msg_seg_list.append(msg_seg_str)
for msg in all_msg_seg_list:
core_msg_str += msg
-
-
+
+
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id,
timestamp=time.time(),
limit=20,
- )
+ )
all_dialogue_prompt_str = build_readable_messages(
all_dialogue_prompt,
timestamp_mode="normal_no_YMD",
show_pic=False,
)
-
+
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str
diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py
index 339b46c3..c0ca2658 100644
--- a/src/mais4u/mais4u_chat/s4u_stream_generator.py
+++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py
@@ -1,7 +1,7 @@
import os
from typing import AsyncGenerator
from src.mais4u.openai_client import AsyncOpenAIClient
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
from src.common.logger import get_logger
@@ -14,24 +14,27 @@ logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator:
def __init__(self):
- replyer_1_config = global_config.model.replyer_1
- provider = replyer_1_config.get("provider")
- if not provider:
- logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段")
- raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段")
+ replyer_1_config = model_config.model_task_config.replyer_1
+ model_to_use = replyer_1_config.model_list[0]
+ model_info = model_config.get_model_info(model_to_use)
+ if not model_info:
+ logger.error(f"模型 {model_to_use} 在配置中未找到")
+ raise ValueError(f"模型 {model_to_use} 在配置中未找到")
+ provider_name = model_info.api_provider
+ provider_info = model_config.get_provider(provider_name)
+ if not provider_info:
+ logger.error("`replyer_1` 找不到对应的Provider")
+ raise ValueError("`replyer_1` 找不到对应的Provider")
- api_key = os.environ.get(f"{provider.upper()}_KEY")
- base_url = os.environ.get(f"{provider.upper()}_BASE_URL")
+ api_key = provider_info.api_key
+ base_url = provider_info.base_url
if not api_key:
- logger.error(f"环境变量 {provider.upper()}_KEY 未设置")
- raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置")
+ logger.error(f"{provider_name}没有配置API KEY")
+ raise ValueError(f"{provider_name}没有配置API KEY")
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
- self.model_1_name = replyer_1_config.get("name")
- if not self.model_1_name:
- logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
- raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
+ self.model_1_name = model_to_use
self.replyer_1_config = replyer_1_config
self.current_model_name = "unknown model"
@@ -44,10 +47,10 @@ class S4UStreamGenerator:
r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符
re.UNICODE | re.DOTALL,
)
-
- self.chat_stream =None
-
- async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""):
+
+ self.chat_stream = None
+
+ async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
# person_id = PersonInfoManager.get_person_id(
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
# )
@@ -71,14 +74,10 @@ class S4UStreamGenerator:
[这是用户发来的新消息, 你需要结合上下文,对此进行回复]:
{message.processed_plain_text}
"""
- return True,message_txt
+ return True, message_txt
else:
message_txt = message.processed_plain_text
- return False,message_txt
-
-
-
-
+ return False, message_txt
async def generate_response(
self, message: MessageRecvS4U, previous_reply_context: str = ""
@@ -88,7 +87,7 @@ class S4UStreamGenerator:
self.partial_response = ""
message_txt = message.processed_plain_text
if not message.is_internal:
- interupted,message_txt_added = await self.build_last_internal_message(message,previous_reply_context)
+ interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
if interupted:
message_txt = message_txt_added
@@ -105,7 +104,6 @@ class S4UStreamGenerator:
current_client = self.client_1
self.current_model_name = self.model_1_name
-
extra_kwargs = {}
if self.replyer_1_config.get("enable_thinking") is not None:
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")
diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py
index 528eaecc..a08d18cd 100644
--- a/src/mais4u/mais4u_chat/super_chat_manager.py
+++ b/src/mais4u/mais4u_chat/super_chat_manager.py
@@ -214,51 +214,49 @@ class SuperChatManager:
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id)
-
+
if not superchats:
return ""
-
+
# 限制显示数量
display_superchats = superchats[:max_count]
-
- lines = []
- lines.append("📢 当前有效超级弹幕:")
-
+
+ lines = ["📢 当前有效超级弹幕:"]
for i, sc in enumerate(display_superchats, 1):
remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60)
-
+
time_display = f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒"
-
+
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度
- line = line[:97] + "..."
+ line = f"{line[:97]}..."
line += f" (剩余{time_display})"
lines.append(line)
-
+
if len(superchats) > max_count:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
-
+
return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id)
-
+
if not superchats:
return "当前没有有效的超级弹幕"
lines = []
for sc in superchats:
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
if len(single_sc_str) > 100:
- single_sc_str = single_sc_str[:97] + "..."
+ single_sc_str = f"{single_sc_str[:97]}..."
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
lines.append(single_sc_str)
-
+
total_amount = sum(sc.price for sc in superchats)
count = len(superchats)
highest_amount = max(sc.price for sc in superchats)
-
+
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元"
if lines:
final_str += "\n" + "\n".join(lines)
@@ -287,7 +285,7 @@ class SuperChatManager:
"lowest_amount": min(amounts)
}
- async def shutdown(self):
+ async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel()
@@ -300,6 +298,7 @@ class SuperChatManager:
+# sourcery skip: assign-if-exp
if ENABLE_S4U:
super_chat_manager = SuperChatManager()
else:
diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py
index edc200f6..c71c160d 100644
--- a/src/mais4u/mais4u_chat/yes_or_no.py
+++ b/src/mais4u/mais4u_chat/yes_or_no.py
@@ -1,19 +1,14 @@
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import model_config
from src.plugin_system.apis import send_api
+
logger = get_logger(__name__)
-head_actions_list = [
- "不做额外动作",
- "点头一次",
- "点头两次",
- "摇头",
- "歪脑袋",
- "低头望向一边"
-]
+head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
-async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat_id: str = ""):
+
+async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
prompt = f"""
{chat_history}
以上是对方的发言:
@@ -30,22 +25,14 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
低头望向一边
请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。"""
- model = LLMRequest(
- model=global_config.model.emotion,
- temperature=0.7,
- request_type="motion",
- )
-
+ model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
+
try:
# logger.info(f"prompt: {prompt}")
- response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt)
+ response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
logger.info(f"response: {response}")
-
- if response in head_actions_list:
- head_action = response
- else:
- head_action = "不做额外动作"
-
+
+ head_action = response if response in head_actions_list else "不做额外动作"
await send_api.custom_to_stream(
message_type="head_action",
content=head_action,
@@ -53,11 +40,7 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
storage_message=False,
show_log=True,
)
-
-
-
+
except Exception as e:
logger.error(f"yes_or_no_head error: {e}")
return "不做额外动作"
-
-
diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py
index eae0ea71..8daf38e6 100644
--- a/src/mood/mood_manager.py
+++ b/src/mood/mood_manager.py
@@ -3,13 +3,14 @@ import random
import time
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv
+from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager
-from src.chat.message_receive.chat_stream import get_chat_manager
+
logger = get_logger("mood")
@@ -49,7 +50,7 @@ class ChatMood:
chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id)
-
+
if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found")
@@ -59,11 +60,7 @@ class ChatMood:
self.regression_count: int = 0
- self.mood_model = LLMRequest(
- model=global_config.model.emotion,
- temperature=0.7,
- request_type="mood",
- )
+ self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")
self.last_change_time: float = 0
@@ -83,12 +80,16 @@ class ChatMood:
logger.debug(
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
)
- update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier)
+ update_probability = global_config.mood.mood_update_threshold * min(
+ 1.0, base_probability * time_multiplier * interest_multiplier
+ )
if random.random() > update_probability:
return
- logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}")
+ logger.debug(
+ f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
+ )
message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
@@ -124,7 +125,9 @@ class ChatMood:
mood_state=self.mood_state,
)
- response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
+ response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
+ prompt=prompt, temperature=0.7
+ )
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}")
logger.info(f"{self.log_prefix} response: {response}")
@@ -171,7 +174,9 @@ class ChatMood:
mood_state=self.mood_state,
)
- response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt)
+ response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
+ prompt=prompt, temperature=0.7
+ )
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}")
diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py
index 6be0ad27..4d5fe709 100644
--- a/src/person_info/person_info.py
+++ b/src/person_info/person_info.py
@@ -11,7 +11,7 @@ from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import global_config, model_config
"""
@@ -54,11 +54,7 @@ person_info_default = {
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
- # TODO: API-Adapter修改标记
- self.qv_name_llm = LLMRequest(
- model=global_config.model.utils,
- request_type="relation.qv_name",
- )
+ self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
@@ -199,7 +195,7 @@ class PersonInfoManager:
if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True
-
+
# 尝试创建
PersonInfo.create(**p_data)
return True
@@ -376,7 +372,7 @@ class PersonInfoManager:
"nickname": "昵称",
"reason": "理由"
}"""
- response, (reasoning_content, model_name) = await self.qv_name_llm.generate_response_async(qv_name_prompt)
+ response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}")
result = self._extract_json_from_text(response)
@@ -592,7 +588,7 @@ class PersonInfoManager:
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record:
return record, False # 记录存在,未创建
-
+
# 记录不存在,尝试创建
try:
PersonInfo.create(**init_data)
@@ -622,7 +618,7 @@ class PersonInfoManager:
"points": [],
"forgotten_points": [],
}
-
+
# 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS:
if key in initial_data:
@@ -630,12 +626,12 @@ class PersonInfoManager:
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False)
-
+
model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
-
+
if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")
diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py
index 99f3be30..267ed96f 100644
--- a/src/person_info/relationship_fetcher.py
+++ b/src/person_info/relationship_fetcher.py
@@ -7,7 +7,7 @@ from typing import List, Dict, Any
from json_repair import repair_json
from src.common.logger import get_logger
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
@@ -73,14 +73,12 @@ class RelationshipFetcher:
# LLM模型配置
self.llm_model = LLMRequest(
- model=global_config.model.utils_small,
- request_type="relation.fetcher",
+ model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
)
# 小模型用于即时信息提取
self.instant_llm_model = LLMRequest(
- model=global_config.model.utils_small,
- request_type="relation.fetch",
+ model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
)
name = get_chat_manager().get_stream_name(self.chat_id)
@@ -96,7 +94,7 @@ class RelationshipFetcher:
if not self.info_fetched_cache[person_id]:
del self.info_fetched_cache[person_id]
- async def build_relation_info(self, person_id, points_num = 3):
+ async def build_relation_info(self, person_id, points_num=3):
# 清理过期的信息缓存
self._cleanup_expired_cache()
@@ -361,7 +359,6 @@ class RelationshipFetcher:
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
logger.error(traceback.format_exc())
-
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中
diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py
index 6c269357..9d7a48b9 100644
--- a/src/person_info/relationship_manager.py
+++ b/src/person_info/relationship_manager.py
@@ -3,7 +3,7 @@ from .person_info import PersonInfoManager, get_person_info_manager
import time
import random
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages
import json
from json_repair import repair_json
@@ -20,9 +20,8 @@ logger = get_logger("relation")
class RelationshipManager:
def __init__(self):
self.relationship_llm = LLMRequest(
- model=global_config.model.utils,
- request_type="relationship", # 用于动作规划
- )
+ model_set=model_config.model_task_config.utils, request_type="relationship"
+ ) # 用于动作规划
@staticmethod
async def is_known_some_one(platform, user_id):
@@ -181,18 +180,14 @@ class RelationshipManager:
try:
points = repair_json(points)
points_data = json.loads(points)
-
+
# 只处理正确的格式,错误格式直接跳过
if points_data == "none" or not points_data:
points_list = []
elif isinstance(points_data, str) and points_data.lower() == "none":
points_list = []
elif isinstance(points_data, list):
- # 正确格式:数组格式 [{"point": "...", "weight": 10}, ...]
- if not points_data: # 空数组
- points_list = []
- else:
- points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
+ points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
else:
# 错误格式,直接跳过不解析
logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(points_data)}, 内容: {points_data}")
diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py
index eb07dbc9..a102ecd0 100644
--- a/src/plugin_system/__init__.py
+++ b/src/plugin_system/__init__.py
@@ -9,6 +9,7 @@ from .base import (
BasePlugin,
BaseAction,
BaseCommand,
+ BaseTool,
ConfigField,
ComponentType,
ActionActivationType,
@@ -17,11 +18,13 @@ from .base import (
ActionInfo,
CommandInfo,
PluginInfo,
+ ToolInfo,
PythonDependency,
BaseEventHandler,
EventHandlerInfo,
EventType,
MaiMessages,
+ ToolParamType,
)
# 导入工具模块
@@ -34,6 +37,7 @@ from .utils import (
from .apis import (
chat_api,
+ tool_api,
component_manage_api,
config_api,
database_api,
@@ -44,17 +48,17 @@ from .apis import (
person_api,
plugin_manage_api,
send_api,
- utils_api,
register_plugin,
get_logger,
)
-__version__ = "1.0.0"
+__version__ = "2.0.0"
__all__ = [
# API 模块
"chat_api",
+ "tool_api",
"component_manage_api",
"config_api",
"database_api",
@@ -65,13 +69,13 @@ __all__ = [
"person_api",
"plugin_manage_api",
"send_api",
- "utils_api",
"register_plugin",
"get_logger",
# 基础类
"BasePlugin",
"BaseAction",
"BaseCommand",
+ "BaseTool",
"BaseEventHandler",
# 类型定义
"ComponentType",
@@ -81,9 +85,11 @@ __all__ = [
"ActionInfo",
"CommandInfo",
"PluginInfo",
+ "ToolInfo",
"PythonDependency",
"EventHandlerInfo",
"EventType",
+ "ToolParamType",
# 消息
"MaiMessages",
# 装饰器
diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py
index 0882fbdc..362c9858 100644
--- a/src/plugin_system/apis/__init__.py
+++ b/src/plugin_system/apis/__init__.py
@@ -17,7 +17,7 @@ from src.plugin_system.apis import (
person_api,
plugin_manage_api,
send_api,
- utils_api,
+ tool_api,
)
from .logging_api import get_logger
from .plugin_register_api import register_plugin
@@ -35,7 +35,7 @@ __all__ = [
"person_api",
"plugin_manage_api",
"send_api",
- "utils_api",
"get_logger",
"register_plugin",
+ "tool_api",
]
diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py
index d9ea051d..1ffa0833 100644
--- a/src/plugin_system/apis/component_manage_api.py
+++ b/src/plugin_system/apis/component_manage_api.py
@@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import (
EventHandlerInfo,
PluginInfo,
ComponentType,
+ ToolInfo,
)
@@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
return component_registry.get_registered_command_info(command_name)
+def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
+ """
+ 获取指定 Tool 的注册信息。
+
+ Args:
+ tool_name (str): Tool 名称。
+
+ Returns:
+ ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。
+ """
+ from src.plugin_system.core.component_registry import component_registry
+
+ return component_registry.get_registered_tool_info(tool_name)
+
+
# === EventHandler 特定查询方法 ===
def get_registered_event_handler_info(
event_handler_name: str,
@@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType,
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
+ case ComponentType.TOOL:
+ return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _:
@@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
+ case ComponentType.TOOL:
+ return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _:
raise ValueError(f"未知 component type: {component_type}")
+
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
"""
获取指定消息流中禁用的组件列表。
@@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp
return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id)
+ case ComponentType.TOOL:
+ return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _:
- raise ValueError(f"未知 component type: {component_type}")
\ No newline at end of file
+ raise ValueError(f"未知 component type: {component_type}")
diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py
index d46bfba3..8b253806 100644
--- a/src/plugin_system/apis/database_api.py
+++ b/src/plugin_system/apis/database_api.py
@@ -152,10 +152,7 @@ async def db_query(
except DoesNotExist:
# 记录不存在
- if query_type == "get" and single_result:
- return None
- return []
-
+ return None if query_type == "get" and single_result else []
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
@@ -170,7 +167,8 @@ async def db_query(
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
-) -> Union[Dict[str, Any], None]:
+) -> Optional[Dict[str, Any]]:
+ # sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
@@ -203,10 +201,9 @@ async def db_save(
try:
# 如果提供了key_field和key_value,尝试更新现有记录
if key_field and key_value is not None:
- # 查找现有记录
- existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1))
-
- if existing_records:
+ if existing_records := list(
+ model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
+ ):
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items():
@@ -244,8 +241,8 @@ async def db_get(
Args:
model_class: Peewee模型类
filters: 过滤条件,字段名和值的字典
- order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
limit: 结果数量限制
+ order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
Returns:
@@ -310,7 +307,7 @@ async def store_action_info(
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
-) -> Union[Dict[str, Any], None]:
+) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py
index cafb52df..479f3aec 100644
--- a/src/plugin_system/apis/emoji_api.py
+++ b/src/plugin_system/apis/emoji_api.py
@@ -65,14 +65,14 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
return None
-async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]:
+async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包
Args:
count: 要获取的表情包数量,默认为1
Returns:
- Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None
+ List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表
Raises:
TypeError: 如果count不是整数类型
@@ -94,13 +94,13 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
- return None
+ return []
# 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包")
- return None
+ return []
if len(valid_emojis) < count:
logger.warning(
@@ -127,14 +127,14 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
if not results and count > 0:
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
- return None
+ return []
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
return results
except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
- return None
+ return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
@@ -162,10 +162,11 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 筛选匹配情感的表情包
matching_emojis = []
- for emoji_obj in all_emojis:
- if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]:
- matching_emojis.append(emoji_obj)
-
+ matching_emojis.extend(
+ emoji_obj
+ for emoji_obj in all_emojis
+ if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
+ )
if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None
@@ -256,10 +257,11 @@ def get_descriptions() -> List[str]:
emoji_manager = get_emoji_manager()
descriptions = []
- for emoji_obj in emoji_manager.emoji_objects:
- if not emoji_obj.is_deleted and emoji_obj.description:
- descriptions.append(emoji_obj.description)
-
+ descriptions.extend(
+ emoji_obj.description
+ for emoji_obj in emoji_manager.emoji_objects
+ if not emoji_obj.is_deleted and emoji_obj.description
+ )
return descriptions
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")
diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py
index f911454c..0e6e6551 100644
--- a/src/plugin_system/apis/generator_api.py
+++ b/src/plugin_system/apis/generator_api.py
@@ -12,6 +12,7 @@ import traceback
from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install
from src.common.logger import get_logger
+from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response
@@ -31,7 +32,7 @@ logger = get_logger("generator_api")
def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
- model_configs: Optional[List[Dict[str, Any]]] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
"""获取回复器对象
@@ -42,7 +43,7 @@ def get_replyer(
Args:
chat_stream: 聊天流对象(优先)
chat_id: 聊天ID(实际上就是stream_id)
- model_configs: 模型配置列表
+ model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型
Returns:
@@ -58,7 +59,7 @@ def get_replyer(
return replyer_manager.get_replyer(
chat_stream=chat_stream,
chat_id=chat_id,
- model_configs=model_configs,
+ model_set_with_weight=model_set_with_weight,
request_type=request_type,
)
except Exception as e:
@@ -83,31 +84,36 @@ async def generate_reply(
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
- model_configs: Optional[List[Dict[str, Any]]] = None,
- request_type: str = "",
- enable_timeout: bool = False,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
+ request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复
Args:
chat_stream: 聊天流对象(优先)
chat_id: 聊天ID(备用)
- action_data: 动作数据
+ action_data: 动作数据(向下兼容,包含reply_to和extra_info)
+ reply_to: 回复对象,格式为 "发送者:消息内容"
+ extra_info: 额外信息,用于补充上下文
+ available_actions: 可用动作
+ enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词
+ model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
+ request_type: 请求类型(可选,记录LLM使用)
Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
"""
try:
# 获取回复器
- replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type)
+ replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复")
-
+
if not reply_to and action_data:
reply_to = action_data.get("reply_to", "")
if not extra_info and action_data:
@@ -118,7 +124,6 @@ async def generate_reply(
reply_to=reply_to,
extra_info=extra_info,
available_actions=available_actions,
- enable_timeout=enable_timeout,
enable_tool=enable_tool,
)
reply_set = []
@@ -150,33 +155,35 @@ async def rewrite_reply(
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
- model_configs: Optional[List[Dict[str, Any]]] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
-) -> Tuple[bool, List[Tuple[str, Any]]]:
+ return_prompt: bool = False,
+) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""重写回复
Args:
chat_stream: 聊天流对象(优先)
- reply_data: 回复数据字典(备用,当其他参数缺失时从此获取)
+ reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取)
chat_id: 聊天ID(备用)
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
- model_configs: 模型配置列表
+ model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
raw_reply: 原始回复内容
reason: 回复原因
reply_to: 回复对象
+ return_prompt: 是否返回提示词
Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
"""
try:
# 获取回复器
- replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs)
+ replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
- return False, []
+ return False, [], None
logger.info("[GeneratorAPI] 开始重写回复")
@@ -187,10 +194,11 @@ async def rewrite_reply(
reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复
- success, content = await replyer.rewrite_reply_with_context(
+ success, content, prompt = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply,
reason=reason,
reply_to=reply_to,
+ return_prompt=return_prompt,
)
reply_set = []
if content:
@@ -201,14 +209,14 @@ async def rewrite_reply(
else:
logger.warning("[GeneratorAPI] 重写回复失败")
- return success, reply_set
+ return success, reply_set, prompt if return_prompt else None
except ValueError as ve:
raise ve
except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
- return False, []
+ return False, [], None
async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
@@ -234,3 +242,27 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
except Exception as e:
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
return []
+
+async def generate_response_custom(
+ chat_stream: Optional[ChatStream] = None,
+ chat_id: Optional[str] = None,
+ model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
+ prompt: str = "",
+) -> Optional[str]:
+ replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
+ if not replyer:
+ logger.error("[GeneratorAPI] 无法获取回复器")
+ return None
+
+ try:
+ logger.debug("[GeneratorAPI] 开始生成自定义回复")
+ response, _, _, _ = await replyer.llm_generate_content(prompt)
+ if response:
+ logger.debug("[GeneratorAPI] 自定义回复生成成功")
+ return response
+ else:
+ logger.warning("[GeneratorAPI] 自定义回复生成失败")
+ return None
+ except Exception as e:
+ logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
+ return None
\ No newline at end of file
diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py
index 72b865b8..9d37a8e3 100644
--- a/src/plugin_system/apis/llm_api.py
+++ b/src/plugin_system/apis/llm_api.py
@@ -7,10 +7,12 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
"""
-from typing import Tuple, Dict, Any
+from typing import Tuple, Dict, List, Any, Optional
from src.common.logger import get_logger
+from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
+from src.config.config import global_config, model_config
+from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api")
@@ -19,9 +21,7 @@ logger = get_logger("llm_api")
# =============================================================================
-
-
-def get_available_models() -> Dict[str, Any]:
+def get_available_models() -> Dict[str, TaskConfig]:
"""获取所有可用的模型配置
Returns:
@@ -33,14 +33,14 @@ def get_available_models() -> Dict[str, Any]:
return {}
# 自动获取所有属性并转换为字典形式
- rets = {}
- models = global_config.model
+ models = model_config.model_task_config
attrs = dir(models)
+ rets: Dict[str, TaskConfig] = {}
for attr in attrs:
if not attr.startswith("__"):
try:
value = getattr(models, attr)
- if not callable(value): # 排除方法
+ if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value
except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
@@ -53,7 +53,11 @@ def get_available_models() -> Dict[str, Any]:
async def generate_with_model(
- prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs
+ prompt: str,
+ model_config: TaskConfig,
+ request_type: str = "plugin.generate",
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容
@@ -61,22 +65,62 @@ async def generate_with_model(
prompt: 提示词
model_config: 模型配置(从 get_available_models 获取的模型配置)
request_type: 请求类型标识
- **kwargs: 其他模型特定参数,如temperature、max_tokens等
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
- model_name = model_config.get("name")
- logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容")
+ model_name_list = model_config.model_list
+ logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
- llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs)
+ llm_request = LLMRequest(model_set=model_config, request_type=request_type)
- response, (reasoning, model_name) = await llm_request.generate_response_async(prompt)
- return True, response, reasoning, model_name
+ response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
+ return True, response, reasoning_content, model_name
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", ""
+
+async def generate_with_model_with_tools(
+ prompt: str,
+ model_config: TaskConfig,
+ tool_options: List[Dict[str, Any]] | None = None,
+ request_type: str = "plugin.generate",
+ temperature: Optional[float] = None,
+ max_tokens: Optional[int] = None,
+) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
+ """使用指定模型和工具生成内容
+
+ Args:
+ prompt: 提示词
+ model_config: 模型配置(从 get_available_models 获取的模型配置)
+ tool_options: 工具选项列表
+ request_type: 请求类型标识
+ temperature: 温度参数
+ max_tokens: 最大token数
+
+ Returns:
+ Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
+ """
+ try:
+ model_name_list = model_config.model_list
+ logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
+ logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
+
+ llm_request = LLMRequest(model_set=model_config, request_type=request_type)
+
+ response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
+ prompt,
+ tools=tool_options,
+ temperature=temperature,
+ max_tokens=max_tokens
+ )
+ return True, response, reasoning_content, model_name, tool_call
+
+ except Exception as e:
+ error_msg = f"生成内容时出错: {str(e)}"
+ logger.error(f"[LLMAPI] {error_msg}")
+ return False, error_msg, "", "", None
diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py
index 7794ee81..7cf9dc04 100644
--- a/src/plugin_system/apis/message_api.py
+++ b/src/plugin_system/apis/message_api.py
@@ -207,7 +207,7 @@ def get_random_chat_messages(
def get_messages_by_time_for_users(
- start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
+ start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""
获取指定用户在所有聊天中指定时间范围内的消息
@@ -287,7 +287,7 @@ def get_messages_before_time_in_chat(
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
-def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]:
+def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
"""
获取指定用户在指定时间戳之前的消息
@@ -372,7 +372,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
return num_new_messages_since(chat_id, start_time, end_time)
-def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list) -> int:
+def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
"""
计算指定聊天中指定用户从开始时间到结束时间的新消息数量
diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py
index 1c01119b..693e42b4 100644
--- a/src/plugin_system/apis/plugin_manage_api.py
+++ b/src/plugin_system/apis/plugin_manage_api.py
@@ -1,10 +1,12 @@
from typing import Tuple, List
+
+
def list_loaded_plugins() -> List[str]:
"""
列出所有当前加载的插件。
Returns:
- list: 当前加载的插件名称列表。
+ List[str]: 当前加载的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
@@ -16,17 +18,38 @@ def list_registered_plugins() -> List[str]:
列出所有已注册的插件。
Returns:
- list: 已注册的插件名称列表。
+ List[str]: 已注册的插件名称列表。
"""
from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_registered_plugins()
+def get_plugin_path(plugin_name: str) -> str:
+ """
+ 获取指定插件的路径。
+
+ Args:
+ plugin_name (str): 插件名称。
+
+ Returns:
+ str: 插件目录的绝对路径。
+
+ Raises:
+ ValueError: 如果插件不存在。
+ """
+ from src.plugin_system.core.plugin_manager import plugin_manager
+
+ if plugin_path := plugin_manager.get_plugin_path(plugin_name):
+ return plugin_path
+ else:
+ raise ValueError(f"插件 '{plugin_name}' 不存在。")
+
+
async def remove_plugin(plugin_name: str) -> bool:
"""
卸载指定的插件。
-
+
**此函数是异步的,确保在异步环境中调用。**
Args:
@@ -43,7 +66,7 @@ async def remove_plugin(plugin_name: str) -> bool:
async def reload_plugin(plugin_name: str) -> bool:
"""
重新加载指定的插件。
-
+
**此函数是异步的,确保在异步环境中调用。**
Args:
@@ -71,6 +94,7 @@ def load_plugin(plugin_name: str) -> Tuple[bool, int]:
return plugin_manager.load_registered_plugin_classes(plugin_name)
+
def add_plugin_directory(plugin_directory: str) -> bool:
"""
添加插件目录。
@@ -84,6 +108,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
return plugin_manager.add_plugin_directory(plugin_directory)
+
def rescan_plugin_directory() -> Tuple[int, int]:
"""
重新扫描插件目录,加载新插件。
@@ -92,4 +117,4 @@ def rescan_plugin_directory() -> Tuple[int, int]:
"""
from src.plugin_system.core.plugin_manager import plugin_manager
- return plugin_manager.rescan_plugin_directory()
\ No newline at end of file
+ return plugin_manager.rescan_plugin_directory()
diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py
index f7af0259..10fbd804 100644
--- a/src/plugin_system/apis/send_api.py
+++ b/src/plugin_system/apis/send_api.py
@@ -49,7 +49,7 @@ async def _send_to_target(
display_message: str = "",
typing: bool = False,
reply_to: str = "",
- reply_to_platform_id: str = "",
+ reply_to_platform_id: Optional[str] = None,
storage_message: bool = True,
show_log: bool = True,
) -> bool:
@@ -60,8 +60,11 @@ async def _send_to_target(
content: 消息内容
stream_id: 目标流ID
display_message: 显示消息
- typing: 是否显示正在输入
- reply_to: 回复消息的格式,如"发送者:消息内容"
+ typing: 是否模拟打字等待。
+ reply_to: 回复消息,格式为"发送者:消息内容"
+ reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
+ storage_message: 是否存储消息到数据库
+ show_log: 发送是否显示日志
Returns:
bool: 是否发送成功
@@ -97,6 +100,10 @@ async def _send_to_target(
anchor_message = None
if reply_to:
anchor_message = await _find_reply_message(target_stream, reply_to)
+ if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
+ reply_to_platform_id = (
+ f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
+ )
# 构建发送消息对象
bot_message = MessageSending(
@@ -262,12 +269,22 @@ async def text_to_stream(
stream_id: 聊天流ID
typing: 是否显示正在输入
reply_to: 回复消息,格式为"发送者:消息内容"
+ reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!)
storage_message: 是否存储消息到数据库
Returns:
bool: 是否发送成功
"""
- return await _send_to_target("text", text, stream_id, "", typing, reply_to, reply_to_platform_id, storage_message)
+ return await _send_to_target(
+ "text",
+ text,
+ stream_id,
+ "",
+ typing,
+ reply_to,
+ reply_to_platform_id=reply_to_platform_id,
+ storage_message=storage_message,
+ )
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
@@ -318,7 +335,7 @@ async def command_to_stream(
async def custom_to_stream(
message_type: str,
- content: str,
+ content: str | dict,
stream_id: str,
display_message: str = "",
typing: bool = False,
@@ -350,249 +367,3 @@ async def custom_to_stream(
storage_message=storage_message,
show_log=show_log,
)
-
-
-async def text_to_group(
- text: str,
- group_id: str,
- platform: str = "qq",
- typing: bool = False,
- reply_to: str = "",
- storage_message: bool = True,
-) -> bool:
- """向群聊发送文本消息
-
- Args:
- text: 要发送的文本内容
- group_id: 群聊ID
- platform: 平台,默认为"qq"
- typing: 是否显示正在输入
- reply_to: 回复消息,格式为"发送者:消息内容"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
-
- return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
-
-
-async def text_to_user(
- text: str,
- user_id: str,
- platform: str = "qq",
- typing: bool = False,
- reply_to: str = "",
- storage_message: bool = True,
-) -> bool:
- """向用户发送私聊文本消息
-
- Args:
- text: 要发送的文本内容
- user_id: 用户ID
- platform: 平台,默认为"qq"
- typing: 是否显示正在输入
- reply_to: 回复消息,格式为"发送者:消息内容"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
- return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
-
-
-async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
- """向群聊发送表情包
-
- Args:
- emoji_base64: 表情包的base64编码
- group_id: 群聊ID
- platform: 平台,默认为"qq"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
- return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
-
-
-async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
- """向用户发送表情包
-
- Args:
- emoji_base64: 表情包的base64编码
- user_id: 用户ID
- platform: 平台,默认为"qq"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
- return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
-
-
-async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
- """向群聊发送图片
-
- Args:
- image_base64: 图片的base64编码
- group_id: 群聊ID
- platform: 平台,默认为"qq"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
- return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
-
-
-async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
- """向用户发送图片
-
- Args:
- image_base64: 图片的base64编码
- user_id: 用户ID
- platform: 平台,默认为"qq"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
- return await _send_to_target("image", image_base64, stream_id, "", typing=False)
-
-
-async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
- """向群聊发送命令
-
- Args:
- command: 命令
- group_id: 群聊ID
- platform: 平台,默认为"qq"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
- return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
-
-
-async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
- """向用户发送命令
-
- Args:
- command: 命令
- user_id: 用户ID
- platform: 平台,默认为"qq"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
- return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
-
-
-# =============================================================================
-# 通用发送函数 - 支持任意消息类型
-# =============================================================================
-
-
-async def custom_to_group(
- message_type: str,
- content: str,
- group_id: str,
- platform: str = "qq",
- display_message: str = "",
- typing: bool = False,
- reply_to: str = "",
- storage_message: bool = True,
-) -> bool:
- """向群聊发送自定义类型消息
-
- Args:
- message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
- content: 消息内容(通常是base64编码或文本)
- group_id: 群聊ID
- platform: 平台,默认为"qq"
- display_message: 显示消息
- typing: 是否显示正在输入
- reply_to: 回复消息,格式为"发送者:消息内容"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
- return await _send_to_target(
- message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
- )
-
-
-async def custom_to_user(
- message_type: str,
- content: str,
- user_id: str,
- platform: str = "qq",
- display_message: str = "",
- typing: bool = False,
- reply_to: str = "",
- storage_message: bool = True,
-) -> bool:
- """向用户发送自定义类型消息
-
- Args:
- message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
- content: 消息内容(通常是base64编码或文本)
- user_id: 用户ID
- platform: 平台,默认为"qq"
- display_message: 显示消息
- typing: 是否显示正在输入
- reply_to: 回复消息,格式为"发送者:消息内容"
-
- Returns:
- bool: 是否发送成功
- """
- stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
- return await _send_to_target(
- message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
- )
-
-
-async def custom_message(
- message_type: str,
- content: str,
- target_id: str,
- is_group: bool = True,
- platform: str = "qq",
- display_message: str = "",
- typing: bool = False,
- reply_to: str = "",
- storage_message: bool = True,
-) -> bool:
- """发送自定义消息的通用接口
-
- Args:
- message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"、"audio"等
- content: 消息内容
- target_id: 目标ID(群ID或用户ID)
- is_group: 是否为群聊,True为群聊,False为私聊
- platform: 平台,默认为"qq"
- display_message: 显示消息
- typing: 是否显示正在输入
- reply_to: 回复消息,格式为"发送者:消息内容"
-
- Returns:
- bool: 是否发送成功
-
- 示例:
- # 发送视频到群聊
- await send_api.custom_message("video", video_base64, "123456", True)
-
- # 发送文件到用户
- await send_api.custom_message("file", file_base64, "987654", False)
-
- # 发送音频到群聊并回复特定消息
- await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
- """
- stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
- return await _send_to_target(
- message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
- )
diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py
new file mode 100644
index 00000000..a6704126
--- /dev/null
+++ b/src/plugin_system/apis/tool_api.py
@@ -0,0 +1,27 @@
+from typing import Optional, Type
+from src.plugin_system.base.base_tool import BaseTool
+from src.plugin_system.base.component_types import ComponentType
+
+from src.common.logger import get_logger
+
+logger = get_logger("tool_api")
+
+
+def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
+ """获取公开工具实例"""
+ from src.plugin_system.core import component_registry
+
+ tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
+ return tool_class() if tool_class else None
+
+
+def get_llm_available_tool_definitions():
+ """获取LLM可用的工具定义列表
+
+ Returns:
+ List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)]
+ """
+ from src.plugin_system.core import component_registry
+
+ llm_available_tools = component_registry.get_llm_available_tools()
+ return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
diff --git a/src/plugin_system/apis/utils_api.py b/src/plugin_system/apis/utils_api.py
deleted file mode 100644
index 45996df5..00000000
--- a/src/plugin_system/apis/utils_api.py
+++ /dev/null
@@ -1,168 +0,0 @@
-"""工具类API模块
-
-提供了各种辅助功能
-使用方式:
- from src.plugin_system.apis import utils_api
- plugin_path = utils_api.get_plugin_path()
- data = utils_api.read_json_file("data.json")
- timestamp = utils_api.get_timestamp()
-"""
-
-import os
-import json
-import time
-import inspect
-import datetime
-import uuid
-from typing import Any, Optional
-from src.common.logger import get_logger
-
-logger = get_logger("utils_api")
-
-
-# =============================================================================
-# 文件操作API函数
-# =============================================================================
-
-
-def get_plugin_path(caller_frame=None) -> str:
- """获取调用者插件的路径
-
- Args:
- caller_frame: 调用者的栈帧,默认为None(自动获取)
-
- Returns:
- str: 插件目录的绝对路径
- """
- try:
- if caller_frame is None:
- caller_frame = inspect.currentframe().f_back # type: ignore
-
- plugin_module_path = inspect.getfile(caller_frame) # type: ignore
- plugin_dir = os.path.dirname(plugin_module_path)
- return plugin_dir
- except Exception as e:
- logger.error(f"[UtilsAPI] 获取插件路径失败: {e}")
- return ""
-
-
-def read_json_file(file_path: str, default: Any = None) -> Any:
- """读取JSON文件
-
- Args:
- file_path: 文件路径,可以是相对于插件目录的路径
- default: 如果文件不存在或读取失败时返回的默认值
-
- Returns:
- Any: JSON数据或默认值
- """
- try:
- # 如果是相对路径,则相对于调用者的插件目录
- if not os.path.isabs(file_path):
- caller_frame = inspect.currentframe().f_back # type: ignore
- plugin_dir = get_plugin_path(caller_frame)
- file_path = os.path.join(plugin_dir, file_path)
-
- if not os.path.exists(file_path):
- logger.warning(f"[UtilsAPI] 文件不存在: {file_path}")
- return default
-
- with open(file_path, "r", encoding="utf-8") as f:
- return json.load(f)
- except Exception as e:
- logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}")
- return default
-
-
-def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
- """写入JSON文件
-
- Args:
- file_path: 文件路径,可以是相对于插件目录的路径
- data: 要写入的数据
- indent: JSON缩进
-
- Returns:
- bool: 是否写入成功
- """
- try:
- # 如果是相对路径,则相对于调用者的插件目录
- if not os.path.isabs(file_path):
- caller_frame = inspect.currentframe().f_back # type: ignore
- plugin_dir = get_plugin_path(caller_frame)
- file_path = os.path.join(plugin_dir, file_path)
-
- # 确保目录存在
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
-
- with open(file_path, "w", encoding="utf-8") as f:
- json.dump(data, f, ensure_ascii=False, indent=indent)
- return True
- except Exception as e:
- logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}")
- return False
-
-
-# =============================================================================
-# 时间相关API函数
-# =============================================================================
-
-
-def get_timestamp() -> int:
- """获取当前时间戳
-
- Returns:
- int: 当前时间戳(秒)
- """
- return int(time.time())
-
-
-def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
- """格式化时间
-
- Args:
- timestamp: 时间戳,如果为None则使用当前时间
- format_str: 时间格式字符串
-
- Returns:
- str: 格式化后的时间字符串
- """
- try:
- if timestamp is None:
- timestamp = time.time()
- return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
- except Exception as e:
- logger.error(f"[UtilsAPI] 格式化时间失败: {e}")
- return ""
-
-
-def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
- """解析时间字符串为时间戳
-
- Args:
- time_str: 时间字符串
- format_str: 时间格式字符串
-
- Returns:
- int: 时间戳(秒)
- """
- try:
- dt = datetime.datetime.strptime(time_str, format_str)
- return int(dt.timestamp())
- except Exception as e:
- logger.error(f"[UtilsAPI] 解析时间失败: {e}")
- return 0
-
-
-# =============================================================================
-# 其他工具函数
-# =============================================================================
-
-
-def generate_unique_id() -> str:
- """生成唯一ID
-
- Returns:
- str: 唯一ID
- """
- return str(uuid.uuid4())
diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py
index a95e05ae..bc63d35d 100644
--- a/src/plugin_system/base/__init__.py
+++ b/src/plugin_system/base/__init__.py
@@ -6,6 +6,7 @@
from .base_plugin import BasePlugin
from .base_action import BaseAction
+from .base_tool import BaseTool
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .component_types import (
@@ -15,11 +16,13 @@ from .component_types import (
ComponentInfo,
ActionInfo,
CommandInfo,
+ ToolInfo,
PluginInfo,
PythonDependency,
EventHandlerInfo,
EventType,
MaiMessages,
+ ToolParamType,
)
from .config_types import ConfigField
@@ -27,12 +30,14 @@ __all__ = [
"BasePlugin",
"BaseAction",
"BaseCommand",
+ "BaseTool",
"ComponentType",
"ActionActivationType",
"ChatMode",
"ComponentInfo",
"ActionInfo",
"CommandInfo",
+ "ToolInfo",
"PluginInfo",
"PythonDependency",
"ConfigField",
@@ -40,4 +45,5 @@ __all__ = [
"EventType",
"BaseEventHandler",
"MaiMessages",
+ "ToolParamType",
]
diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py
index 7acd14a4..66d723f5 100644
--- a/src/plugin_system/base/base_action.py
+++ b/src/plugin_system/base/base_action.py
@@ -208,7 +208,7 @@ class BaseAction(ABC):
return False, f"等待新消息失败: {str(e)}"
async def send_text(
- self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False
+ self, content: str, reply_to: str = "", typing: bool = False
) -> bool:
"""发送文本消息
@@ -227,7 +227,6 @@ class BaseAction(ABC):
text=content,
stream_id=self.chat_id,
reply_to=reply_to,
- reply_to_platform_id=reply_to_platform_id,
typing=typing,
)
diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py
index 3cf82390..ea28c514 100644
--- a/src/plugin_system/base/base_plugin.py
+++ b/src/plugin_system/base/base_plugin.py
@@ -3,10 +3,11 @@ from typing import List, Type, Tuple, Union
from .plugin_base import PluginBase
from src.common.logger import get_logger
-from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo
+from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
+from .base_tool import BaseTool
logger = get_logger("base_plugin")
@@ -31,6 +32,7 @@ class BasePlugin(PluginBase):
Tuple[ActionInfo, Type[BaseAction]],
Tuple[CommandInfo, Type[BaseCommand]],
Tuple[EventHandlerInfo, Type[BaseEventHandler]],
+ Tuple[ToolInfo, Type[BaseTool]],
]
]:
"""获取插件包含的组件列表
diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py
new file mode 100644
index 00000000..1d589eca
--- /dev/null
+++ b/src/plugin_system/base/base_tool.py
@@ -0,0 +1,91 @@
+from abc import ABC, abstractmethod
+from typing import Any, List, Tuple
+from rich.traceback import install
+
+from src.common.logger import get_logger
+from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType
+
+install(extra_lines=3)
+
+logger = get_logger("base_tool")
+
+
+class BaseTool(ABC):
+ """所有工具的基类"""
+
+ name: str = ""
+ """工具的名称"""
+ description: str = ""
+ """工具的描述"""
+ parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = []
+ """工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式
+ param_name: 参数名称
+ param_type: 参数类型
+ description: 参数描述
+ required: 是否必填
+ enum_values: 枚举值列表
+ 例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])]
+ """
+ available_for_llm: bool = False
+ """是否可供LLM使用"""
+
+ @classmethod
+ def get_tool_definition(cls) -> dict[str, Any]:
+ """获取工具定义,用于LLM工具调用
+
+ Returns:
+ dict: 工具定义字典
+ """
+ if not cls.name or not cls.description or not cls.parameters:
+ raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
+
+ return {"name": cls.name, "description": cls.description, "parameters": cls.parameters}
+
+ @classmethod
+ def get_tool_info(cls) -> ToolInfo:
+ """获取工具信息"""
+ if not cls.name or not cls.description or not cls.parameters:
+ raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
+
+ return ToolInfo(
+ name=cls.name,
+ tool_description=cls.description,
+ enabled=cls.available_for_llm,
+ tool_parameters=cls.parameters,
+ component_type=ComponentType.TOOL,
+ )
+
+ @abstractmethod
+ async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
+ """执行工具函数(供llm调用)
+ 通过该方法,maicore会通过llm的tool call来调用工具
+ 传入的是json格式的参数,符合parameters定义的格式
+
+ Args:
+ function_args: 工具调用参数
+
+ Returns:
+ dict: 工具执行结果
+ """
+ raise NotImplementedError("子类必须实现execute方法")
+
+ async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]:
+ """直接执行工具函数(供插件调用)
+ 通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数
+ 插件可以直接调用此方法,用更加明了的方式传入参数
+ 示例: result = await tool.direct_execute(arg1="参数",arg2="参数2")
+
+ 工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑
+
+ Args:
+ **function_args: 工具调用参数
+
+ Returns:
+ dict: 工具执行结果
+ """
+ parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名
+ for param_name in parameter_required:
+ if param_name not in function_args:
+ raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}")
+
+ return await self.execute(function_args)
diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py
index eeb2a5a0..7775f5fb 100644
--- a/src/plugin_system/base/component_types.py
+++ b/src/plugin_system/base/component_types.py
@@ -1,8 +1,9 @@
from enum import Enum
-from typing import Dict, Any, List, Optional
+from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
from maim_message import Seg
+from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
# 组件类型枚举
class ComponentType(Enum):
@@ -10,6 +11,7 @@ class ComponentType(Enum):
ACTION = "action" # 动作组件
COMMAND = "command" # 命令组件
+ TOOL = "tool" # 服务组件(预留)
SCHEDULER = "scheduler" # 定时任务组件(预留)
EVENT_HANDLER = "event_handler" # 事件处理组件(预留)
@@ -146,6 +148,18 @@ class CommandInfo(ComponentInfo):
self.component_type = ComponentType.COMMAND
+@dataclass
+class ToolInfo(ComponentInfo):
+ """工具组件信息"""
+
+ tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
+ tool_description: str = "" # 工具描述
+
+ def __post_init__(self):
+ super().__post_init__()
+ self.component_type = ComponentType.TOOL
+
+
@dataclass
class EventHandlerInfo(ComponentInfo):
"""事件处理器组件信息"""
diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py
index 2ea89b88..59a03b73 100644
--- a/src/plugin_system/core/component_registry.py
+++ b/src/plugin_system/core/component_registry.py
@@ -6,6 +6,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import (
ComponentInfo,
ActionInfo,
+ ToolInfo,
CommandInfo,
EventHandlerInfo,
PluginInfo,
@@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import (
)
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_action import BaseAction
+from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.base_events_handler import BaseEventHandler
logger = get_logger("component_registry")
@@ -30,7 +32,7 @@ class ComponentRegistry:
"""组件注册表 命名空间式组件名 -> 组件信息"""
self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType}
"""类型 -> 组件原名称 -> 组件信息"""
- self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {}
+ self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {}
"""命名空间式组件名 -> 组件类"""
# 插件注册表
@@ -49,6 +51,10 @@ class ComponentRegistry:
self._command_patterns: Dict[Pattern, str] = {}
"""编译后的正则 -> command名"""
+ # 工具特定注册表
+ self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类
+ self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类
+
# EventHandler特定注册表
self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {}
"""event_handler名 -> event_handler类"""
@@ -79,7 +85,9 @@ class ComponentRegistry:
return True
def register_component(
- self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]]
+ self,
+ component_info: ComponentInfo,
+ component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]],
) -> bool:
"""注册组件
@@ -125,6 +133,10 @@ class ComponentRegistry:
assert isinstance(component_info, CommandInfo)
assert issubclass(component_class, BaseCommand)
ret = self._register_command_component(component_info, component_class)
+ case ComponentType.TOOL:
+ assert isinstance(component_info, ToolInfo)
+ assert issubclass(component_class, BaseTool)
+ ret = self._register_tool_component(component_info, component_class)
case ComponentType.EVENT_HANDLER:
assert isinstance(component_info, EventHandlerInfo)
assert issubclass(component_class, BaseEventHandler)
@@ -180,6 +192,18 @@ class ComponentRegistry:
return True
+ def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool:
+ """注册Tool组件到Tool特定注册表"""
+ tool_name = tool_info.name
+
+ self._tool_registry[tool_name] = tool_class
+
+ # 如果是llm可用的且启用的工具,添加到 llm可用工具列表
+ if tool_info.enabled:
+ self._llm_available_tools[tool_name] = tool_class
+
+ return True
+
def _register_event_handler_component(
self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler]
) -> bool:
@@ -222,6 +246,9 @@ class ComponentRegistry:
keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name]
for key in keys_to_remove:
self._command_patterns.pop(key)
+ case ComponentType.TOOL:
+ self._tool_registry.pop(component_name)
+ self._llm_available_tools.pop(component_name)
case ComponentType.EVENT_HANDLER:
from .events_manager import events_manager # 延迟导入防止循环导入问题
@@ -234,13 +261,13 @@ class ComponentRegistry:
self._components_classes.pop(namespaced_name)
logger.info(f"组件 {component_name} 已移除")
return True
- except KeyError:
- logger.warning(f"移除组件时未找到组件: {component_name}")
+ except KeyError as e:
+ logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}")
return False
except Exception as e:
logger.error(f"移除组件 {component_name} 时发生错误: {e}")
return False
-
+
def remove_plugin_registry(self, plugin_name: str) -> bool:
"""移除插件注册信息
@@ -281,6 +308,10 @@ class ComponentRegistry:
assert isinstance(target_component_info, CommandInfo)
pattern = target_component_info.command_pattern
self._command_patterns[re.compile(pattern)] = component_name
+ case ComponentType.TOOL:
+ assert isinstance(target_component_info, ToolInfo)
+ assert issubclass(target_component_class, BaseTool)
+ self._llm_available_tools[component_name] = target_component_class
case ComponentType.EVENT_HANDLER:
assert isinstance(target_component_info, EventHandlerInfo)
assert issubclass(target_component_class, BaseEventHandler)
@@ -308,20 +339,29 @@ class ComponentRegistry:
logger.warning(f"组件 {component_name} 未注册,无法禁用")
return False
target_component_info.enabled = False
- match component_type:
- case ComponentType.ACTION:
- self._default_actions.pop(component_name, None)
- case ComponentType.COMMAND:
- self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
- case ComponentType.EVENT_HANDLER:
- self._enabled_event_handlers.pop(component_name, None)
- from .events_manager import events_manager # 延迟导入防止循环导入问题
+ try:
+ match component_type:
+ case ComponentType.ACTION:
+ self._default_actions.pop(component_name)
+ case ComponentType.COMMAND:
+ self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name}
+ case ComponentType.TOOL:
+ self._llm_available_tools.pop(component_name)
+ case ComponentType.EVENT_HANDLER:
+ self._enabled_event_handlers.pop(component_name)
+ from .events_manager import events_manager # 延迟导入防止循环导入问题
- await events_manager.unregister_event_subscriber(component_name)
- self._components[component_name].enabled = False
- self._components_by_type[component_type][component_name].enabled = False
- logger.info(f"组件 {component_name} 已禁用")
- return True
+ await events_manager.unregister_event_subscriber(component_name)
+ self._components[component_name].enabled = False
+ self._components_by_type[component_type][component_name].enabled = False
+ logger.info(f"组件 {component_name} 已禁用")
+ return True
+ except KeyError as e:
+ logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}")
+ return False
+ except Exception as e:
+ logger.error(f"禁用组件 {component_name} 时发生错误: {e}")
+ return False
# === 组件查询方法 ===
def get_component_info(
@@ -371,7 +411,7 @@ class ComponentRegistry:
self,
component_name: str,
component_type: Optional[ComponentType] = None,
- ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]:
+ ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]:
"""获取组件类,支持自动命名空间解析
Args:
@@ -476,6 +516,27 @@ class ComponentRegistry:
command_info,
)
+ # === Tool 特定查询方法 ===
+ def get_tool_registry(self) -> Dict[str, Type[BaseTool]]:
+ """获取Tool注册表"""
+ return self._tool_registry.copy()
+
+ def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]:
+ """获取LLM可用的Tool列表"""
+ return self._llm_available_tools.copy()
+
+ def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]:
+ """获取Tool信息
+
+ Args:
+ tool_name: 工具名称
+
+ Returns:
+ ToolInfo: 工具信息对象,如果工具不存在则返回 None
+ """
+ info = self.get_component_info(tool_name, ComponentType.TOOL)
+ return info if isinstance(info, ToolInfo) else None
+
# === EventHandler 特定查询方法 ===
def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]:
@@ -529,17 +590,21 @@ class ComponentRegistry:
"""获取注册中心统计信息"""
action_components: int = 0
command_components: int = 0
+ tool_components: int = 0
events_handlers: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
elif component.component_type == ComponentType.COMMAND:
command_components += 1
+ elif component.component_type == ComponentType.TOOL:
+ tool_components += 1
elif component.component_type == ComponentType.EVENT_HANDLER:
events_handlers += 1
return {
"action_components": action_components,
"command_components": command_components,
+ "tool_components": tool_components,
"event_handlers": events_handlers,
"total_components": len(self._components),
"total_plugins": len(self._plugins),
diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py
index 9f7052f5..bb6f06b4 100644
--- a/src/plugin_system/core/global_announcement_manager.py
+++ b/src/plugin_system/core/global_announcement_manager.py
@@ -13,6 +13,8 @@ class GlobalAnnouncementManager:
self._user_disabled_commands: Dict[str, List[str]] = {}
# 用户禁用的事件处理器,chat_id -> [handler_name]
self._user_disabled_event_handlers: Dict[str, List[str]] = {}
+ # 用户禁用的工具,chat_id -> [tool_name]
+ self._user_disabled_tools: Dict[str, List[str]] = {}
def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool:
"""禁用特定聊天的某个动作"""
@@ -77,6 +79,27 @@ class GlobalAnnouncementManager:
return False
return False
+ def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
+ """禁用特定聊天的某个工具"""
+ if chat_id not in self._user_disabled_tools:
+ self._user_disabled_tools[chat_id] = []
+ if tool_name in self._user_disabled_tools[chat_id]:
+ logger.warning(f"工具 {tool_name} 已经被禁用")
+ return False
+ self._user_disabled_tools[chat_id].append(tool_name)
+ return True
+
+ def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool:
+ """启用特定聊天的某个工具"""
+ if chat_id in self._user_disabled_tools:
+ try:
+ self._user_disabled_tools[chat_id].remove(tool_name)
+ return True
+ except ValueError:
+ logger.warning(f"工具 {tool_name} 不在禁用列表中")
+ return False
+ return False
+
def get_disabled_chat_actions(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有动作"""
return self._user_disabled_actions.get(chat_id, []).copy()
@@ -88,6 +111,10 @@ class GlobalAnnouncementManager:
def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]:
"""获取特定聊天禁用的所有事件处理器"""
return self._user_disabled_event_handlers.get(chat_id, []).copy()
+
+ def get_disabled_chat_tools(self, chat_id: str) -> List[str]:
+ """获取特定聊天禁用的所有工具"""
+ return self._user_disabled_tools.get(chat_id, []).copy()
global_announcement_manager = GlobalAnnouncementManager()
diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py
index dfafda18..014b7a0c 100644
--- a/src/plugin_system/core/plugin_manager.py
+++ b/src/plugin_system/core/plugin_manager.py
@@ -224,6 +224,18 @@ class PluginManager:
list: 已注册的插件类名称列表。
"""
return list(self.plugin_classes.keys())
+
+ def get_plugin_path(self, plugin_name: str) -> Optional[str]:
+ """
+ 获取指定插件的路径。
+
+ Args:
+ plugin_name: 插件名称
+
+ Returns:
+ Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。
+ """
+ return self.plugin_paths.get(plugin_name)
# === 私有方法 ===
# == 目录管理 ==
@@ -346,6 +358,7 @@ class PluginManager:
stats = component_registry.get_registry_stats()
action_count = stats.get("action_components", 0)
command_count = stats.get("command_components", 0)
+ tool_count = stats.get("tool_components", 0)
event_handler_count = stats.get("event_handlers", 0)
total_components = stats.get("total_components", 0)
@@ -353,7 +366,7 @@ class PluginManager:
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
- f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})"
+ f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})"
)
# 显示详细的插件列表
@@ -388,6 +401,9 @@ class PluginManager:
command_components = [
c for c in plugin_info.components if c.component_type == ComponentType.COMMAND
]
+ tool_components = [
+ c for c in plugin_info.components if c.component_type == ComponentType.TOOL
+ ]
event_handler_components = [
c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER
]
@@ -399,7 +415,9 @@ class PluginManager:
if command_components:
command_names = [c.name for c in command_components]
logger.info(f" ⚡ Command组件: {', '.join(command_names)}")
-
+ if tool_components:
+ tool_names = [c.name for c in tool_components]
+ logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}")
if event_handler_components:
event_handler_names = [c.name for c in event_handler_components]
logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}")
diff --git a/src/tools/tool_executor.py b/src/plugin_system/core/tool_use.py
similarity index 76%
rename from src/tools/tool_executor.py
rename to src/plugin_system/core/tool_use.py
index 0f50ca2a..7a5eee31 100644
--- a/src/tools/tool_executor.py
+++ b/src/plugin_system/core/tool_use.py
@@ -1,14 +1,16 @@
-from src.llm_models.utils_model import LLMRequest
-from src.config.config import global_config
import time
-from src.common.logger import get_logger
+from typing import List, Dict, Tuple, Optional, Any
+from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
+from src.plugin_system.base.base_tool import BaseTool
+from src.plugin_system.core.global_announcement_manager import global_announcement_manager
+from src.llm_models.utils_model import LLMRequest
+from src.llm_models.payload_content import ToolCall
+from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
-from src.tools.tool_use import ToolUser
-from src.chat.utils.json_utils import process_llm_tool_calls
-from typing import List, Dict, Tuple, Optional
from src.chat.message_receive.chat_stream import get_chat_manager
+from src.common.logger import get_logger
-logger = get_logger("tool_executor")
+logger = get_logger("tool_use")
def init_tool_executor_prompt():
@@ -28,6 +30,10 @@ If you need to use a tool, please directly call the corresponding tool function.
Prompt(tool_executor_prompt, "tool_executor_prompt")
+# 初始化提示词
+init_tool_executor_prompt()
+
+
class ToolExecutor:
"""独立的工具执行器组件
@@ -46,13 +52,7 @@ class ToolExecutor:
self.chat_stream = get_chat_manager().get_stream(self.chat_id)
self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]"
- self.llm_model = LLMRequest(
- model=global_config.model.tool_use,
- request_type="tool_executor",
- )
-
- # 初始化工具实例
- self.tool_instance = ToolUser()
+ self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor")
# 缓存配置
self.enable_cache = enable_cache
@@ -63,7 +63,7 @@ class ToolExecutor:
async def execute_from_chat_message(
self, target_message: str, chat_history: str, sender: str, return_details: bool = False
- ) -> Tuple[List[Dict], List[str], str]:
+ ) -> Tuple[List[Dict[str, Any]], List[str], str]:
"""从聊天消息执行工具
Args:
@@ -73,7 +73,7 @@ class ToolExecutor:
return_details: 是否返回详细信息(使用的工具列表和提示词)
Returns:
- 如果return_details为False: List[Dict] - 工具执行结果列表
+ 如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空)
如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词)
"""
@@ -82,15 +82,15 @@ class ToolExecutor:
if cached_result := self._get_from_cache(cache_key):
logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行")
if not return_details:
- return cached_result, [], "使用缓存结果"
+ return cached_result, [], ""
# 从缓存结果中提取工具名称
used_tools = [result.get("tool_name", "unknown") for result in cached_result]
- return cached_result, used_tools, "使用缓存结果"
+ return cached_result, used_tools, ""
# 缓存未命中,执行工具调用
# 获取可用工具
- tools = self.tool_instance._define_tools()
+ tools = self._get_tool_definitions()
# 获取当前时间
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
@@ -110,17 +110,12 @@ class ToolExecutor:
logger.debug(f"{self.log_prefix}开始LLM工具调用分析")
# 调用LLM进行工具决策
- response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools)
-
- # 解析LLM响应
- if len(other_info) == 3:
- reasoning_content, model_name, tool_calls = other_info
- else:
- reasoning_content, model_name = other_info
- tool_calls = None
+ response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
+ prompt=prompt, tools=tools
+ )
# 执行工具调用
- tool_results, used_tools = await self._execute_tool_calls(tool_calls)
+ tool_results, used_tools = await self.execute_tool_calls(tool_calls)
# 缓存结果
if tool_results:
@@ -134,7 +129,12 @@ class ToolExecutor:
else:
return tool_results, [], ""
- async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]:
+ def _get_tool_definitions(self) -> List[Dict[str, Any]]:
+ all_tools = get_llm_available_tool_definitions()
+ user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id)
+ return [definition for name, definition in all_tools if name not in user_disabled_tools]
+
+ async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]:
"""执行工具调用
Args:
@@ -143,36 +143,23 @@ class ToolExecutor:
Returns:
Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表)
"""
- tool_results = []
+ tool_results: List[Dict[str, Any]] = []
used_tools = []
if not tool_calls:
logger.debug(f"{self.log_prefix}无需执行工具")
- return tool_results, used_tools
+ return [], []
logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}")
- # 处理工具调用
- success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls)
-
- if not success:
- logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}")
- return tool_results, used_tools
-
- if not valid_tool_calls:
- logger.debug(f"{self.log_prefix}无有效工具调用")
- return tool_results, used_tools
-
# 执行每个工具调用
- for tool_call in valid_tool_calls:
+ for tool_call in tool_calls:
try:
- tool_name = tool_call.get("name", "unknown_tool")
- used_tools.append(tool_name)
-
+ tool_name = tool_call.func_name
logger.debug(f"{self.log_prefix}执行工具: {tool_name}")
# 执行工具
- result = await self.tool_instance.execute_tool_call(tool_call)
+ result = await self.execute_tool_call(tool_call)
if result:
tool_info = {
@@ -182,15 +169,15 @@ class ToolExecutor:
"tool_name": tool_name,
"timestamp": time.time(),
}
- tool_results.append(tool_info)
-
- logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
content = tool_info["content"]
if not isinstance(content, (str, list, tuple)):
- content = str(content)
+ tool_info["content"] = str(content)
+
+ tool_results.append(tool_info)
+ used_tools.append(tool_name)
+ logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}")
preview = content[:200]
logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...")
-
except Exception as e:
logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}")
# 添加错误信息到结果中
@@ -205,6 +192,42 @@ class ToolExecutor:
return tool_results, used_tools
+ async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]:
+ # sourcery skip: use-assigned-variable
+ """执行单个工具调用
+
+ Args:
+ tool_call: 工具调用对象
+
+ Returns:
+ Optional[Dict]: 工具调用结果,如果失败则返回None
+ """
+ try:
+ function_name = tool_call.func_name
+ function_args = tool_call.args or {}
+ function_args["llm_called"] = True # 标记为LLM调用
+
+ # 获取对应工具实例
+ tool_instance = tool_instance or get_tool_instance(function_name)
+ if not tool_instance:
+ logger.warning(f"未知工具名称: {function_name}")
+ return None
+
+ # 执行工具
+ result = await tool_instance.execute(function_args)
+ if result:
+ return {
+ "tool_call_id": tool_call.call_id,
+ "role": "tool",
+ "name": function_name,
+ "type": "function",
+ "content": result["content"],
+ }
+ return None
+ except Exception as e:
+ logger.error(f"执行工具调用时发生错误: {str(e)}")
+ raise e
+
def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str:
"""生成缓存键
@@ -272,18 +295,7 @@ class ToolExecutor:
if expired_keys:
logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存")
- def get_available_tools(self) -> List[str]:
- """获取可用工具列表
-
- Returns:
- List[str]: 可用工具名称列表
- """
- tools = self.tool_instance._define_tools()
- return [tool.get("function", {}).get("name", "unknown") for tool in tools]
-
- async def execute_specific_tool(
- self, tool_name: str, tool_args: Dict, validate_args: bool = True
- ) -> Optional[Dict]:
+ async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]:
"""直接执行指定工具
Args:
@@ -295,11 +307,15 @@ class ToolExecutor:
Optional[Dict]: 工具执行结果,失败时返回None
"""
try:
- tool_call = {"name": tool_name, "arguments": tool_args}
+ tool_call = ToolCall(
+ call_id=f"direct_tool_{time.time()}",
+ func_name=tool_name,
+ args=tool_args,
+ )
logger.info(f"{self.log_prefix}直接执行工具: {tool_name}")
- result = await self.tool_instance.execute_tool_call(tool_call)
+ result = await self.execute_tool_call(tool_call)
if result:
tool_info = {
@@ -366,12 +382,8 @@ class ToolExecutor:
logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}")
-# 初始化提示词
-init_tool_executor_prompt()
-
-
"""
-使用示例:
+ToolExecutor使用示例:
# 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3)
executor = ToolExecutor(executor_id="my_executor")
@@ -394,13 +406,12 @@ results, used_tools, prompt = await executor.execute_from_chat_message(
)
# 5. 直接执行特定工具
-result = await executor.execute_specific_tool(
+result = await executor.execute_specific_tool_simple(
tool_name="get_knowledge",
tool_args={"query": "机器学习"}
)
# 6. 缓存管理
-available_tools = executor.get_available_tools()
cache_status = executor.get_cache_status() # 查看缓存状态
executor.clear_cache() # 清空缓存
executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置
diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py
index fa922dc1..790f2096 100644
--- a/src/plugins/built_in/core_actions/emoji.py
+++ b/src/plugins/built_in/core_actions/emoji.py
@@ -58,6 +58,7 @@ class EmojiAction(BaseAction):
associated_types = ["emoji"]
async def execute(self) -> Tuple[bool, str]:
+ # sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression
"""执行表情动作"""
logger.info(f"{self.log_prefix} 决定发送表情")
diff --git a/src/tools/not_using/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py
similarity index 83%
rename from src/tools/not_using/lpmm_get_knowledge.py
rename to src/plugins/built_in/knowledge/lpmm_get_knowledge.py
index 467db6ed..fd3d811b 100644
--- a/src/tools/not_using/lpmm_get_knowledge.py
+++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py
@@ -1,10 +1,9 @@
-from src.tools.tool_can_use.base_tool import BaseTool
-
-# from src.common.database import db
-from src.common.logger import get_logger
from typing import Dict, Any
-from src.chat.knowledge.knowledge_lib import qa_manager
+from src.common.logger import get_logger
+from src.config.config import global_config
+from src.chat.knowledge.knowledge_lib import qa_manager
+from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("lpmm_get_knowledge_tool")
@@ -14,14 +13,11 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
name = "lpmm_search_knowledge"
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
- parameters = {
- "type": "object",
- "properties": {
- "query": {"type": "string", "description": "搜索查询关键词"},
- "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
- },
- "required": ["query"],
- }
+ parameters = [
+ ("query", ToolParamType.STRING, "搜索查询关键词", True, None),
+ ("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None),
+ ]
+ available_for_llm = global_config.lpmm_knowledge.enable
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
"""执行知识库搜索
diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py
index de846dd5..c2489a38 100644
--- a/src/plugins/built_in/plugin_management/plugin.py
+++ b/src/plugins/built_in/plugin_management/plugin.py
@@ -11,6 +11,7 @@ from src.plugin_system import (
component_manage_api,
ComponentInfo,
ComponentType,
+ send_api,
)
@@ -27,8 +28,15 @@ class ManagementCommand(BaseCommand):
or not self.message.message_info.user_info
or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore
):
- await self.send_text("你没有权限使用插件管理命令")
+ await self._send_message("你没有权限使用插件管理命令")
return False, "没有权限", True
+ if not self.message.chat_stream:
+ await self._send_message("无法获取聊天流信息")
+ return False, "无法获取聊天流信息", True
+ self.stream_id = self.message.chat_stream.stream_id
+ if not self.stream_id:
+ await self._send_message("无法获取聊天流信息")
+ return False, "无法获取聊天流信息", True
command_list = self.matched_groups["manage_command"].strip().split(" ")
if len(command_list) == 1:
await self.show_help("all")
@@ -42,7 +50,7 @@ class ManagementCommand(BaseCommand):
case "help":
await self.show_help("all")
case _:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 3:
if command_list[1] == "plugin":
@@ -56,7 +64,7 @@ class ManagementCommand(BaseCommand):
case "rescan":
await self._rescan_plugin_dirs()
case _:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] == "list":
@@ -64,10 +72,10 @@ class ManagementCommand(BaseCommand):
elif command_list[2] == "help":
await self.show_help("component")
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 4:
if command_list[1] == "plugin":
@@ -81,28 +89,28 @@ class ManagementCommand(BaseCommand):
case "add_dir":
await self._add_dir(command_list[3])
case _:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[1] == "component":
if command_list[2] != "list":
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components()
elif command_list[3] == "disabled":
await self._list_disabled_components()
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 5:
if command_list[1] != "component":
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] != "list":
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[3] == "enabled":
await self._list_enabled_components(target_type=command_list[4])
@@ -111,11 +119,11 @@ class ManagementCommand(BaseCommand):
elif command_list[3] == "type":
await self._list_registered_components_by_type(command_list[4])
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if len(command_list) == 6:
if command_list[1] != "component":
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
if command_list[2] == "enable":
if command_list[3] == "global":
@@ -123,7 +131,7 @@ class ManagementCommand(BaseCommand):
elif command_list[3] == "local":
await self._locally_enable_component(command_list[4], command_list[5])
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
elif command_list[2] == "disable":
if command_list[3] == "global":
@@ -131,10 +139,10 @@ class ManagementCommand(BaseCommand):
elif command_list[3] == "local":
await self._locally_disable_component(command_list[4], command_list[5])
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
else:
- await self.send_text("插件管理命令不合法")
+ await self._send_message("插件管理命令不合法")
return False, "命令不合法", True
return True, "命令执行完成", True
@@ -180,51 +188,51 @@ class ManagementCommand(BaseCommand):
)
case _:
return
- await self.send_text(help_msg)
+ await self._send_message(help_msg)
async def _list_loaded_plugins(self):
plugins = plugin_manage_api.list_loaded_plugins()
- await self.send_text(f"已加载的插件: {', '.join(plugins)}")
+ await self._send_message(f"已加载的插件: {', '.join(plugins)}")
async def _list_registered_plugins(self):
plugins = plugin_manage_api.list_registered_plugins()
- await self.send_text(f"已注册的插件: {', '.join(plugins)}")
+ await self._send_message(f"已注册的插件: {', '.join(plugins)}")
async def _rescan_plugin_dirs(self):
plugin_manage_api.rescan_plugin_directory()
- await self.send_text("插件目录重新扫描执行中")
+ await self._send_message("插件目录重新扫描执行中")
async def _load_plugin(self, plugin_name: str):
success, count = plugin_manage_api.load_plugin(plugin_name)
if success:
- await self.send_text(f"插件加载成功: {plugin_name}")
+ await self._send_message(f"插件加载成功: {plugin_name}")
else:
if count == 0:
- await self.send_text(f"插件{plugin_name}为禁用状态")
- await self.send_text(f"插件加载失败: {plugin_name}")
+ await self._send_message(f"插件{plugin_name}为禁用状态")
+ await self._send_message(f"插件加载失败: {plugin_name}")
async def _unload_plugin(self, plugin_name: str):
success = await plugin_manage_api.remove_plugin(plugin_name)
if success:
- await self.send_text(f"插件卸载成功: {plugin_name}")
+ await self._send_message(f"插件卸载成功: {plugin_name}")
else:
- await self.send_text(f"插件卸载失败: {plugin_name}")
+ await self._send_message(f"插件卸载失败: {plugin_name}")
async def _reload_plugin(self, plugin_name: str):
success = await plugin_manage_api.reload_plugin(plugin_name)
if success:
- await self.send_text(f"插件重新加载成功: {plugin_name}")
+ await self._send_message(f"插件重新加载成功: {plugin_name}")
else:
- await self.send_text(f"插件重新加载失败: {plugin_name}")
+ await self._send_message(f"插件重新加载失败: {plugin_name}")
async def _add_dir(self, dir_path: str):
- await self.send_text(f"正在添加插件目录: {dir_path}")
+ await self._send_message(f"正在添加插件目录: {dir_path}")
success = plugin_manage_api.add_plugin_directory(dir_path)
await asyncio.sleep(0.5) # 防止乱序发送
if success:
- await self.send_text(f"插件目录添加成功: {dir_path}")
+ await self._send_message(f"插件目录添加成功: {dir_path}")
else:
- await self.send_text(f"插件目录添加失败: {dir_path}")
+ await self._send_message(f"插件目录添加失败: {dir_path}")
def _fetch_all_registered_components(self) -> List[ComponentInfo]:
all_plugin_info = component_manage_api.get_all_plugin_info()
@@ -255,29 +263,29 @@ class ManagementCommand(BaseCommand):
async def _list_all_registered_components(self):
components_info = self._fetch_all_registered_components()
if not components_info:
- await self.send_text("没有注册的组件")
+ await self._send_message("没有注册的组件")
return
all_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in components_info
)
- await self.send_text(f"已注册的组件: {all_components_str}")
+ await self._send_message(f"已注册的组件: {all_components_str}")
async def _list_enabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
- await self.send_text("没有注册的组件")
+ await self._send_message("没有注册的组件")
return
if target_type == "global":
enabled_components = [component for component in components_info if component.enabled]
if not enabled_components:
- await self.send_text("没有满足条件的已启用全局组件")
+ await self._send_message("没有满足条件的已启用全局组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
- await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}")
+ await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
enabled_components = [
@@ -286,28 +294,28 @@ class ManagementCommand(BaseCommand):
if (component.name not in locally_disabled_components and component.enabled)
]
if not enabled_components:
- await self.send_text("本聊天没有满足条件的已启用组件")
+ await self._send_message("本聊天没有满足条件的已启用组件")
return
enabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in enabled_components
)
- await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}")
+ await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}")
async def _list_disabled_components(self, target_type: str = "global"):
components_info = self._fetch_all_registered_components()
if not components_info:
- await self.send_text("没有注册的组件")
+ await self._send_message("没有注册的组件")
return
if target_type == "global":
disabled_components = [component for component in components_info if not component.enabled]
if not disabled_components:
- await self.send_text("没有满足条件的已禁用全局组件")
+ await self._send_message("没有满足条件的已禁用全局组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
- await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}")
+ await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}")
elif target_type == "local":
locally_disabled_components = self._fetch_locally_disabled_components()
disabled_components = [
@@ -316,12 +324,12 @@ class ManagementCommand(BaseCommand):
if (component.name in locally_disabled_components or not component.enabled)
]
if not disabled_components:
- await self.send_text("本聊天没有满足条件的已禁用组件")
+ await self._send_message("本聊天没有满足条件的已禁用组件")
return
disabled_components_str = ", ".join(
f"{component.name} ({component.component_type})" for component in disabled_components
)
- await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
+ await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}")
async def _list_registered_components_by_type(self, target_type: str):
match target_type:
@@ -332,18 +340,18 @@ class ManagementCommand(BaseCommand):
case "event_handler":
component_type = ComponentType.EVENT_HANDLER
case _:
- await self.send_text(f"未知组件类型: {target_type}")
+ await self._send_message(f"未知组件类型: {target_type}")
return
components_info = component_manage_api.get_components_info_by_type(component_type)
if not components_info:
- await self.send_text(f"没有注册的 {target_type} 组件")
+ await self._send_message(f"没有注册的 {target_type} 组件")
return
components_str = ", ".join(
f"{name} ({component.component_type})" for name, component in components_info.items()
)
- await self.send_text(f"注册的 {target_type} 组件: {components_str}")
+ await self._send_message(f"注册的 {target_type} 组件: {components_str}")
async def _globally_enable_component(self, component_name: str, component_type: str):
match component_type:
@@ -354,12 +362,12 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
- await self.send_text(f"未知组件类型: {component_type}")
+ await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.globally_enable_component(component_name, target_component_type):
- await self.send_text(f"全局启用组件成功: {component_name}")
+ await self._send_message(f"全局启用组件成功: {component_name}")
else:
- await self.send_text(f"全局启用组件失败: {component_name}")
+ await self._send_message(f"全局启用组件失败: {component_name}")
async def _globally_disable_component(self, component_name: str, component_type: str):
match component_type:
@@ -370,13 +378,13 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
- await self.send_text(f"未知组件类型: {component_type}")
+ await self._send_message(f"未知组件类型: {component_type}")
return
success = await component_manage_api.globally_disable_component(component_name, target_component_type)
if success:
- await self.send_text(f"全局禁用组件成功: {component_name}")
+ await self._send_message(f"全局禁用组件成功: {component_name}")
else:
- await self.send_text(f"全局禁用组件失败: {component_name}")
+ await self._send_message(f"全局禁用组件失败: {component_name}")
async def _locally_enable_component(self, component_name: str, component_type: str):
match component_type:
@@ -387,16 +395,16 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
- await self.send_text(f"未知组件类型: {component_type}")
+ await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_enable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
- await self.send_text(f"本地启用组件成功: {component_name}")
+ await self._send_message(f"本地启用组件成功: {component_name}")
else:
- await self.send_text(f"本地启用组件失败: {component_name}")
+ await self._send_message(f"本地启用组件失败: {component_name}")
async def _locally_disable_component(self, component_name: str, component_type: str):
match component_type:
@@ -407,16 +415,19 @@ class ManagementCommand(BaseCommand):
case "event_handler":
target_component_type = ComponentType.EVENT_HANDLER
case _:
- await self.send_text(f"未知组件类型: {component_type}")
+ await self._send_message(f"未知组件类型: {component_type}")
return
if component_manage_api.locally_disable_component(
component_name,
target_component_type,
self.message.chat_stream.stream_id,
):
- await self.send_text(f"本地禁用组件成功: {component_name}")
+ await self._send_message(f"本地禁用组件成功: {component_name}")
else:
- await self.send_text(f"本地禁用组件失败: {component_name}")
+ await self._send_message(f"本地禁用组件失败: {component_name}")
+
+ async def _send_message(self, message: str):
+ await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False)
@register_plugin
@@ -430,7 +441,9 @@ class PluginManagementPlugin(BasePlugin):
"plugin": {
"enabled": ConfigField(bool, default=False, description="是否启用插件"),
"config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"),
- "permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID"),
+ "permission": ConfigField(
+ list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID"
+ ),
},
}
diff --git a/src/tools/not_using/get_knowledge.py b/src/tools/not_using/get_knowledge.py
deleted file mode 100644
index c436d774..00000000
--- a/src/tools/not_using/get_knowledge.py
+++ /dev/null
@@ -1,133 +0,0 @@
-from src.tools.tool_can_use.base_tool import BaseTool
-from src.chat.utils.utils import get_embedding
-from src.common.database.database_model import Knowledges # Updated import
-from src.common.logger import get_logger
-from typing import Any, Union, List # Added List
-import json # Added for parsing embedding
-import math # Added for cosine similarity
-
-logger = get_logger("get_knowledge_tool")
-
-
-class SearchKnowledgeTool(BaseTool):
- """从知识库中搜索相关信息的工具"""
-
- name = "search_knowledge"
- description = "使用工具从知识库中搜索相关信息"
- parameters = {
- "type": "object",
- "properties": {
- "query": {"type": "string", "description": "搜索查询关键词"},
- "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"},
- },
- "required": ["query"],
- }
-
- async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
- """执行知识库搜索
-
- Args:
- function_args: 工具参数
-
- Returns:
- dict: 工具执行结果
- """
- query = "" # Initialize query to ensure it's defined in except block
- try:
- query = function_args.get("query")
- threshold = function_args.get("threshold", 0.4)
-
- # 调用知识库搜索
- embedding = await get_embedding(query, request_type="info_retrieval")
- if embedding:
- knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
- if knowledge_info:
- content = f"你知道这些知识: {knowledge_info}"
- else:
- content = f"你不太了解有关{query}的知识"
- return {"type": "knowledge", "id": query, "content": content}
- return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"}
- except Exception as e:
- logger.error(f"知识库搜索工具执行失败: {str(e)}")
- return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
-
- @staticmethod
- def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
- """计算两个向量之间的余弦相似度"""
- dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False))
- magnitude1 = math.sqrt(sum(p * p for p in vec1))
- magnitude2 = math.sqrt(sum(q * q for q in vec2))
- if magnitude1 == 0 or magnitude2 == 0:
- return 0.0
- return dot_product / (magnitude1 * magnitude2)
-
- @staticmethod
- def get_info_from_db(
- query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
- ) -> Union[str, list]:
- """从数据库中获取相关信息
-
- Args:
- query_embedding: 查询的嵌入向量
- limit: 最大返回结果数
- threshold: 相似度阈值
- return_raw: 是否返回原始结果
-
- Returns:
- Union[str, list]: 格式化的信息字符串或原始结果列表
- """
- if not query_embedding:
- return "" if not return_raw else []
-
- similar_items = []
- try:
- all_knowledges = Knowledges.select()
- for item in all_knowledges:
- try:
- item_embedding_str = item.embedding
- if not item_embedding_str:
- logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
- continue
- item_embedding = json.loads(item_embedding_str)
- if not isinstance(item_embedding, list) or not all(
- isinstance(x, (int, float)) for x in item_embedding
- ):
- logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
- continue
- except json.JSONDecodeError:
- logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
- continue
- except AttributeError:
- logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
- continue
-
- similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
-
- if similarity >= threshold:
- similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
-
- # 按相似度降序排序
- similar_items.sort(key=lambda x: x["similarity"], reverse=True)
-
- # 应用限制
- results = similar_items[:limit]
- logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
-
- except Exception as e:
- logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
- return "" if not return_raw else []
-
- if not results:
- return "" if not return_raw else []
-
- if return_raw:
- # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理
- # 这里返回包含内容和相似度的字典列表
- return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
- else:
- # 返回所有找到的内容,用换行分隔
- return "\n".join(str(result["content"]) for result in results)
-
-
-# 注册工具
-# register_tool(SearchKnowledgeTool)
diff --git a/src/tools/tool_can_use/__init__.py b/src/tools/tool_can_use/__init__.py
deleted file mode 100644
index 14bae04c..00000000
--- a/src/tools/tool_can_use/__init__.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from src.tools.tool_can_use.base_tool import (
- BaseTool,
- register_tool,
- discover_tools,
- get_all_tool_definitions,
- get_tool_instance,
- TOOL_REGISTRY,
-)
-
-__all__ = [
- "BaseTool",
- "register_tool",
- "discover_tools",
- "get_all_tool_definitions",
- "get_tool_instance",
- "TOOL_REGISTRY",
-]
-
-# 自动发现并注册工具
-discover_tools()
diff --git a/src/tools/tool_can_use/base_tool.py b/src/tools/tool_can_use/base_tool.py
deleted file mode 100644
index 89d051dc..00000000
--- a/src/tools/tool_can_use/base_tool.py
+++ /dev/null
@@ -1,115 +0,0 @@
-from typing import List, Any, Optional, Type
-import inspect
-import importlib
-import pkgutil
-import os
-from src.common.logger import get_logger
-from rich.traceback import install
-
-install(extra_lines=3)
-
-logger = get_logger("base_tool")
-
-# 工具注册表
-TOOL_REGISTRY = {}
-
-
-class BaseTool:
- """所有工具的基类"""
-
- # 工具名称,子类必须重写
- name = None
- # 工具描述,子类必须重写
- description = None
- # 工具参数定义,子类必须重写
- parameters = None
-
- @classmethod
- def get_tool_definition(cls) -> dict[str, Any]:
- """获取工具定义,用于LLM工具调用
-
- Returns:
- dict: 工具定义字典
- """
- if not cls.name or not cls.description or not cls.parameters:
- raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性")
-
- return {
- "type": "function",
- "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters},
- }
-
- async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
- """执行工具函数
-
- Args:
- function_args: 工具调用参数
-
- Returns:
- dict: 工具执行结果
- """
- raise NotImplementedError("子类必须实现execute方法")
-
-
-def register_tool(tool_class: Type[BaseTool]):
- """注册工具到全局注册表
-
- Args:
- tool_class: 工具类
- """
- if not issubclass(tool_class, BaseTool):
- raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类")
-
- tool_name = tool_class.name
- if not tool_name:
- raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性")
-
- TOOL_REGISTRY[tool_name] = tool_class
- logger.info(f"已注册: {tool_name}")
-
-
-def discover_tools():
- """自动发现并注册tool_can_use目录下的所有工具"""
- # 获取当前目录路径
- current_dir = os.path.dirname(os.path.abspath(__file__))
- package_name = os.path.basename(current_dir)
-
- # 遍历包中的所有模块
- for _, module_name, _ in pkgutil.iter_modules([current_dir]):
- # 跳过当前模块和__pycache__
- if module_name == "base_tool" or module_name.startswith("__"):
- continue
-
- # 导入模块
- module = importlib.import_module(f"src.tools.{package_name}.{module_name}")
-
- # 查找模块中的工具类
- for _, obj in inspect.getmembers(module):
- if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool:
- register_tool(obj)
-
- logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具")
-
-
-def get_all_tool_definitions() -> List[dict[str, Any]]:
- """获取所有已注册工具的定义
-
- Returns:
- List[dict]: 工具定义列表
- """
- return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()]
-
-
-def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
- """获取指定名称的工具实例
-
- Args:
- tool_name: 工具名称
-
- Returns:
- Optional[BaseTool]: 工具实例,如果找不到则返回None
- """
- tool_class = TOOL_REGISTRY.get(tool_name)
- if not tool_class:
- return None
- return tool_class()
diff --git a/src/tools/tool_can_use/compare_numbers_tool.py b/src/tools/tool_can_use/compare_numbers_tool.py
deleted file mode 100644
index 236a4587..00000000
--- a/src/tools/tool_can_use/compare_numbers_tool.py
+++ /dev/null
@@ -1,45 +0,0 @@
-from src.tools.tool_can_use.base_tool import BaseTool
-from src.common.logger import get_logger
-from typing import Any
-
-logger = get_logger("compare_numbers_tool")
-
-
-class CompareNumbersTool(BaseTool):
- """比较两个数大小的工具"""
-
- name = "compare_numbers"
- description = "使用工具 比较两个数的大小,返回较大的数"
- parameters = {
- "type": "object",
- "properties": {
- "num1": {"type": "number", "description": "第一个数字"},
- "num2": {"type": "number", "description": "第二个数字"},
- },
- "required": ["num1", "num2"],
- }
-
- async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
- """执行比较两个数的大小
-
- Args:
- function_args: 工具参数
-
- Returns:
- dict: 工具执行结果
- """
- num1: int | float = function_args.get("num1") # type: ignore
- num2: int | float = function_args.get("num2") # type: ignore
-
- try:
- if num1 > num2:
- result = f"{num1} 大于 {num2}"
- elif num1 < num2:
- result = f"{num1} 小于 {num2}"
- else:
- result = f"{num1} 等于 {num2}"
-
- return {"name": self.name, "content": result}
- except Exception as e:
- logger.error(f"比较数字失败: {str(e)}")
- return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
diff --git a/src/tools/tool_can_use/rename_person_tool.py b/src/tools/tool_can_use/rename_person_tool.py
deleted file mode 100644
index 17e62468..00000000
--- a/src/tools/tool_can_use/rename_person_tool.py
+++ /dev/null
@@ -1,103 +0,0 @@
-from src.tools.tool_can_use.base_tool import BaseTool
-from src.person_info.person_info import get_person_info_manager
-from src.common.logger import get_logger
-
-
-logger = get_logger("rename_person_tool")
-
-
-class RenamePersonTool(BaseTool):
- name = "rename_person"
- description = (
- "这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。"
- )
- parameters = {
- "type": "object",
- "properties": {
- "person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"},
- "message_content": {
- "type": "string",
- "description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。",
- },
- },
- "required": ["person_name"],
- }
-
- async def execute(self, function_args: dict):
- """
- 执行取名工具逻辑
-
- Args:
- function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典
- message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确)
-
- Returns:
- dict: 包含执行结果的字典
- """
- person_name_to_find = function_args.get("person_name")
- request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串
-
- if not person_name_to_find:
- return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"}
- person_info_manager = get_person_info_manager()
- try:
- # 1. 根据昵称查找用户信息
- logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...")
- person_info = await person_info_manager.get_person_info_by_name(person_name_to_find)
-
- if not person_info:
- logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。")
- return {
- "name": self.name,
- "content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。",
- }
-
- person_id = person_info.get("person_id")
- user_nickname = person_info.get("nickname") # 这是用户原始昵称
- user_cardname = person_info.get("user_cardname")
- user_avatar = person_info.get("user_avatar")
-
- if not person_id:
- logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id")
- return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"}
-
- # 2. 调用 qv_person_name 进行取名
- logger.debug(
- f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'"
- )
- result = await person_info_manager.qv_person_name(
- person_id=person_id,
- user_nickname=user_nickname, # type: ignore
- user_cardname=user_cardname, # type: ignore
- user_avatar=user_avatar, # type: ignore
- request=request_context,
- )
-
- # 3. 处理结果
- if result and result.get("nickname"):
- new_name = result["nickname"]
- # reason = result.get("reason", "未提供理由")
- logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}")
-
- content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}"
- logger.info(content)
- return {"name": self.name, "content": content}
- else:
- logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。")
- # 尝试从内存中获取可能已经更新的名字
- current_name = await person_info_manager.get_value(person_id, "person_name")
- if current_name and current_name != person_name_to_find:
- return {
- "name": self.name,
- "content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。",
- }
- else:
- return {
- "name": self.name,
- "content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。",
- }
-
- except Exception as e:
- error_msg = f"重命名失败: {str(e)}"
- logger.error(error_msg, exc_info=True)
- return {"name": self.name, "content": error_msg}
diff --git a/src/tools/tool_use.py b/src/tools/tool_use.py
deleted file mode 100644
index 6a8cd48a..00000000
--- a/src/tools/tool_use.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import json
-from src.common.logger import get_logger
-from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance
-
-logger = get_logger("tool_use")
-
-
-class ToolUser:
- @staticmethod
- def _define_tools():
- """获取所有已注册工具的定义
-
- Returns:
- list: 工具定义列表
- """
- return get_all_tool_definitions()
-
- @staticmethod
- async def execute_tool_call(tool_call):
- # sourcery skip: use-assigned-variable
- """执行特定的工具调用
-
- Args:
- tool_call: 工具调用对象
- message_txt: 原始消息文本
-
- Returns:
- dict: 工具调用结果
- """
- try:
- function_name = tool_call["function"]["name"]
- function_args = json.loads(tool_call["function"]["arguments"])
-
- # 获取对应工具实例
- tool_instance = get_tool_instance(function_name)
- if not tool_instance:
- logger.warning(f"未知工具名称: {function_name}")
- return None
-
- # 执行工具
- result = await tool_instance.execute(function_args)
- if result:
- # 直接使用 function_name 作为 tool_type
- tool_type = function_name
-
- return {
- "tool_call_id": tool_call["id"],
- "role": "tool",
- "name": function_name,
- "type": tool_type,
- "content": result["content"],
- }
- return None
- except Exception as e:
- logger.error(f"执行工具调用时发生错误: {str(e)}")
- return None
diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml
index 39857d66..fae41f82 100644
--- a/template/bot_config_template.toml
+++ b/template/bot_config_template.toml
@@ -1,5 +1,5 @@
[inner]
-version = "4.5.0"
+version = "6.0.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件,请在修改后将version的值进行变更
@@ -213,134 +213,9 @@ file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ER
suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"] # 完全屏蔽的库
library_log_levels = { "aiohttp" = "WARNING"} # 设置特定库的日志级别
-#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写
-
-# stream = : 用于指定模型是否是使用流式输出
-# pri_in = : 用于指定模型输入价格
-# pri_out = : 用于指定模型输出价格
-# temp = : 用于指定模型温度
-# enable_thinking = : 用于指定模型是否启用思考
-# thinking_budget = : 用于指定模型思考最长长度
-
[debug]
show_prompt = false # 是否显示prompt
-
-[model]
-model_max_output_length = 1024 # 模型单次返回的最大token数
-
-#------------必填:组件模型------------
-
-[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
-pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
-#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数
-temp = 0.2 #模型的温度,新V3建议0.1-0.3
-
-[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
-# 强烈建议使用免费的小模型
-name = "Qwen/Qwen3-8B"
-provider = "SILICONFLOW"
-pri_in = 0
-pri_out = 0
-temp = 0.7
-enable_thinking = false # 是否启用思考
-
-[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
-pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
-#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数
-temp = 0.2 #模型的温度,新V3建议0.1-0.3
-
-[model.replyer_2] # 次要回复模型
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2 #模型的输入价格(非必填,可以记录消耗)
-pri_out = 8 #模型的输出价格(非必填,可以记录消耗)
-#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数
-temp = 0.2 #模型的温度,新V3建议0.1-0.3
-
-[model.planner] #决策:负责决定麦麦该做什么的模型
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2
-pri_out = 8
-temp = 0.3
-
-[model.emotion] #负责麦麦的情绪变化
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2
-pri_out = 8
-temp = 0.3
-
-
-[model.memory] # 记忆模型
-name = "Qwen/Qwen3-30B-A3B"
-provider = "SILICONFLOW"
-pri_in = 0.7
-pri_out = 2.8
-temp = 0.7
-enable_thinking = false # 是否启用思考
-
-[model.vlm] # 图像识别模型
-name = "Pro/Qwen/Qwen2.5-VL-7B-Instruct"
-provider = "SILICONFLOW"
-pri_in = 0.35
-pri_out = 0.35
-
-[model.voice] # 语音识别模型
-name = "FunAudioLLM/SenseVoiceSmall"
-provider = "SILICONFLOW"
-pri_in = 0
-pri_out = 0
-
-[model.tool_use] #工具调用模型,需要使用支持工具调用的模型
-name = "Qwen/Qwen3-14B"
-provider = "SILICONFLOW"
-pri_in = 0.5
-pri_out = 2
-temp = 0.7
-enable_thinking = false # 是否启用思考(qwen3 only)
-
-#嵌入模型
-[model.embedding]
-name = "BAAI/bge-m3"
-provider = "SILICONFLOW"
-pri_in = 0
-pri_out = 0
-
-
-#------------LPMM知识库模型------------
-
-[model.lpmm_entity_extract] # 实体提取模型
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2
-pri_out = 8
-temp = 0.2
-
-
-[model.lpmm_rdf_build] # RDF构建模型
-name = "Pro/deepseek-ai/DeepSeek-V3"
-provider = "SILICONFLOW"
-pri_in = 2
-pri_out = 8
-temp = 0.2
-
-
-[model.lpmm_qa] # 问答模型
-name = "Qwen/Qwen3-30B-A3B"
-provider = "SILICONFLOW"
-pri_in = 0.7
-pri_out = 2.8
-temp = 0.7
-enable_thinking = false # 是否启用思考
-
[maim_message]
auth_token = [] # 认证令牌,用于API验证,为空则不启用验证
# 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器
@@ -356,8 +231,4 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效
enable = true
[experimental] #实验性功能
-enable_friend_chat = false # 是否启用好友聊天
-
-
-
-
+enable_friend_chat = false # 是否启用好友聊天
\ No newline at end of file
diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml
deleted file mode 100644
index 5bf24732..00000000
--- a/template/lpmm_config_template.toml
+++ /dev/null
@@ -1,60 +0,0 @@
-[lpmm]
-version = "0.1.0"
-
-# LLM API 服务提供商,可配置多个
-[[llm_providers]]
-name = "localhost"
-base_url = "http://127.0.0.1:8888/v1/"
-api_key = "lm_studio"
-
-[[llm_providers]]
-name = "siliconflow"
-base_url = "https://api.siliconflow.cn/v1/"
-api_key = ""
-
-[entity_extract.llm]
-# 设置用于实体提取的LLM模型
-provider = "siliconflow" # 服务提供商
-model = "deepseek-ai/DeepSeek-V3" # 模型名称
-
-[rdf_build.llm]
-# 设置用于RDF构建的LLM模型
-provider = "siliconflow" # 服务提供商
-model = "deepseek-ai/DeepSeek-V3" # 模型名称
-
-[embedding]
-# 设置用于文本嵌入的Embedding模型
-provider = "siliconflow" # 服务提供商
-model = "Pro/BAAI/bge-m3" # 模型名称
-dimension = 1024 # 嵌入维度
-
-[rag.params]
-# RAG参数配置
-synonym_search_top_k = 10 # 同义词搜索TopK
-synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词)
-
-[qa.llm]
-# 设置用于QA的LLM模型
-provider = "siliconflow" # 服务提供商
-model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称
-
-[info_extraction]
-workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5
-
-[qa.params]
-# QA参数配置
-relation_search_top_k = 10 # 关系搜索TopK
-relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系)
-paragraph_search_top_k = 1000 # 段落搜索TopK(不能过小,可能影响搜索结果)
-paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用)
-ent_filter_top_k = 10 # 实体过滤TopK
-ppr_damping = 0.8 # PPR阻尼系数
-res_top_k = 3 # 最终提供的文段TopK
-
-[persistence]
-# 持久化配置(存储中间数据,防止重复计算)
-data_root_path = "data" # 数据根目录
-imported_data_path = "data/imported_lpmm_data" # 转换为json的raw文件数据路径
-openie_data_path = "data/openie" # OpenIE数据路径
-embedding_data_dir = "data/embedding" # 嵌入数据目录
-rag_data_dir = "data/rag" # RAG数据目录
diff --git a/template/model_config_template.toml b/template/model_config_template.toml
new file mode 100644
index 00000000..3dcff6f8
--- /dev/null
+++ b/template/model_config_template.toml
@@ -0,0 +1,171 @@
+[inner]
+version = "1.1.1"
+
+# 配置文件版本号迭代规则同bot_config.toml
+
+[[api_providers]] # API服务提供商(可以配置多个)
+name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名)
+base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL
+api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥)
+client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini")
+max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数)
+timeout = 30 # API请求超时时间(单位:秒)
+retry_interval = 10 # 重试间隔时间(单位:秒)
+
+[[api_providers]] # SiliconFlow的API服务商配置
+name = "SiliconFlow"
+base_url = "https://api.siliconflow.cn/v1"
+api_key = "your-siliconflow-api-key"
+client_type = "openai"
+max_retry = 2
+timeout = 30
+retry_interval = 10
+
+[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini"
+name = "Google"
+base_url = "https://api.google.com/v1"
+api_key = "your-google-api-key-1"
+client_type = "gemini"
+max_retry = 2
+timeout = 30
+retry_interval = 10
+
+
+[[models]] # 模型(可以配置多个)
+model_identifier = "deepseek-chat" # 模型标识符(API服务商提供的模型标识符)
+name = "deepseek-v3" # 模型名称(可随意命名,在后面中需使用这个命名)
+api_provider = "DeepSeek" # API服务商名称(对应在api_providers中配置的服务商名称)
+price_in = 2.0 # 输入价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0)
+price_out = 8.0 # 输出价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0)
+#force_stream_mode = true # 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出,若无该字段,默认值为false)
+
+[[models]]
+model_identifier = "Pro/deepseek-ai/DeepSeek-V3"
+name = "siliconflow-deepseek-v3"
+api_provider = "SiliconFlow"
+price_in = 2.0
+price_out = 8.0
+
+[[models]]
+model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
+name = "deepseek-r1-distill-qwen-32b"
+api_provider = "SiliconFlow"
+price_in = 4.0
+price_out = 16.0
+
+[[models]]
+model_identifier = "Qwen/Qwen3-8B"
+name = "qwen3-8b"
+api_provider = "SiliconFlow"
+price_in = 0
+price_out = 0
+[models.extra_params] # 可选的额外参数配置
+enable_thinking = false # 不启用思考
+
+[[models]]
+model_identifier = "Qwen/Qwen3-14B"
+name = "qwen3-14b"
+api_provider = "SiliconFlow"
+price_in = 0.5
+price_out = 2.0
+[models.extra_params] # 可选的额外参数配置
+enable_thinking = false # 不启用思考
+
+[[models]]
+model_identifier = "Qwen/Qwen3-30B-A3B"
+name = "qwen3-30b"
+api_provider = "SiliconFlow"
+price_in = 0.7
+price_out = 2.8
+[models.extra_params] # 可选的额外参数配置
+enable_thinking = false # 不启用思考
+
+[[models]]
+model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct"
+name = "qwen2.5-vl-72b"
+api_provider = "SiliconFlow"
+price_in = 4.13
+price_out = 4.13
+
+[[models]]
+model_identifier = "FunAudioLLM/SenseVoiceSmall"
+name = "sensevoice-small"
+api_provider = "SiliconFlow"
+price_in = 0
+price_out = 0
+
+[[models]]
+model_identifier = "BAAI/bge-m3"
+name = "bge-m3"
+api_provider = "SiliconFlow"
+price_in = 0
+price_out = 0
+
+
+[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型
+model_list = ["siliconflow-deepseek-v3"] # 使用的模型列表,每个子项对应上面的模型名称(name)
+temperature = 0.2 # 模型温度,新V3建议0.1-0.3
+max_tokens = 800 # 最大输出token数
+
+[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型
+model_list = ["qwen3-8b"]
+temperature = 0.7
+max_tokens = 800
+
+[model_task_config.replyer_1] # 首要回复模型,还用于表达器和表达方式学习
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2 # 模型温度,新V3建议0.1-0.3
+max_tokens = 800
+
+[model_task_config.replyer_2] # 次要回复模型
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.7
+max_tokens = 800
+
+[model_task_config.planner] #决策:负责决定麦麦该做什么的模型
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.3
+max_tokens = 800
+
+[model_task_config.emotion] #负责麦麦的情绪变化
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.3
+max_tokens = 800
+
+[model_task_config.memory] # 记忆模型
+model_list = ["qwen3-30b"]
+temperature = 0.7
+max_tokens = 800
+
+[model_task_config.vlm] # 图像识别模型
+model_list = ["qwen2.5-vl-72b"]
+max_tokens = 800
+
+[model_task_config.voice] # 语音识别模型
+model_list = ["sensevoice-small"]
+
+[model_task_config.tool_use] #工具调用模型,需要使用支持工具调用的模型
+model_list = ["qwen3-14b"]
+temperature = 0.7
+max_tokens = 800
+
+#嵌入模型
+[model_task_config.embedding]
+model_list = ["bge-m3"]
+
+#------------LPMM知识库模型------------
+
+[model_task_config.lpmm_entity_extract] # 实体提取模型
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2
+max_tokens = 800
+
+[model_task_config.lpmm_rdf_build] # RDF构建模型
+model_list = ["siliconflow-deepseek-v3"]
+temperature = 0.2
+max_tokens = 800
+
+[model_task_config.lpmm_qa] # 问答模型
+model_list = ["deepseek-r1-distill-qwen-32b"]
+temperature = 0.7
+max_tokens = 800
diff --git a/template/template.env b/template/template.env
index 4718203d..d9b6e2bd 100644
--- a/template/template.env
+++ b/template/template.env
@@ -1,23 +1,2 @@
HOST=127.0.0.1
-PORT=8000
-
-# 密钥和url
-
-# 硅基流动
-SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
-# DeepSeek官方
-DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
-# 阿里百炼
-BAILIAN_BASE_URL = https://dashscope.aliyuncs.com/compatible-mode/v1
-# 火山引擎
-HUOSHAN_BASE_URL =
-# xxxxx平台
-xxxxxxx_BASE_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
-
-# 定义你要用的api的key(需要去对应网站申请哦)
-DEEP_SEEK_KEY=
-CHAT_ANY_WHERE_KEY=
-SILICONFLOW_KEY=
-BAILIAN_KEY =
-HUOSHAN_KEY =
-xxxxxxx_KEY=
+PORT=8000
\ No newline at end of file