diff --git a/.gitignore b/.gitignore index 4db85eab..f51b8d6f 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ config/bot_config.toml.bak config/lpmm_config.toml config/lpmm_config.toml.bak template/compare/bot_config_template.toml +template/compare/model_config_template.toml (测试版)麦麦生成人格.bat (临时版)麦麦开始学习.bat src/plugins/utils/statistic.py @@ -321,4 +322,5 @@ run_pet.bat config.toml -interested_rates.txt \ No newline at end of file +interested_rates.txt +MaiBot.code-workspace diff --git a/bot.py b/bot.py index 72ea65d2..b8f154cd 100644 --- a/bot.py +++ b/bot.py @@ -74,36 +74,6 @@ def easter_egg(): print(rainbow_text) -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") - async def graceful_shutdown(): try: @@ -229,9 +199,6 @@ def raw_main(): easter_egg() - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) - # 返回MainSystem实例 return MainSystem() diff --git a/changelogs/changelog.md b/changelogs/changelog.md index c835e684..35631077 100644 --- a/changelogs/changelog.md +++ b/changelogs/changelog.md @@ -1,5 +1,17 @@ # Changelog +## [0.10.0] - 2025-7-1 +### 主要功能更改 +- 工具系统重构,现在合并到了插件系统中 +- 彻底重构了整个LLM Request了,现在支持模型轮询和更多灵活的参数 + - 同时重构了整个模型配置系统,升级需要重新配置llm配置文件 +- 随着LLM Request的重构,插件系统彻底重构完成。插件系统进入稳定状态,仅增加新的API + - 具体相比于之前的更改可以查看[changes.md](./changes.md) + +### 细节优化 +- 修复了lint爆炸的问题,代码更加规范了 +- 修改了log的颜色,更加护眼 + ## [0.9.1] - 2025-7-26 ### 主要修复和优化 diff --git a/changes.md b/changelogs/changes.md similarity index 89% rename from changes.md rename to changelogs/changes.md index b776991d..db41703c 100644 --- a/changes.md +++ b/changelogs/changes.md @@ -25,6 +25,7 @@ - 这意味着你终于可以动态控制是否继续后续消息的处理了。 8. 移除了dependency_manager,但是依然保留了`python_dependencies`属性,等待后续重构。 - 一并移除了文档有关manager的内容。 +9. 增加了工具的有关api # 插件系统修改 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** @@ -57,30 +58,12 @@ 15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。 - 通过`disable_specific_chat_action`,`enable_specific_chat_action`,`disable_specific_chat_command`,`enable_specific_chat_command`,`disable_specific_chat_event_handler`,`enable_specific_chat_event_handler`来操作 - 同样不保存到配置文件~ +16. 把`BaseTool`一并合并进入了插件系统 # 官方插件修改 1. `HelloWorld`插件现在有一个样例的`EventHandler`。 -2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。 - -### TODO -把这个看起来就很别扭的config获取方式改一下 - - -# 吐槽 -```python -plugin_path = Path(plugin_file) -if plugin_path.parent.name != "plugins": - # 插件包格式:parent_dir.plugin - module_name = f"plugins.{plugin_path.parent.name}.plugin" -else: - # 单文件格式:plugins.filename - module_name = f"plugins.{plugin_path.stem}" -``` -```python -plugin_path = Path(plugin_file) -module_name = ".".join(plugin_path.parent.parts) -``` -这两个区别很大的。 +2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。(需要自行启用) +3. `HelloWorld`插件现在有一个样例的`CompareNumbersTool`。 ### 执笔BGM 塞壬唱片! \ No newline at end of file diff --git a/docs/image-1.png b/docs/image-1.png new file mode 100644 index 00000000..c7a0adc8 Binary files /dev/null and b/docs/image-1.png differ diff --git a/docs/image.png b/docs/image.png new file mode 100644 index 00000000..63416251 Binary files /dev/null and b/docs/image.png differ diff --git a/docs/model_configuration_guide.md b/docs/model_configuration_guide.md new file mode 100644 index 00000000..6bbe05af --- /dev/null +++ b/docs/model_configuration_guide.md @@ -0,0 +1,331 @@ +# 模型配置指南 + +本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。 + +## 配置文件结构 + +配置文件主要包含以下几个部分: +- 版本信息 +- API服务提供商配置 +- 模型配置 +- 模型任务配置 + +## 1. 版本信息 + +```toml +[inner] +version = "1.1.1" +``` + +用于标识配置文件的版本,遵循语义化版本规则。 + +## 2. API服务提供商配置 + +### 2.1 基本配置 + +使用 `[[api_providers]]` 数组配置多个API服务提供商: + +```toml +[[api_providers]] +name = "DeepSeek" # 服务商名称(自定义) +base_url = "https://api.deepseek.cn/v1" # API服务的基础URL +api_key = "your-api-key-here" # API密钥 +client_type = "openai" # 客户端类型 +max_retry = 2 # 最大重试次数 +timeout = 30 # 超时时间(秒) +retry_interval = 10 # 重试间隔(秒) +``` + +### 2.2 配置参数说明 + +| 参数 | 必填 | 说明 | 默认值 | +|------|------|------|--------| +| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - | +| `base_url` | ✅ | API服务的基础URL | - | +| `api_key` | ✅ | API密钥,请替换为实际密钥 | - | +| `client_type` | ❌ | 客户端类型:`openai`(OpenAI格式)或 `gemini`(Gemini格式,现在支持不良好) | `openai` | +| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 | +| `timeout` | ❌ | API请求超时时间(秒) | 30 | +| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 | + +### 2.3 支持的服务商示例 + +#### DeepSeek +```toml +[[api_providers]] +name = "DeepSeek" +base_url = "https://api.deepseek.cn/v1" +api_key = "your-deepseek-api-key" +client_type = "openai" +``` + +#### SiliconFlow +```toml +[[api_providers]] +name = "SiliconFlow" +base_url = "https://api.siliconflow.cn/v1" +api_key = "your-siliconflow-api-key" +client_type = "openai" +``` + +#### Google Gemini +```toml +[[api_providers]] +name = "Google" +base_url = "https://api.google.com/v1" +api_key = "your-google-api-key" +client_type = "gemini" # 注意:Gemini需要使用特殊客户端 +``` + +## 3. 模型配置 + +### 3.1 基本模型配置 + +使用 `[[models]]` 数组配置多个模型: + +```toml +[[models]] +model_identifier = "deepseek-chat" # 模型在API服务商中的标识符 +name = "deepseek-v3" # 自定义模型名称 +api_provider = "DeepSeek" # 引用的API服务商名称 +price_in = 2.0 # 输入价格(元/M token) +price_out = 8.0 # 输出价格(元/M token) +``` + +### 3.2 高级模型配置 + +#### 强制流式输出 +对于不支持非流式输出的模型: +```toml +[[models]] +model_identifier = "some-model" +name = "custom-name" +api_provider = "Provider" +force_stream_mode = true # 启用强制流式输出 +``` + +#### 额外参数配置`extra_params` +```toml +[[models]] +model_identifier = "Qwen/Qwen3-8B" +name = "qwen3-8b" +api_provider = "SiliconFlow" +[models.extra_params] +enable_thinking = false # 禁用思考 +``` +这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置,**配置时应参考相应的API文档**。 + +比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。 + +![SiliconFlow文档截图](image-1.png) + +以豆包文档为另一个例子 + +![豆包文档截图](image.png) + +得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为 +```toml +[[models]] +# 你的模型 +[models.extra_params] +thinking = {type = "disabled"} # 禁用思考 +``` +请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构,具体内容取决于API服务商的要求。 + +### 3.3 配置参数说明 + +| 参数 | 必填 | 说明 | +|------|------|------| +| `model_identifier` | ✅ | API服务商提供的模型标识符 | +| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 | +| `api_provider` | ✅ | 对应的API服务商名称 | +| `price_in` | ❌ | 输入价格(元/M token),用于成本统计 | +| `price_out` | ❌ | 输出价格(元/M token),用于成本统计 | +| `force_stream_mode` | ❌ | 是否强制使用流式输出 | +| `extra_params` | ❌ | 额外的模型参数配置 | + +## 4. 模型任务配置 + +### utils - 工具模型 +用于表情包模块、取名模块、关系模块等核心功能: +```toml +[model_task_config.utils] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 +max_tokens = 800 +``` + +### utils_small - 小型工具模型 +用于高频率调用的场景,建议使用速度快的小模型: +```toml +[model_task_config.utils_small] +model_list = ["qwen3-8b"] +temperature = 0.7 +max_tokens = 800 +``` + +### replyer_1 - 主要回复模型 +首要回复模型,也用于表达器和表达方式学习: +```toml +[model_task_config.replyer_1] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 +max_tokens = 800 +``` + +### replyer_2 - 次要回复模型 +```toml +[model_task_config.replyer_2] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.7 +max_tokens = 800 +``` + +### planner - 决策模型 +负责决定MaiBot该做什么: +```toml +[model_task_config.planner] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.3 +max_tokens = 800 +``` + +### emotion - 情绪模型 +负责MaiBot的情绪变化: +```toml +[model_task_config.emotion] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.3 +max_tokens = 800 +``` + +### memory - 记忆模型 +```toml +[model_task_config.memory] +model_list = ["qwen3-30b"] +temperature = 0.7 +max_tokens = 800 +``` + +### vlm - 视觉语言模型 +用于图像识别: +```toml +[model_task_config.vlm] +model_list = ["qwen2.5-vl-72b"] +max_tokens = 800 +``` + +### voice - 语音识别模型 +```toml +[model_task_config.voice] +model_list = ["sensevoice-small"] +``` + +### embedding - 嵌入模型 +```toml +[model_task_config.embedding] +model_list = ["bge-m3"] +``` + +### tool_use - 工具调用模型 +需要使用支持工具调用的模型: +```toml +[model_task_config.tool_use] +model_list = ["qwen3-14b"] +temperature = 0.7 +max_tokens = 800 +``` + +### lpmm_entity_extract - 实体提取模型 +```toml +[model_task_config.lpmm_entity_extract] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 +max_tokens = 800 +``` + +### lpmm_rdf_build - RDF构建模型 +```toml +[model_task_config.lpmm_rdf_build] +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 +max_tokens = 800 +``` + +### lpmm_qa - 问答模型 +```toml +[model_task_config.lpmm_qa] +model_list = ["deepseek-r1-distill-qwen-32b"] +temperature = 0.7 +max_tokens = 800 +``` + +## 5. 配置建议 + +### 5.1 Temperature 参数选择 + +| 任务类型 | 推荐温度 | 说明 | +|----------|----------|------| +| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 | +| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 | +| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 | + +### 5.2 模型选择建议 + +| 任务类型 | 推荐模型类型 | 示例 | +|----------|--------------|------| +| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 | +| 高频率任务 | 小模型 | Qwen3-8B | +| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice | +| 工具调用 | 支持Function Call的模型 | Qwen3-14B | + +### 5.3 成本优化 + +1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型 +2. **合理配置max_tokens**:根据实际需求设置,避免浪费 +3. **选择免费模型**:对于测试环境,优先使用price为0的模型 + +## 6. 配置验证 + +### 6.1 必要检查项 + +1. ✅ API密钥是否正确配置 +2. ✅ 模型标识符是否与API服务商提供的一致 +3. ✅ 任务配置中引用的模型名称是否在models中定义 +4. ✅ 多模态任务是否配置了对应的专用模型 + +### 6.2 测试配置 + +建议在正式使用前: +1. 使用少量测试数据验证配置 +2. 检查API调用是否正常 +3. 确认成本统计功能正常工作 + +## 7. 故障排除 + +### 7.1 常见问题 + +**问题1**: API调用失败 +- 检查API密钥是否正确 +- 确认base_url是否可访问 +- 检查模型标识符是否正确 + +**问题2**: 模型未找到 +- 确认模型名称在任务配置和模型定义中一致 +- 检查api_provider名称是否匹配 + +**问题3**: 响应异常 +- 检查温度参数是否合理(0-1之间) +- 确认max_tokens设置是否合适 +- 验证模型是否支持所需功能 + +### 7.2 日志查看 + +查看 `logs/` 目录下的日志文件,寻找相关错误信息。 + +## 8. 更新和维护 + +1. **定期更新**: 关注API服务商的模型更新,及时调整配置 +2. **性能监控**: 监控模型调用的成本和性能 +3. **备份配置**: 在修改前备份当前配置文件 + diff --git a/docs/plugins/api/component-manage-api.md b/docs/plugins/api/component-manage-api.md new file mode 100644 index 00000000..a857fb27 --- /dev/null +++ b/docs/plugins/api/component-manage-api.md @@ -0,0 +1,194 @@ +# 组件管理API + +组件管理API模块提供了对插件组件的查询和管理功能,使得插件能够获取和使用组件相关的信息。 + +## 导入方式 +```python +from src.plugin_system.apis import component_manage_api +# 或者 +from src.plugin_system import component_manage_api +``` + +## 功能概述 + +组件管理API主要提供以下功能: +- **插件信息查询** - 获取所有插件或指定插件的信息。 +- **组件查询** - 按名称或类型查询组件信息。 +- **组件管理** - 启用或禁用组件,支持全局和局部操作。 + +## 主要功能 + +### 1. 获取所有插件信息 +```python +def get_all_plugin_info() -> Dict[str, PluginInfo]: +``` +获取所有插件的信息。 + +**Returns:** +- `Dict[str, PluginInfo]` - 包含所有插件信息的字典,键为插件名称,值为 `PluginInfo` 对象。 + +### 2. 获取指定插件信息 +```python +def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]: +``` +获取指定插件的信息。 + +**Args:** +- `plugin_name` (str): 插件名称。 + +**Returns:** +- `Optional[PluginInfo]`: 插件信息对象,如果插件不存在则返回 `None`。 + +### 3. 获取指定组件信息 +```python +def get_component_info(component_name: str, component_type: ComponentType) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +``` +获取指定组件的信息。 + +**Args:** +- `component_name` (str): 组件名称。 +- `component_type` (ComponentType): 组件类型。 + +**Returns:** +- `Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 组件信息对象,如果组件不存在则返回 `None`。 + +### 4. 获取指定类型的所有组件信息 +```python +def get_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +``` +获取指定类型的所有组件信息。 + +**Args:** +- `component_type` (ComponentType): 组件类型。 + +**Returns:** +- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。 + +### 5. 获取指定类型的所有启用的组件信息 +```python +def get_enabled_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]: +``` +获取指定类型的所有启用的组件信息。 + +**Args:** +- `component_type` (ComponentType): 组件类型。 + +**Returns:** +- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。 + +### 6. 获取指定 Action 的注册信息 +```python +def get_registered_action_info(action_name: str) -> Optional[ActionInfo]: +``` +获取指定 Action 的注册信息。 + +**Args:** +- `action_name` (str): Action 名称。 + +**Returns:** +- `Optional[ActionInfo]` - Action 信息对象,如果 Action 不存在则返回 `None`。 + +### 7. 获取指定 Command 的注册信息 +```python +def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: +``` +获取指定 Command 的注册信息。 + +**Args:** +- `command_name` (str): Command 名称。 + +**Returns:** +- `Optional[CommandInfo]` - Command 信息对象,如果 Command 不存在则返回 `None`。 + +### 8. 获取指定 Tool 的注册信息 +```python +def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: +``` +获取指定 Tool 的注册信息。 + +**Args:** +- `tool_name` (str): Tool 名称。 + +**Returns:** +- `Optional[ToolInfo]` - Tool 信息对象,如果 Tool 不存在则返回 `None`。 + +### 9. 获取指定 EventHandler 的注册信息 +```python +def get_registered_event_handler_info(event_handler_name: str) -> Optional[EventHandlerInfo]: +``` +获取指定 EventHandler 的注册信息。 + +**Args:** +- `event_handler_name` (str): EventHandler 名称。 + +**Returns:** +- `Optional[EventHandlerInfo]` - EventHandler 信息对象,如果 EventHandler 不存在则返回 `None`。 + +### 10. 全局启用指定组件 +```python +def globally_enable_component(component_name: str, component_type: ComponentType) -> bool: +``` +全局启用指定组件。 + +**Args:** +- `component_name` (str): 组件名称。 +- `component_type` (ComponentType): 组件类型。 + +**Returns:** +- `bool` - 启用成功返回 `True`,否则返回 `False`。 + +### 11. 全局禁用指定组件 +```python +async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool: +``` +全局禁用指定组件。 + +**此函数是异步的,确保在异步环境中调用。** + +**Args:** +- `component_name` (str): 组件名称。 +- `component_type` (ComponentType): 组件类型。 + +**Returns:** +- `bool` - 禁用成功返回 `True`,否则返回 `False`。 + +### 12. 局部启用指定组件 +```python +def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool: +``` +局部启用指定组件。 + +**Args:** +- `component_name` (str): 组件名称。 +- `component_type` (ComponentType): 组件类型。 +- `stream_id` (str): 消息流 ID。 + +**Returns:** +- `bool` - 启用成功返回 `True`,否则返回 `False`。 + +### 13. 局部禁用指定组件 +```python +def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool: +``` +局部禁用指定组件。 + +**Args:** +- `component_name` (str): 组件名称。 +- `component_type` (ComponentType): 组件类型。 +- `stream_id` (str): 消息流 ID。 + +**Returns:** +- `bool` - 禁用成功返回 `True`,否则返回 `False`。 + +### 14. 获取指定消息流中禁用的组件列表 +```python +def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]: +``` +获取指定消息流中禁用的组件列表。 + +**Args:** +- `stream_id` (str): 消息流 ID。 +- `component_type` (ComponentType): 组件类型。 + +**Returns:** +- `list[str]` - 禁用的组件名称列表。 diff --git a/docs/plugins/api/config-api.md b/docs/plugins/api/config-api.md index 2a5691fc..2ee1cdfc 100644 --- a/docs/plugins/api/config-api.md +++ b/docs/plugins/api/config-api.md @@ -1,6 +1,6 @@ # 配置API -配置API模块提供了配置读取和用户信息获取等功能,让插件能够安全地访问全局配置和用户信息。 +配置API模块提供了配置读取功能,让插件能够安全地访问全局配置和插件配置。 ## 导入方式 diff --git a/docs/plugins/api/database-api.md b/docs/plugins/api/database-api.md index 174bef15..5b6b4468 100644 --- a/docs/plugins/api/database-api.md +++ b/docs/plugins/api/database-api.md @@ -6,72 +6,51 @@ ```python from src.plugin_system.apis import database_api +# 或者 +from src.plugin_system import database_api ``` ## 主要功能 -### 1. 通用数据库查询 - -#### `db_query(model_class, query_type="get", filters=None, data=None, limit=None, order_by=None, single_result=False)` -执行数据库查询操作的通用接口 - -**参数:** -- `model_class`:Peewee模型类,如ActionRecords、Messages等 -- `query_type`:查询类型,可选值: "get", "create", "update", "delete", "count" -- `filters`:过滤条件字典,键为字段名,值为要匹配的值 -- `data`:用于创建或更新的数据字典 -- `limit`:限制结果数量 -- `order_by`:排序字段列表,使用字段名,前缀'-'表示降序 -- `single_result`:是否只返回单个结果 - -**返回:** -根据查询类型返回不同的结果: -- "get":返回查询结果列表或单个结果 -- "create":返回创建的记录 -- "update":返回受影响的行数 -- "delete":返回受影响的行数 -- "count":返回记录数量 - -### 2. 便捷查询函数 - -#### `db_save(model_class, data, key_field=None, key_value=None)` -保存数据到数据库(创建或更新) - -**参数:** -- `model_class`:Peewee模型类 -- `data`:要保存的数据字典 -- `key_field`:用于查找现有记录的字段名 -- `key_value`:用于查找现有记录的字段值 - -**返回:** -- `Dict[str, Any]`:保存后的记录数据,失败时返回None - -#### `db_get(model_class, filters=None, order_by=None, limit=None)` -简化的查询函数 - -**参数:** -- `model_class`:Peewee模型类 -- `filters`:过滤条件字典 -- `order_by`:排序字段 -- `limit`:限制结果数量 - -**返回:** -- `Union[List[Dict], Dict, None]`:查询结果 - -### 3. 专用函数 - -#### `store_action_info(...)` -存储动作信息的专用函数 - -## 使用示例 - -### 1. 基本查询操作 +### 1. 通用数据库操作 ```python -from src.plugin_system.apis import database_api -from src.common.database.database_model import Messages, ActionRecords +async def db_query( + model_class: Type[Model], + data: Optional[Dict[str, Any]] = None, + query_type: Optional[str] = "get", + filters: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + order_by: Optional[List[str]] = None, + single_result: Optional[bool] = False, +) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: +``` +执行数据库查询操作的通用接口。 -# 查询最近10条消息 +**Args:** +- `model_class`: Peewee模型类。 + - Peewee模型类可以在`src.common.database.database_model`模块中找到,如`ActionRecords`、`Messages`等。 +- `data`: 用于创建或更新的数据 +- `query_type`: 查询类型 + - 可选值: `get`, `create`, `update`, `delete`, `count`。 +- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。 +- `limit`: 限制结果数量。 +- `order_by`: 排序字段列表,使用字段名,前缀'-'表示降序。 + - 排序字段,前缀`-`表示降序,例如`-time`表示按时间字段(即`time`字段)降序 +- `single_result`: 是否只返回单个结果。 + +**Returns:** +- 根据查询类型返回不同的结果: + - `get`: 返回查询结果列表或单个结果。(如果 `single_result=True`) + - `create`: 返回创建的记录。 + - `update`: 返回受影响的行数。 + - `delete`: 返回受影响的行数。 + - `count`: 返回记录数量。 + +#### 示例 + +1. 查询最近10条消息 +```python messages = await database_api.db_query( Messages, query_type="get", @@ -79,180 +58,159 @@ messages = await database_api.db_query( limit=10, order_by=["-time"] ) - -# 查询单条记录 -message = await database_api.db_query( - Messages, - query_type="get", - filters={"message_id": "msg_123"}, - single_result=True -) ``` - -### 2. 创建记录 - +2. 创建一条记录 ```python -# 创建新的动作记录 new_record = await database_api.db_query( ActionRecords, + data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}, query_type="create", - data={ - "action_id": "action_123", - "time": time.time(), - "action_name": "TestAction", - "action_done": True - } ) - -print(f"创建了记录: {new_record['id']}") ``` - -### 3. 更新记录 - +3. 更新记录 ```python -# 更新动作状态 updated_count = await database_api.db_query( ActionRecords, + data={"action_done": True}, query_type="update", - filters={"action_id": "action_123"}, - data={"action_done": True, "completion_time": time.time()} + filters={"action_id": "123"}, ) - -print(f"更新了 {updated_count} 条记录") ``` - -### 4. 删除记录 - +4. 删除记录 ```python -# 删除过期记录 deleted_count = await database_api.db_query( ActionRecords, query_type="delete", - filters={"time__lt": time.time() - 86400} # 删除24小时前的记录 + filters={"action_id": "123"} ) - -print(f"删除了 {deleted_count} 条过期记录") ``` - -### 5. 统计查询 - +5. 计数 ```python -# 统计消息数量 -message_count = await database_api.db_query( +count = await database_api.db_query( Messages, query_type="count", filters={"chat_id": chat_stream.stream_id} ) - -print(f"该聊天有 {message_count} 条消息") ``` -### 6. 使用便捷函数 - +### 2. 数据库保存 +```python +async def db_save( + model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None +) -> Optional[Dict[str, Any]]: +``` +保存数据到数据库(创建或更新) + +如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; + +如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 + +**Args:** +- `model_class`: Peewee模型类。 +- `data`: 要保存的数据字典。 +- `key_field`: 用于查找现有记录的字段名,例如"action_id"。 +- `key_value`: 用于查找现有记录的字段值。 + +**Returns:** +- `Optional[Dict[str, Any]]`: 保存后的记录数据,失败时返回None。 + +#### 示例 +创建或更新一条记录 ```python -# 使用db_save进行创建或更新 record = await database_api.db_save( ActionRecords, { - "action_id": "action_123", + "action_id": "123", "time": time.time(), "action_name": "TestAction", "action_done": True }, key_field="action_id", - key_value="action_123" + key_value="123" ) +``` -# 使用db_get进行简单查询 -recent_messages = await database_api.db_get( +### 3. 数据库获取 +```python +async def db_get( + model_class: Type[Model], + filters: Optional[Dict[str, Any]] = None, + limit: Optional[int] = None, + order_by: Optional[str] = None, + single_result: Optional[bool] = False, +) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: +``` + +从数据库获取记录 + +这是db_query方法的简化版本,专注于数据检索操作。 + +**Args:** +- `model_class`: Peewee模型类。 +- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。 +- `limit`: 限制结果数量。 +- `order_by`: 排序字段,使用字段名,前缀'-'表示降序。 +- `single_result`: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表 + +**Returns:** +- `Union[List[Dict], Dict, None]`: 查询结果列表或单个结果(如果`single_result=True`),失败时返回None。 + +#### 示例 +1. 获取单个记录 +```python +record = await database_api.db_get( + ActionRecords, + filters={"action_id": "123"}, + limit=1 +) +``` +2. 获取最近10条记录 +```python +records = await database_api.db_get( Messages, filters={"chat_id": chat_stream.stream_id}, + limit=10, order_by="-time", - limit=5 ) ``` -## 高级用法 - -### 复杂查询示例 - +### 4. 动作信息存储 ```python -# 查询特定用户在特定时间段的消息 -user_messages = await database_api.db_query( - Messages, - query_type="get", - filters={ - "user_id": "123456", - "time__gte": start_time, # 大于等于开始时间 - "time__lt": end_time # 小于结束时间 - }, - order_by=["-time"], - limit=50 +async def store_action_info( + chat_stream=None, + action_build_into_prompt: bool = False, + action_prompt_display: str = "", + action_done: bool = True, + thinking_id: str = "", + action_data: Optional[dict] = None, + action_name: str = "", +) -> Optional[Dict[str, Any]]: +``` +存储动作信息到数据库,是一种针对 Action 的 `db_save()` 的封装函数。 + +将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。 + +**Args:** +- `chat_stream`: 聊天流对象,包含聊天ID等信息。 +- `action_build_into_prompt`: 是否将动作信息构建到提示中。 +- `action_prompt_display`: 动作提示的显示文本。 +- `action_done`: 动作是否完成。 +- `thinking_id`: 思考过程的ID。 +- `action_data`: 动作的数据字典。 +- `action_name`: 动作的名称。 + +**Returns:** +- `Optional[Dict[str, Any]]`: 存储后的记录数据,失败时返回None。 + +#### 示例 +```python +record = await database_api.store_action_info( + chat_stream=chat_stream, + action_build_into_prompt=True, + action_prompt_display="执行了回复动作", + action_done=True, + thinking_id="thinking_123", + action_data={"content": "Hello"}, + action_name="reply_action" ) - -# 批量处理 -for message in user_messages: - print(f"消息内容: {message['plain_text']}") - print(f"发送时间: {message['time']}") -``` - -### 插件中的数据持久化 - -```python -from src.plugin_system.base import BasePlugin -from src.plugin_system.apis import database_api - -class DataPlugin(BasePlugin): - async def handle_action(self, action_data, chat_stream): - # 保存插件数据 - plugin_data = { - "plugin_name": self.plugin_name, - "chat_id": chat_stream.stream_id, - "data": json.dumps(action_data), - "created_time": time.time() - } - - # 使用自定义表模型(需要先定义) - record = await database_api.db_save( - PluginData, # 假设的插件数据模型 - plugin_data, - key_field="plugin_name", - key_value=self.plugin_name - ) - - return {"success": True, "record_id": record["id"]} -``` - -## 数据模型 - -### 常用模型类 -系统提供了以下常用的数据模型: - -- `Messages`:消息记录 -- `ActionRecords`:动作记录 -- `UserInfo`:用户信息 -- `GroupInfo`:群组信息 - -### 字段说明 - -#### Messages模型主要字段 -- `message_id`:消息ID -- `chat_id`:聊天ID -- `user_id`:用户ID -- `plain_text`:纯文本内容 -- `time`:时间戳 - -#### ActionRecords模型主要字段 -- `action_id`:动作ID -- `action_name`:动作名称 -- `action_done`:是否完成 -- `time`:创建时间 - -## 注意事项 - -1. **异步操作**:所有数据库API都是异步的,必须使用`await` -2. **错误处理**:函数内置错误处理,失败时返回None或空列表 -3. **数据类型**:返回的都是字典格式的数据,不是模型对象 -4. **性能考虑**:使用`limit`参数避免查询大量数据 -5. **过滤条件**:支持简单的等值过滤,复杂查询需要使用原生Peewee语法 -6. **事务**:如需事务支持,建议直接使用Peewee的事务功能 \ No newline at end of file +``` \ No newline at end of file diff --git a/docs/plugins/api/emoji-api.md b/docs/plugins/api/emoji-api.md index 6dd071b9..ce9dd0c8 100644 --- a/docs/plugins/api/emoji-api.md +++ b/docs/plugins/api/emoji-api.md @@ -6,11 +6,13 @@ ```python from src.plugin_system.apis import emoji_api +# 或者 +from src.plugin_system import emoji_api ``` -## 🆕 **二步走识别优化** +## 二步走识别优化 -从最新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案: +从新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案: ### **收到表情包时的识别流程** 1. **第一步**:VLM视觉分析 - 生成详细描述 @@ -30,217 +32,84 @@ from src.plugin_system.apis import emoji_api ## 主要功能 ### 1. 表情包获取 - -#### `get_by_description(description: str) -> Optional[Tuple[str, str, str]]` +```python +async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]: +``` 根据场景描述选择表情包 -**参数:** -- `description`:场景描述文本,例如"开心的大笑"、"轻微的讽刺"、"表示无奈和沮丧"等 +**Args:** +- `description`:表情包的描述文本,例如"开心"、"难过"、"愤怒"等 -**返回:** -- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 匹配的场景) 或 None +**Returns:** +- `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到匹配的表情包则返回None -**示例:** +#### 示例 ```python -emoji_result = await emoji_api.get_by_description("开心的大笑") +emoji_result = await emoji_api.get_by_description("大笑") if emoji_result: emoji_base64, description, matched_scene = emoji_result print(f"获取到表情包: {description}, 场景: {matched_scene}") # 可以将emoji_base64用于发送表情包 ``` -#### `get_random() -> Optional[Tuple[str, str, str]]` -随机获取表情包 - -**返回:** -- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 随机场景) 或 None - -**示例:** +### 2. 随机获取表情包 ```python -random_emoji = await emoji_api.get_random() -if random_emoji: - emoji_base64, description, scene = random_emoji - print(f"随机表情包: {description}") +async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: ``` +随机获取指定数量的表情包 -#### `get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]` -根据场景关键词获取表情包 +**Args:** +- `count`:要获取的表情包数量,默认为1 -**参数:** -- `emotion`:场景关键词,如"大笑"、"讽刺"、"无奈"等 +**Returns:** +- `List[Tuple[str, str, str]]`:一个包含多个表情包的列表,每个元素是一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到或出错则返回空列表 -**返回:** -- `Optional[Tuple[str, str, str]]`:(base64编码, 表情包描述, 匹配的场景) 或 None - -**示例:** +### 3. 根据情感获取表情包 ```python -emoji_result = await emoji_api.get_by_emotion("讽刺") -if emoji_result: - emoji_base64, description, scene = emoji_result - # 发送讽刺表情包 +async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: ``` +根据情感标签获取表情包 -### 2. 表情包信息查询 +**Args:** +- `emotion`:情感标签,例如"开心"、"悲伤"、"愤怒"等 -#### `get_count() -> int` -获取表情包数量 +**Returns:** +- `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到则返回None -**返回:** -- `int`:当前可用的表情包数量 +### 4. 获取表情包数量 +```python +def get_count() -> int: +``` +获取当前可用表情包的数量 -#### `get_info() -> dict` -获取表情包系统信息 +### 5. 获取表情包系统信息 +```python +def get_info() -> Dict[str, Any]: +``` +获取表情包系统的基本信息 -**返回:** -- `dict`:包含表情包数量、最大数量等信息 +**Returns:** +- `Dict[str, Any]`:包含表情包数量、描述等信息的字典,包含以下键: + - `current_count`:当前表情包数量 + - `max_count`:最大表情包数量 + - `available_emojis`:当前可用的表情包数量 -**返回字典包含:** -- `current_count`:当前表情包数量 -- `max_count`:最大表情包数量 -- `available_emojis`:可用表情包数量 +### 6. 获取所有可用的情感标签 +```python +def get_emotions() -> List[str]: +``` +获取所有可用的情感标签 **(已经去重)** -#### `get_emotions() -> list` -获取所有可用的场景关键词 - -**返回:** -- `list`:所有表情包的场景关键词列表(去重) - -#### `get_descriptions() -> list` +### 7. 获取所有表情包描述 +```python +def get_descriptions() -> List[str]: +``` 获取所有表情包的描述列表 -**返回:** -- `list`:所有表情包的描述文本列表 - -## 使用示例 - -### 1. 智能表情包选择 - -```python -from src.plugin_system.apis import emoji_api - -async def send_emotion_response(message_text: str, chat_stream): - """根据消息内容智能选择表情包回复""" - - # 分析消息场景 - if "哈哈" in message_text or "好笑" in message_text: - emoji_result = await emoji_api.get_by_description("开心的大笑") - elif "无语" in message_text or "算了" in message_text: - emoji_result = await emoji_api.get_by_description("表示无奈和沮丧") - elif "呵呵" in message_text or "是吗" in message_text: - emoji_result = await emoji_api.get_by_description("轻微的讽刺") - elif "生气" in message_text or "愤怒" in message_text: - emoji_result = await emoji_api.get_by_description("愤怒和不满") - else: - # 随机选择一个表情包 - emoji_result = await emoji_api.get_random() - - if emoji_result: - emoji_base64, description, scene = emoji_result - # 使用send_api发送表情包 - from src.plugin_system.apis import send_api - success = await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id) - return success - - return False -``` - -### 2. 表情包管理功能 - -```python -async def show_emoji_stats(): - """显示表情包统计信息""" - - # 获取基本信息 - count = emoji_api.get_count() - info = emoji_api.get_info() - scenes = emoji_api.get_emotions() # 实际返回的是场景关键词 - - stats = f""" -📊 表情包统计信息: -- 总数量: {count} -- 可用数量: {info['available_emojis']} -- 最大容量: {info['max_count']} -- 支持场景: {len(scenes)}种 - -🎭 支持的场景关键词: {', '.join(scenes[:10])}{'...' if len(scenes) > 10 else ''} - """ - - return stats -``` - -### 3. 表情包测试功能 - -```python -async def test_emoji_system(): - """测试表情包系统的各种功能""" - - print("=== 表情包系统测试 ===") - - # 测试场景描述查找 - test_descriptions = ["开心的大笑", "轻微的讽刺", "表示无奈和沮丧", "愤怒和不满"] - for desc in test_descriptions: - result = await emoji_api.get_by_description(desc) - if result: - _, description, scene = result - print(f"✅ 场景'{desc}' -> {description} ({scene})") - else: - print(f"❌ 场景'{desc}' -> 未找到") - - # 测试关键词查找 - scenes = emoji_api.get_emotions() - if scenes: - test_scene = scenes[0] - result = await emoji_api.get_by_emotion(test_scene) - if result: - print(f"✅ 关键词'{test_scene}' -> 找到匹配表情包") - - # 测试随机获取 - random_result = await emoji_api.get_random() - if random_result: - print("✅ 随机获取 -> 成功") - - print(f"📊 系统信息: {emoji_api.get_info()}") -``` - -### 4. 在Action中使用表情包 - -```python -from src.plugin_system.base import BaseAction - -class EmojiAction(BaseAction): - async def execute(self, action_data, chat_stream): - # 从action_data获取场景描述或关键词 - scene_keyword = action_data.get("scene", "") - scene_description = action_data.get("description", "") - - emoji_result = None - - # 优先使用具体的场景描述 - if scene_description: - emoji_result = await emoji_api.get_by_description(scene_description) - # 其次使用场景关键词 - elif scene_keyword: - emoji_result = await emoji_api.get_by_emotion(scene_keyword) - # 最后随机选择 - else: - emoji_result = await emoji_api.get_random() - - if emoji_result: - emoji_base64, description, scene = emoji_result - return { - "success": True, - "emoji_base64": emoji_base64, - "description": description, - "scene": scene - } - - return {"success": False, "message": "未找到合适的表情包"} -``` - ## 场景描述说明 ### 常用场景描述 -表情包系统支持多种具体的场景描述,常见的包括: +表情包系统支持多种具体的场景描述,举例如下: - **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈 - **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头 @@ -248,8 +117,8 @@ class EmojiAction(BaseAction): - **惊讶类场景**:震惊的表情、意外的发现、困惑的思考 - **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子 -### 场景关键词示例 -系统支持的场景关键词包括: +### 情感关键词示例 +系统支持的情感关键词举例如下: - 大笑、微笑、兴奋、手舞足蹈 - 无奈、沮丧、讽刺、无语、摇头 - 愤怒、不满、生气、瞪视、抓狂 @@ -263,9 +132,9 @@ class EmojiAction(BaseAction): ## 注意事项 -1. **异步函数**:获取表情包的函数都是异步的,需要使用 `await` +1. **异步函数**:部分函数是异步的,需要使用 `await` 2. **返回格式**:表情包以base64编码返回,可直接用于发送 -3. **错误处理**:所有函数都有错误处理,失败时返回None或默认值 +3. **错误处理**:所有函数都有错误处理,失败时返回None,空列表或默认值 4. **使用统计**:系统会记录表情包的使用次数 5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在 6. **编码格式**:返回的是base64编码的图片数据,可直接用于网络传输 diff --git a/docs/plugins/api/generator-api.md b/docs/plugins/api/generator-api.md index 964fff84..afeb6eec 100644 --- a/docs/plugins/api/generator-api.md +++ b/docs/plugins/api/generator-api.md @@ -6,241 +6,151 @@ ```python from src.plugin_system.apis import generator_api +# 或者 +from src.plugin_system import generator_api ``` ## 主要功能 ### 1. 回复器获取 - -#### `get_replyer(chat_stream=None, platform=None, chat_id=None, is_group=True)` +```python +def get_replyer( + chat_stream: Optional[ChatStream] = None, + chat_id: Optional[str] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, + request_type: str = "replyer", +) -> Optional[DefaultReplyer]: +``` 获取回复器对象 -**参数:** -- `chat_stream`:聊天流对象(优先) -- `platform`:平台名称,如"qq" -- `chat_id`:聊天ID(群ID或用户ID) -- `is_group`:是否为群聊 +优先使用chat_stream,如果没有则使用chat_id直接查找。 -**返回:** -- `DefaultReplyer`:回复器对象,如果获取失败则返回None +使用 ReplyerManager 来管理实例,避免重复创建。 -**示例:** +**Args:** +- `chat_stream`: 聊天流对象 +- `chat_id`: 聊天ID(实际上就是`stream_id`) +- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组 +- `request_type`: 请求类型,用于记录LLM使用情况,可以不写 + +**Returns:** +- `DefaultReplyer`: 回复器对象,如果获取失败则返回None + +#### 示例 ```python # 使用聊天流获取回复器 replyer = generator_api.get_replyer(chat_stream=chat_stream) -# 使用平台和ID获取回复器 -replyer = generator_api.get_replyer( - platform="qq", - chat_id="123456789", - is_group=True -) +# 使用平台和ID获取回复器 +replyer = generator_api.get_replyer(chat_id="123456789") ``` ### 2. 回复生成 - -#### `generate_reply(chat_stream=None, action_data=None, platform=None, chat_id=None, is_group=True)` +```python +async def generate_reply( + chat_stream: Optional[ChatStream] = None, + chat_id: Optional[str] = None, + action_data: Optional[Dict[str, Any]] = None, + reply_to: str = "", + extra_info: str = "", + available_actions: Optional[Dict[str, ActionInfo]] = None, + enable_tool: bool = False, + enable_splitter: bool = True, + enable_chinese_typo: bool = True, + return_prompt: bool = False, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, + request_type: str = "generator_api", +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: +``` 生成回复 -**参数:** -- `chat_stream`:聊天流对象(优先) -- `action_data`:动作数据 -- `platform`:平台名称(备用) -- `chat_id`:聊天ID(备用) -- `is_group`:是否为群聊(备用) +优先使用chat_stream,如果没有则使用chat_id直接查找。 -**返回:** -- `Tuple[bool, List[Tuple[str, Any]]]`:(是否成功, 回复集合) +**Args:** +- `chat_stream`: 聊天流对象 +- `chat_id`: 聊天ID(实际上就是`stream_id`) +- `action_data`: 动作数据(向下兼容,包含`reply_to`和`extra_info`) +- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}` +- `extra_info`: 附加信息 +- `available_actions`: 可用动作字典,格式为 `{"action_name": ActionInfo}` +- `enable_tool`: 是否启用工具 +- `enable_splitter`: 是否启用分割器 +- `enable_chinese_typo`: 是否启用中文错别字 +- `return_prompt`: 是否返回提示词 +- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组 +- `request_type`: 请求类型(可选,记录LLM使用) +- `request_type`: 请求类型,用于记录LLM使用情况 -**示例:** +**Returns:** +- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词) + +#### 示例 ```python -success, reply_set = await generator_api.generate_reply( +success, reply_set, prompt = await generator_api.generate_reply( chat_stream=chat_stream, - action_data={"message": "你好", "intent": "greeting"} + action_data=action_data, + reply_to="麦麦:你好", + available_actions=action_info, + enable_tool=True, + return_prompt=True ) - if success: for reply_type, reply_content in reply_set: print(f"回复类型: {reply_type}, 内容: {reply_content}") + if prompt: + print(f"使用的提示词: {prompt}") ``` -#### `rewrite_reply(chat_stream=None, reply_data=None, platform=None, chat_id=None, is_group=True)` -重写回复 - -**参数:** -- `chat_stream`:聊天流对象(优先) -- `reply_data`:回复数据 -- `platform`:平台名称(备用) -- `chat_id`:聊天ID(备用) -- `is_group`:是否为群聊(备用) - -**返回:** -- `Tuple[bool, List[Tuple[str, Any]]]`:(是否成功, 回复集合) - -**示例:** +### 3. 回复重写 ```python -success, reply_set = await generator_api.rewrite_reply( +async def rewrite_reply( + chat_stream: Optional[ChatStream] = None, + reply_data: Optional[Dict[str, Any]] = None, + chat_id: Optional[str] = None, + enable_splitter: bool = True, + enable_chinese_typo: bool = True, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, + raw_reply: str = "", + reason: str = "", + reply_to: str = "", + return_prompt: bool = False, +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: +``` +重写回复,使用新的内容替换旧的回复内容。 + +优先使用chat_stream,如果没有则使用chat_id直接查找。 + +**Args:** +- `chat_stream`: 聊天流对象 +- `reply_data`: 回复数据,包含`raw_reply`, `reason`和`reply_to`,**(向下兼容备用,当其他参数缺失时从此获取)** +- `chat_id`: 聊天ID(实际上就是`stream_id`) +- `enable_splitter`: 是否启用分割器 +- `enable_chinese_typo`: 是否启用中文错别字 +- `model_set_with_weight`: 模型配置列表,每个元素为 (TaskConfig, weight) 元组 +- `raw_reply`: 原始回复内容 +- `reason`: 重写原因 +- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}` + +**Returns:** +- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词) + +#### 示例 +```python +success, reply_set, prompt = await generator_api.rewrite_reply( chat_stream=chat_stream, - reply_data={"original_text": "原始回复", "style": "more_friendly"} + raw_reply="原始回复内容", + reason="重写原因", + reply_to="麦麦:你好", + return_prompt=True ) +if success: + for reply_type, reply_content in reply_set: + print(f"回复类型: {reply_type}, 内容: {reply_content}") + if prompt: + print(f"使用的提示词: {prompt}") ``` -## 使用示例 - -### 1. 基础回复生成 - -```python -from src.plugin_system.apis import generator_api - -async def generate_greeting_reply(chat_stream, user_name): - """生成问候回复""" - - action_data = { - "intent": "greeting", - "user_name": user_name, - "context": "morning_greeting" - } - - success, reply_set = await generator_api.generate_reply( - chat_stream=chat_stream, - action_data=action_data - ) - - if success and reply_set: - # 获取第一个回复 - reply_type, reply_content = reply_set[0] - return reply_content - - return "你好!" # 默认回复 -``` - -### 2. 在Action中使用回复生成器 - -```python -from src.plugin_system.base import BaseAction - -class ChatAction(BaseAction): - async def execute(self, action_data, chat_stream): - # 准备回复数据 - reply_context = { - "message_type": "response", - "user_input": action_data.get("user_message", ""), - "intent": action_data.get("intent", ""), - "entities": action_data.get("entities", {}), - "context": self.get_conversation_context(chat_stream) - } - - # 生成回复 - success, reply_set = await generator_api.generate_reply( - chat_stream=chat_stream, - action_data=reply_context - ) - - if success: - return { - "success": True, - "replies": reply_set, - "generated_count": len(reply_set) - } - - return { - "success": False, - "error": "回复生成失败", - "fallback_reply": "抱歉,我现在无法理解您的消息。" - } -``` - -### 3. 多样化回复生成 - -```python -async def generate_diverse_replies(chat_stream, topic, count=3): - """生成多个不同风格的回复""" - - styles = ["formal", "casual", "humorous"] - all_replies = [] - - for i, style in enumerate(styles[:count]): - action_data = { - "topic": topic, - "style": style, - "variation": i - } - - success, reply_set = await generator_api.generate_reply( - chat_stream=chat_stream, - action_data=action_data - ) - - if success and reply_set: - all_replies.extend(reply_set) - - return all_replies -``` - -### 4. 回复重写功能 - -```python -async def improve_reply(chat_stream, original_reply, improvement_type="more_friendly"): - """改进原始回复""" - - reply_data = { - "original_text": original_reply, - "improvement_type": improvement_type, - "target_audience": "young_users", - "tone": "positive" - } - - success, improved_replies = await generator_api.rewrite_reply( - chat_stream=chat_stream, - reply_data=reply_data - ) - - if success and improved_replies: - # 返回改进后的第一个回复 - _, improved_content = improved_replies[0] - return improved_content - - return original_reply # 如果改进失败,返回原始回复 -``` - -### 5. 条件回复生成 - -```python -async def conditional_reply_generation(chat_stream, user_message, user_emotion): - """根据用户情感生成条件回复""" - - # 根据情感调整回复策略 - if user_emotion == "sad": - action_data = { - "intent": "comfort", - "tone": "empathetic", - "style": "supportive" - } - elif user_emotion == "angry": - action_data = { - "intent": "calm", - "tone": "peaceful", - "style": "understanding" - } - else: - action_data = { - "intent": "respond", - "tone": "neutral", - "style": "helpful" - } - - action_data["user_message"] = user_message - action_data["user_emotion"] = user_emotion - - success, reply_set = await generator_api.generate_reply( - chat_stream=chat_stream, - action_data=action_data - ) - - return reply_set if success else [] -``` - -## 回复集合格式 +## 回复集合`reply_set`格式 ### 回复类型 生成的回复集合包含多种类型的回复: @@ -260,82 +170,32 @@ reply_set = [ ] ``` -## 高级用法 - -### 1. 自定义回复器配置 - +### 4. 自定义提示词回复 ```python -async def generate_with_custom_config(chat_stream, action_data): - """使用自定义配置生成回复""" - - # 获取回复器 - replyer = generator_api.get_replyer(chat_stream=chat_stream) - - if replyer: - # 可以访问回复器的内部方法 - success, reply_set = await replyer.generate_reply_with_context( - reply_data=action_data, - # 可以传递额外的配置参数 - ) - return success, reply_set - - return False, [] +async def generate_response_custom( + chat_stream: Optional[ChatStream] = None, + chat_id: Optional[str] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, + prompt: str = "", +) -> Optional[str]: ``` +生成自定义提示词回复 -### 2. 回复质量评估 +优先使用chat_stream,如果没有则使用chat_id直接查找。 -```python -async def generate_and_evaluate_replies(chat_stream, action_data): - """生成回复并评估质量""" - - success, reply_set = await generator_api.generate_reply( - chat_stream=chat_stream, - action_data=action_data - ) - - if success: - evaluated_replies = [] - for reply_type, reply_content in reply_set: - # 简单的质量评估 - quality_score = evaluate_reply_quality(reply_content) - evaluated_replies.append({ - "type": reply_type, - "content": reply_content, - "quality": quality_score - }) - - # 按质量排序 - evaluated_replies.sort(key=lambda x: x["quality"], reverse=True) - return evaluated_replies - - return [] +**Args:** +- `chat_stream`: 聊天流对象 +- `chat_id`: 聊天ID(备用) +- `model_set_with_weight`: 模型集合配置列表 +- `prompt`: 自定义提示词 -def evaluate_reply_quality(reply_content): - """简单的回复质量评估""" - if not reply_content: - return 0 - - score = 50 # 基础分 - - # 长度适中加分 - if 5 <= len(reply_content) <= 100: - score += 20 - - # 包含积极词汇加分 - positive_words = ["好", "棒", "不错", "感谢", "开心"] - for word in positive_words: - if word in reply_content: - score += 10 - break - - return min(score, 100) -``` +**Returns:** +- `Optional[str]`: 生成的自定义回复内容,如果生成失败则返回None ## 注意事项 -1. **异步操作**:所有生成函数都是异步的,必须使用`await` -2. **错误处理**:函数内置错误处理,失败时返回False和空列表 -3. **聊天流依赖**:需要有效的聊天流对象才能正常工作 -4. **性能考虑**:回复生成可能需要一些时间,特别是使用LLM时 -5. **回复格式**:返回的回复集合是元组列表,包含类型和内容 -6. **上下文感知**:生成器会考虑聊天上下文和历史消息 \ No newline at end of file +1. **异步操作**:部分函数是异步的,须使用`await` +2. **聊天流依赖**:需要有效的聊天流对象才能正常工作 +3. **性能考虑**:回复生成可能需要一些时间,特别是使用LLM时 +4. **回复格式**:返回的回复集合是元组列表,包含类型和内容 +5. **上下文感知**:生成器会考虑聊天上下文和历史消息,除非你用的是自定义提示词。 \ No newline at end of file diff --git a/docs/plugins/api/llm-api.md b/docs/plugins/api/llm-api.md index e0879ddf..9a266933 100644 --- a/docs/plugins/api/llm-api.md +++ b/docs/plugins/api/llm-api.md @@ -6,239 +6,34 @@ LLM API模块提供与大语言模型交互的功能,让插件能够使用系 ```python from src.plugin_system.apis import llm_api +# 或者 +from src.plugin_system import llm_api ``` ## 主要功能 -### 1. 模型管理 - -#### `get_available_models() -> Dict[str, Any]` -获取所有可用的模型配置 - -**返回:** -- `Dict[str, Any]`:模型配置字典,key为模型名称,value为模型配置 - -**示例:** +### 1. 查询可用模型 ```python -models = llm_api.get_available_models() -for model_name, model_config in models.items(): - print(f"模型: {model_name}") - print(f"配置: {model_config}") +def get_available_models() -> Dict[str, TaskConfig]: ``` +获取所有可用的模型配置。 -### 2. 内容生成 +**Return:** +- `Dict[str, TaskConfig]`:模型配置字典,key为模型名称,value为模型配置对象。 -#### `generate_with_model(prompt, model_config, request_type="plugin.generate", **kwargs)` -使用指定模型生成内容 - -**参数:** -- `prompt`:提示词 -- `model_config`:模型配置(从 get_available_models 获取) -- `request_type`:请求类型标识 -- `**kwargs`:其他模型特定参数,如temperature、max_tokens等 - -**返回:** -- `Tuple[bool, str, str, str]`:(是否成功, 生成的内容, 推理过程, 模型名称) - -**示例:** +### 2. 使用模型生成内容 ```python -models = llm_api.get_available_models() -default_model = models.get("default") - -if default_model: - success, response, reasoning, model_name = await llm_api.generate_with_model( - prompt="请写一首关于春天的诗", - model_config=default_model, - temperature=0.7, - max_tokens=200 - ) - - if success: - print(f"生成内容: {response}") - print(f"使用模型: {model_name}") +async def generate_with_model( + prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs +) -> Tuple[bool, str, str, str]: ``` +使用指定模型生成内容。 -## 使用示例 +**Args:** +- `prompt`:提示词。 +- `model_config`:模型配置对象(从 `get_available_models` 获取)。 +- `request_type`:请求类型标识,默认为 `"plugin.generate"`。 +- `**kwargs`:其他模型特定参数,如 `temperature`、`max_tokens` 等。 -### 1. 基础文本生成 - -```python -from src.plugin_system.apis import llm_api - -async def generate_story(topic: str): - """生成故事""" - models = llm_api.get_available_models() - model = models.get("default") - - if not model: - return "未找到可用模型" - - prompt = f"请写一个关于{topic}的短故事,大约100字左右。" - - success, story, reasoning, model_name = await llm_api.generate_with_model( - prompt=prompt, - model_config=model, - request_type="story.generate", - temperature=0.8, - max_tokens=150 - ) - - return story if success else "故事生成失败" -``` - -### 2. 在Action中使用LLM - -```python -from src.plugin_system.base import BaseAction - -class LLMAction(BaseAction): - async def execute(self, action_data, chat_stream): - # 获取用户输入 - user_input = action_data.get("user_message", "") - intent = action_data.get("intent", "chat") - - # 获取模型配置 - models = llm_api.get_available_models() - model = models.get("default") - - if not model: - return {"success": False, "error": "未配置LLM模型"} - - # 构建提示词 - prompt = self.build_prompt(user_input, intent) - - # 生成回复 - success, response, reasoning, model_name = await llm_api.generate_with_model( - prompt=prompt, - model_config=model, - request_type=f"plugin.{self.plugin_name}", - temperature=0.7 - ) - - if success: - return { - "success": True, - "response": response, - "model_used": model_name, - "reasoning": reasoning - } - - return {"success": False, "error": response} - - def build_prompt(self, user_input: str, intent: str) -> str: - """构建提示词""" - base_prompt = "你是一个友善的AI助手。" - - if intent == "question": - return f"{base_prompt}\n\n用户问题:{user_input}\n\n请提供准确、有用的回答:" - elif intent == "chat": - return f"{base_prompt}\n\n用户说:{user_input}\n\n请进行自然的对话:" - else: - return f"{base_prompt}\n\n用户输入:{user_input}\n\n请回复:" -``` - -### 3. 多模型对比 - -```python -async def compare_models(prompt: str): - """使用多个模型生成内容并对比""" - models = llm_api.get_available_models() - results = {} - - for model_name, model_config in models.items(): - success, response, reasoning, actual_model = await llm_api.generate_with_model( - prompt=prompt, - model_config=model_config, - request_type="comparison.test" - ) - - results[model_name] = { - "success": success, - "response": response, - "model": actual_model, - "reasoning": reasoning - } - - return results -``` - -### 4. 智能对话插件 - -```python -class ChatbotPlugin(BasePlugin): - async def handle_action(self, action_data, chat_stream): - user_message = action_data.get("message", "") - - # 获取历史对话上下文 - context = self.get_conversation_context(chat_stream) - - # 构建对话提示词 - prompt = self.build_conversation_prompt(user_message, context) - - # 获取模型配置 - models = llm_api.get_available_models() - chat_model = models.get("chat", models.get("default")) - - if not chat_model: - return {"success": False, "message": "聊天模型未配置"} - - # 生成回复 - success, response, reasoning, model_name = await llm_api.generate_with_model( - prompt=prompt, - model_config=chat_model, - request_type="chat.conversation", - temperature=0.8, - max_tokens=500 - ) - - if success: - # 保存对话历史 - self.save_conversation(chat_stream, user_message, response) - - return { - "success": True, - "reply": response, - "model": model_name - } - - return {"success": False, "message": "回复生成失败"} - - def build_conversation_prompt(self, user_message: str, context: list) -> str: - """构建对话提示词""" - prompt = "你是一个有趣、友善的聊天机器人。请自然地回复用户的消息。\n\n" - - # 添加历史对话 - if context: - prompt += "对话历史:\n" - for msg in context[-5:]: # 只保留最近5条 - prompt += f"用户: {msg['user']}\n机器人: {msg['bot']}\n" - prompt += "\n" - - prompt += f"用户: {user_message}\n机器人: " - return prompt -``` - -## 模型配置说明 - -### 常用模型类型 -- `default`:默认模型 -- `chat`:聊天专用模型 -- `creative`:创意生成模型 -- `code`:代码生成模型 - -### 配置参数 -LLM模型支持的常用参数: -- `temperature`:控制输出随机性(0.0-1.0) -- `max_tokens`:最大生成长度 -- `top_p`:核采样参数 -- `frequency_penalty`:频率惩罚 -- `presence_penalty`:存在惩罚 - -## 注意事项 - -1. **异步操作**:LLM生成是异步的,必须使用`await` -2. **错误处理**:生成失败时返回False和错误信息 -3. **配置依赖**:需要正确配置模型才能使用 -4. **请求类型**:建议为不同用途设置不同的request_type -5. **性能考虑**:LLM调用可能较慢,考虑超时和缓存 -6. **成本控制**:注意控制max_tokens以控制成本 \ No newline at end of file +**Return:** +- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。 \ No newline at end of file diff --git a/docs/plugins/api/logging-api.md b/docs/plugins/api/logging-api.md new file mode 100644 index 00000000..5576bf5c --- /dev/null +++ b/docs/plugins/api/logging-api.md @@ -0,0 +1,29 @@ +# Logging API + +Logging API模块提供了获取本体logger的功能,允许插件记录日志信息。 + +## 导入方式 + +```python +from src.plugin_system.apis import get_logger +# 或者 +from src.plugin_system import get_logger +``` + +## 主要功能 +### 1. 获取本体logger +```python +def get_logger(name: str) -> structlog.stdlib.BoundLogger: +``` +获取本体logger实例。 + +**Args:** +- `name` (str): 日志记录器的名称。 + +**Returns:** +- 一个logger实例,有以下方法: + - `debug` + - `info` + - `warning` + - `error` + - `critical` \ No newline at end of file diff --git a/docs/plugins/api/message-api.md b/docs/plugins/api/message-api.md index c95a9cc6..85d83a9b 100644 --- a/docs/plugins/api/message-api.md +++ b/docs/plugins/api/message-api.md @@ -1,11 +1,13 @@ # 消息API -> 消息API提供了强大的消息查询、计数和格式化功能,让你轻松处理聊天消息数据。 +消息API提供了强大的消息查询、计数和格式化功能,让你轻松处理聊天消息数据。 ## 导入方式 ```python from src.plugin_system.apis import message_api +# 或者 +from src.plugin_system import message_api ``` ## 功能概述 @@ -15,297 +17,356 @@ from src.plugin_system.apis import message_api - **消息计数** - 统计新消息数量 - **消息格式化** - 将消息转换为可读格式 ---- +## 主要功能 -## 消息查询API +### 1. 按照事件查询消息 +```python +def get_messages_by_time( + start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False +) -> List[Dict[str, Any]]: +``` +获取指定时间范围内的消息。 -### 按时间查询消息 - -#### `get_messages_by_time(start_time, end_time, limit=0, limit_mode="latest")` - -获取指定时间范围内的消息 - -**参数:** +**Args:** - `start_time` (float): 开始时间戳 -- `end_time` (float): 结束时间戳 +- `end_time` (float): 结束时间戳 - `limit` (int): 限制返回消息数量,0为不限制 - `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False -**返回:** `List[Dict[str, Any]]` - 消息列表 +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 -**示例:** +消息列表中包含的键与`Messages`类的属性一致。(位于`src.common.database.database_model`) + +### 2. 获取指定聊天中指定时间范围内的信息 ```python -import time - -# 获取最近24小时的消息 -now = time.time() -yesterday = now - 24 * 3600 -messages = message_api.get_messages_by_time(yesterday, now, limit=50) +def get_messages_by_time_in_chat( + chat_id: str, + start_time: float, + end_time: float, + limit: int = 0, + limit_mode: str = "latest", + filter_mai: bool = False, +) -> List[Dict[str, Any]]: ``` +获取指定聊天中指定时间范围内的消息。 -### 按聊天查询消息 - -#### `get_messages_by_time_in_chat(chat_id, start_time, end_time, limit=0, limit_mode="latest")` - -获取指定聊天中指定时间范围内的消息 - -**参数:** -- `chat_id` (str): 聊天ID -- 其他参数同上 - -**示例:** -```python -# 获取某个群聊最近的100条消息 -messages = message_api.get_messages_by_time_in_chat( - chat_id="123456789", - start_time=yesterday, - end_time=now, - limit=100 -) -``` - -#### `get_messages_by_time_in_chat_inclusive(chat_id, start_time, end_time, limit=0, limit_mode="latest")` - -获取指定聊天中指定时间范围内的消息(包含边界时间点) - -与 `get_messages_by_time_in_chat` 类似,但包含边界时间戳的消息。 - -#### `get_recent_messages(chat_id, hours=24.0, limit=100, limit_mode="latest")` - -获取指定聊天中最近一段时间的消息(便捷方法) - -**参数:** -- `chat_id` (str): 聊天ID -- `hours` (float): 最近多少小时,默认24小时 -- `limit` (int): 限制返回消息数量,默认100条 -- `limit_mode` (str): 限制模式 - -**示例:** -```python -# 获取最近6小时的消息 -recent_messages = message_api.get_recent_messages( - chat_id="123456789", - hours=6.0, - limit=50 -) -``` - -### 按用户查询消息 - -#### `get_messages_by_time_in_chat_for_users(chat_id, start_time, end_time, person_ids, limit=0, limit_mode="latest")` - -获取指定聊天中指定用户在指定时间范围内的消息 - -**参数:** +**Args:** - `chat_id` (str): 聊天ID - `start_time` (float): 开始时间戳 - `end_time` (float): 结束时间戳 -- `person_ids` (list): 用户ID列表 -- `limit` (int): 限制返回消息数量 -- `limit_mode` (str): 限制模式 +- `limit` (int): 限制返回消息数量,0为不限制 +- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False -**示例:** +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 + + +### 3. 获取指定聊天中指定时间范围内的信息(包含边界) ```python -# 获取特定用户的消息 -user_messages = message_api.get_messages_by_time_in_chat_for_users( - chat_id="123456789", - start_time=yesterday, - end_time=now, - person_ids=["user1", "user2"] -) +def get_messages_by_time_in_chat_inclusive( + chat_id: str, + start_time: float, + end_time: float, + limit: int = 0, + limit_mode: str = "latest", + filter_mai: bool = False, + filter_command: bool = False, +) -> List[Dict[str, Any]]: ``` +获取指定聊天中指定时间范围内的消息(包含边界)。 -#### `get_messages_by_time_for_users(start_time, end_time, person_ids, limit=0, limit_mode="latest")` +**Args:** +- `chat_id` (str): 聊天ID +- `start_time` (float): 开始时间戳(包含) +- `end_time` (float): 结束时间戳(包含) +- `limit` (int): 限制返回消息数量,0为不限制 +- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False +- `filter_command` (bool): 是否过滤命令消息,默认False -获取指定用户在所有聊天中指定时间范围内的消息 +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 -### 其他查询方法 -#### `get_random_chat_messages(start_time, end_time, limit=0, limit_mode="latest")` +### 4. 获取指定聊天中指定用户在指定时间范围内的消息 +```python +def get_messages_by_time_in_chat_for_users( + chat_id: str, + start_time: float, + end_time: float, + person_ids: List[str], + limit: int = 0, + limit_mode: str = "latest", +) -> List[Dict[str, Any]]: +``` +获取指定聊天中指定用户在指定时间范围内的消息。 -随机选择一个聊天,返回该聊天在指定时间范围内的消息 - -#### `get_messages_before_time(timestamp, limit=0)` - -获取指定时间戳之前的消息 - -#### `get_messages_before_time_in_chat(chat_id, timestamp, limit=0)` - -获取指定聊天中指定时间戳之前的消息 - -#### `get_messages_before_time_for_users(timestamp, person_ids, limit=0)` - -获取指定用户在指定时间戳之前的消息 - ---- - -## 消息计数API - -### `count_new_messages(chat_id, start_time=0.0, end_time=None)` - -计算指定聊天中从开始时间到结束时间的新消息数量 - -**参数:** +**Args:** - `chat_id` (str): 聊天ID - `start_time` (float): 开始时间戳 -- `end_time` (float): 结束时间戳,如果为None则使用当前时间 +- `end_time` (float): 结束时间戳 +- `person_ids` (List[str]): 用户ID列表 +- `limit` (int): 限制返回消息数量,0为不限制 +- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 -**返回:** `int` - 新消息数量 +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 -**示例:** + +### 5. 随机选择一个聊天,返回该聊天在指定时间范围内的消息 ```python -# 计算最近1小时的新消息数 -import time -now = time.time() -hour_ago = now - 3600 -new_count = message_api.count_new_messages("123456789", hour_ago, now) -print(f"最近1小时有{new_count}条新消息") +def get_random_chat_messages( + start_time: float, + end_time: float, + limit: int = 0, + limit_mode: str = "latest", + filter_mai: bool = False, +) -> List[Dict[str, Any]]: ``` +随机选择一个聊天,返回该聊天在指定时间范围内的消息。 -### `count_new_messages_for_users(chat_id, start_time, end_time, person_ids)` +**Args:** +- `start_time` (float): 开始时间戳 +- `end_time` (float): 结束时间戳 +- `limit` (int): 限制返回消息数量,0为不限制 +- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False -计算指定聊天中指定用户从开始时间到结束时间的新消息数量 +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 ---- -## 消息格式化API +### 6. 获取指定用户在所有聊天中指定时间范围内的消息 +```python +def get_messages_by_time_for_users( + start_time: float, + end_time: float, + person_ids: List[str], + limit: int = 0, + limit_mode: str = "latest", +) -> List[Dict[str, Any]]: +``` +获取指定用户在所有聊天中指定时间范围内的消息。 -### `build_readable_messages_to_str(messages, **options)` +**Args:** +- `start_time` (float): 开始时间戳 +- `end_time` (float): 结束时间戳 +- `person_ids` (List[str]): 用户ID列表 +- `limit` (int): 限制返回消息数量,0为不限制 +- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 -将消息列表构建成可读的字符串 +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 -**参数:** + +### 7. 获取指定时间戳之前的消息 +```python +def get_messages_before_time( + timestamp: float, + limit: int = 0, + filter_mai: bool = False, +) -> List[Dict[str, Any]]: +``` +获取指定时间戳之前的消息。 + +**Args:** +- `timestamp` (float): 时间戳 +- `limit` (int): 限制返回消息数量,0为不限制 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False + +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 + + +### 8. 获取指定聊天中指定时间戳之前的消息 +```python +def get_messages_before_time_in_chat( + chat_id: str, + timestamp: float, + limit: int = 0, + filter_mai: bool = False, +) -> List[Dict[str, Any]]: +``` +获取指定聊天中指定时间戳之前的消息。 + +**Args:** +- `chat_id` (str): 聊天ID +- `timestamp` (float): 时间戳 +- `limit` (int): 限制返回消息数量,0为不限制 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False + +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 + + +### 9. 获取指定用户在指定时间戳之前的消息 +```python +def get_messages_before_time_for_users( + timestamp: float, + person_ids: List[str], + limit: int = 0, +) -> List[Dict[str, Any]]: +``` +获取指定用户在指定时间戳之前的消息。 + +**Args:** +- `timestamp` (float): 时间戳 +- `person_ids` (List[str]): 用户ID列表 +- `limit` (int): 限制返回消息数量,0为不限制 + +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 + + +### 10. 获取指定聊天中最近一段时间的消息 +```python +def get_recent_messages( + chat_id: str, + hours: float = 24.0, + limit: int = 100, + limit_mode: str = "latest", + filter_mai: bool = False, +) -> List[Dict[str, Any]]: +``` +获取指定聊天中最近一段时间的消息。 + +**Args:** +- `chat_id` (str): 聊天ID +- `hours` (float): 最近多少小时,默认24小时 +- `limit` (int): 限制返回消息数量,默认100条 +- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 +- `filter_mai` (bool): 是否过滤掉机器人的消息,默认False + +**Returns:** +- `List[Dict[str, Any]]` - 消息列表 + + +### 11. 计算指定聊天中从开始时间到结束时间的新消息数量 +```python +def count_new_messages( + chat_id: str, + start_time: float = 0.0, + end_time: Optional[float] = None, +) -> int: +``` +计算指定聊天中从开始时间到结束时间的新消息数量。 + +**Args:** +- `chat_id` (str): 聊天ID +- `start_time` (float): 开始时间戳 +- `end_time` (Optional[float]): 结束时间戳,如果为None则使用当前时间 + +**Returns:** +- `int` - 新消息数量 + + +### 12. 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 +```python +def count_new_messages_for_users( + chat_id: str, + start_time: float, + end_time: float, + person_ids: List[str], +) -> int: +``` +计算指定聊天中指定用户从开始时间到结束时间的新消息数量。 + +**Args:** +- `chat_id` (str): 聊天ID +- `start_time` (float): 开始时间戳 +- `end_time` (float): 结束时间戳 +- `person_ids` (List[str]): 用户ID列表 + +**Returns:** +- `int` - 新消息数量 + + +### 13. 将消息列表构建成可读的字符串 +```python +def build_readable_messages_to_str( + messages: List[Dict[str, Any]], + replace_bot_name: bool = True, + merge_messages: bool = False, + timestamp_mode: str = "relative", + read_mark: float = 0.0, + truncate: bool = False, + show_actions: bool = False, +) -> str: +``` +将消息列表构建成可读的字符串。 + +**Args:** - `messages` (List[Dict[str, Any]]): 消息列表 -- `replace_bot_name` (bool): 是否将机器人的名称替换为"你",默认True -- `merge_messages` (bool): 是否合并连续消息,默认False -- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`,默认`"relative"` -- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息,默认0.0 -- `truncate` (bool): 是否截断长消息,默认False -- `show_actions` (bool): 是否显示动作记录,默认False +- `replace_bot_name` (bool): 是否将机器人的名称替换为"你" +- `merge_messages` (bool): 是否合并连续消息 +- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"` +- `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息 +- `truncate` (bool): 是否截断长消息 +- `show_actions` (bool): 是否显示动作记录 -**返回:** `str` - 格式化后的可读字符串 +**Returns:** +- `str` - 格式化后的可读字符串 -**示例:** + +### 14. 将消息列表构建成可读的字符串,并返回详细信息 ```python -# 获取消息并格式化为可读文本 -messages = message_api.get_recent_messages("123456789", hours=2) -readable_text = message_api.build_readable_messages_to_str( - messages, - replace_bot_name=True, - merge_messages=True, - timestamp_mode="relative" -) -print(readable_text) +async def build_readable_messages_with_details( + messages: List[Dict[str, Any]], + replace_bot_name: bool = True, + merge_messages: bool = False, + timestamp_mode: str = "relative", + truncate: bool = False, +) -> Tuple[str, List[Tuple[float, str, str]]]: ``` +将消息列表构建成可读的字符串,并返回详细信息。 -### `build_readable_messages_with_details(messages, **options)` 异步 +**Args:** +- `messages` (List[Dict[str, Any]]): 消息列表 +- `replace_bot_name` (bool): 是否将机器人的名称替换为"你" +- `merge_messages` (bool): 是否合并连续消息 +- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"` +- `truncate` (bool): 是否截断长消息 -将消息列表构建成可读的字符串,并返回详细信息 +**Returns:** +- `Tuple[str, List[Tuple[float, str, str]]]` - 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容) -**参数:** 与 `build_readable_messages_to_str` 类似,但不包含 `read_mark` 和 `show_actions` -**返回:** `Tuple[str, List[Tuple[float, str, str]]]` - 格式化字符串和详细信息元组列表(时间戳, 昵称, 内容) - -**示例:** +### 15. 从消息列表中提取不重复的用户ID列表 ```python -# 异步获取详细格式化信息 -readable_text, details = await message_api.build_readable_messages_with_details( - messages, - timestamp_mode="absolute" -) - -for timestamp, nickname, content in details: - print(f"{timestamp}: {nickname} 说: {content}") +async def get_person_ids_from_messages( + messages: List[Dict[str, Any]], +) -> List[str]: ``` +从消息列表中提取不重复的用户ID列表。 -### `get_person_ids_from_messages(messages)` 异步 - -从消息列表中提取不重复的用户ID列表 - -**参数:** +**Args:** - `messages` (List[Dict[str, Any]]): 消息列表 -**返回:** `List[str]` - 用户ID列表 +**Returns:** +- `List[str]` - 用户ID列表 -**示例:** + +### 16. 从消息列表中移除机器人的消息 ```python -# 获取参与对话的所有用户ID -messages = message_api.get_recent_messages("123456789") -person_ids = await message_api.get_person_ids_from_messages(messages) -print(f"参与对话的用户: {person_ids}") +def filter_mai_messages( + messages: List[Dict[str, Any]], +) -> List[Dict[str, Any]]: ``` +从消息列表中移除机器人的消息。 ---- +**Args:** +- `messages` (List[Dict[str, Any]]): 消息列表,每个元素是消息字典 -## 完整使用示例 - -### 场景1:统计活跃度 - -```python -import time -from src.plugin_system.apis import message_api - -async def analyze_chat_activity(chat_id: str): - """分析聊天活跃度""" - now = time.time() - day_ago = now - 24 * 3600 - - # 获取最近24小时的消息 - messages = message_api.get_recent_messages(chat_id, hours=24) - - # 统计消息数量 - total_count = len(messages) - - # 获取参与用户 - person_ids = await message_api.get_person_ids_from_messages(messages) - - # 格式化消息内容 - readable_text = message_api.build_readable_messages_to_str( - messages[-10:], # 最后10条消息 - merge_messages=True, - timestamp_mode="relative" - ) - - return { - "total_messages": total_count, - "active_users": len(person_ids), - "recent_chat": readable_text - } -``` - -### 场景2:查看特定用户的历史消息 - -```python -def get_user_history(chat_id: str, user_id: str, days: int = 7): - """获取用户最近N天的消息历史""" - now = time.time() - start_time = now - days * 24 * 3600 - - # 获取特定用户的消息 - user_messages = message_api.get_messages_by_time_in_chat_for_users( - chat_id=chat_id, - start_time=start_time, - end_time=now, - person_ids=[user_id], - limit=100 - ) - - # 格式化为可读文本 - readable_history = message_api.build_readable_messages_to_str( - user_messages, - replace_bot_name=False, - timestamp_mode="absolute" - ) - - return readable_history -``` - ---- +**Returns:** +- `List[Dict[str, Any]]` - 过滤后的消息列表 ## 注意事项 1. **时间戳格式**:所有时间参数都使用Unix时间戳(float类型) -2. **异步函数**:`build_readable_messages_with_details` 和 `get_person_ids_from_messages` 是异步函数,需要使用 `await` +2. **异步函数**:部分函数是异步函数,需要使用 `await` 3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数 4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息 5. **用户ID**:`person_ids` 参数接受字符串列表,用于筛选特定用户的消息 \ No newline at end of file diff --git a/docs/plugins/api/person-api.md b/docs/plugins/api/person-api.md index 3e1bafaf..f97498dc 100644 --- a/docs/plugins/api/person-api.md +++ b/docs/plugins/api/person-api.md @@ -6,59 +6,65 @@ ```python from src.plugin_system.apis import person_api +# 或者 +from src.plugin_system import person_api ``` ## 主要功能 -### 1. Person ID管理 - -#### `get_person_id(platform: str, user_id: int) -> str` +### 1. Person ID 获取 +```python +def get_person_id(platform: str, user_id: int) -> str: +``` 根据平台和用户ID获取person_id -**参数:** +**Args:** - `platform`:平台名称,如 "qq", "telegram" 等 - `user_id`:用户ID -**返回:** +**Returns:** - `str`:唯一的person_id(MD5哈希值) -**示例:** +#### 示例 ```python person_id = person_api.get_person_id("qq", 123456) -print(f"Person ID: {person_id}") ``` ### 2. 用户信息查询 +```python +async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any: +``` +查询单个用户信息字段值 -#### `get_person_value(person_id: str, field_name: str, default: Any = None) -> Any` -根据person_id和字段名获取某个值 - -**参数:** +**Args:** - `person_id`:用户的唯一标识ID -- `field_name`:要获取的字段名,如 "nickname", "impression" 等 -- `default`:当字段不存在或获取失败时返回的默认值 +- `field_name`:要获取的字段名 +- `default`:字段值不存在时的默认值 -**返回:** +**Returns:** - `Any`:字段值或默认值 -**示例:** +#### 示例 ```python nickname = await person_api.get_person_value(person_id, "nickname", "未知用户") impression = await person_api.get_person_value(person_id, "impression") ``` -#### `get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict` +### 3. 批量用户信息查询 +```python +async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict: +``` 批量获取用户信息字段值 -**参数:** +**Args:** - `person_id`:用户的唯一标识ID - `field_names`:要获取的字段名列表 - `default_dict`:默认值字典,键为字段名,值为默认值 -**返回:** +**Returns:** - `dict`:字段名到值的映射字典 -**示例:** +#### 示例 ```python values = await person_api.get_person_values( person_id, @@ -67,204 +73,31 @@ values = await person_api.get_person_values( ) ``` -### 3. 用户状态查询 - -#### `is_person_known(platform: str, user_id: int) -> bool` +### 4. 判断用户是否已知 +```python +async def is_person_known(platform: str, user_id: int) -> bool: +``` 判断是否认识某个用户 -**参数:** +**Args:** - `platform`:平台名称 - `user_id`:用户ID -**返回:** +**Returns:** - `bool`:是否认识该用户 -**示例:** +### 5. 根据用户名获取Person ID ```python -known = await person_api.is_person_known("qq", 123456) -if known: - print("这个用户我认识") +def get_person_id_by_name(person_name: str) -> str: ``` - -### 4. 用户名查询 - -#### `get_person_id_by_name(person_name: str) -> str` 根据用户名获取person_id -**参数:** +**Args:** - `person_name`:用户名 -**返回:** +**Returns:** - `str`:person_id,如果未找到返回空字符串 -**示例:** -```python -person_id = person_api.get_person_id_by_name("张三") -if person_id: - print(f"找到用户: {person_id}") -``` - -## 使用示例 - -### 1. 基础用户信息获取 - -```python -from src.plugin_system.apis import person_api - -async def get_user_info(platform: str, user_id: int): - """获取用户基本信息""" - - # 获取person_id - person_id = person_api.get_person_id(platform, user_id) - - # 获取用户信息 - user_info = await person_api.get_person_values( - person_id, - ["nickname", "impression", "know_times", "last_seen"], - { - "nickname": "未知用户", - "impression": "", - "know_times": 0, - "last_seen": 0 - } - ) - - return { - "person_id": person_id, - "nickname": user_info["nickname"], - "impression": user_info["impression"], - "know_times": user_info["know_times"], - "last_seen": user_info["last_seen"] - } -``` - -### 2. 在Action中使用用户信息 - -```python -from src.plugin_system.base import BaseAction - -class PersonalizedAction(BaseAction): - async def execute(self, action_data, chat_stream): - # 获取发送者信息 - user_id = chat_stream.user_info.user_id - platform = chat_stream.platform - - # 获取person_id - person_id = person_api.get_person_id(platform, user_id) - - # 获取用户昵称和印象 - nickname = await person_api.get_person_value(person_id, "nickname", "朋友") - impression = await person_api.get_person_value(person_id, "impression", "") - - # 根据用户信息个性化回复 - if impression: - response = f"你好 {nickname}!根据我对你的了解:{impression}" - else: - response = f"你好 {nickname}!很高兴见到你。" - - return { - "success": True, - "response": response, - "user_info": { - "nickname": nickname, - "impression": impression - } - } -``` - -### 3. 用户识别和欢迎 - -```python -async def welcome_user(chat_stream): - """欢迎用户,区分新老用户""" - - user_id = chat_stream.user_info.user_id - platform = chat_stream.platform - - # 检查是否认识这个用户 - is_known = await person_api.is_person_known(platform, user_id) - - if is_known: - # 老用户,获取详细信息 - person_id = person_api.get_person_id(platform, user_id) - nickname = await person_api.get_person_value(person_id, "nickname", "老朋友") - know_times = await person_api.get_person_value(person_id, "know_times", 0) - - welcome_msg = f"欢迎回来,{nickname}!我们已经聊过 {know_times} 次了。" - else: - # 新用户 - welcome_msg = "你好!很高兴认识你,我是MaiBot。" - - return welcome_msg -``` - -### 4. 用户搜索功能 - -```python -async def find_user_by_name(name: str): - """根据名字查找用户""" - - person_id = person_api.get_person_id_by_name(name) - - if not person_id: - return {"found": False, "message": f"未找到名为 '{name}' 的用户"} - - # 获取用户详细信息 - user_info = await person_api.get_person_values( - person_id, - ["nickname", "platform", "user_id", "impression", "know_times"], - {} - ) - - return { - "found": True, - "person_id": person_id, - "info": user_info - } -``` - -### 5. 用户印象分析 - -```python -async def analyze_user_relationship(chat_stream): - """分析用户关系""" - - user_id = chat_stream.user_info.user_id - platform = chat_stream.platform - person_id = person_api.get_person_id(platform, user_id) - - # 获取关系相关信息 - relationship_info = await person_api.get_person_values( - person_id, - ["nickname", "impression", "know_times", "relationship_level", "last_interaction"], - { - "nickname": "未知", - "impression": "", - "know_times": 0, - "relationship_level": "stranger", - "last_interaction": 0 - } - ) - - # 分析关系程度 - know_times = relationship_info["know_times"] - if know_times == 0: - relationship = "陌生人" - elif know_times < 5: - relationship = "新朋友" - elif know_times < 20: - relationship = "熟人" - else: - relationship = "老朋友" - - return { - "nickname": relationship_info["nickname"], - "relationship": relationship, - "impression": relationship_info["impression"], - "interaction_count": know_times - } -``` - ## 常用字段说明 ### 基础信息字段 @@ -274,69 +107,13 @@ async def analyze_user_relationship(chat_stream): ### 关系信息字段 - `impression`:对用户的印象 -- `know_times`:交互次数 -- `relationship_level`:关系等级 -- `last_seen`:最后见面时间 -- `last_interaction`:最后交互时间 +- `points`: 用户特征点 -### 个性化字段 -- `preferences`:用户偏好 -- `interests`:兴趣爱好 -- `mood_history`:情绪历史 -- `topic_interests`:话题兴趣 - -## 最佳实践 - -### 1. 错误处理 -```python -async def safe_get_user_info(person_id: str, field: str): - """安全获取用户信息""" - try: - value = await person_api.get_person_value(person_id, field) - return value if value is not None else "未设置" - except Exception as e: - logger.error(f"获取用户信息失败: {e}") - return "获取失败" -``` - -### 2. 批量操作 -```python -async def get_complete_user_profile(person_id: str): - """获取完整用户档案""" - - # 一次性获取所有需要的字段 - fields = [ - "nickname", "impression", "know_times", - "preferences", "interests", "relationship_level" - ] - - defaults = { - "nickname": "用户", - "impression": "", - "know_times": 0, - "preferences": "{}", - "interests": "[]", - "relationship_level": "stranger" - } - - profile = await person_api.get_person_values(person_id, fields, defaults) - - # 处理JSON字段 - try: - profile["preferences"] = json.loads(profile["preferences"]) - profile["interests"] = json.loads(profile["interests"]) - except: - profile["preferences"] = {} - profile["interests"] = [] - - return profile -``` +其他字段可以参考`PersonInfo`类的属性(位于`src.common.database.database_model`) ## 注意事项 -1. **异步操作**:大部分查询函数都是异步的,需要使用`await` -2. **错误处理**:所有函数都有错误处理,失败时记录日志并返回默认值 -3. **数据类型**:返回的数据可能是字符串、数字或JSON,需要适当处理 -4. **性能考虑**:批量查询优于单个查询 -5. **隐私保护**:确保用户信息的使用符合隐私政策 -6. **数据一致性**:person_id是用户的唯一标识,应妥善保存和使用 \ No newline at end of file +1. **异步操作**:部分查询函数都是异步的,需要使用`await` +2. **性能考虑**:批量查询优于单个查询 +3. **隐私保护**:确保用户信息的使用符合隐私政策 +4. **数据一致性**:person_id是用户的唯一标识,应妥善保存和使用 \ No newline at end of file diff --git a/docs/plugins/api/plugin-manage-api.md b/docs/plugins/api/plugin-manage-api.md new file mode 100644 index 00000000..688ea9ef --- /dev/null +++ b/docs/plugins/api/plugin-manage-api.md @@ -0,0 +1,105 @@ +# 插件管理API + +插件管理API模块提供了对插件的加载、卸载、重新加载以及目录管理功能。 + +## 导入方式 +```python +from src.plugin_system.apis import plugin_manage_api +# 或者 +from src.plugin_system import plugin_manage_api +``` + +## 功能概述 + +插件管理API主要提供以下功能: +- **插件查询** - 列出当前加载的插件或已注册的插件。 +- **插件管理** - 加载、卸载、重新加载插件。 +- **插件目录管理** - 添加插件目录并重新扫描。 + +## 主要功能 + +### 1. 列出当前加载的插件 +```python +def list_loaded_plugins() -> List[str]: +``` +列出所有当前加载的插件。 + +**Returns:** +- `List[str]` - 当前加载的插件名称列表。 + +### 2. 列出所有已注册的插件 +```python +def list_registered_plugins() -> List[str]: +``` +列出所有已注册的插件。 + +**Returns:** +- `List[str]` - 已注册的插件名称列表。 + +### 3. 获取插件路径 +```python +def get_plugin_path(plugin_name: str) -> str: +``` +获取指定插件的路径。 + +**Args:** +- `plugin_name` (str): 要查询的插件名称。 +**Returns:** +- `str` - 插件的路径,如果插件不存在则 raise ValueError。 + +### 4. 卸载指定的插件 +```python +async def remove_plugin(plugin_name: str) -> bool: +``` +卸载指定的插件。 + +**Args:** +- `plugin_name` (str): 要卸载的插件名称。 + +**Returns:** +- `bool` - 卸载是否成功。 + +### 5. 重新加载指定的插件 +```python +async def reload_plugin(plugin_name: str) -> bool: +``` +重新加载指定的插件。 + +**Args:** +- `plugin_name` (str): 要重新加载的插件名称。 + +**Returns:** +- `bool` - 重新加载是否成功。 + +### 6. 加载指定的插件 +```python +def load_plugin(plugin_name: str) -> Tuple[bool, int]: +``` +加载指定的插件。 + +**Args:** +- `plugin_name` (str): 要加载的插件名称。 + +**Returns:** +- `Tuple[bool, int]` - 加载是否成功,成功或失败的个数。 + +### 7. 添加插件目录 +```python +def add_plugin_directory(plugin_directory: str) -> bool: +``` +添加插件目录。 + +**Args:** +- `plugin_directory` (str): 要添加的插件目录路径。 + +**Returns:** +- `bool` - 添加是否成功。 + +### 8. 重新扫描插件目录 +```python +def rescan_plugin_directory() -> Tuple[int, int]: +``` +重新扫描插件目录,加载新插件。 + +**Returns:** +- `Tuple[int, int]` - 成功加载的插件数量和失败的插件数量。 \ No newline at end of file diff --git a/docs/plugins/api/send-api.md b/docs/plugins/api/send-api.md index 79335c61..8b3c607f 100644 --- a/docs/plugins/api/send-api.md +++ b/docs/plugins/api/send-api.md @@ -6,86 +6,108 @@ ```python from src.plugin_system.apis import send_api +# 或者 +from src.plugin_system import send_api ``` ## 主要功能 -### 1. 文本消息发送 +### 1. 发送文本消息 +```python +async def text_to_stream( + text: str, + stream_id: str, + typing: bool = False, + reply_to: str = "", + storage_message: bool = True, +) -> bool: +``` +发送文本消息到指定的流 -#### `text_to_group(text, group_id, platform="qq", typing=False, reply_to="", storage_message=True)` -向群聊发送文本消息 +**Args:** +- `text` (str): 要发送的文本内容 +- `stream_id` (str): 聊天流ID +- `typing` (bool): 是否显示正在输入 +- `reply_to` (str): 回复消息,格式为"发送者:消息内容" +- `storage_message` (bool): 是否存储消息到数据库 -**参数:** -- `text`:要发送的文本内容 -- `group_id`:群聊ID -- `platform`:平台,默认为"qq" -- `typing`:是否显示正在输入 -- `reply_to`:回复消息的格式,如"发送者:消息内容" -- `storage_message`:是否存储到数据库 +**Returns:** +- `bool` - 是否发送成功 -**返回:** -- `bool`:是否发送成功 +### 2. 发送表情包 +```python +async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool: +``` +向指定流发送表情包。 -#### `text_to_user(text, user_id, platform="qq", typing=False, reply_to="", storage_message=True)` -向用户发送私聊文本消息 +**Args:** +- `emoji_base64` (str): 表情包的base64编码 +- `stream_id` (str): 聊天流ID +- `storage_message` (bool): 是否存储消息到数据库 -**参数与返回值同上** +**Returns:** +- `bool` - 是否发送成功 -### 2. 表情包发送 +### 3. 发送图片 +```python +async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool: +``` +向指定流发送图片。 -#### `emoji_to_group(emoji_base64, group_id, platform="qq", storage_message=True)` -向群聊发送表情包 +**Args:** +- `image_base64` (str): 图片的base64编码 +- `stream_id` (str): 聊天流ID +- `storage_message` (bool): 是否存储消息到数据库 -**参数:** -- `emoji_base64`:表情包的base64编码 -- `group_id`:群聊ID -- `platform`:平台,默认为"qq" -- `storage_message`:是否存储到数据库 +**Returns:** +- `bool` - 是否发送成功 -#### `emoji_to_user(emoji_base64, user_id, platform="qq", storage_message=True)` -向用户发送表情包 +### 4. 发送命令 +```python +async def command_to_stream(command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "") -> bool: +``` +向指定流发送命令。 -### 3. 图片发送 +**Args:** +- `command` (Union[str, dict]): 命令内容 +- `stream_id` (str): 聊天流ID +- `storage_message` (bool): 是否存储消息到数据库 +- `display_message` (str): 显示消息 -#### `image_to_group(image_base64, group_id, platform="qq", storage_message=True)` -向群聊发送图片 +**Returns:** +- `bool` - 是否发送成功 -#### `image_to_user(image_base64, user_id, platform="qq", storage_message=True)` -向用户发送图片 +### 5. 发送自定义类型消息 +```python +async def custom_to_stream( + message_type: str, + content: str, + stream_id: str, + display_message: str = "", + typing: bool = False, + reply_to: str = "", + storage_message: bool = True, + show_log: bool = True, +) -> bool: +``` +向指定流发送自定义类型消息。 -### 4. 命令发送 +**Args:** +- `message_type` (str): 消息类型,如"text"、"image"、"emoji"、"video"、"file"等 +- `content` (str): 消息内容(通常是base64编码或文本) +- `stream_id` (str): 聊天流ID +- `display_message` (str): 显示消息 +- `typing` (bool): 是否显示正在输入 +- `reply_to` (str): 回复消息,格式为"发送者:消息内容" +- `storage_message` (bool): 是否存储消息到数据库 +- `show_log` (bool): 是否显示日志 -#### `command_to_group(command, group_id, platform="qq", storage_message=True)` -向群聊发送命令 - -#### `command_to_user(command, user_id, platform="qq", storage_message=True)` -向用户发送命令 - -### 5. 自定义消息发送 - -#### `custom_to_group(message_type, content, group_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)` -向群聊发送自定义类型消息 - -#### `custom_to_user(message_type, content, user_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)` -向用户发送自定义类型消息 - -#### `custom_message(message_type, content, target_id, is_group=True, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)` -通用的自定义消息发送 - -**参数:** -- `message_type`:消息类型,如"text"、"image"、"emoji"等 -- `content`:消息内容 -- `target_id`:目标ID(群ID或用户ID) -- `is_group`:是否为群聊 -- `platform`:平台 -- `display_message`:显示消息 -- `typing`:是否显示正在输入 -- `reply_to`:回复消息 -- `storage_message`:是否存储 +**Returns:** +- `bool` - 是否发送成功 ## 使用示例 -### 1. 基础文本发送 +### 1. 基础文本发送,并回复消息 ```python from src.plugin_system.apis import send_api @@ -93,57 +115,23 @@ from src.plugin_system.apis import send_api async def send_hello(chat_stream): """发送问候消息""" - if chat_stream.group_info: - # 群聊 - success = await send_api.text_to_group( - text="大家好!", - group_id=chat_stream.group_info.group_id, - typing=True - ) - else: - # 私聊 - success = await send_api.text_to_user( - text="你好!", - user_id=chat_stream.user_info.user_id, - typing=True - ) + success = await send_api.text_to_stream( + text="Hello, world!", + stream_id=chat_stream.stream_id, + typing=True, + reply_to="User:How are you?", + storage_message=True + ) return success ``` -### 2. 回复特定消息 - -```python -async def reply_to_message(chat_stream, reply_text, original_sender, original_message): - """回复特定消息""" - - # 构建回复格式 - reply_to = f"{original_sender}:{original_message}" - - if chat_stream.group_info: - success = await send_api.text_to_group( - text=reply_text, - group_id=chat_stream.group_info.group_id, - reply_to=reply_to - ) - else: - success = await send_api.text_to_user( - text=reply_text, - user_id=chat_stream.user_info.user_id, - reply_to=reply_to - ) - - return success -``` - -### 3. 发送表情包 +### 2. 发送表情包 ```python +from src.plugin_system.apis import emoji_api async def send_emoji_reaction(chat_stream, emotion): """根据情感发送表情包""" - - from src.plugin_system.apis import emoji_api - # 获取表情包 emoji_result = await emoji_api.get_by_emotion(emotion) if not emoji_result: @@ -152,107 +140,10 @@ async def send_emoji_reaction(chat_stream, emotion): emoji_base64, description, matched_emotion = emoji_result # 发送表情包 - if chat_stream.group_info: - success = await send_api.emoji_to_group( - emoji_base64=emoji_base64, - group_id=chat_stream.group_info.group_id - ) - else: - success = await send_api.emoji_to_user( - emoji_base64=emoji_base64, - user_id=chat_stream.user_info.user_id - ) - - return success -``` - -### 4. 在Action中发送消息 - -```python -from src.plugin_system.base import BaseAction - -class MessageAction(BaseAction): - async def execute(self, action_data, chat_stream): - message_type = action_data.get("type", "text") - content = action_data.get("content", "") - - if message_type == "text": - success = await self.send_text(chat_stream, content) - elif message_type == "emoji": - success = await self.send_emoji(chat_stream, content) - elif message_type == "image": - success = await self.send_image(chat_stream, content) - else: - success = False - - return {"success": success} - - async def send_text(self, chat_stream, text): - if chat_stream.group_info: - return await send_api.text_to_group(text, chat_stream.group_info.group_id) - else: - return await send_api.text_to_user(text, chat_stream.user_info.user_id) - - async def send_emoji(self, chat_stream, emoji_base64): - if chat_stream.group_info: - return await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id) - else: - return await send_api.emoji_to_user(emoji_base64, chat_stream.user_info.user_id) - - async def send_image(self, chat_stream, image_base64): - if chat_stream.group_info: - return await send_api.image_to_group(image_base64, chat_stream.group_info.group_id) - else: - return await send_api.image_to_user(image_base64, chat_stream.user_info.user_id) -``` - -### 5. 批量发送消息 - -```python -async def broadcast_message(message: str, target_groups: list): - """向多个群组广播消息""" - - results = {} - - for group_id in target_groups: - try: - success = await send_api.text_to_group( - text=message, - group_id=group_id, - typing=True - ) - results[group_id] = success - except Exception as e: - results[group_id] = False - print(f"发送到群 {group_id} 失败: {e}") - - return results -``` - -### 6. 智能消息发送 - -```python -async def smart_send(chat_stream, message_data): - """智能发送不同类型的消息""" - - message_type = message_data.get("type", "text") - content = message_data.get("content", "") - options = message_data.get("options", {}) - - # 根据聊天流类型选择发送方法 - target_id = (chat_stream.group_info.group_id if chat_stream.group_info - else chat_stream.user_info.user_id) - is_group = chat_stream.group_info is not None - - # 使用通用发送方法 - success = await send_api.custom_message( - message_type=message_type, - content=content, - target_id=target_id, - is_group=is_group, - typing=options.get("typing", False), - reply_to=options.get("reply_to", ""), - display_message=options.get("display_message", "") + success = await send_api.emoji_to_stream( + emoji_base64=emoji_base64, + stream_id=chat_stream.stream_id, + storage_message=False # 不存储到数据库 ) return success @@ -273,90 +164,6 @@ async def smart_send(chat_stream, message_data): 系统会自动查找匹配的原始消息并进行回复。 -## 高级用法 - -### 1. 消息发送队列 - -```python -import asyncio - -class MessageQueue: - def __init__(self): - self.queue = asyncio.Queue() - self.running = False - - async def add_message(self, chat_stream, message_type, content, options=None): - """添加消息到队列""" - message_item = { - "chat_stream": chat_stream, - "type": message_type, - "content": content, - "options": options or {} - } - await self.queue.put(message_item) - - async def process_queue(self): - """处理消息队列""" - self.running = True - - while self.running: - try: - message_item = await asyncio.wait_for(self.queue.get(), timeout=1.0) - - # 发送消息 - success = await smart_send( - message_item["chat_stream"], - { - "type": message_item["type"], - "content": message_item["content"], - "options": message_item["options"] - } - ) - - # 标记任务完成 - self.queue.task_done() - - # 发送间隔 - await asyncio.sleep(0.5) - - except asyncio.TimeoutError: - continue - except Exception as e: - print(f"处理消息队列出错: {e}") -``` - -### 2. 消息模板系统 - -```python -class MessageTemplate: - def __init__(self): - self.templates = { - "welcome": "欢迎 {nickname} 加入群聊!", - "goodbye": "{nickname} 离开了群聊。", - "notification": "🔔 通知:{message}", - "error": "❌ 错误:{error_message}", - "success": "✅ 成功:{message}" - } - - def format_message(self, template_name: str, **kwargs) -> str: - """格式化消息模板""" - template = self.templates.get(template_name, "{message}") - return template.format(**kwargs) - - async def send_template(self, chat_stream, template_name: str, **kwargs): - """发送模板消息""" - message = self.format_message(template_name, **kwargs) - - if chat_stream.group_info: - return await send_api.text_to_group(message, chat_stream.group_info.group_id) - else: - return await send_api.text_to_user(message, chat_stream.user_info.user_id) - -# 使用示例 -template_system = MessageTemplate() -await template_system.send_template(chat_stream, "welcome", nickname="张三") -``` - ## 注意事项 1. **异步操作**:所有发送函数都是异步的,必须使用`await` diff --git a/docs/plugins/api/tool-api.md b/docs/plugins/api/tool-api.md new file mode 100644 index 00000000..d86734fc --- /dev/null +++ b/docs/plugins/api/tool-api.md @@ -0,0 +1,55 @@ +# 工具API + +工具API模块提供了获取和管理工具实例的功能,让插件能够访问系统中注册的工具。 + +## 导入方式 + +```python +from src.plugin_system.apis import tool_api +# 或者 +from src.plugin_system import tool_api +``` + +## 主要功能 + +### 1. 获取工具实例 + +```python +def get_tool_instance(tool_name: str) -> Optional[BaseTool]: +``` + +获取指定名称的工具实例。 + +**Args**: +- `tool_name`: 工具名称字符串 + +**Returns**: +- `Optional[BaseTool]`: 工具实例,如果工具不存在则返回 None + +### 2. 获取LLM可用的工具定义 + +```python +def get_llm_available_tool_definitions(): +``` + +获取所有LLM可用的工具定义列表。 + +**Returns**: +- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组 + - 其具体定义请参照[tool-components.md](../tool-components.md)中的工具定义格式。 +#### 示例: + +```python +# 获取所有LLM可用的工具定义 +tools = tool_api.get_llm_available_tool_definitions() +for tool_name, tool_definition in tools: + print(f"工具: {tool_name}") + print(f"定义: {tool_definition}") +``` + +## 注意事项 + +1. **工具存在性检查**:使用前请检查工具实例是否为 None +2. **权限控制**:某些工具可能有使用权限限制 +3. **异步调用**:大多数工具方法是异步的,需要使用 await +4. **错误处理**:调用工具时请做好异常处理 diff --git a/docs/plugins/api/utils-api.md b/docs/plugins/api/utils-api.md deleted file mode 100644 index bbab092e..00000000 --- a/docs/plugins/api/utils-api.md +++ /dev/null @@ -1,435 +0,0 @@ -# 工具API - -工具API模块提供了各种辅助功能,包括文件操作、时间处理、唯一ID生成等常用工具函数。 - -## 导入方式 - -```python -from src.plugin_system.apis import utils_api -``` - -## 主要功能 - -### 1. 文件操作 - -#### `get_plugin_path(caller_frame=None) -> str` -获取调用者插件的路径 - -**参数:** -- `caller_frame`:调用者的栈帧,默认为None(自动获取) - -**返回:** -- `str`:插件目录的绝对路径 - -**示例:** -```python -plugin_path = utils_api.get_plugin_path() -print(f"插件路径: {plugin_path}") -``` - -#### `read_json_file(file_path: str, default: Any = None) -> Any` -读取JSON文件 - -**参数:** -- `file_path`:文件路径,可以是相对于插件目录的路径 -- `default`:如果文件不存在或读取失败时返回的默认值 - -**返回:** -- `Any`:JSON数据或默认值 - -**示例:** -```python -# 读取插件配置文件 -config = utils_api.read_json_file("config.json", {}) -settings = utils_api.read_json_file("data/settings.json", {"enabled": True}) -``` - -#### `write_json_file(file_path: str, data: Any, indent: int = 2) -> bool` -写入JSON文件 - -**参数:** -- `file_path`:文件路径,可以是相对于插件目录的路径 -- `data`:要写入的数据 -- `indent`:JSON缩进 - -**返回:** -- `bool`:是否写入成功 - -**示例:** -```python -data = {"name": "test", "value": 123} -success = utils_api.write_json_file("output.json", data) -``` - -### 2. 时间相关 - -#### `get_timestamp() -> int` -获取当前时间戳 - -**返回:** -- `int`:当前时间戳(秒) - -#### `format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str` -格式化时间 - -**参数:** -- `timestamp`:时间戳,如果为None则使用当前时间 -- `format_str`:时间格式字符串 - -**返回:** -- `str`:格式化后的时间字符串 - -#### `parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int` -解析时间字符串为时间戳 - -**参数:** -- `time_str`:时间字符串 -- `format_str`:时间格式字符串 - -**返回:** -- `int`:时间戳(秒) - -### 3. 其他工具 - -#### `generate_unique_id() -> str` -生成唯一ID - -**返回:** -- `str`:唯一ID - -## 使用示例 - -### 1. 插件数据管理 - -```python -from src.plugin_system.apis import utils_api - -class DataPlugin(BasePlugin): - def __init__(self): - self.plugin_path = utils_api.get_plugin_path() - self.data_file = "plugin_data.json" - self.load_data() - - def load_data(self): - """加载插件数据""" - default_data = { - "users": {}, - "settings": {"enabled": True}, - "stats": {"message_count": 0} - } - self.data = utils_api.read_json_file(self.data_file, default_data) - - def save_data(self): - """保存插件数据""" - return utils_api.write_json_file(self.data_file, self.data) - - async def handle_action(self, action_data, chat_stream): - # 更新统计信息 - self.data["stats"]["message_count"] += 1 - self.data["stats"]["last_update"] = utils_api.get_timestamp() - - # 保存数据 - if self.save_data(): - return {"success": True, "message": "数据已保存"} - else: - return {"success": False, "message": "数据保存失败"} -``` - -### 2. 日志记录系统 - -```python -class PluginLogger: - def __init__(self, plugin_name: str): - self.plugin_name = plugin_name - self.log_file = f"{plugin_name}_log.json" - self.logs = utils_api.read_json_file(self.log_file, []) - - def log_event(self, event_type: str, message: str, data: dict = None): - """记录事件""" - log_entry = { - "id": utils_api.generate_unique_id(), - "timestamp": utils_api.get_timestamp(), - "formatted_time": utils_api.format_time(), - "event_type": event_type, - "message": message, - "data": data or {} - } - - self.logs.append(log_entry) - - # 保持最新的100条记录 - if len(self.logs) > 100: - self.logs = self.logs[-100:] - - # 保存到文件 - utils_api.write_json_file(self.log_file, self.logs) - - def get_logs_by_type(self, event_type: str) -> list: - """获取指定类型的日志""" - return [log for log in self.logs if log["event_type"] == event_type] - - def get_recent_logs(self, count: int = 10) -> list: - """获取最近的日志""" - return self.logs[-count:] - -# 使用示例 -logger = PluginLogger("my_plugin") -logger.log_event("user_action", "用户发送了消息", {"user_id": "123", "message": "hello"}) -``` - -### 3. 配置管理系统 - -```python -class ConfigManager: - def __init__(self, config_file: str = "plugin_config.json"): - self.config_file = config_file - self.default_config = { - "enabled": True, - "debug": False, - "max_users": 100, - "response_delay": 1.0, - "features": { - "auto_reply": True, - "logging": True - } - } - self.config = self.load_config() - - def load_config(self) -> dict: - """加载配置""" - return utils_api.read_json_file(self.config_file, self.default_config) - - def save_config(self) -> bool: - """保存配置""" - return utils_api.write_json_file(self.config_file, self.config, indent=4) - - def get(self, key: str, default=None): - """获取配置值,支持嵌套访问""" - keys = key.split('.') - value = self.config - - for k in keys: - if isinstance(value, dict) and k in value: - value = value[k] - else: - return default - - return value - - def set(self, key: str, value): - """设置配置值,支持嵌套设置""" - keys = key.split('.') - config = self.config - - for k in keys[:-1]: - if k not in config: - config[k] = {} - config = config[k] - - config[keys[-1]] = value - - def update_config(self, updates: dict): - """批量更新配置""" - def deep_update(base, updates): - for key, value in updates.items(): - if isinstance(value, dict) and key in base and isinstance(base[key], dict): - deep_update(base[key], value) - else: - base[key] = value - - deep_update(self.config, updates) - -# 使用示例 -config = ConfigManager() -print(f"调试模式: {config.get('debug', False)}") -print(f"自动回复: {config.get('features.auto_reply', True)}") - -config.set('features.new_feature', True) -config.save_config() -``` - -### 4. 缓存系统 - -```python -class PluginCache: - def __init__(self, cache_file: str = "plugin_cache.json", ttl: int = 3600): - self.cache_file = cache_file - self.ttl = ttl # 缓存过期时间(秒) - self.cache = self.load_cache() - - def load_cache(self) -> dict: - """加载缓存""" - return utils_api.read_json_file(self.cache_file, {}) - - def save_cache(self): - """保存缓存""" - return utils_api.write_json_file(self.cache_file, self.cache) - - def get(self, key: str): - """获取缓存值""" - if key not in self.cache: - return None - - item = self.cache[key] - current_time = utils_api.get_timestamp() - - # 检查是否过期 - if current_time - item["timestamp"] > self.ttl: - del self.cache[key] - return None - - return item["value"] - - def set(self, key: str, value): - """设置缓存值""" - self.cache[key] = { - "value": value, - "timestamp": utils_api.get_timestamp() - } - self.save_cache() - - def clear_expired(self): - """清理过期缓存""" - current_time = utils_api.get_timestamp() - expired_keys = [] - - for key, item in self.cache.items(): - if current_time - item["timestamp"] > self.ttl: - expired_keys.append(key) - - for key in expired_keys: - del self.cache[key] - - if expired_keys: - self.save_cache() - - return len(expired_keys) - -# 使用示例 -cache = PluginCache(ttl=1800) # 30分钟过期 -cache.set("user_data_123", {"name": "张三", "score": 100}) -user_data = cache.get("user_data_123") -``` - -### 5. 时间处理工具 - -```python -class TimeHelper: - @staticmethod - def get_time_info(): - """获取当前时间的详细信息""" - timestamp = utils_api.get_timestamp() - return { - "timestamp": timestamp, - "datetime": utils_api.format_time(timestamp), - "date": utils_api.format_time(timestamp, "%Y-%m-%d"), - "time": utils_api.format_time(timestamp, "%H:%M:%S"), - "year": utils_api.format_time(timestamp, "%Y"), - "month": utils_api.format_time(timestamp, "%m"), - "day": utils_api.format_time(timestamp, "%d"), - "weekday": utils_api.format_time(timestamp, "%A") - } - - @staticmethod - def time_ago(timestamp: int) -> str: - """计算时间差""" - current = utils_api.get_timestamp() - diff = current - timestamp - - if diff < 60: - return f"{diff}秒前" - elif diff < 3600: - return f"{diff // 60}分钟前" - elif diff < 86400: - return f"{diff // 3600}小时前" - else: - return f"{diff // 86400}天前" - - @staticmethod - def parse_duration(duration_str: str) -> int: - """解析时间段字符串,返回秒数""" - import re - - pattern = r'(\d+)([smhd])' - matches = re.findall(pattern, duration_str.lower()) - - total_seconds = 0 - for value, unit in matches: - value = int(value) - if unit == 's': - total_seconds += value - elif unit == 'm': - total_seconds += value * 60 - elif unit == 'h': - total_seconds += value * 3600 - elif unit == 'd': - total_seconds += value * 86400 - - return total_seconds - -# 使用示例 -time_info = TimeHelper.get_time_info() -print(f"当前时间: {time_info['datetime']}") - -last_seen = 1699000000 -print(f"最后见面: {TimeHelper.time_ago(last_seen)}") - -duration = TimeHelper.parse_duration("1h30m") # 1小时30分钟 = 5400秒 -``` - -## 最佳实践 - -### 1. 错误处理 -```python -def safe_file_operation(file_path: str, data: dict): - """安全的文件操作""" - try: - success = utils_api.write_json_file(file_path, data) - if not success: - logger.warning(f"文件写入失败: {file_path}") - return success - except Exception as e: - logger.error(f"文件操作出错: {e}") - return False -``` - -### 2. 路径处理 -```python -import os - -def get_data_path(filename: str) -> str: - """获取数据文件的完整路径""" - plugin_path = utils_api.get_plugin_path() - data_dir = os.path.join(plugin_path, "data") - - # 确保数据目录存在 - os.makedirs(data_dir, exist_ok=True) - - return os.path.join(data_dir, filename) -``` - -### 3. 定期清理 -```python -async def cleanup_old_files(): - """清理旧文件""" - plugin_path = utils_api.get_plugin_path() - current_time = utils_api.get_timestamp() - - for filename in os.listdir(plugin_path): - if filename.endswith('.tmp'): - file_path = os.path.join(plugin_path, filename) - file_time = os.path.getmtime(file_path) - - # 删除超过24小时的临时文件 - if current_time - file_time > 86400: - os.remove(file_path) -``` - -## 注意事项 - -1. **相对路径**:文件路径支持相对于插件目录的路径 -2. **自动创建目录**:写入文件时会自动创建必要的目录 -3. **错误处理**:所有函数都有错误处理,失败时返回默认值 -4. **编码格式**:文件读写使用UTF-8编码 -5. **时间格式**:时间戳使用秒为单位 -6. **JSON格式**:JSON文件使用可读性好的缩进格式 \ No newline at end of file diff --git a/docs/plugins/index.md b/docs/plugins/index.md index af8fad85..2454c98a 100644 --- a/docs/plugins/index.md +++ b/docs/plugins/index.md @@ -10,6 +10,7 @@ - [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件 - [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件 +- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力 - [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件 - [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构 @@ -43,24 +44,24 @@ Command vs Action 选择指南 - [LLM API](api/llm-api.md) - 大语言模型交互接口,可以使用内置LLM生成内容 - [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器 -### 表情包api +### 表情包API - [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口 -### 关系系统api +### 关系系统API - [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口 ### 数据与配置API - [🗄️ 数据库API](api/database-api.md) - 数据库操作接口 - [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口 +### 插件和组件管理API +- [🔌 插件API](api/plugin-manage-api.md) - 插件加载和管理接口 +- [🧩 组件API](api/component-manage-api.md) - 组件注册和管理接口 + +### 日志API +- [📜 日志API](api/logging-api.md) - logger实例获取接口 ### 工具API -- [工具API](api/utils-api.md) - 文件操作、时间处理等工具函数 - - -## 实验性 - -这些功能将在未来重构或移除 -- [🔧 工具系统详解](tool-system.md) - 工具系统的使用和开发 +- [🔧 工具API](api/tool-api.md) - tool获取接口 diff --git a/docs/plugins/tool-system.md b/docs/plugins/tool-components.md similarity index 58% rename from docs/plugins/tool-system.md rename to docs/plugins/tool-components.md index baa43528..059656aa 100644 --- a/docs/plugins/tool-system.md +++ b/docs/plugins/tool-components.md @@ -1,10 +1,10 @@ -# 🔧 工具系统详解 +# 🔧 工具组件详解 -## 📖 什么是工具系统 +## 📖 什么是工具 -工具系统是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。 +工具是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门,可以拓展麦麦能做的事情,那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。 -### 🎯 工具系统的特点 +### 🎯 工具的特点 - 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力 - 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据 @@ -20,14 +20,11 @@ | **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 | | **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 | -## 🏗️ 工具基本结构 - -### 必要组件 +## 🏗️ Tool组件的基本结构 每个工具必须继承 `BaseTool` 基类并实现以下属性和方法: - ```python -from src.tools.tool_can_use.base_tool import BaseTool, register_tool +from src.plugin_system import BaseTool, ToolParamType class MyTool(BaseTool): # 工具名称,必须唯一 @@ -36,21 +33,29 @@ class MyTool(BaseTool): # 工具描述,告诉LLM这个工具的用途 description = "这个工具用于获取特定类型的信息" - # 参数定义,遵循JSONSchema格式 - parameters = { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "查询参数" - }, - "limit": { - "type": "integer", - "description": "结果数量限制" - } - }, - "required": ["query"] - } + # 参数定义,仅定义参数 + # 比如想要定义一个类似下面的openai格式的参数表,则可以这么定义: + # { + # "type": "object", + # "properties": { + # "query": { + # "type": "string", + # "description": "查询参数" + # }, + # "limit": { + # "type": "integer", + # "description": "结果数量限制" + # "enum": [10, 20, 50] # 可选值 + # } + # }, + # "required": ["query"] + # } + parameters = [ + ("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数 + ("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数 + ] + + available_for_llm = True # 是否对LLM可用 async def execute(self, function_args: Dict[str, Any]): """执行工具逻辑""" @@ -69,7 +74,12 @@ class MyTool(BaseTool): |-----|------|------| | `name` | str | 工具的唯一标识名称 | | `description` | str | 工具功能描述,帮助LLM理解用途 | -| `parameters` | dict | JSONSchema格式的参数定义 | +| `parameters` | list[tuple] | 参数定义 | + +其构造而成的工具定义为: +```python +{"name": cls.name, "description": cls.description, "parameters": cls.parameters} +``` ### 方法说明 @@ -77,15 +87,6 @@ class MyTool(BaseTool): |-----|------|--------|------| | `execute` | `function_args` | `dict` | 执行工具核心逻辑 | -## 🔄 自动注册机制 - -工具系统采用自动发现和注册机制: - -1. **文件扫描**:系统自动遍历 `tool_can_use` 目录中的所有Python文件 -2. **类识别**:寻找继承自 `BaseTool` 的工具类 -3. **自动注册**:只需要实现对应的类并把文件放在正确文件夹中就可自动注册 -4. **即用即加载**:工具在需要时被实例化和调用 - --- ## 🎨 完整工具示例 @@ -93,7 +94,7 @@ class MyTool(BaseTool): 完成一个天气查询工具 ```python -from src.tools.tool_can_use.base_tool import BaseTool, register_tool +from src.plugin_system import BaseTool import aiohttp import json @@ -102,23 +103,13 @@ class WeatherTool(BaseTool): name = "weather_query" description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等" + available_for_llm = True # 允许LLM调用此工具 + parameters = [ + ("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None), + ("country", ToolParamType.STRING, "国家代码,如:CN、US,可选参数", False, None) + ] - parameters = { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "要查询天气的城市名称,如:北京、上海、纽约" - }, - "country": { - "type": "string", - "description": "国家代码,如:CN、US,可选参数" - } - }, - "required": ["city"] - } - - async def execute(self, function_args, message_txt=""): + async def execute(self, function_args: dict): """执行天气查询""" try: city = function_args.get("city") @@ -177,55 +168,12 @@ class WeatherTool(BaseTool): --- -## 📊 工具开发步骤 - -### 1. 创建工具文件 - -在 `src/tools/tool_can_use/` 目录下创建新的Python文件: - -```bash -# 例如创建 my_new_tool.py -touch src/tools/tool_can_use/my_new_tool.py -``` - -### 2. 实现工具类 - -```python -from src.tools.tool_can_use.base_tool import BaseTool, register_tool - -class MyNewTool(BaseTool): - name = "my_new_tool" - description = "新工具的功能描述" - - parameters = { - "type": "object", - "properties": { - # 定义参数 - }, - "required": [] - } - - async def execute(self, function_args, message_txt=""): - # 实现工具逻辑 - return { - "name": self.name, - "content": "执行结果" - } -``` - -### 3. 系统集成 - -工具创建完成后,系统会自动发现和注册,无需额外配置。 - ---- - ## 🚨 注意事项和限制 ### 当前限制 -1. **独立开发**:需要单独编写,暂未完全融入插件系统 -2. **适用范围**:主要适用于信息获取场景 -3. **配置要求**:必须开启工具处理器 +1. **适用范围**:主要适用于信息获取场景 +2. **配置要求**:必须开启工具处理器 ### 开发建议 @@ -238,66 +186,49 @@ class MyNewTool(BaseTool): ## 🎯 最佳实践 ### 1. 工具命名规范 - +#### ✅ 好的命名 ```python -# ✅ 好的命名 name = "weather_query" # 清晰表达功能 name = "knowledge_search" # 描述性强 name = "stock_price_check" # 功能明确 - -# ❌ 避免的命名 +``` +#### ❌ 避免的命名 +```python name = "tool1" # 无意义 name = "wq" # 过于简短 name = "weather_and_news" # 功能过于复杂 ``` ### 2. 描述规范 - +#### ✅ 良好的描述 ```python -# ✅ 好的描述 description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况" - -# ❌ 避免的描述 +``` +#### ❌ 避免的描述 +```python description = "天气" # 过于简单 description = "获取信息" # 不够具体 ``` ### 3. 参数设计 +#### ✅ 合理的参数设计 ```python -# ✅ 合理的参数设计 -parameters = { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "城市名称,如:北京、上海" - }, - "unit": { - "type": "string", - "description": "温度单位:celsius(摄氏度) 或 fahrenheit(华氏度)", - "enum": ["celsius", "fahrenheit"] - } - }, - "required": ["city"] -} - -# ❌ 避免的参数设计 -parameters = { - "type": "object", - "properties": { - "data": { - "type": "string", - "description": "数据" # 描述不清晰 - } - } -} +parameters = [ + ("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None), + ("unit", ToolParamType.STRING, "温度单位:celsius 或 fahrenheit", False, ["celsius", "fahrenheit"]) +] +``` +#### ❌ 避免的参数设计 +```python +parameters = [ + ("data", "string", "数据", True) # 参数过于模糊 +] ``` ### 4. 结果格式化 - +#### ✅ 良好的结果格式 ```python -# ✅ 良好的结果格式 def _format_result(self, data): return f""" 🔍 查询结果 @@ -307,12 +238,9 @@ def _format_result(self, data): 📝 说明: {data['description']} ━━━━━━━━━━━━ """.strip() - -# ❌ 避免的结果格式 +``` +#### ❌ 避免的结果格式 +```python def _format_result(self, data): return str(data) # 直接返回原始数据 ``` - ---- - -🎉 **工具系统为麦麦提供了强大的信息获取能力!合理使用工具可以让麦麦变得更加智能和博学。** \ No newline at end of file diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 8ede9616..f9855481 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -1,18 +1,55 @@ -from typing import List, Tuple, Type +from typing import List, Tuple, Type, Any from src.plugin_system import ( BasePlugin, register_plugin, BaseAction, BaseCommand, + BaseTool, ComponentInfo, ActionActivationType, ConfigField, BaseEventHandler, EventType, MaiMessages, + ToolParamType ) +class CompareNumbersTool(BaseTool): + """比较两个数大小的工具""" + + name = "compare_numbers" + description = "使用工具 比较两个数的大小,返回较大的数" + parameters = [ + ("num1", ToolParamType.FLOAT, "第一个数字", True, None), + ("num2", ToolParamType.FLOAT, "第二个数字", True, None), + ] + + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + """执行比较两个数的大小 + + Args: + function_args: 工具参数 + + Returns: + dict: 工具执行结果 + """ + num1: int | float = function_args.get("num1") # type: ignore + num2: int | float = function_args.get("num2") # type: ignore + + try: + if num1 > num2: + result = f"{num1} 大于 {num2}" + elif num1 < num2: + result = f"{num1} 小于 {num2}" + else: + result = f"{num1} 等于 {num2}" + + return {"name": self.name, "content": result} + except Exception as e: + return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"} + + # ===== Action组件 ===== class HelloAction(BaseAction): """问候Action - 简单的问候动作""" @@ -132,7 +169,9 @@ class HelloWorldPlugin(BasePlugin): "enabled": ConfigField(type=bool, default=False, description="是否启用插件"), }, "greeting": { - "message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"), + "message": ConfigField( + type=list, default=["嗨!很开心见到你!😊", "Ciallo~(∠・ω< )⌒★"], description="默认问候消息" + ), "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), }, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, @@ -142,6 +181,7 @@ class HelloWorldPlugin(BasePlugin): def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: return [ (HelloAction.get_action_info(), HelloAction), + (CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具 (ByeAction.get_action_info(), ByeAction), # 添加告别Action (TimeCommand.get_command_info(), TimeCommand), (PrintMessage.get_handler_info(), PrintMessage), diff --git a/requirements.txt b/requirements.txt index a09637a9..999bd5fd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ matplotlib networkx numpy openai +google-genai pandas peewee pyarrow diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 63a4d985..1177650d 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -24,46 +24,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") -ENV_FILE = os.path.join(ROOT_PATH, ".env") - -if os.path.exists(".env"): - load_dotenv(".env", override=True) - print("成功加载环境变量配置") -else: - print("未找到.env文件,请确保程序所需的环境变量被正确设置") - raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") - -env_mask = {key: os.getenv(key) for key in os.environ} -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") - def ensure_openie_dir(): """确保OpenIE数据目录存在""" if not os.path.exists(OPENIE_DIR): @@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k def main(): # sourcery skip: dict-comprehension # 新增确认提示 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index c36a7789..47ad55a8 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -25,9 +25,8 @@ from rich.progress import ( TextColumn, ) from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest -from dotenv import load_dotenv logger = get_logger("LPMM知识库-信息提取") @@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") -ENV_FILE = os.path.join(ROOT_PATH, ".env") - -if os.path.exists(".env"): - load_dotenv(".env", override=True) - print("成功加载环境变量配置") -else: - print("未找到.env文件,请确保程序所需的环境变量被正确设置") - raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") - -env_mask = {key: os.getenv(key) for key in os.environ} -def scan_provider(env_config: dict): - provider = {} - - # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 - # 避免 GPG_KEY 这样的变量干扰检查 - env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) - - # 遍历 env_config 的所有键 - for key in env_config: - # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 - if key.endswith("_BASE_URL") or key.endswith("_KEY"): - # 提取 provider 名称 - provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 - - # 初始化 provider 的字典(如果尚未初始化) - if provider_name not in provider: - provider[provider_name] = {"url": None, "key": None} - - # 根据键的类型填充 url 或 key - if key.endswith("_BASE_URL"): - provider[provider_name]["url"] = env_config[key] - elif key.endswith("_KEY"): - provider[provider_name]["key"] = env_config[key] - - # 检查每个 provider 是否同时存在 url 和 key - for provider_name, config in provider.items(): - if config["url"] is None or config["key"] is None: - logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") - raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -96,11 +56,11 @@ open_ie_doc_lock = Lock() shutdown_event = Event() lpmm_entity_extract_llm = LLMRequest( - model=global_config.model.lpmm_entity_extract, + model_set=model_config.model_task_config.lpmm_entity_extract, request_type="lpmm.entity_extract" ) lpmm_rdf_build_llm = LLMRequest( - model=global_config.model.lpmm_rdf_build, + model_set=model_config.model_task_config.lpmm_rdf_build, request_type="lpmm.rdf_build" ) def process_single_text(pg_hash, raw_data): @@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) ensure_dirs() # 确保目录存在 - env_config = {key: os.getenv(key) for key in os.environ} - scan_provider(env_config) # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index efa8f69b..d51fa96b 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -414,7 +414,7 @@ class HeartFChatting: else: logger.warning(f"{self.log_prefix} 预生成的回复任务未生成有效内容") - action_message: Dict[str, Any] = message_data or target_message # type: ignore + action_message = message_data or target_message if action_type == "reply": # 等待回复生成完毕 if self.loop_mode == ChatMode.NORMAL: diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 918b8396..6d50d890 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -8,15 +8,15 @@ import traceback import io import re import binascii + from typing import Optional, Tuple, List, Any from PIL import Image from rich.traceback import install - from src.common.database.database_model import Emoji from src.common.database.database import db as peewee_db from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.llm_models.utils_model import LLMRequest @@ -379,9 +379,9 @@ class EmojiManager: self._scan_task = None - self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") + self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji") self.llm_emotion_judge = LLMRequest( - model=global_config.model.utils, max_tokens=600, request_type="emoji" + model_set=model_config.model_task_config.utils, request_type="emoji" ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) self.emoji_num = 0 @@ -492,6 +492,7 @@ class EmojiManager: return None def _levenshtein_distance(self, s1: str, s2: str) -> int: + # sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison """计算两个字符串的编辑距离 Args: @@ -629,11 +630,11 @@ class EmojiManager: if success: # 注册成功则跳出循环 break - else: - # 注册失败则删除对应文件 - file_path = os.path.join(EMOJI_DIR, filename) - os.remove(file_path) - logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") + + # 注册失败则删除对应文件 + file_path = os.path.join(EMOJI_DIR, filename) + os.remove(file_path) + logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") except Exception as e: logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") @@ -694,6 +695,7 @@ class EmojiManager: return [] async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: + # sourcery skip: use-next """从内存中的 emoji_objects 列表获取表情包 参数: @@ -709,10 +711,10 @@ class EmojiManager: async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: """根据哈希值获取已注册表情包的描述 - + Args: emoji_hash: 表情包的哈希值 - + Returns: Optional[str]: 表情包描述,如果未找到则返回None """ @@ -722,7 +724,7 @@ class EmojiManager: if emoji and emoji.description: logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...") return emoji.description - + # 如果内存中没有,从数据库查找 self._ensure_db() try: @@ -732,9 +734,9 @@ class EmojiManager: return emoji_record.description except Exception as e: logger.error(f"从数据库查询表情包描述时出错: {e}") - + return None - + except Exception as e: logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") return None @@ -779,6 +781,7 @@ class EmojiManager: return False async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: + # sourcery skip: use-getitem-for-re-match-groups """替换一个表情包 Args: @@ -820,7 +823,7 @@ class EmojiManager: ) # 调用大模型进行决策 - decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8) + decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600) logger.info(f"[决策] 结果: {decision}") # 解析决策结果 @@ -828,9 +831,7 @@ class EmojiManager: logger.info("[决策] 不删除任何表情包") return False - # 尝试从决策中提取表情包编号 - match = re.search(r"删除编号(\d+)", decision) - if match: + if match := re.search(r"删除编号(\d+)", decision): emoji_index = int(match.group(1)) - 1 # 转换为0-based索引 # 检查索引是否有效 @@ -889,6 +890,7 @@ class EmojiManager: existing_description = None try: from src.common.database.database_model import Images + existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji")) if existing_image and existing_image.description: existing_description = existing_image.description @@ -902,15 +904,21 @@ class EmojiManager: logger.info("[优化] 复用已有的详细描述,跳过VLM调用") else: logger.info("[VLM分析] 生成新的详细描述") - if image_format == "gif" or image_format == "GIF": + if image_format in ["gif", "GIF"]: image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore if not image_base64: raise RuntimeError("GIF表情包转换失败") prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000 + ) else: - prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + prompt = ( + "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + ) + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.3, max_tokens=1000 + ) # 审核表情包 if global_config.emoji.content_filtration: @@ -922,7 +930,9 @@ class EmojiManager: 4. 不要出现5个以上文字 请回答这个表情包是否满足上述要求,是则回答是,否则回答否,不要出现任何其他内容 ''' - content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + content, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.3, max_tokens=1000 + ) if content == "否": return "", [] @@ -933,7 +943,9 @@ class EmojiManager: 你可以关注其幽默和讽刺意味,动用贴吧,微博,小红书的知识,必须从互联网梗,meme的角度去分析 请直接输出描述,不要出现任何其他内容,如果有多个描述,可以用逗号分隔 """ - emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) + emotions_text, _ = await self.llm_emotion_judge.generate_response_async( + emotion_prompt, temperature=0.7, max_tokens=600 + ) # 处理情感列表 emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 1870c470..a9808503 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -7,12 +7,12 @@ from datetime import datetime from typing import List, Dict, Optional, Any, Tuple from src.common.logger import get_logger +from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import model_config from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.common.database.database_model import Expression MAX_EXPRESSION_COUNT = 300 @@ -80,11 +80,8 @@ def init_prompt() -> None: class ExpressionLearner: def __init__(self) -> None: - # TODO: API-Adapter修改标记 self.express_learn_model: LLMRequest = LLMRequest( - model=global_config.model.replyer_1, - temperature=0.3, - request_type="expressor.learner", + model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner" ) self.llm_model = None self._ensure_expression_directories() @@ -101,7 +98,7 @@ class ExpressionLearner: os.path.join(base_dir, "learnt_style"), os.path.join(base_dir, "learnt_grammar"), ] - + for directory in directories_to_create: try: os.makedirs(directory, exist_ok=True) @@ -116,7 +113,7 @@ class ExpressionLearner: """ base_dir = os.path.join("data", "expression") done_flag = os.path.join(base_dir, "done.done") - + # 确保基础目录存在 try: os.makedirs(base_dir, exist_ok=True) @@ -124,28 +121,28 @@ class ExpressionLearner: except Exception as e: logger.error(f"创建表达方式目录失败: {e}") return - + if os.path.exists(done_flag): logger.info("表达方式JSON已迁移,无需重复迁移。") return - + logger.info("开始迁移表达方式JSON到数据库...") migrated_count = 0 - + for type in ["learnt_style", "learnt_grammar"]: type_str = "style" if type == "learnt_style" else "grammar" type_dir = os.path.join(base_dir, type) if not os.path.exists(type_dir): logger.debug(f"目录不存在,跳过: {type_dir}") continue - + try: chat_ids = os.listdir(type_dir) logger.debug(f"在 {type_dir} 中找到 {len(chat_ids)} 个聊天ID目录") except Exception as e: logger.error(f"读取目录失败 {type_dir}: {e}") continue - + for chat_id in chat_ids: expr_file = os.path.join(type_dir, chat_id, "expressions.json") if not os.path.exists(expr_file): @@ -153,24 +150,24 @@ class ExpressionLearner: try: with open(expr_file, "r", encoding="utf-8") as f: expressions = json.load(f) - + if not isinstance(expressions, list): logger.warning(f"表达方式文件格式错误,跳过: {expr_file}") continue - + for expr in expressions: if not isinstance(expr, dict): continue - + situation = expr.get("situation") style_val = expr.get("style") count = expr.get("count", 1) last_active_time = expr.get("last_active_time", time.time()) - + if not situation or not style_val: logger.warning(f"表达方式缺少必要字段,跳过: {expr}") continue - + # 查重:同chat_id+type+situation+style from src.common.database.database_model import Expression @@ -201,7 +198,7 @@ class ExpressionLearner: logger.error(f"JSON解析失败 {expr_file}: {e}") except Exception as e: logger.error(f"迁移表达方式 {expr_file} 失败: {e}") - + # 标记迁移完成 try: # 确保done.done文件的父目录存在 @@ -209,7 +206,7 @@ class ExpressionLearner: if not os.path.exists(done_parent_dir): os.makedirs(done_parent_dir, exist_ok=True) logger.debug(f"为done.done创建父目录: {done_parent_dir}") - + with open(done_flag, "w", encoding="utf-8") as f: f.write("done\n") logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件") @@ -229,13 +226,13 @@ class ExpressionLearner: # 查找所有create_date为空的表达方式 old_expressions = Expression.select().where(Expression.create_date.is_null()) updated_count = 0 - + for expr in old_expressions: # 使用last_active_time作为create_date expr.create_date = expr.last_active_time expr.save() updated_count += 1 - + if updated_count > 0: logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") except Exception as e: @@ -287,25 +284,29 @@ class ExpressionLearner: 获取指定chat_id的表达方式创建信息,按创建日期排序 """ try: - expressions = (Expression.select() - .where(Expression.chat_id == chat_id) - .order_by(Expression.create_date.desc()) - .limit(limit)) - + expressions = ( + Expression.select() + .where(Expression.chat_id == chat_id) + .order_by(Expression.create_date.desc()) + .limit(limit) + ) + result = [] for expr in expressions: create_date = expr.create_date if expr.create_date is not None else expr.last_active_time - result.append({ - "situation": expr.situation, - "style": expr.style, - "type": expr.type, - "count": expr.count, - "create_date": create_date, - "create_date_formatted": format_create_date(create_date), - "last_active_time": expr.last_active_time, - "last_active_formatted": format_create_date(expr.last_active_time), - }) - + result.append( + { + "situation": expr.situation, + "style": expr.style, + "type": expr.type, + "count": expr.count, + "create_date": create_date, + "create_date_formatted": format_create_date(create_date), + "last_active_time": expr.last_active_time, + "last_active_formatted": format_create_date(expr.last_active_time), + } + ) + return result except Exception as e: logger.error(f"获取表达方式创建信息失败: {e}") @@ -355,19 +356,19 @@ class ExpressionLearner: try: # 获取所有表达方式 all_expressions = Expression.select() - + updated_count = 0 deleted_count = 0 - + for expr in all_expressions: # 计算时间差 last_active = expr.last_active_time time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 - + # 计算衰减值 decay_value = self.calculate_decay_factor(time_diff_days) new_count = max(0.01, expr.count - decay_value) - + if new_count <= 0.01: # 如果count太小,删除这个表达方式 expr.delete_instance() @@ -377,10 +378,10 @@ class ExpressionLearner: expr.count = new_count expr.save() updated_count += 1 - + if updated_count > 0 or deleted_count > 0: logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") - + except Exception as e: logger.error(f"数据库全局衰减失败: {e}") @@ -527,7 +528,7 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的prompt: {prompt}") try: - response, _ = await self.express_learn_model.generate_response_async(prompt) + response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) except Exception as e: logger.error(f"学习{type_str}失败: {e}") return None diff --git a/src/chat/express/expression_selector.py b/src/chat/express/expression_selector.py index 910b43c2..111225c8 100644 --- a/src/chat/express/expression_selector.py +++ b/src/chat/express/expression_selector.py @@ -1,16 +1,17 @@ import json import time import random +import hashlib from typing import List, Dict, Tuple, Optional, Any from json_repair import repair_json from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.common.logger import get_logger +from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from .expression_learner import get_expression_learner -from src.common.database.database_model import Expression logger = get_logger("expression_selector") @@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis class ExpressionSelector: def __init__(self): self.expression_learner = get_expression_learner() - # TODO: API-Adapter修改标记 self.llm_model = LLMRequest( - model=global_config.model.utils_small, - request_type="expression.selector", + model_set=model_config.model_task_config.utils_small, request_type="expression.selector" ) @staticmethod @@ -92,7 +91,6 @@ class ExpressionSelector: id_str = parts[1] stream_type = parts[2] is_group = stream_type == "group" - import hashlib if is_group: components = [platform, str(id_str)] else: @@ -108,8 +106,7 @@ class ExpressionSelector: for group in groups: group_chat_ids = [] for stream_config_str in group: - chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str) - if chat_id_candidate: + if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str): group_chat_ids.append(chat_id_candidate) if chat_id in group_chat_ids: return group_chat_ids @@ -118,9 +115,10 @@ class ExpressionSelector: def get_random_expressions( self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: + # sourcery skip: extract-duplicate-method, move-assign # 支持多chat_id合并抽选 related_chat_ids = self.get_related_chat_ids(chat_id) - + # 优化:一次性查询所有相关chat_id的表达方式 style_query = Expression.select().where( (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") @@ -128,7 +126,7 @@ class ExpressionSelector: grammar_query = Expression.select().where( (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") ) - + style_exprs = [ { "situation": expr.situation, @@ -138,9 +136,10 @@ class ExpressionSelector: "source_id": expr.chat_id, "type": "style", "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } for expr in style_query + } + for expr in style_query ] - + grammar_exprs = [ { "situation": expr.situation, @@ -150,9 +149,10 @@ class ExpressionSelector: "source_id": expr.chat_id, "type": "grammar", "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, - } for expr in grammar_query + } + for expr in grammar_query ] - + style_num = int(total_num * style_percentage) grammar_num = int(total_num * grammar_percentage) # 按权重抽样(使用count作为权重) @@ -174,22 +174,22 @@ class ExpressionSelector: return updates_by_key = {} for expr in expressions_to_update: - source_id = expr.get("source_id") - expr_type = expr.get("type", "style") - situation = expr.get("situation") - style = expr.get("style") + source_id: str = expr.get("source_id") # type: ignore + expr_type: str = expr.get("type", "style") + situation: str = expr.get("situation") # type: ignore + style: str = expr.get("style") # type: ignore if not source_id or not situation or not style: logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") continue key = (source_id, expr_type, situation, style) if key not in updates_by_key: updates_by_key[key] = expr - for (chat_id, expr_type, situation, style), _expr in updates_by_key.items(): + for chat_id, expr_type, situation, style in updates_by_key: query = Expression.select().where( - (Expression.chat_id == chat_id) & - (Expression.type == expr_type) & - (Expression.situation == situation) & - (Expression.style == style) + (Expression.chat_id == chat_id) + & (Expression.type == expr_type) + & (Expression.situation == situation) + & (Expression.style == style) ) if query.exists(): expr_obj = query.get() @@ -264,7 +264,7 @@ class ExpressionSelector: # 4. 调用LLM try: - content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt) + content, _ = await self.llm_model.generate_response_async(prompt=prompt) # logger.info(f"{self.log_prefix} LLM返回结果: {content}") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index d732683a..d0f6e774 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -3,6 +3,7 @@ import json import os import math import asyncio +from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict, List, Tuple import numpy as np @@ -11,8 +12,6 @@ import pandas as pd # import tqdm import faiss -# from .llm_client import LLMClient -# from .lpmmconfig import global_config from .utils.hash import get_sha256 from .global_logger import logger from rich.traceback import install @@ -26,12 +25,20 @@ from rich.progress import ( SpinnerColumn, TextColumn, ) -from src.manager.local_store_manager import local_storage from src.chat.utils.utils import get_embedding from src.config.config import global_config install(extra_lines=3) + +# 多线程embedding配置常量 +DEFAULT_MAX_WORKERS = 10 # 默认最大线程数 +DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小 +MIN_CHUNK_SIZE = 1 # 最小分块大小 +MAX_CHUNK_SIZE = 50 # 最大分块大小 +MIN_WORKERS = 1 # 最小线程数 +MAX_WORKERS = 20 # 最大线程数 + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") @@ -87,13 +94,23 @@ class EmbeddingStoreItem: class EmbeddingStore: - def __init__(self, namespace: str, dir_path: str): + def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): self.namespace = namespace self.dir = dir_path self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.index_file_path = f"{dir_path}/{namespace}.index" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" + # 多线程配置参数验证和设置 + self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers)) + self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size)) + + # 如果配置值被调整,记录日志 + if self.max_workers != max_workers: + logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})") + if self.chunk_size != chunk_size: + logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})") + self.store = {} self.faiss_index = None @@ -125,16 +142,134 @@ class EmbeddingStore: return [] return result + def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]: + """使用多线程批量获取嵌入向量 + + Args: + strs: 要获取嵌入的字符串列表 + chunk_size: 每个线程处理的数据块大小 + max_workers: 最大线程数 + progress_callback: 进度回调函数,接收一个参数表示完成的数量 + + Returns: + 包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致 + """ + if not strs: + return [] + + # 分块 + chunks = [] + for i in range(0, len(strs), chunk_size): + chunk = strs[i:i + chunk_size] + chunks.append((i, chunk)) # 保存起始索引以维持顺序 + + # 结果存储,使用字典按索引存储以保证顺序 + results = {} + + def process_chunk(chunk_data): + """处理单个数据块的函数""" + start_idx, chunk_strs = chunk_data + chunk_results = [] + + # 为每个线程创建独立的LLMRequest实例 + from src.llm_models.utils_model import LLMRequest + from src.config.config import model_config + + try: + # 创建线程专用的LLM实例 + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding") + + for i, s in enumerate(chunk_strs): + try: + # 直接使用异步函数 + embedding = asyncio.run(llm.get_embedding(s)) + if embedding and len(embedding) > 0: + chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量 + else: + logger.error(f"获取嵌入失败: {s}") + chunk_results.append((start_idx + i, s, [])) + + # 每完成一个嵌入立即更新进度 + if progress_callback: + progress_callback(1) + + except Exception as e: + logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}") + chunk_results.append((start_idx + i, s, [])) + + # 即使失败也要更新进度 + if progress_callback: + progress_callback(1) + + except Exception as e: + logger.error(f"创建LLM实例失败: {e}") + # 如果创建LLM实例失败,返回空结果 + for i, s in enumerate(chunk_strs): + chunk_results.append((start_idx + i, s, [])) + # 即使失败也要更新进度 + if progress_callback: + progress_callback(1) + + return chunk_results + + # 使用线程池处理 + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # 提交所有任务 + future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks} + + # 收集结果(进度已在process_chunk中实时更新) + for future in as_completed(future_to_chunk): + try: + chunk_results = future.result() + for idx, s, embedding in chunk_results: + results[idx] = (s, embedding) + except Exception as e: + chunk = future_to_chunk[future] + logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}") + # 为失败的块添加空结果 + start_idx, chunk_strs = chunk + for i, s in enumerate(chunk_strs): + results[start_idx + i] = (s, []) + + # 按原始顺序返回结果 + ordered_results = [] + for i in range(len(strs)): + if i in results: + ordered_results.append(results[i]) + else: + # 防止遗漏 + ordered_results.append((strs[i], [])) + + return ordered_results + def get_test_file_path(self): return EMBEDDING_TEST_FILE def save_embedding_test_vectors(self): - """保存测试字符串的嵌入到本地""" + """保存测试字符串的嵌入到本地(使用多线程优化)""" + logger.info("开始保存测试字符串的嵌入向量...") + + # 使用多线程批量获取测试字符串的嵌入 + embedding_results = self._get_embeddings_batch_threaded( + EMBEDDING_TEST_STRINGS, + chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), + max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) + ) + + # 构建测试向量字典 test_vectors = {} - for idx, s in enumerate(EMBEDDING_TEST_STRINGS): - test_vectors[str(idx)] = self._get_embedding(s) + for idx, (s, embedding) in enumerate(embedding_results): + if embedding: + test_vectors[str(idx)] = embedding + else: + logger.error(f"获取测试字符串嵌入失败: {s}") + # 使用原始单线程方法作为后备 + test_vectors[str(idx)] = self._get_embedding(s) + with open(self.get_test_file_path(), "w", encoding="utf-8") as f: json.dump(test_vectors, f, ensure_ascii=False, indent=2) + + logger.info("测试字符串嵌入向量保存完成") def load_embedding_test_vectors(self): """加载本地保存的测试字符串嵌入""" @@ -145,29 +280,64 @@ class EmbeddingStore: return json.load(f) def check_embedding_model_consistency(self): - """校验当前模型与本地嵌入模型是否一致""" + """校验当前模型与本地嵌入模型是否一致(使用多线程优化)""" local_vectors = self.load_embedding_test_vectors() if local_vectors is None: logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") self.save_embedding_test_vectors() return True - for idx, s in enumerate(EMBEDDING_TEST_STRINGS): - local_emb = local_vectors.get(str(idx)) - if local_emb is None: + + # 检查本地向量完整性 + for idx in range(len(EMBEDDING_TEST_STRINGS)): + if local_vectors.get(str(idx)) is None: logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") self.save_embedding_test_vectors() return True - new_emb = self._get_embedding(s) + + logger.info("开始检验嵌入模型一致性...") + + # 使用多线程批量获取当前模型的嵌入 + embedding_results = self._get_embeddings_batch_threaded( + EMBEDDING_TEST_STRINGS, + chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)), + max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS)) + ) + + # 检查一致性 + for idx, (s, new_emb) in enumerate(embedding_results): + local_emb = local_vectors.get(str(idx)) + if not new_emb: + logger.error(f"获取测试字符串嵌入失败: {s}") + return False + sim = cosine_similarity(local_emb, new_emb) if sim < EMBEDDING_SIM_THRESHOLD: - logger.error("嵌入模型一致性校验失败") + logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}") return False + logger.info("嵌入模型一致性校验通过。") return True def batch_insert_strs(self, strs: List[str], times: int) -> None: - """向库中存入字符串""" + """向库中存入字符串(使用多线程优化)""" + if not strs: + return + total = len(strs) + + # 过滤已存在的字符串 + new_strs = [] + for s in strs: + item_hash = self.namespace + "-" + get_sha256(s) + if item_hash not in self.store: + new_strs.append(s) + + if not new_strs: + logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理") + return + + logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串") + with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), @@ -181,19 +351,38 @@ class EmbeddingStore: transient=False, ) as progress: task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) - for s in strs: - # 计算hash去重 - item_hash = self.namespace + "-" + get_sha256(s) - if item_hash in self.store: - progress.update(task, advance=1) - continue - - # 获取embedding - embedding = self._get_embedding(s) - - # 存入 - self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) - progress.update(task, advance=1) + + # 首先更新已存在项的进度 + already_processed = total - len(new_strs) + if already_processed > 0: + progress.update(task, advance=already_processed) + + if new_strs: + # 使用实例配置的参数,智能调整分块和线程数 + optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size)) + optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1)) + + logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}") + + # 定义进度更新回调函数 + def update_progress(count): + progress.update(task, advance=count) + + # 批量获取嵌入,并实时更新进度 + embedding_results = self._get_embeddings_batch_threaded( + new_strs, + chunk_size=optimal_chunk_size, + max_workers=optimal_max_workers, + progress_callback=update_progress + ) + + # 存入结果(不再需要在这里更新进度,因为已经在回调中更新了) + for s, embedding in embedding_results: + item_hash = self.namespace + "-" + get_sha256(s) + if embedding: # 只有成功获取到嵌入才存入 + self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) + else: + logger.warning(f"跳过存储失败的嵌入: {s[:50]}...") def save_to_file(self) -> None: """保存到文件""" @@ -316,31 +505,37 @@ class EmbeddingStore: class EmbeddingManager: - def __init__(self): + def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE): + """ + 初始化EmbeddingManager + + Args: + max_workers: 最大线程数 + chunk_size: 每个线程处理的数据块大小 + """ self.paragraphs_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "paragraph", # type: ignore EMBEDDING_DATA_DIR_STR, + max_workers=max_workers, + chunk_size=chunk_size, ) self.entities_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "entity", # type: ignore EMBEDDING_DATA_DIR_STR, + max_workers=max_workers, + chunk_size=chunk_size, ) self.relation_embedding_store = EmbeddingStore( - local_storage["pg_namespace"], # type: ignore + "relation", # type: ignore EMBEDDING_DATA_DIR_STR, + max_workers=max_workers, + chunk_size=chunk_size, ) self.stored_pg_hashes = set() def check_all_embedding_model_consistency(self): """对所有嵌入库做模型一致性校验""" - for store in [ - self.paragraphs_embedding_store, - self.entities_embedding_store, - self.relation_embedding_store, - ]: - if not store.check_embedding_model_consistency(): - return False - return True + return self.paragraphs_embedding_store.check_embedding_model_consistency() def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): """将段落编码存入Embedding库""" diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index 16d4e080..340a678d 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -8,12 +8,15 @@ from . import prompt_template from .knowledge_lib import INVALID_ENTITY from src.llm_models.utils_model import LLMRequest from json_repair import repair_json + + def _extract_json_from_text(text: str): + # sourcery skip: assign-if-exp, extract-method """从文本中提取JSON数据的高容错方法""" if text is None: logger.error("输入文本为None") return [] - + try: fixed_json = repair_json(text) if isinstance(fixed_json, str): @@ -24,7 +27,7 @@ def _extract_json_from_text(text: str): # 如果是列表,直接返回 if isinstance(parsed_json, list): return parsed_json - + # 如果是字典且只有一个项目,可能包装了列表 if isinstance(parsed_json, dict): # 如果字典只有一个键,并且值是列表,返回那个列表 @@ -33,7 +36,7 @@ def _extract_json_from_text(text: str): if isinstance(value, list): return value return parsed_json - + # 其他情况,尝试转换为列表 logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}") return [] @@ -42,44 +45,40 @@ def _extract_json_from_text(text: str): logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...") return [] + def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: + # sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) - + # 使用 asyncio.run 来运行异步方法 try: # 如果当前已有事件循环在运行,使用它 loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe( - llm_req.generate_response_async(entity_extract_context), loop - ) - response, (reasoning_content, model_name) = future.result() + future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop) + response, _ = future.result() except RuntimeError: # 如果没有运行中的事件循环,直接使用 asyncio.run - response, (reasoning_content, model_name) = asyncio.run( - llm_req.generate_response_async(entity_extract_context) - ) + response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context)) # 添加调试日志 logger.debug(f"LLM返回的原始响应: {response}") - + entity_extract_result = _extract_json_from_text(response) - + # 检查返回的是否为有效的实体列表 if not isinstance(entity_extract_result, list): - # 如果不是列表,可能是字典格式,尝试从中提取列表 - if isinstance(entity_extract_result, dict): - # 尝试常见的键名 - for key in ['entities', 'result', 'data', 'items']: - if key in entity_extract_result and isinstance(entity_extract_result[key], list): - entity_extract_result = entity_extract_result[key] - break - else: - # 如果找不到合适的列表,抛出异常 - raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") + if not isinstance(entity_extract_result, dict): + raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") + + # 尝试常见的键名 + for key in ["entities", "result", "data", "items"]: + if key in entity_extract_result and isinstance(entity_extract_result[key], list): + entity_extract_result = entity_extract_result[key] + break else: - raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") - + # 如果找不到合适的列表,抛出异常 + raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") # 过滤无效实体 entity_extract_result = [ entity @@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY) ] - if len(entity_extract_result) == 0: - raise Exception("实体提取结果为空") + if not entity_extract_result: + raise ValueError("实体提取结果为空") return entity_extract_result @@ -98,45 +97,44 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - + # 使用 asyncio.run 来运行异步方法 try: # 如果当前已有事件循环在运行,使用它 loop = asyncio.get_running_loop() - future = asyncio.run_coroutine_threadsafe( - llm_req.generate_response_async(rdf_extract_context), loop - ) - response, (reasoning_content, model_name) = future.result() + future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop) + response, _ = future.result() except RuntimeError: # 如果没有运行中的事件循环,直接使用 asyncio.run - response, (reasoning_content, model_name) = asyncio.run( - llm_req.generate_response_async(rdf_extract_context) - ) + response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context)) # 添加调试日志 logger.debug(f"RDF LLM返回的原始响应: {response}") - + rdf_triple_result = _extract_json_from_text(response) - + # 检查返回的是否为有效的三元组列表 if not isinstance(rdf_triple_result, list): - # 如果不是列表,可能是字典格式,尝试从中提取列表 - if isinstance(rdf_triple_result, dict): - # 尝试常见的键名 - for key in ['triples', 'result', 'data', 'items']: - if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): - rdf_triple_result = rdf_triple_result[key] - break - else: - # 如果找不到合适的列表,抛出异常 - raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") + if not isinstance(rdf_triple_result, dict): + raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") + + # 尝试常见的键名 + for key in ["triples", "result", "data", "items"]: + if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): + rdf_triple_result = rdf_triple_result[key] + break else: - raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") - + # 如果找不到合适的列表,抛出异常 + raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") # 验证三元组格式 for triple in rdf_triple_result: - if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: - raise Exception("RDF提取结果格式错误") + if ( + not isinstance(triple, list) + or len(triple) != 3 + or (triple[0] is None or triple[1] is None or triple[2] is None) + or "" in triple + ): + raise ValueError("RDF提取结果格式错误") return rdf_triple_result diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index 083a741d..da082e39 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -20,8 +20,7 @@ from quick_algo import di_graph, pagerank from .utils.hash import get_sha256 from .embedding_store import EmbeddingManager, EmbeddingStoreItem -from .lpmmconfig import global_config -from src.manager.local_store_manager import local_storage +from src.config.config import global_config from .global_logger import logger @@ -30,19 +29,9 @@ def _get_kg_dir(): """ 安全地获取KG数据目录路径 """ - root_path: str = local_storage["root_path"] - if root_path is None: - # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用 - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) - logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}") - - # 获取RAG数据目录 - rag_data_dir: str = global_config["persistence"]["rag_data_dir"] - if rag_data_dir is None: - kg_dir = os.path.join(root_path, "data/rag") - else: - kg_dir = os.path.join(root_path, rag_data_dir) + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) + kg_dir = os.path.join(root_path, "data/rag") return str(kg_dir).replace("\\", "/") @@ -65,9 +54,9 @@ class KGManager: # 持久化相关 - 使用延迟初始化的路径 self.dir_path = get_kg_dir_str() - self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" + self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -122,8 +111,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) - hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) + hash_key1 = "entity" + "-" + get_sha256(triple[0]) + hash_key2 = "entity" + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -141,8 +130,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) - pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) + ent_hash_key = "entity" + "-" + get_sha256(triple[0]) + pg_hash_key = "paragraph" + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -157,12 +146,12 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) - ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) + ent_hash_list.add("entity" + "-" + get_sha256(triple[0])) + ent_hash_list.add("entity" + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() - synonym_result = dict() + synonym_result = {} # rich 进度条 total = len(ent_hash_list) @@ -190,14 +179,14 @@ class KGManager: assert isinstance(ent, EmbeddingStoreItem) # 查询相似实体 similar_ents = embedding_manager.entities_embedding_store.search_top_k( - ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] + ent.embedding, global_config.lpmm_knowledge.rag_synonym_search_top_k ) res_ent = [] # Debug for res_ent_hash, similarity in similar_ents: if res_ent_hash == ent_hash: # 避免自连接 continue - if similarity < global_config["rag"]["params"]["synonym_threshold"]: + if similarity < global_config.lpmm_knowledge.rag_synonym_threshold: # 相似度阈值 continue node_to_node[(res_ent_hash, ent_hash)] = similarity @@ -263,7 +252,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(local_storage["ent_namespace"]): + if node_hash.startswith("entity"): # 新增实体节点 node = embedding_manager.entities_embedding_store.store.get(node_hash) if node is None: @@ -275,7 +264,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(local_storage["pg_namespace"]): + elif node_hash.startswith("paragraph"): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) if node is None: @@ -359,7 +348,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) + ent_hash = "entity" + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -380,7 +369,7 @@ class KGManager: for ent_hash in ent_weights.keys(): ent_weights[ent_hash] = 1.0 else: - down_edge = global_config["qa"]["params"]["paragraph_node_weight"] + down_edge = global_config.lpmm_knowledge.qa_paragraph_node_weight # 缩放取值区间至[down_edge, 1] for ent_hash, score in ent_weights.items(): # 缩放相似度 @@ -389,7 +378,7 @@ class KGManager: ) + down_edge # 取平均相似度的top_k实体 - top_k = global_config["qa"]["params"]["ent_filter_top_k"] + top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k if len(ent_mean_scores) > top_k: # 从大到小排序,取后len - k个 ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)} @@ -418,7 +407,7 @@ class KGManager: for pg_hash, score in pg_sim_scores.items(): pg_weights[pg_hash] = ( - score * global_config["qa"]["params"]["paragraph_node_weight"] + score * global_config.lpmm_knowledge.qa_paragraph_node_weight ) # 文段权重 = 归一化相似度 * 文段节点权重参数 del pg_sim_scores @@ -431,7 +420,7 @@ class KGManager: self.graph, personalization=ppr_node_weights, max_iter=100, - alpha=global_config["qa"]["params"]["ppr_damping"], + alpha=global_config.lpmm_knowledge.qa_ppr_damping, ) # 获取最终结果 @@ -439,7 +428,7 @@ class KGManager: passage_node_res = [ (node_key, score) for node_key, score in ppr_res.items() - if node_key.startswith(local_storage["pg_namespace"]) + if node_key.startswith("paragraph") ] del ppr_res diff --git a/src/chat/knowledge/knowledge_lib.py b/src/chat/knowledge/knowledge_lib.py index 1e87d382..13629f18 100644 --- a/src/chat/knowledge/knowledge_lib.py +++ b/src/chat/knowledge/knowledge_lib.py @@ -1,12 +1,8 @@ -from src.chat.knowledge.lpmmconfig import global_config from src.chat.knowledge.embedding_store import EmbeddingManager -from src.chat.knowledge.llm_client import LLMClient -from src.chat.knowledge.mem_active_manager import MemoryActiveManager from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.global_logger import logger -from src.config.config import global_config as bot_global_config -from src.manager.local_store_manager import local_storage +from src.config.config import global_config import os INVALID_ENTITY = [ @@ -21,9 +17,6 @@ INVALID_ENTITY = [ "她们", "它们", ] -PG_NAMESPACE = "paragraph" -ENT_NAMESPACE = "entity" -REL_NAMESPACE = "relation" RAG_GRAPH_NAMESPACE = "rag-graph" RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" @@ -34,67 +27,13 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", DATA_PATH = os.path.join(ROOT_PATH, "data") -def _initialize_knowledge_local_storage(): - """ - 初始化知识库相关的本地存储配置 - 使用字典批量设置,避免重复的if判断 - """ - # 定义所有需要初始化的配置项 - default_configs = { - # 路径配置 - "root_path": ROOT_PATH, - "data_path": f"{ROOT_PATH}/data", - # 实体和命名空间配置 - "lpmm_invalid_entity": INVALID_ENTITY, - "pg_namespace": PG_NAMESPACE, - "ent_namespace": ENT_NAMESPACE, - "rel_namespace": REL_NAMESPACE, - # RAG相关命名空间配置 - "rag_graph_namespace": RAG_GRAPH_NAMESPACE, - "rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE, - "rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE, - } - - # 日志级别映射:重要配置用info,其他用debug - important_configs = {"root_path", "data_path"} - - # 批量设置配置项 - initialized_count = 0 - for key, default_value in default_configs.items(): - if local_storage[key] is None: - local_storage[key] = default_value - - # 根据重要性选择日志级别 - if key in important_configs: - logger.info(f"设置{key}: {default_value}") - else: - logger.debug(f"设置{key}: {default_value}") - - initialized_count += 1 - - if initialized_count > 0: - logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置") - else: - logger.debug("知识库本地存储配置已存在,跳过初始化") - - -# 初始化本地存储路径 -# sourcery skip: dict-comprehension -_initialize_knowledge_local_storage() - qa_manager = None inspire_manager = None # 检查LPMM知识库是否启用 -if bot_global_config.lpmm_knowledge.enable: +if global_config.lpmm_knowledge.enable: logger.info("正在初始化Mai-LPMM") logger.info("创建LLM客户端") - llm_client_list = {} - for key in global_config["llm_providers"]: - llm_client_list[key] = LLMClient( - global_config["llm_providers"][key]["base_url"], # type: ignore - global_config["llm_providers"][key]["api_key"], # type: ignore - ) # 初始化Embedding库 embed_manager = EmbeddingManager() @@ -120,7 +59,7 @@ if bot_global_config.lpmm_knowledge.enable: # 数据比对:Embedding库与KG的段落hash集合 for pg_hash in kg_manager.stored_paragraph_hashes: - key = f"{PG_NAMESPACE}-{pg_hash}" + key = f"paragraph-{pg_hash}" if key not in embed_manager.stored_pg_hashes: logger.warning(f"KG中存在Embedding库中不存在的段落:{key}") @@ -130,11 +69,11 @@ if bot_global_config.lpmm_knowledge.enable: kg_manager, ) - # 记忆激活(用于记忆库) - inspire_manager = MemoryActiveManager( - embed_manager, - llm_client_list[global_config["embedding"]["provider"]], - ) + # # 记忆激活(用于记忆库) + # inspire_manager = MemoryActiveManager( + # embed_manager, + # llm_client_list[global_config["embedding"]["provider"]], + # ) else: logger.info("LPMM知识库已禁用,跳过初始化") # 创建空的占位符对象,避免导入错误 diff --git a/src/chat/knowledge/llm_client.py b/src/chat/knowledge/llm_client.py deleted file mode 100644 index 52d0dca0..00000000 --- a/src/chat/knowledge/llm_client.py +++ /dev/null @@ -1,45 +0,0 @@ -from openai import OpenAI - - -class LLMMessage: - def __init__(self, role, content): - self.role = role - self.content = content - - def to_dict(self): - return {"role": self.role, "content": self.content} - - -class LLMClient: - """LLM客户端,对应一个API服务商""" - - def __init__(self, url, api_key): - self.client = OpenAI( - base_url=url, - api_key=api_key, - ) - - def send_chat_request(self, model, messages): - """发送对话请求,等待返回结果""" - response = self.client.chat.completions.create(model=model, messages=messages, stream=False) - if hasattr(response.choices[0].message, "reasoning_content"): - # 有单独的推理内容块 - reasoning_content = response.choices[0].message.reasoning_content - content = response.choices[0].message.content - else: - # 无单独的推理内容块 - response = response.choices[0].message.content.split("")[-1].split("") - # 如果有推理内容,则分割推理内容和内容 - if len(response) == 2: - reasoning_content = response[0] - content = response[1] - else: - reasoning_content = None - content = response[0] - - return reasoning_content, content - - def send_embedding_request(self, model, text): - """发送嵌入请求,等待返回结果""" - text = text.replace("\n", " ") - return self.client.embeddings.create(input=[text], model=model).data[0].embedding diff --git a/src/chat/knowledge/lpmmconfig.py b/src/chat/knowledge/lpmmconfig.py deleted file mode 100644 index 49f77725..00000000 --- a/src/chat/knowledge/lpmmconfig.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import toml -import sys - -# import argparse -from .global_logger import logger - -PG_NAMESPACE = "paragraph" -ENT_NAMESPACE = "entity" -REL_NAMESPACE = "relation" - -RAG_GRAPH_NAMESPACE = "rag-graph" -RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" -RAG_PG_HASH_NAMESPACE = "rag-pg-hash" - -# 无效实体 -INVALID_ENTITY = [ - "", - "你", - "他", - "她", - "它", - "我们", - "你们", - "他们", - "她们", - "它们", -] - - -def _load_config(config, config_file_path): - """读取TOML格式的配置文件""" - if not os.path.exists(config_file_path): - return - with open(config_file_path, "r", encoding="utf-8") as f: - file_config = toml.load(f) - - # Check if all top-level keys from default config exist in the file config - for key in config.keys(): - if key not in file_config: - logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。") - logger.critical("请通过template/lpmm_config_template.toml文件进行更新") - sys.exit(1) - - if "llm_providers" in file_config: - for provider in file_config["llm_providers"]: - if provider["name"] not in config["llm_providers"]: - config["llm_providers"][provider["name"]] = {} - config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"] - config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"] - - if "entity_extract" in file_config: - config["entity_extract"] = file_config["entity_extract"] - - if "rdf_build" in file_config: - config["rdf_build"] = file_config["rdf_build"] - - if "embedding" in file_config: - config["embedding"] = file_config["embedding"] - - if "rag" in file_config: - config["rag"] = file_config["rag"] - - if "qa" in file_config: - config["qa"] = file_config["qa"] - - if "persistence" in file_config: - config["persistence"] = file_config["persistence"] - # print(config) - logger.info(f"从文件中读取配置: {config_file_path}") - - -global_config = dict( - { - "lpmm": { - "version": "0.1.0", - }, - "llm_providers": { - "localhost": { - "base_url": "https://api.siliconflow.cn/v1", - "api_key": "sk-ospynxadyorf", - } - }, - "entity_extract": { - "llm": { - "provider": "localhost", - "model": "Pro/deepseek-ai/DeepSeek-V3", - } - }, - "rdf_build": { - "llm": { - "provider": "localhost", - "model": "Pro/deepseek-ai/DeepSeek-V3", - } - }, - "embedding": { - "provider": "localhost", - "model": "Pro/BAAI/bge-m3", - "dimension": 1024, - }, - "rag": { - "params": { - "synonym_search_top_k": 10, - "synonym_threshold": 0.75, - } - }, - "qa": { - "params": { - "relation_search_top_k": 10, - "relation_threshold": 0.75, - "paragraph_search_top_k": 10, - "paragraph_node_weight": 0.05, - "ent_filter_top_k": 10, - "ppr_damping": 0.8, - "res_top_k": 10, - }, - "llm": { - "provider": "localhost", - "model": "qa", - }, - }, - "persistence": { - "data_root_path": "data", - "raw_data_path": "data/raw.json", - "openie_data_path": "data/openie.json", - "embedding_data_dir": "data/embedding", - "rag_data_dir": "data/rag", - }, - "info_extraction": { - "workers": 10, - }, - } -) - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml") -_load_config(global_config, config_path) diff --git a/src/chat/knowledge/mem_active_manager.py b/src/chat/knowledge/mem_active_manager.py index 3998c066..a55b929f 100644 --- a/src/chat/knowledge/mem_active_manager.py +++ b/src/chat/knowledge/mem_active_manager.py @@ -1,3 +1,4 @@ +raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it") from .lpmmconfig import global_config from .embedding_store import EmbeddingManager from .llm_client import LLMClient diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index c83683b7..5354447a 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -2,16 +2,14 @@ import time from typing import Tuple, List, Dict, Optional from .global_logger import logger - -# from . import prompt_template from .embedding_store import EmbeddingManager -# from .llm_client import LLMClient from .kg_manager import KGManager + # from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k from src.llm_models.utils_model import LLMRequest from src.chat.utils.utils import get_embedding -from src.config.config import global_config +from src.config.config import global_config, model_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -21,17 +19,12 @@ class QAManager: self, embed_manager: EmbeddingManager, kg_manager: KGManager, - ): self.embed_manager = embed_manager self.kg_manager = kg_manager - # TODO: API-Adapter修改标记 - self.qa_model = LLMRequest( - model=global_config.model.lpmm_qa, - request_type="lpmm.qa" - ) + self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa") - async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: + async def process_query(self, question: str) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]: """处理查询""" # 生成问题的Embedding @@ -49,66 +42,70 @@ class QAManager: question_embedding, global_config.lpmm_knowledge.qa_relation_search_top_k, ) - if relation_search_res is not None: - # 过滤阈值 - # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 - relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) - if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: - # 未找到相关关系 - logger.debug("未找到相关关系,跳过关系检索") - relation_search_res = [] + if relation_search_res is None: + return None + # 过滤阈值 + # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 + relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) + if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: + # 未找到相关关系 + logger.debug("未找到相关关系,跳过关系检索") + relation_search_res = [] - part_end_time = time.perf_counter() - logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") + part_end_time = time.perf_counter() + logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") - for res in relation_search_res: - rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str - print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") + for res in relation_search_res: + rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str + print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") - # TODO: 使用LLM过滤三元组结果 - # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s") - # part_start_time = time.time() + # TODO: 使用LLM过滤三元组结果 + # logger.info(f"LLM过滤三元组用时:{time.time() - part_start_time:.2f}s") + # part_start_time = time.time() - # 根据问题Embedding查询Paragraph Embedding库 + # 根据问题Embedding查询Paragraph Embedding库 + part_start_time = time.perf_counter() + paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( + question_embedding, + global_config.lpmm_knowledge.qa_paragraph_search_top_k, + ) + part_end_time = time.perf_counter() + logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") + + if len(relation_search_res) != 0: + logger.info("找到相关关系,将使用RAG进行检索") + # 使用KG检索 part_start_time = time.perf_counter() - paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( - question_embedding, - global_config.lpmm_knowledge.qa_paragraph_search_top_k, + result, ppr_node_weights = self.kg_manager.kg_search( + relation_search_res, paragraph_search_res, self.embed_manager ) part_end_time = time.perf_counter() - logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") - - if len(relation_search_res) != 0: - logger.info("找到相关关系,将使用RAG进行检索") - # 使用KG检索 - part_start_time = time.perf_counter() - result, ppr_node_weights = self.kg_manager.kg_search( - relation_search_res, paragraph_search_res, self.embed_manager - ) - part_end_time = time.perf_counter() - logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s") - else: - logger.info("未找到相关关系,将使用文段检索结果") - result = paragraph_search_res - ppr_node_weights = None - - # 过滤阈值 - result = dyn_select_top_k(result, 0.5, 1.0) - - for res in result: - raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str - print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") - - return result, ppr_node_weights + logger.info(f"RAG检索用时:{part_end_time - part_start_time:.5f}s") else: - return None + logger.info("未找到相关关系,将使用文段检索结果") + result = paragraph_search_res + ppr_node_weights = None - async def get_knowledge(self, question: str) -> str: + # 过滤阈值 + result = dyn_select_top_k(result, 0.5, 1.0) + + for res in result: + raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str + print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n") + + return result, ppr_node_weights + + async def get_knowledge(self, question: str) -> Optional[str]: """获取知识""" # 处理查询 processed_result = await self.process_query(question) if processed_result is not None: query_res = processed_result[0] + # 检查查询结果是否为空 + if not query_res: + logger.debug("知识库查询结果为空,可能是知识库中没有相关内容") + return None + knowledge = [ ( self.embed_manager.paragraphs_embedding_store.store[res[0]].str, diff --git a/src/chat/knowledge/raw_processing.py b/src/chat/knowledge/raw_processing.py deleted file mode 100644 index 98b1f168..00000000 --- a/src/chat/knowledge/raw_processing.py +++ /dev/null @@ -1,48 +0,0 @@ -import json -import os - -from .global_logger import logger -from .lpmmconfig import global_config -from src.chat.knowledge.utils.hash import get_sha256 - - -def load_raw_data(path: str = None) -> tuple[list[str], list[str]]: - """加载原始数据文件 - - 读取原始数据文件,将原始数据加载到内存中 - - Args: - path: 可选,指定要读取的json文件绝对路径 - - Returns: - - raw_data: 原始数据列表 - - sha256_list: 原始数据的SHA256集合 - """ - # 读取指定路径或默认路径的json文件 - json_path = path if path else global_config["persistence"]["raw_data_path"] - if os.path.exists(json_path): - with open(json_path, "r", encoding="utf-8") as f: - import_json = json.loads(f.read()) - else: - raise Exception(f"原始数据文件读取失败: {json_path}") - """ - import_json 内容示例: - import_json = ["The capital of China is Beijing. The capital of France is Paris.",] - """ - raw_data = [] - sha256_list = [] - sha256_set = set() - for item in import_json: - if not isinstance(item, str): - logger.warning("数据类型错误:{}".format(item)) - continue - pg_hash = get_sha256(item) - if pg_hash in sha256_set: - logger.warning("重复数据:{}".format(item)) - continue - sha256_set.add(pg_hash) - sha256_list.append(pg_hash) - raw_data.append(item) - logger.info("共读取到{}条数据".format(len(raw_data))) - - return sha256_list, raw_data diff --git a/src/chat/knowledge/utils/dyn_topk.py b/src/chat/knowledge/utils/dyn_topk.py index eb40ef3a..5304934f 100644 --- a/src/chat/knowledge/utils/dyn_topk.py +++ b/src/chat/knowledge/utils/dyn_topk.py @@ -5,6 +5,10 @@ def dyn_select_top_k( score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float ) -> List[Tuple[Any, float, float]]: """动态TopK选择""" + # 检查输入列表是否为空 + if not score: + return [] + # 按照分数排序(降序) sorted_score = sorted(score, key=lambda x: x[1], reverse=True) diff --git a/src/chat/knowledge/utils/visualize_graph.py b/src/chat/knowledge/utils/visualize_graph.py deleted file mode 100644 index 7ca9b7e6..00000000 --- a/src/chat/knowledge/utils/visualize_graph.py +++ /dev/null @@ -1,17 +0,0 @@ -import networkx as nx -from matplotlib import pyplot as plt - - -def draw_graph_and_show(graph): - """绘制图并显示,画布大小1280*1280""" - fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100) - nx.draw_networkx( - graph, - node_size=100, - width=0.5, - with_labels=True, - labels=nx.get_node_attributes(graph, "content"), - font_family="Sarasa Mono SC", - font_size=8, - ) - fig.show() diff --git a/src/chat/memory_system/Hippocampus.py b/src/chat/memory_system/Hippocampus.py index 26660e5c..fe3c2562 100644 --- a/src/chat/memory_system/Hippocampus.py +++ b/src/chat/memory_system/Hippocampus.py @@ -5,25 +5,27 @@ import random import time import re import json -from itertools import combinations - import jieba import networkx as nx import numpy as np + +from itertools import combinations +from typing import List, Tuple, Coroutine, Any, Set from collections import Counter -from ...llm_models.utils_model import LLMRequest +from rich.traceback import install + +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config, model_config +from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 from src.common.logger import get_logger from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 -from ..utils.chat_message_builder import ( +from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp, build_readable_messages, get_raw_msg_by_timestamp_with_chat, ) # 导入 build_readable_messages -from ..utils.utils import translate_timestamp_to_human_readable -from rich.traceback import install +from src.chat.utils.utils import translate_timestamp_to_human_readable -from ...config.config import global_config -from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入 install(extra_lines=3) @@ -198,8 +200,7 @@ class Hippocampus: self.parahippocampal_gyrus = ParahippocampalGyrus(self) # 从数据库加载记忆图 self.entorhinal_cortex.sync_memory_from_db() - # TODO: API-Adapter修改标记 - self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder") + self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder") def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表""" @@ -339,9 +340,7 @@ class Hippocampus: else: topic_num = 5 # 51+字符: 5个关键词 (其余长文本) - topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( - self.find_topic_llm(text, topic_num) - ) + topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num)) # 提取关键词 keywords = re.findall(r"<([^>]+)>", topics_response) @@ -353,12 +352,11 @@ class Hippocampus: for keyword in ",".join(keywords).replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if keyword.strip() ] - + if keywords: logger.info(f"提取关键词: {keywords}") - - return keywords - + + return keywords async def get_memory_from_text( self, @@ -1245,7 +1243,7 @@ class ParahippocampalGyrus: # 2. 使用LLM提取关键主题 topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) - topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async( + topics_response, _ = await self.hippocampus.model_summary.generate_response_async( self.hippocampus.find_topic_llm(input_text, topic_num) ) @@ -1269,7 +1267,7 @@ class ParahippocampalGyrus: logger.debug(f"过滤后话题: {filtered_topics}") # 4. 创建所有话题的摘要生成任务 - tasks = [] + tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List | None]]]]] = [] for topic in filtered_topics: # 调用修改后的 topic_what,不再需要 time_info topic_what_prompt = self.hippocampus.topic_what(input_text, topic) @@ -1281,7 +1279,7 @@ class ParahippocampalGyrus: continue # 等待所有任务完成 - compressed_memory = set() + compressed_memory: Set[Tuple[str, str]] = set() similar_topics_dict = {} for topic, task in tasks: diff --git a/src/chat/memory_system/instant_memory.py b/src/chat/memory_system/instant_memory.py index f7e54f8e..a702a87e 100644 --- a/src/chat/memory_system/instant_memory.py +++ b/src/chat/memory_system/instant_memory.py @@ -3,13 +3,16 @@ import time import re import json import ast -from json_repair import repair_json -from src.llm_models.utils_model import LLMRequest -from src.common.logger import get_logger import traceback -from src.config.config import global_config +from json_repair import repair_json +from datetime import datetime, timedelta + +from src.llm_models.utils_model import LLMRequest +from src.common.logger import get_logger from src.common.database.database_model import Memory # Peewee Models导入 +from src.config.config import model_config + logger = get_logger(__name__) @@ -35,8 +38,7 @@ class InstantMemory: self.chat_id = chat_id self.last_view_time = time.time() self.summary_model = LLMRequest( - model=global_config.model.memory, - temperature=0.5, + model_set=model_config.model_task_config.memory, request_type="memory.summary", ) @@ -48,14 +50,11 @@ class InstantMemory: """ try: - response, _ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) print(prompt) print(response) - if "1" in response: - return True - else: - return False + return "1" in response except Exception as e: logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") return False @@ -71,9 +70,9 @@ class InstantMemory: }} """ try: - response, _ = await self.summary_model.generate_response_async(prompt) - print(prompt) - print(response) + response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) + # print(prompt) + # print(response) if not response: return None try: @@ -142,7 +141,7 @@ class InstantMemory: 请只输出json格式,不要输出其他多余内容 """ try: - response, _ = await self.summary_model.generate_response_async(prompt) + response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5) print(prompt) print(response) if not response: @@ -177,7 +176,7 @@ class InstantMemory: for mem in query: # 对每条记忆 - mem_keywords = mem.keywords or [] + mem_keywords = mem.keywords or "" parsed = ast.literal_eval(mem_keywords) if isinstance(parsed, list): mem_keywords = [str(k).strip() for k in parsed if str(k).strip()] @@ -201,6 +200,7 @@ class InstantMemory: return None def _parse_time_range(self, time_str): + # sourcery skip: extract-duplicate-method, use-contextlib-suppress """ 支持解析如下格式: - 具体日期时间:YYYY-MM-DD HH:MM:SS @@ -208,8 +208,6 @@ class InstantMemory: - 相对时间:今天,昨天,前天,N天前,N个月前 - 空字符串:返回(None, None) """ - from datetime import datetime, timedelta - now = datetime.now() if not time_str: return 0, now @@ -239,14 +237,12 @@ class InstantMemory: start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0) end = start + timedelta(days=1) return start, end - m = re.match(r"(\d+)天前", time_str) - if m: + if m := re.match(r"(\d+)天前", time_str): days = int(m.group(1)) start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0) end = start + timedelta(days=1) return start, end - m = re.match(r"(\d+)个月前", time_str) - if m: + if m := re.match(r"(\d+)个月前", time_str): months = int(m.group(1)) # 近似每月30天 start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) diff --git a/src/chat/memory_system/memory_activator.py b/src/chat/memory_system/memory_activator.py index 715d9c06..d3cbb5d7 100644 --- a/src/chat/memory_system/memory_activator.py +++ b/src/chat/memory_system/memory_activator.py @@ -1,13 +1,15 @@ -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config -from src.common.logger import get_logger -from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from datetime import datetime -from src.chat.memory_system.Hippocampus import hippocampus_manager -from typing import List, Dict import difflib import json + from json_repair import repair_json +from typing import List, Dict +from datetime import datetime + +from src.llm_models.utils_model import LLMRequest +from src.config.config import global_config, model_config +from src.common.logger import get_logger +from src.chat.utils.prompt_builder import Prompt, global_prompt_manager +from src.chat.memory_system.Hippocampus import hippocampus_manager logger = get_logger("memory_activator") @@ -61,11 +63,8 @@ def init_prompt(): class MemoryActivator: def __init__(self): - # TODO: API-Adapter修改标记 - self.key_words_model = LLMRequest( - model=global_config.model.utils_small, - temperature=0.5, + model_set=model_config.model_task_config.utils_small, request_type="memory.activator", ) @@ -92,7 +91,9 @@ class MemoryActivator: # logger.debug(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt) + response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async( + prompt, temperature=0.5 + ) keywords = list(get_keywords_from_json(response)) diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 7a18dcf0..58dd6d68 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -203,7 +203,7 @@ class MessageRecvS4U(MessageRecv): self.is_superchat = False self.gift_info = None self.gift_name = None - self.gift_count = None + self.gift_count: Optional[str] = None self.superchat_info = None self.superchat_price = None self.superchat_message_text = None @@ -444,7 +444,7 @@ class MessageSending(MessageProcessBase): is_emoji: bool = False, thinking_start_time: float = 0, apply_set_reply_logic: bool = False, - reply_to: str = None, # type: ignore + reply_to: Optional[str] = None, ): # 调用父类初始化 super().__init__( diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index 21d47c75..267b7a8f 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -1,9 +1,10 @@ from typing import Dict, Optional, Type -from src.plugin_system.base.base_action import BaseAction + from src.chat.message_receive.chat_stream import ChatStream from src.common.logger import get_logger from src.plugin_system.core.component_registry import component_registry from src.plugin_system.base.component_types import ComponentType, ActionInfo +from src.plugin_system.base.base_action import BaseAction logger = get_logger("action_manager") diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index da11c54f..dfa4c79c 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -5,7 +5,7 @@ import time from typing import List, Any, Dict, TYPE_CHECKING, Tuple from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.planner_actions.action_manager import ActionManager @@ -36,10 +36,7 @@ class ActionModifier: self.action_manager = action_manager # 用于LLM判定的小模型 - self.llm_judge = LLMRequest( - model=global_config.model.utils_small, - request_type="action.judge", - ) + self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge") # 缓存相关属性 self._llm_judge_cache = {} # 缓存LLM判定结果 @@ -438,4 +435,4 @@ class ActionModifier: return True else: logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}") - return False \ No newline at end of file + return False diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 0b26a97d..85dd5e63 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -7,7 +7,7 @@ from datetime import datetime from json_repair import repair_json from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.common.logger import get_logger from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import ( @@ -36,8 +36,6 @@ def init_prompt(): {chat_context_description},以下是具体的聊天内容 {chat_content_block} - - {moderation_prompt} 现在请你根据{by_what}选择合适的action和触发action的消息: @@ -73,10 +71,7 @@ class ActionPlanner: self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.action_manager = action_manager # LLM规划器配置 - self.planner_llm = LLMRequest( - model=global_config.model.planner, - request_type="planner", # 用于动作规划 - ) + self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划 self.last_obs_time_mark = 0.0 @@ -140,7 +135,7 @@ class ActionPlanner: # --- 调用 LLM (普通文本生成) --- llm_content = None try: - llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) + llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index cab6a2b4..c2b6e1cb 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -8,7 +8,8 @@ from typing import List, Optional, Dict, Any, Tuple from datetime import datetime from src.mais4u.mai_think import mai_thinking_manager from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config +from src.config.api_ada_configs import TaskConfig from src.individuality.individuality import get_individuality from src.llm_models.utils_model import LLMRequest from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending @@ -23,14 +24,13 @@ from src.chat.utils.chat_message_builder import ( replace_user_references_sync, ) from src.chat.express.expression_selector import expression_selector -from src.chat.knowledge.knowledge_lib import qa_manager from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.instant_memory import InstantMemory from src.mood.mood_manager import mood_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.person_info import get_person_info_manager -from src.tools.tool_executor import ToolExecutor from src.plugin_system.base.component_types import ActionInfo +from src.plugin_system.apis import llm_api logger = get_logger("replyer") @@ -40,7 +40,7 @@ def init_prompt(): Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") Prompt("在群里聊天", "chat_target_group2") Prompt("和{sender_name}聊天", "chat_target_private2") - + Prompt( """ {expression_habits_block} @@ -102,36 +102,57 @@ def init_prompt(): "s4u_style_prompt", ) + Prompt( + """ +你是一个专门获取知识的助手。你的名字是{bot_name}。现在是{time_now}。 +群里正在进行的聊天内容: +{chat_history} + +现在,{sender}发送了内容:{target_message},你想要回复ta。 +请仔细分析聊天内容,考虑以下几点: +1. 内容中是否包含需要查询信息的问题 +2. 是否有明确的知识获取指令 + +If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed". +""", + name="lpmm_get_knowledge_prompt", + ) + class DefaultReplyer: def __init__( self, chat_stream: ChatStream, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "focus.replyer", ): self.request_type = request_type - if model_configs: - self.express_model_configs = model_configs + if model_set_with_weight: + # self.express_model_configs = model_configs + self.model_set: List[Tuple[TaskConfig, float]] = model_set_with_weight else: # 当未提供配置时,使用默认配置并赋予默认权重 - model_config_1 = global_config.model.replyer_1.copy() - model_config_2 = global_config.model.replyer_2.copy() + # model_config_1 = global_config.model.replyer_1.copy() + # model_config_2 = global_config.model.replyer_2.copy() prob_first = global_config.chat.replyer_random_probability - model_config_1["weight"] = prob_first - model_config_2["weight"] = 1.0 - prob_first + # model_config_1["weight"] = prob_first + # model_config_2["weight"] = 1.0 - prob_first - self.express_model_configs = [model_config_1, model_config_2] + # self.express_model_configs = [model_config_1, model_config_2] + self.model_set = [ + (model_config.model_task_config.replyer_1, prob_first), + (model_config.model_task_config.replyer_2, 1.0 - prob_first), + ] - if not self.express_model_configs: - logger.warning("未找到有效的模型配置,回复生成可能会失败。") - # 提供一个最终的回退,以防止在空列表上调用 random.choice - fallback_config = global_config.model.replyer_1.copy() - fallback_config.setdefault("weight", 1.0) - self.express_model_configs = [fallback_config] + # if not self.express_model_configs: + # logger.warning("未找到有效的模型配置,回复生成可能会失败。") + # # 提供一个最终的回退,以防止在空列表上调用 random.choice + # fallback_config = global_config.model.replyer_1.copy() + # fallback_config.setdefault("weight", 1.0) + # self.express_model_configs = [fallback_config] self.chat_stream = chat_stream self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) @@ -139,13 +160,16 @@ class DefaultReplyer: self.heart_fc_sender = HeartFCSender() self.memory_activator = MemoryActivator() self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id) + + from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor,不然会循环依赖 + self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) - def _select_weighted_model_config(self) -> Dict[str, Any]: + def _select_weighted_models_config(self) -> Tuple[TaskConfig, float]: """使用加权随机选择来挑选一个模型配置""" - configs = self.express_model_configs + configs = self.model_set # 提取权重,如果模型配置中没有'weight'键,则默认为1.0 - weights = [config.get("weight", 1.0) for config in configs] + weights = [weight for _, weight in configs] return random.choices(population=configs, weights=weights, k=1)[0] @@ -155,18 +179,16 @@ class DefaultReplyer: extra_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, enable_tool: bool = True, - enable_timeout: bool = False, ) -> Tuple[bool, Optional[str], Optional[str]]: """ 回复器 (Replier): 负责生成回复文本的核心逻辑。 - + Args: reply_to: 回复对象,格式为 "发送者:消息内容" extra_info: 额外信息,用于补充上下文 available_actions: 可用的动作信息字典 enable_tool: 是否启用工具调用 - enable_timeout: 是否启用超时处理 - + Returns: Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt) """ @@ -177,13 +199,12 @@ class DefaultReplyer: # 3. 构建 Prompt with Timer("构建Prompt", {}): # 内部计时器,可选保留 prompt = await self.build_prompt_reply_context( - reply_to = reply_to, + reply_to=reply_to, extra_info=extra_info, available_actions=available_actions, - enable_timeout=enable_timeout, enable_tool=enable_tool, ) - + if not prompt: logger.warning("构建prompt失败,跳过回复生成") return False, None, None @@ -194,26 +215,8 @@ class DefaultReplyer: model_name = "unknown_model" try: - with Timer("LLM生成", {}): # 内部计时器,可选保留 - # 加权随机选择一个模型配置 - selected_model_config = self._select_weighted_model_config() - logger.info( - f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})" - ) - - express_model = LLMRequest( - model=selected_model_config, - request_type=self.request_type, - ) - - if global_config.debug.show_prompt: - logger.info(f"\n{prompt}\n") - else: - logger.debug(f"\n{prompt}\n") - - content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt) - - logger.debug(f"replyer生成内容: {content}") + content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt) + logger.debug(f"replyer生成内容: {content}") except Exception as llm_e: # 精简报错信息 @@ -232,22 +235,21 @@ class DefaultReplyer: raw_reply: str = "", reason: str = "", reply_to: str = "", - ) -> Tuple[bool, Optional[str]]: + return_prompt: bool = False, + ) -> Tuple[bool, Optional[str], Optional[str]]: """ 表达器 (Expressor): 负责重写和优化回复文本。 - + Args: raw_reply: 原始回复内容 reason: 回复原因 reply_to: 回复对象,格式为 "发送者:消息内容" relation_info: 关系信息 - + Returns: Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容) """ try: - - with Timer("构建Prompt", {}): # 内部计时器,可选保留 prompt = await self.build_prompt_rewrite_context( raw_reply=raw_reply, @@ -260,36 +262,23 @@ class DefaultReplyer: model_name = "unknown_model" if not prompt: logger.error("Prompt 构建失败,无法生成回复。") - return False, None + return False, None, None try: - with Timer("LLM生成", {}): # 内部计时器,可选保留 - # 加权随机选择一个模型配置 - selected_model_config = self._select_weighted_model_config() - logger.info( - f"使用模型重写回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})" - ) - - express_model = LLMRequest( - model=selected_model_config, - request_type=self.request_type, - ) - - content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt) - - logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n") + content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt) + logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n") except Exception as llm_e: # 精简报错信息 logger.error(f"LLM 生成失败: {llm_e}") - return False, None # LLM 调用失败则无法生成回复 + return False, None, prompt if return_prompt else None # LLM 调用失败则无法生成回复 - return True, content + return True, content, prompt if return_prompt else None except Exception as e: logger.error(f"回复生成意外失败: {e}") traceback.print_exc() - return False, None + return False, None, prompt if return_prompt else None async def build_relation_info(self, reply_to: str = ""): if not global_config.relationship.enable_relationship: @@ -313,11 +302,11 @@ class DefaultReplyer: async def build_expression_habits(self, chat_history: str, target: str) -> str: """构建表达习惯块 - + Args: chat_history: 聊天历史记录 target: 目标消息内容 - + Returns: str: 表达习惯信息字符串 """ @@ -366,17 +355,15 @@ class DefaultReplyer: if style_habits_str.strip() and grammar_habits_str.strip(): expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:" - expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}" - - return expression_habits_block + return f"{expression_habits_title}\n{expression_habits_block}" async def build_memory_block(self, chat_history: str, target: str) -> str: """构建记忆块 - + Args: chat_history: 聊天历史记录 target: 目标消息内容 - + Returns: str: 记忆信息字符串 """ @@ -441,7 +428,7 @@ class DefaultReplyer: for tool_result in tool_results: tool_name = tool_result.get("tool_name", "unknown") content = tool_result.get("content", "") - result_type = tool_result.get("type", "info") + result_type = tool_result.get("type", "tool_result") tool_info_str += f"- 【{tool_name}】{result_type}: {content}\n" @@ -459,10 +446,10 @@ class DefaultReplyer: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: """解析回复目标消息 - + Args: target_message: 目标消息,格式为 "发送者:消息内容" 或 "发送者:消息内容" - + Returns: Tuple[str, str]: (发送者名称, 消息内容) """ @@ -481,10 +468,10 @@ class DefaultReplyer: async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: """构建关键词反应提示 - + Args: target: 目标消息内容 - + Returns: str: 关键词反应提示字符串 """ @@ -523,11 +510,11 @@ class DefaultReplyer: async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]: """计时并运行异步任务的辅助函数 - + Args: coroutine: 要执行的协程 name: 任务名称 - + Returns: Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时) """ @@ -537,7 +524,9 @@ class DefaultReplyer: duration = end_time - start_time return name, result, duration - def build_s4u_chat_history_prompts(self, message_list_before_now: List[Dict[str, Any]], target_user_id: str) -> Tuple[str, str]: + def build_s4u_chat_history_prompts( + self, message_list_before_now: List[Dict[str, Any]], target_user_id: str + ) -> Tuple[str, str]: """ 构建 s4u 风格的分离对话 prompt @@ -612,7 +601,7 @@ class DefaultReplyer: chat_info: str, ) -> Any: """构建 mai_think 上下文信息 - + Args: chat_id: 聊天ID memory_block: 记忆块内容 @@ -625,7 +614,7 @@ class DefaultReplyer: sender: 发送者名称 target: 目标消息内容 chat_info: 聊天信息 - + Returns: Any: mai_think 实例 """ @@ -647,19 +636,17 @@ class DefaultReplyer: reply_to: str, extra_info: str = "", available_actions: Optional[Dict[str, ActionInfo]] = None, - enable_timeout: bool = False, enable_tool: bool = True, ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if """ 构建回复器上下文 Args: - reply_data: 回复数据 - replay_data 包含以下字段: - structured_info: 结构化信息,一般是工具调用获得的信息 - reply_to: 回复对象 - extra_info/extra_info_block: 额外信息 + reply_to: 回复对象,格式为 "发送者:消息内容" + extra_info: 额外信息,用于补充上下文 available_actions: 可用动作 + enable_timeout: 是否启用超时处理 + enable_tool: 是否启用工具调用 Returns: str: 构建好的上下文 @@ -727,7 +714,7 @@ class DefaultReplyer: self._time_and_run_task( self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info" ), - self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"), + self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, reply_to), "prompt_info"), ) # 任务名称中英文映射 @@ -877,7 +864,7 @@ class DefaultReplyer: raw_reply: str, reason: str, reply_to: str, - ) -> str: + ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if chat_stream = self.chat_stream chat_id = chat_stream.stream_id is_group_chat = bool(chat_stream.group_info) @@ -1011,6 +998,81 @@ class DefaultReplyer: display_message=display_message, ) + async def llm_generate_content(self, prompt: str): + with Timer("LLM生成", {}): # 内部计时器,可选保留 + # 加权随机选择一个模型配置 + selected_model_config, weight = self._select_weighted_models_config() + logger.info(f"使用模型集生成回复: {selected_model_config} (选中概率: {weight})") + + express_model = LLMRequest(model_set=selected_model_config, request_type=self.request_type) + + if global_config.debug.show_prompt: + logger.info(f"\n{prompt}\n") + else: + logger.debug(f"\n{prompt}\n") + + content, (reasoning_content, model_name, tool_calls) = await express_model.generate_response_async(prompt) + + logger.debug(f"replyer生成内容: {content}") + return content, reasoning_content, model_name, tool_calls + + async def get_prompt_info(self, message: str, reply_to: str): + related_info = "" + start_time = time.time() + from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool + if not reply_to: + logger.debug("没有回复对象,跳过获取知识库内容") + return "" + sender, content = self._parse_reply_target(reply_to) + if not content: + logger.debug("回复对象内容为空,跳过获取知识库内容") + return "" + logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") + # 从LPMM知识库获取知识 + try: + # 检查LPMM知识库是否启用 + if not global_config.lpmm_knowledge.enable: + logger.debug("LPMM知识库未启用,跳过获取知识库内容") + return "" + time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + + bot_name = global_config.bot.nickname + + prompt = await global_prompt_manager.format_prompt( + "lpmm_get_knowledge_prompt", + bot_name=bot_name, + time_now=time_now, + chat_history=message, + sender=sender, + target_message=content, + ) + _, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools( + prompt, + model_config=model_config.model_task_config.tool_use, + tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()], + ) + if tool_calls: + result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool()) + end_time = time.time() + if not result or not result.get("content"): + logger.debug("从LPMM知识库获取知识失败,返回空知识...") + return "" + found_knowledge_from_lpmm = result.get("content", "") + logger.debug( + f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" + ) + related_info += found_knowledge_from_lpmm + logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") + logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") + + return f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" + else: + logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") + return "" + except Exception as e: + logger.error(f"获取知识库内容时发生异常: {str(e)}") + return "" + def weighted_sample_no_replacement(items, weights, k) -> list: """ @@ -1046,38 +1108,4 @@ def weighted_sample_no_replacement(items, weights, k) -> list: return selected -async def get_prompt_info(message: str, threshold: float): - related_info = "" - start_time = time.time() - - logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - # 从LPMM知识库获取知识 - try: - # 检查LPMM知识库是否启用 - if qa_manager is None: - logger.debug("LPMM知识库已禁用,跳过知识获取") - return "" - - found_knowledge_from_lpmm = await qa_manager.get_knowledge(message) - - end_time = time.time() - if found_knowledge_from_lpmm is not None: - logger.debug( - f"从LPMM知识库获取知识,相关信息:{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}" - ) - related_info += found_knowledge_from_lpmm - logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}秒") - logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}") - - # 格式化知识信息 - formatted_prompt_info = f"你有以下这些**知识**:\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n" - return formatted_prompt_info - else: - logger.debug("从LPMM知识库获取知识失败,可能是从未导入过知识,返回空知识...") - return "" - except Exception as e: - logger.error(f"获取知识库内容时发生异常: {str(e)}") - return "" - - init_prompt() diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 3f1c731b..bb3a313b 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,6 +1,7 @@ -from typing import Dict, Any, Optional, List +from typing import Dict, Optional, List, Tuple from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer @@ -15,7 +16,7 @@ class ReplyerManager: self, chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer]: """ @@ -49,7 +50,7 @@ class ReplyerManager: # model_configs 只在此时(初始化时)生效 replyer = DefaultReplyer( chat_stream=target_stream, - model_configs=model_configs, # 可以是None,此时使用默认模型 + model_set_with_weight=model_set_with_weight, # 可以是None,此时使用默认模型 request_type=request_type, ) self._repliers[stream_id] = replyer diff --git a/src/chat/utils/json_utils.py b/src/chat/utils/json_utils.py deleted file mode 100644 index 892deac4..00000000 --- a/src/chat/utils/json_utils.py +++ /dev/null @@ -1,223 +0,0 @@ -import ast -import json -import logging - -from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional - -# 定义类型变量用于泛型类型提示 -T = TypeVar("T") - -# 获取logger -logger = logging.getLogger("json_utils") - - -def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: - """ - 安全地解析JSON字符串,出错时返回默认值 - 现在尝试处理单引号和标准JSON - - 参数: - json_str: 要解析的JSON字符串 - default_value: 解析失败时返回的默认值 - - 返回: - 解析后的Python对象,或在解析失败时返回default_value - """ - if not json_str or not isinstance(json_str, str): - logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}") - return default_value - - try: - # 尝试标准的 JSON 解析 - return json.loads(json_str) - except json.JSONDecodeError: - # 如果标准解析失败,尝试用 ast.literal_eval 解析 - try: - # logger.debug(f"标准JSON解析失败,尝试用 ast.literal_eval 解析: {json_str[:100]}...") - result = ast.literal_eval(json_str) - if isinstance(result, dict): - return result - logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}") - return default_value - except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e: - logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...") - return default_value - except Exception as e: - logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...") - return default_value - except Exception as e: - logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...") - return default_value - - -def extract_tool_call_arguments( - tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None -) -> Dict[str, Any]: - """ - 从LLM工具调用对象中提取参数 - - 参数: - tool_call: 工具调用对象字典 - default_value: 解析失败时返回的默认值 - - 返回: - 解析后的参数字典,或在解析失败时返回default_value - """ - default_result = default_value or {} - - if not tool_call or not isinstance(tool_call, dict): - logger.error(f"无效的工具调用对象: {tool_call}") - return default_result - - try: - # 提取function参数 - function_data = tool_call.get("function", {}) - if not function_data or not isinstance(function_data, dict): - logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") - return default_result - - if arguments_str := function_data.get("arguments", "{}"): - # 解析JSON - return safe_json_loads(arguments_str, default_result) - else: - return default_result - - except Exception as e: - logger.error(f"提取工具调用参数时出错: {e}") - return default_result - - -def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str: - """ - 安全地将Python对象序列化为JSON字符串 - - 参数: - obj: 要序列化的Python对象 - default_value: 序列化失败时返回的默认值 - ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符) - pretty: 是否美化输出JSON - - 返回: - 序列化后的JSON字符串,或在序列化失败时返回default_value - """ - try: - indent = 2 if pretty else None - return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent) - except TypeError as e: - logger.error(f"JSON序列化失败(类型错误): {e}") - return default_value - except Exception as e: - logger.error(f"JSON序列化过程中发生意外错误: {e}") - return default_value - - -def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]: - """ - 标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式 - - 参数: - response: 原始LLM响应 - log_prefix: 日志前缀 - - 返回: - 元组 (成功标志, 标准化后的响应列表, 错误消息) - """ - - logger.debug(f"{log_prefix}原始人 LLM响应: {response}") - - # 检查是否为None - if response is None: - return False, [], "LLM响应为None" - - # 记录原始类型 - logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}") - - # 将元组转换为列表 - if isinstance(response, tuple): - logger.debug(f"{log_prefix}将元组响应转换为列表") - response = list(response) - - # 确保是列表类型 - if not isinstance(response, list): - return False, [], f"无法处理的LLM响应类型: {type(response).__name__}" - - # 处理工具调用部分(如果存在) - if len(response) == 3: - content, reasoning, tool_calls = response - - # 将工具调用部分转换为列表(如果是元组) - if isinstance(tool_calls, tuple): - logger.debug(f"{log_prefix}将工具调用元组转换为列表") - tool_calls = list(tool_calls) - response[2] = tool_calls - - return True, response, "" - - -def process_llm_tool_calls( - tool_calls: List[Dict[str, Any]], log_prefix: str = "" -) -> Tuple[bool, List[Dict[str, Any]], str]: - """ - 处理并验证LLM响应中的工具调用列表 - - 参数: - tool_calls: 从LLM响应中直接获取的工具调用列表 - log_prefix: 日志前缀 - - 返回: - 元组 (成功标志, 验证后的工具调用列表, 错误消息) - """ - - # 如果列表为空,表示没有工具调用,这不是错误 - if not tool_calls: - return True, [], "工具调用列表为空" - - # 验证每个工具调用的格式 - valid_tool_calls = [] - for i, tool_call in enumerate(tool_calls): - if not isinstance(tool_call, dict): - logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}") - continue - - # 检查基本结构 - if tool_call.get("type") != "function": - logger.warning( - f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}" - ) - continue - - if "function" not in tool_call or not isinstance(tool_call.get("function"), dict): - logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}") - continue - - func_details = tool_call["function"] - if "name" not in func_details or not isinstance(func_details.get("name"), str): - logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}") - continue - - # 验证参数 'arguments' - args_value = func_details.get("arguments") - - # 1. 检查 arguments 是否存在且是字符串 - if args_value is None or not isinstance(args_value, str): - logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}") - continue - - # 2. 尝试安全地解析 arguments 字符串 - parsed_args = safe_json_loads(args_value, None) - - # 3. 检查解析结果是否为字典 - if parsed_args is None or not isinstance(parsed_args, dict): - logger.warning( - f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, " - f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}" - ) - continue - - # 如果检查通过,将原始的 tool_call 加入有效列表 - valid_tool_calls.append(tool_call) - - if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空 - return False, [], "所有工具调用格式均无效" - - return True, valid_tool_calls, "" diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 3ee4ae7b..0b9ec779 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -11,7 +11,7 @@ from typing import Optional, Tuple, Dict, List, Any from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.chat_stream import get_chat_manager from src.llm_models.utils_model import LLMRequest @@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]: return is_mentioned, reply_probability -async def get_embedding(text, request_type="embedding"): +async def get_embedding(text, request_type="embedding") -> Optional[List[float]]: """获取文本的embedding向量""" - # TODO: API-Adapter修改标记 - llm = LLMRequest(model=global_config.model.embedding, request_type=request_type) - # return llm.get_embedding_sync(text) + llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type) try: - embedding = await llm.get_embedding(text) + embedding, _ = await llm.get_embedding(text) except Exception as e: logger.error(f"获取embedding失败: {str(e)}") embedding = None diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 7f14aa6d..fcf1c717 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -14,7 +14,7 @@ from rich.traceback import install from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import Images, ImageDescriptions -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest install(extra_lines=3) @@ -37,7 +37,7 @@ class ImageManager: self._ensure_image_dir() self._initialized = True - self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") + self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") try: db.connect(reuse_if_open=True) @@ -107,6 +107,7 @@ class ImageManager: # 优先使用EmojiManager查询已注册表情包的描述 try: from src.chat.emoji_system.emoji_manager import get_emoji_manager + emoji_manager = get_emoji_manager() cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash) if cached_emoji_description: @@ -116,13 +117,12 @@ class ImageManager: logger.debug(f"查询EmojiManager时出错: {e}") # 查询ImageDescriptions表的缓存描述 - cached_description = self._get_description_from_db(image_hash, "emoji") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "emoji"): logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") return f"[表情包:{cached_description}]" # === 二步走识别流程 === - + # 第一步:VLM视觉分析 - 生成详细描述 if image_format in ["gif", "GIF"]: image_base64_processed = self.transform_gif(image_base64) @@ -130,10 +130,16 @@ class ImageManager: logger.warning("GIF转换失败,无法获取描述") return "[表情包(GIF处理失败)]" vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg") + detailed_description, _ = await self.vlm.generate_response_for_image( + vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300 + ) else: - vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" - detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format) + vlm_prompt = ( + "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" + ) + detailed_description, _ = await self.vlm.generate_response_for_image( + vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) if detailed_description is None: logger.warning("VLM未能生成表情包详细描述") @@ -150,31 +156,32 @@ class ImageManager: 3. 输出简短精准,不要解释 4. 如果有多个词用逗号分隔 """ - + # 使用较低温度确保输出稳定 - emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji") - emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt) + emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji") + emotion_result, _ = await emotion_llm.generate_response_async( + emotion_prompt, temperature=0.3, max_tokens=50 + ) if emotion_result is None: logger.warning("LLM未能生成情感标签,使用详细描述的前几个词") # 降级处理:从详细描述中提取关键词 import jieba + words = list(jieba.cut(detailed_description)) emotion_result = ",".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情") # 处理情感结果,取前1-2个最重要的标签 emotions = [e.strip() for e in emotion_result.replace(",", ",").split(",") if e.strip()] final_emotion = emotions[0] if emotions else "表情" - + # 如果有第二个情感且不重复,也包含进来 if len(emotions) > 1 and emotions[1] != emotions[0]: final_emotion = f"{emotions[0]},{emotions[1]}" logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") - # 再次检查缓存,防止并发写入时重复生成 - cached_description = self._get_description_from_db(image_hash, "emoji") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "emoji"): logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") return f"[表情包:{cached_description}]" @@ -242,9 +249,7 @@ class ImageManager: logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...") return f"[图片:{existing_image.description}]" - # 查询ImageDescriptions表的缓存描述 - cached_description = self._get_description_from_db(image_hash, "image") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "image"): logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") return f"[图片:{cached_description}]" @@ -252,7 +257,9 @@ class ImageManager: image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore prompt = global_config.custom_prompt.image_prompt logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) if description is None: logger.warning("AI未能生成图片描述") @@ -445,10 +452,7 @@ class ImageManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - # 检查图片是否已存在 - existing_image = Images.get_or_none(Images.emoji_hash == image_hash) - - if existing_image: + if existing_image := Images.get_or_none(Images.emoji_hash == image_hash): # 检查是否缺少必要字段,如果缺少则创建新记录 if ( not hasattr(existing_image, "image_id") @@ -524,9 +528,7 @@ class ImageManager: # 优先检查是否已有其他相同哈希的图片记录包含描述 existing_with_description = Images.get_or_none( - (Images.emoji_hash == image_hash) & - (Images.description.is_null(False)) & - (Images.description != "") + (Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "") ) if existing_with_description and existing_with_description.id != image.id: logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") @@ -538,8 +540,7 @@ class ImageManager: return # 检查ImageDescriptions表的缓存描述 - cached_description = self._get_description_from_db(image_hash, "image") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "image"): logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") image.description = cached_description image.vlm_processed = True @@ -554,15 +555,15 @@ class ImageManager: # 获取VLM描述 logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)") - description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) + description, _ = await self.vlm.generate_response_for_image( + prompt, image_base64, image_format, temperature=0.4, max_tokens=300 + ) if description is None: logger.warning("VLM未能生成图片描述") description = "无法生成描述" - # 再次检查缓存,防止并发写入时重复生成 - cached_description = self._get_description_from_db(image_hash, "image") - if cached_description: + if cached_description := self._get_description_from_db(image_hash, "image"): logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}") description = cached_description @@ -606,7 +607,7 @@ def image_path_to_base64(image_path: str) -> str: raise FileNotFoundError(f"图片文件不存在: {image_path}") with open(image_path, "rb") as f: - image_data = f.read() - if not image_data: + if image_data := f.read(): + return base64.b64encode(image_data).decode("utf-8") + else: raise IOError(f"读取图片文件失败: {image_path}") - return base64.b64encode(image_data).decode("utf-8") diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index cf71dc56..49ec1079 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -1,35 +1,29 @@ -import base64 - -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from rich.traceback import install + install(extra_lines=3) logger = get_logger("chat_voice") + async def get_voice_text(voice_base64: str) -> str: - """获取音频文件描述""" + """获取音频文件转录文本""" if not global_config.voice.enable_asr: logger.warning("语音识别未启用,无法处理语音消息") return "[语音]" try: - # 解码base64音频数据 - # 确保base64字符串只包含ASCII字符 - if isinstance(voice_base64, str): - voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii") - voice_bytes = base64.b64decode(voice_base64) - _llm = LLMRequest(model=global_config.model.voice, request_type="voice") - text = await _llm.generate_response_for_voice(voice_bytes) + _llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio") + text = await _llm.generate_response_for_voice(voice_base64) if text is None: logger.warning("未能生成语音文本") return "[语音(文本生成失败)]" - + logger.debug(f"描述是{text}") return f"[语音:{text}]" except Exception as e: logger.error(f"语音转文字失败: {str(e)}") return "[语音]" - diff --git a/src/chat/willing/mode_mxp.py b/src/chat/willing/mode_mxp.py index 5a13a628..a249cb6f 100644 --- a/src/chat/willing/mode_mxp.py +++ b/src/chat/willing/mode_mxp.py @@ -19,13 +19,13 @@ Mxp 模式:梦溪畔独家赞助 下下策是询问一个菜鸟(@梦溪畔) """ -from .willing_manager import BaseWillingManager from typing import Dict import asyncio import time import math from src.chat.message_receive.chat_stream import ChatStream +from .willing_manager import BaseWillingManager class MxpWillingManager(BaseWillingManager): diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 1d0b8a39..d2b3acce 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -281,20 +281,6 @@ class Memory(BaseModel): table_name = "memory" -class Knowledges(BaseModel): - """ - 用于存储知识库条目的模型。 - """ - - content = TextField() # 知识内容的文本 - embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表 - # 可以添加其他元数据字段,如 source, create_time 等 - - class Meta: - # database = db # 继承自 BaseModel - table_name = "knowledges" - - class Expression(BaseModel): """ 用于存储表达风格的模型。 @@ -382,7 +368,6 @@ def create_tables(): ImageDescriptions, OnlineTime, PersonInfo, - Knowledges, Expression, ThinkingLog, GraphNodes, # 添加图节点表 @@ -408,7 +393,6 @@ def initialize_database(): ImageDescriptions, OnlineTime, PersonInfo, - Knowledges, Expression, Memory, ThinkingLog, diff --git a/src/common/logger.py b/src/common/logger.py index 78446dec..e27fcb4e 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -334,7 +334,7 @@ MODULE_COLORS = { "llm_models": "\033[36m", # 青色 "remote": "\033[38;5;242m", # 深灰色,更不显眼 "planner": "\033[36m", - "memory": "\033[34m", + "memory": "\033[38;5;117m", # 天蓝色 "hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读 "action_manager": "\033[38;5;208m", # 橙色,不与replyer重复 # 关系系统 @@ -352,7 +352,7 @@ MODULE_COLORS = { "expressor": "\033[38;5;166m", # 橙色 # 专注聊天模块 "replyer": "\033[38;5;166m", # 橙色 - "memory_activator": "\033[34m", # 绿色 + "memory_activator": "\033[38;5;117m", # 天蓝色 # 插件系统 "plugins": "\033[31m", # 红色 "plugin_api": "\033[33m", # 黄色 @@ -451,7 +451,7 @@ class ModuleColoredConsoleRenderer: # 日志级别颜色 self._level_colors = { "debug": "\033[38;5;208m", # 橙色 - "info": "\033[34m", # 蓝色 + "info": "\033[38;5;117m", # 天蓝色 "success": "\033[32m", # 绿色 "warning": "\033[33m", # 黄色 "error": "\033[31m", # 红色 diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py new file mode 100644 index 00000000..5f3398e0 --- /dev/null +++ b/src/config/api_ada_configs.py @@ -0,0 +1,142 @@ +from dataclasses import dataclass, field + +from .config_base import ConfigBase + + +@dataclass +class APIProvider(ConfigBase): + """API提供商配置类""" + + name: str + """API提供商名称""" + + base_url: str + """API基础URL""" + + api_key: str = field(default_factory=str, repr=False) + """API密钥列表""" + + client_type: str = field(default="openai") + """客户端类型(如openai/google等,默认为openai)""" + + max_retry: int = 2 + """最大重试次数(单个模型API调用失败,最多重试的次数)""" + + timeout: int = 10 + """API调用的超时时长(超过这个时长,本次请求将被视为“请求超时”,单位:秒)""" + + retry_interval: int = 10 + """重试间隔(如果API调用失败,重试的间隔时间,单位:秒)""" + + def get_api_key(self) -> str: + return self.api_key + + def __post_init__(self): + """确保api_key在repr中不被显示""" + if not self.api_key: + raise ValueError("API密钥不能为空,请在配置中设置有效的API密钥。") + if not self.base_url: + raise ValueError("API基础URL不能为空,请在配置中设置有效的基础URL。") + if not self.name: + raise ValueError("API提供商名称不能为空,请在配置中设置有效的名称。") + + +@dataclass +class ModelInfo(ConfigBase): + """单个模型信息配置类""" + + model_identifier: str + """模型标识符(用于URL调用)""" + + name: str + """模型名称(用于模块调用)""" + + api_provider: str + """API提供商(如OpenAI、Azure等)""" + + price_in: float = field(default=0.0) + """每M token输入价格""" + + price_out: float = field(default=0.0) + """每M token输出价格""" + + force_stream_mode: bool = field(default=False) + """是否强制使用流式输出模式""" + + extra_params: dict = field(default_factory=dict) + """额外参数(用于API调用时的额外配置)""" + + def __post_init__(self): + if not self.model_identifier: + raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。") + if not self.name: + raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。") + if not self.api_provider: + raise ValueError("API提供商不能为空,请在配置中设置有效的API提供商。") + + +@dataclass +class TaskConfig(ConfigBase): + """任务配置类""" + + model_list: list[str] = field(default_factory=list) + """任务使用的模型列表""" + + max_tokens: int = 1024 + """任务最大输出token数""" + + temperature: float = 0.3 + """模型温度""" + + +@dataclass +class ModelTaskConfig(ConfigBase): + """模型配置类""" + + utils: TaskConfig + """组件模型配置""" + + utils_small: TaskConfig + """组件小模型配置""" + + replyer_1: TaskConfig + """normal_chat首要回复模型模型配置""" + + replyer_2: TaskConfig + """normal_chat次要回复模型配置""" + + memory: TaskConfig + """记忆模型配置""" + + emotion: TaskConfig + """情绪模型配置""" + + vlm: TaskConfig + """视觉语言模型配置""" + + voice: TaskConfig + """语音识别模型配置""" + + tool_use: TaskConfig + """专注工具使用模型配置""" + + planner: TaskConfig + """规划模型配置""" + + embedding: TaskConfig + """嵌入模型配置""" + + lpmm_entity_extract: TaskConfig + """LPMM实体提取模型配置""" + + lpmm_rdf_build: TaskConfig + """LPMM RDF构建模型配置""" + + lpmm_qa: TaskConfig + """LPMM问答模型配置""" + + def get_task(self, task_name: str) -> TaskConfig: + """获取指定任务的配置""" + if hasattr(self, task_name): + return getattr(self, task_name) + raise ValueError(f"任务 '{task_name}' 未找到对应的配置") diff --git a/src/config/auto_update.py b/src/config/auto_update.py deleted file mode 100644 index e6471e80..00000000 --- a/src/config/auto_update.py +++ /dev/null @@ -1,162 +0,0 @@ -import shutil -import tomlkit -from tomlkit.items import Table, KeyType -from pathlib import Path -from datetime import datetime - - -def get_key_comment(toml_table, key): - # 获取key的注释(如果有) - if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"): - return toml_table.trivia.comment - if hasattr(toml_table, "value") and isinstance(toml_table.value, dict): - item = toml_table.value.get(key) - if item is not None and hasattr(item, "trivia"): - return item.trivia.comment - if hasattr(toml_table, "keys"): - for k in toml_table.keys(): - if isinstance(k, KeyType) and k.key == key: - return k.trivia.comment - return None - - -def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None): - # 递归比较两个dict,找出新增和删减项,收集注释 - if path is None: - path = [] - if logs is None: - logs = [] - if new_comments is None: - new_comments = {} - if old_comments is None: - old_comments = {} - # 新增项 - for key in new: - if key == "version": - continue - if key not in old: - comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") - elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): - compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs) - # 删减项 - for key in old: - if key == "version": - continue - if key not in new: - comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") - return logs - - -def update_config(): - print("开始更新配置文件...") - # 获取根目录路径 - root_dir = Path(__file__).parent.parent.parent.parent - template_dir = root_dir / "template" - config_dir = root_dir / "config" - old_config_dir = config_dir / "old" - - # 创建old目录(如果不存在) - old_config_dir.mkdir(exist_ok=True) - - # 定义文件路径 - template_path = template_dir / "bot_config_template.toml" - old_config_path = config_dir / "bot_config.toml" - new_config_path = config_dir / "bot_config.toml" - - # 读取旧配置文件 - old_config = {} - if old_config_path.exists(): - print(f"发现旧配置文件: {old_config_path}") - with open(old_config_path, "r", encoding="utf-8") as f: - old_config = tomlkit.load(f) - - # 生成带时间戳的新文件名 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml" - - # 移动旧配置文件到old目录 - shutil.move(old_config_path, old_backup_path) - print(f"已备份旧配置文件到: {old_backup_path}") - - # 复制模板文件到配置目录 - print(f"从模板文件创建新配置: {template_path}") - shutil.copy2(template_path, new_config_path) - - # 读取新配置文件 - with open(new_config_path, "r", encoding="utf-8") as f: - new_config = tomlkit.load(f) - - # 检查version是否相同 - if old_config and "inner" in old_config and "inner" in new_config: - old_version = old_config["inner"].get("version") # type: ignore - new_version = new_config["inner"].get("version") # type: ignore - if old_version and new_version and old_version == new_version: - print(f"检测到版本号相同 (v{old_version}),跳过更新") - # 如果version相同,恢复旧配置文件并返回 - shutil.move(old_backup_path, old_config_path) # type: ignore - return - else: - print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}") - - # 输出新增和删减项及注释 - if old_config: - print("配置项变动如下:") - logs = compare_dicts(new_config, old_config) - if logs: - for log in logs: - print(log) - else: - print("无新增或删减项") - - # 递归更新配置 - def update_dict(target, source): - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - if isinstance(value, dict) and isinstance(target[key], (dict, Table)): - update_dict(target[key], value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - if not value: - target[key] = tomlkit.array() - else: - # 特殊处理正则表达式数组和包含正则表达式的结构 - if key == "ban_msgs_regex": - # 直接使用原始值,不进行额外处理 - target[key] = value - elif key == "regex_rules": - # 对于regex_rules,需要特殊处理其中的regex字段 - target[key] = value - else: - # 检查是否包含正则表达式相关的字典项 - contains_regex = False - if value and isinstance(value[0], dict) and "regex" in value[0]: - contains_regex = True - - target[key] = value if contains_regex else tomlkit.array(str(value)) - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - - # 将旧配置的值更新到新配置中 - print("开始合并新旧配置...") - update_dict(new_config, old_config) - - # 保存更新后的配置(保留注释和格式) - with open(new_config_path, "w", encoding="utf-8") as f: - f.write(tomlkit.dumps(new_config)) - print("配置文件更新完成") - - -if __name__ == "__main__": - update_config() diff --git a/src/config/config.py b/src/config/config.py index 805a17d4..368adaa5 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -1,12 +1,14 @@ import os import tomlkit import shutil +import sys from datetime import datetime from tomlkit import TOMLDocument from tomlkit.items import Table, KeyType from dataclasses import field, dataclass from rich.traceback import install +from typing import List, Optional from src.common.logger import get_logger from src.config.config_base import ConfigBase @@ -25,7 +27,6 @@ from src.config.official_configs import ( ResponseSplitterConfig, TelemetryConfig, ExperimentalConfig, - ModelConfig, MessageReceiveConfig, MaimMessageConfig, LPMMKnowledgeConfig, @@ -36,6 +37,13 @@ from src.config.official_configs import ( CustomPromptConfig, ) +from .api_ada_configs import ( + ModelTaskConfig, + ModelInfo, + APIProvider, +) + + install(extra_lines=3) @@ -49,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.9.1" +MMC_VERSION = "0.10.0-snapshot.4" def get_key_comment(toml_table, key): @@ -79,7 +87,7 @@ def compare_dicts(new, old, path=None, logs=None): continue if key not in old: comment = get_key_comment(new, key) - logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): compare_dicts(new[key], old[key], path + [str(key)], logs) # 删减项 @@ -88,7 +96,7 @@ def compare_dicts(new, old, path=None, logs=None): continue if key not in new: comment = get_key_comment(old, key) - logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else '无'}") + logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or '无'}") return logs @@ -123,67 +131,110 @@ def compare_default_values(new, old, path=None, logs=None, changes=None): if key in old: if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): compare_default_values(new[key], old[key], path + [str(key)], logs, changes) - else: - # 只要值发生变化就记录 - if new[key] != old[key]: - logs.append( - f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}" - ) - changes.append((path + [str(key)], old[key], new[key])) + elif new[key] != old[key]: + logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}") + changes.append((path + [str(key)], old[key], new[key])) return logs, changes -def update_config(): +def _get_version_from_toml(toml_path) -> Optional[str]: + """从TOML文件中获取版本号""" + if not os.path.exists(toml_path): + return None + with open(toml_path, "r", encoding="utf-8") as f: + doc = tomlkit.load(f) + if "inner" in doc and "version" in doc["inner"]: # type: ignore + return doc["inner"]["version"] # type: ignore + return None + + +def _version_tuple(v): + """将版本字符串转换为元组以便比较""" + if v is None: + return (0,) + return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) + + +def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): + """ + 将source字典的值更新到target字典中(如果target中存在相同的键) + """ + for key, value in source.items(): + # 跳过version字段的更新 + if key == "version": + continue + if key in target: + target_value = target[key] + if isinstance(value, dict) and isinstance(target_value, (dict, Table)): + _update_dict(target_value, value) + else: + try: + # 对数组类型进行特殊处理 + if isinstance(value, list): + # 如果是空数组,确保它保持为空数组 + target[key] = tomlkit.array(str(value)) if value else tomlkit.array() + else: + # 其他类型使用item方法创建新值 + target[key] = tomlkit.item(value) + except (TypeError, ValueError): + # 如果转换失败,直接赋值 + target[key] = value + + +def _update_config_generic(config_name: str, template_name: str): + """ + 通用的配置文件更新函数 + + Args: + config_name: 配置文件名(不含扩展名),如 'bot_config' 或 'model_config' + template_name: 模板文件名(不含扩展名),如 'bot_config_template' 或 'model_config_template' + """ # 获取根目录路径 old_config_dir = os.path.join(CONFIG_DIR, "old") compare_dir = os.path.join(TEMPLATE_DIR, "compare") # 定义文件路径 - template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml") - old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") - new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") - compare_path = os.path.join(compare_dir, "bot_config_template.toml") + template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml") + old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") + new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml") + compare_path = os.path.join(compare_dir, f"{template_name}.toml") # 创建compare目录(如果不存在) os.makedirs(compare_dir, exist_ok=True) - # 处理compare下的模板文件 - def get_version_from_toml(toml_path): - if not os.path.exists(toml_path): - return None - with open(toml_path, "r", encoding="utf-8") as f: - doc = tomlkit.load(f) - if "inner" in doc and "version" in doc["inner"]: # type: ignore - return doc["inner"]["version"] # type: ignore - return None + template_version = _get_version_from_toml(template_path) + compare_version = _get_version_from_toml(compare_path) - template_version = get_version_from_toml(template_path) - compare_version = get_version_from_toml(compare_path) + # 检查配置文件是否存在 + if not os.path.exists(old_config_path): + logger.info(f"{config_name}.toml配置文件不存在,从模板创建新配置") + os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 + shutil.copy2(template_path, old_config_path) # 复制模板文件 + logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}") + # 新创建配置文件,退出 + sys.exit(0) - def version_tuple(v): - if v is None: - return (0,) - return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split(".")) + compare_config = None + new_config = None + old_config = None # 先读取 compare 下的模板(如果有),用于默认值变动检测 if os.path.exists(compare_path): with open(compare_path, "r", encoding="utf-8") as f: compare_config = tomlkit.load(f) - else: - compare_config = None # 读取当前模板 with open(template_path, "r", encoding="utf-8") as f: new_config = tomlkit.load(f) # 检查默认值变化并处理(只有 compare_config 存在时才做) - if compare_config is not None: + if compare_config: # 读取旧配置 with open(old_config_path, "r", encoding="utf-8") as f: old_config = tomlkit.load(f) logs, changes = compare_default_values(new_config, compare_config) if logs: - logger.info("检测到模板默认值变动如下:") + logger.info(f"检测到{config_name}模板默认值变动如下:") for log in logs: logger.info(log) # 检查旧配置是否等于旧默认值,如果是则更新为新默认值 @@ -192,33 +243,20 @@ def update_config(): if old_value == old_default: set_value_by_path(old_config, path, new_default) logger.info( - f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" + f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" ) else: - logger.info("未检测到模板默认值变动") - # 保存旧配置的变更(后续合并逻辑会用到 old_config) - else: - old_config = None + logger.info(f"未检测到{config_name}模板默认值变动") # 检查 compare 下没有模板,或新模板版本更高,则复制 if not os.path.exists(compare_path): shutil.copy2(template_path, compare_path) - logger.info(f"已将模板文件复制到: {compare_path}") + logger.info(f"已将{config_name}模板文件复制到: {compare_path}") + elif _version_tuple(template_version) > _version_tuple(compare_version): + shutil.copy2(template_path, compare_path) + logger.info(f"{config_name}模板版本较新,已替换compare下的模板: {compare_path}") else: - if version_tuple(template_version) > version_tuple(compare_version): - shutil.copy2(template_path, compare_path) - logger.info(f"模板版本较新,已替换compare下的模板: {compare_path}") - else: - logger.debug(f"compare下的模板版本不低于当前模板,无需替换: {compare_path}") - - # 检查配置文件是否存在 - if not os.path.exists(old_config_path): - logger.info("配置文件不存在,从模板创建新配置") - os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹 - shutil.copy2(template_path, old_config_path) # 复制模板文件 - logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}") - # 如果是新创建的配置文件,直接返回 - quit() + logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}") # 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次) if old_config is None: @@ -226,79 +264,60 @@ def update_config(): old_config = tomlkit.load(f) # new_config 已经读取 - # 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用 - # 检查version是否相同 if old_config and "inner" in old_config and "inner" in new_config: old_version = old_config["inner"].get("version") # type: ignore new_version = new_config["inner"].get("version") # type: ignore if old_version and new_version and old_version == new_version: - logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") + logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新") return else: logger.info( - f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" + f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" ) else: - logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") + logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新") # 创建old目录(如果不存在) os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml") + old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml") # 移动旧配置文件到old目录 shutil.move(old_config_path, old_backup_path) - logger.info(f"已备份旧配置文件到: {old_backup_path}") + logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}") # 复制模板文件到配置目录 shutil.copy2(template_path, new_config_path) - logger.info(f"已创建新配置文件: {new_config_path}") + logger.info(f"已创建新{config_name}配置文件: {new_config_path}") # 输出新增和删减项及注释 if old_config: - logger.info("配置项变动如下:\n----------------------------------------") - logs = compare_dicts(new_config, old_config) - if logs: + logger.info(f"{config_name}配置项变动如下:\n----------------------------------------") + if logs := compare_dicts(new_config, old_config): for log in logs: logger.info(log) else: logger.info("无新增或删减项") - def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): - """ - 将source字典的值更新到target字典中(如果target中存在相同的键) - """ - for key, value in source.items(): - # 跳过version字段的更新 - if key == "version": - continue - if key in target: - target_value = target[key] - if isinstance(value, dict) and isinstance(target_value, (dict, Table)): - update_dict(target_value, value) - else: - try: - # 对数组类型进行特殊处理 - if isinstance(value, list): - # 如果是空数组,确保它保持为空数组 - target[key] = tomlkit.array(str(value)) if value else tomlkit.array() - else: - # 其他类型使用item方法创建新值 - target[key] = tomlkit.item(value) - except (TypeError, ValueError): - # 如果转换失败,直接赋值 - target[key] = value - # 将旧配置的值更新到新配置中 - logger.info("开始合并新旧配置...") - update_dict(new_config, old_config) + logger.info(f"开始合并{config_name}新旧配置...") + _update_dict(new_config, old_config) # 保存更新后的配置(保留注释和格式) with open(new_config_path, "w", encoding="utf-8") as f: f.write(tomlkit.dumps(new_config)) - logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") - quit() + logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") + + +def update_config(): + """更新bot_config.toml配置文件""" + _update_config_generic("bot_config", "bot_config_template") + + +def update_model_config(): + """更新model_config.toml配置文件""" + _update_config_generic("model_config", "model_config_template") @dataclass @@ -323,7 +342,6 @@ class Config(ConfigBase): response_splitter: ResponseSplitterConfig telemetry: TelemetryConfig experimental: ExperimentalConfig - model: ModelConfig maim_message: MaimMessageConfig lpmm_knowledge: LPMMKnowledgeConfig tool: ToolConfig @@ -331,11 +349,69 @@ class Config(ConfigBase): custom_prompt: CustomPromptConfig voice: VoiceConfig + +@dataclass +class APIAdapterConfig(ConfigBase): + """API Adapter配置类""" + + models: List[ModelInfo] + """模型列表""" + + model_task_config: ModelTaskConfig + """模型任务配置""" + + api_providers: List[APIProvider] = field(default_factory=list) + """API提供商列表""" + + def __post_init__(self): + if not self.models: + raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。") + if not self.api_providers: + raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。") + + # 检查API提供商名称是否重复 + provider_names = [provider.name for provider in self.api_providers] + if len(provider_names) != len(set(provider_names)): + raise ValueError("API提供商名称存在重复,请检查配置文件。") + + # 检查模型名称是否重复 + model_names = [model.name for model in self.models] + if len(model_names) != len(set(model_names)): + raise ValueError("模型名称存在重复,请检查配置文件。") + + self.api_providers_dict = {provider.name: provider for provider in self.api_providers} + self.models_dict = {model.name: model for model in self.models} + + for model in self.models: + if not model.model_identifier: + raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空") + if not model.api_provider or model.api_provider not in self.api_providers_dict: + raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在") + + def get_model_info(self, model_name: str) -> ModelInfo: + """根据模型名称获取模型信息""" + if not model_name: + raise ValueError("模型名称不能为空") + if model_name not in self.models_dict: + raise KeyError(f"模型 '{model_name}' 不存在") + return self.models_dict[model_name] + + def get_provider(self, provider_name: str) -> APIProvider: + """根据提供商名称获取API提供商信息""" + if not provider_name: + raise ValueError("API提供商名称不能为空") + if provider_name not in self.api_providers_dict: + raise KeyError(f"API提供商 '{provider_name}' 不存在") + return self.api_providers_dict[provider_name] + + def load_config(config_path: str) -> Config: """ 加载配置文件 - :param config_path: 配置文件路径 - :return: Config对象 + Args: + config_path: 配置文件路径 + Returns: + Config对象 """ # 读取配置文件 with open(config_path, "r", encoding="utf-8") as f: @@ -349,18 +425,32 @@ def load_config(config_path: str) -> Config: raise e -def get_config_dir() -> str: +def api_ada_load_config(config_path: str) -> APIAdapterConfig: """ - 获取配置目录 - :return: 配置目录路径 + 加载API适配器配置文件 + Args: + config_path: 配置文件路径 + Returns: + APIAdapterConfig对象 """ - return CONFIG_DIR + # 读取配置文件 + with open(config_path, "r", encoding="utf-8") as f: + config_data = tomlkit.load(f) + + # 创建APIAdapterConfig对象 + try: + return APIAdapterConfig.from_dict(config_data) + except Exception as e: + logger.critical("API适配器配置文件解析失败") + raise e # 获取配置文件路径 logger.info(f"MaiCore当前版本: {MMC_VERSION}") update_config() +update_model_config() logger.info("正在品鉴配置文件...") global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) +model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml")) logger.info("非常的新鲜,非常的美味!") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 2c9f847c..8f34a184 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass, field -from typing import Any, Literal, Optional +from typing import Literal, Optional from src.config.config_base import ConfigBase @@ -598,51 +598,3 @@ class LPMMKnowledgeConfig(ConfigBase): embedding_dimension: int = 1024 """嵌入向量维度,应该与模型的输出维度一致""" - -@dataclass -class ModelConfig(ConfigBase): - """模型配置类""" - - model_max_output_length: int = 800 # 最大回复长度 - - utils: dict[str, Any] = field(default_factory=lambda: {}) - """组件模型配置""" - - utils_small: dict[str, Any] = field(default_factory=lambda: {}) - """组件小模型配置""" - - replyer_1: dict[str, Any] = field(default_factory=lambda: {}) - """normal_chat首要回复模型模型配置""" - - replyer_2: dict[str, Any] = field(default_factory=lambda: {}) - """normal_chat次要回复模型配置""" - - memory: dict[str, Any] = field(default_factory=lambda: {}) - """记忆模型配置""" - - emotion: dict[str, Any] = field(default_factory=lambda: {}) - """情绪模型配置""" - - vlm: dict[str, Any] = field(default_factory=lambda: {}) - """视觉语言模型配置""" - - voice: dict[str, Any] = field(default_factory=lambda: {}) - """语音识别模型配置""" - - tool_use: dict[str, Any] = field(default_factory=lambda: {}) - """专注工具使用模型配置""" - - planner: dict[str, Any] = field(default_factory=lambda: {}) - """规划模型配置""" - - embedding: dict[str, Any] = field(default_factory=lambda: {}) - """嵌入模型配置""" - - lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {}) - """LPMM实体提取模型配置""" - - lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {}) - """LPMM RDF构建模型配置""" - - lpmm_qa: dict[str, Any] = field(default_factory=lambda: {}) - """LPMM问答模型配置""" diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 4c8fcac5..c2655fba 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -4,7 +4,7 @@ import hashlib import time from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager from rich.traceback import install @@ -23,10 +23,7 @@ class Individuality: self.meta_info_file_path = "data/personality/meta.json" self.personality_data_file_path = "data/personality/personality_data.json" - self.model = LLMRequest( - model=global_config.model.utils, - request_type="individuality.compress", - ) + self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress") async def initialize(self) -> None: """初始化个体特征""" @@ -35,7 +32,6 @@ class Individuality: personality_side = global_config.personality.personality_side identity = global_config.personality.identity - person_info_manager = get_person_info_manager() self.bot_person_id = person_info_manager.get_person_id("system", "bot_id") self.name = bot_nickname @@ -85,16 +81,16 @@ class Individuality: bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" else: bot_nickname = "" - + # 从文件获取 short_impression personality, identity = self._get_personality_from_file() - + # 确保short_impression是列表格式且有足够的元素 if not personality or not identity: logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值") personality = "友好活泼" identity = "人类" - + prompt_personality = f"{personality}\n{identity}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" @@ -215,7 +211,7 @@ class Individuality: def _get_personality_from_file(self) -> tuple[str, str]: """从文件获取personality数据 - + Returns: tuple: (personality, identity) """ @@ -226,7 +222,7 @@ class Individuality: def _save_personality_to_file(self, personality: str, identity: str): """保存personality数据到文件 - + Args: personality: 压缩后的人格描述 identity: 压缩后的身份描述 @@ -235,7 +231,7 @@ class Individuality: "personality": personality, "identity": identity, "bot_nickname": self.name, - "last_updated": int(time.time()) + "last_updated": int(time.time()), } self._save_personality_data(personality_data) @@ -269,7 +265,7 @@ class Individuality: 2. 尽量简洁,不超过30字 3. 直接输出压缩后的内容,不要解释""" - response, (_, _) = await self.model.generate_response_async( + response, _ = await self.model.generate_response_async( prompt=prompt, ) @@ -281,7 +277,7 @@ class Individuality: # 压缩失败时使用原始内容 if personality_side: personality_parts.append(personality_side) - + if personality_parts: personality_result = "。".join(personality_parts) else: @@ -308,7 +304,7 @@ class Individuality: 2. 尽量简洁,不超过30字 3. 直接输出压缩后的内容,不要解释""" - response, (_, _) = await self.model.generate_response_async( + response, _ = await self.model.generate_response_async( prompt=prompt, ) diff --git a/src/llm_models/LICENSE b/src/llm_models/LICENSE new file mode 100644 index 00000000..8b3236ed --- /dev/null +++ b/src/llm_models/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 Mai.To.The.Gate + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/src/llm_models/__init__.py b/src/llm_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/llm_models/exceptions.py b/src/llm_models/exceptions.py new file mode 100644 index 00000000..5b04f58c --- /dev/null +++ b/src/llm_models/exceptions.py @@ -0,0 +1,98 @@ +from typing import Any + + +# 常见Error Code Mapping (以OpenAI API为例) +error_code_mapping = { + 400: "参数不正确", + 401: "API-Key错误,认证失败,请检查/config/model_list.toml中的配置是否正确", + 402: "账号余额不足", + 403: "模型拒绝访问,可能需要实名或余额不足", + 404: "Not Found", + 413: "请求体过大,请尝试压缩图片或减少输入内容", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高", +} + + +class NetworkConnectionError(Exception): + """连接异常,常见于网络问题或服务器不可用""" + + def __init__(self): + super().__init__() + + def __str__(self): + return "连接异常,请检查网络连接状态或URL是否正确" + + +class ReqAbortException(Exception): + """请求异常退出,常见于请求被中断或取消""" + + def __init__(self, message: str | None = None): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message or "请求因未知原因异常终止" + + +class RespNotOkException(Exception): + """请求响应异常,见于请求未能成功响应(非 '200 OK')""" + + def __init__(self, status_code: int, message: str | None = None): + super().__init__(message) + self.status_code = status_code + self.message = message + + def __str__(self): + if self.status_code in error_code_mapping: + return error_code_mapping[self.status_code] + elif self.message: + return self.message + else: + return f"未知的异常响应代码:{self.status_code}" + + +class RespParseException(Exception): + """响应解析错误,常见于响应格式不正确或解析方法不匹配""" + + def __init__(self, ext_info: Any, message: str | None = None): + super().__init__(message) + self.ext_info = ext_info + self.message = message + + def __str__(self): + return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法" + + +class PayLoadTooLargeError(Exception): + """自定义异常类,用于处理请求体过大错误""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return "请求体过大,请尝试压缩图片或减少输入内容。" + + +class RequestAbortException(Exception): + """自定义异常类,用于处理请求中断异常""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message + + +class PermissionDeniedException(Exception): + """自定义异常类,用于处理访问拒绝的异常""" + + def __init__(self, message: str): + super().__init__(message) + self.message = message + + def __str__(self): + return self.message diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py new file mode 100644 index 00000000..80f7e115 --- /dev/null +++ b/src/llm_models/model_client/__init__.py @@ -0,0 +1,8 @@ +from src.config.config import model_config + +used_client_types = {provider.client_type for provider in model_config.api_providers} + +if "openai" in used_client_types: + from . import openai_client # noqa: F401 +if "gemini" in used_client_types: + from . import gemini_client # noqa: F401 diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py new file mode 100644 index 00000000..8e8affba --- /dev/null +++ b/src/llm_models/model_client/base_client.py @@ -0,0 +1,172 @@ +import asyncio +from dataclasses import dataclass +from abc import ABC, abstractmethod +from typing import Callable, Any, Optional + +from src.config.api_ada_configs import ModelInfo, APIProvider +from ..payload_content.message import Message +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption, ToolCall + + +@dataclass +class UsageRecord: + """ + 使用记录类 + """ + + model_name: str + """模型名称""" + + provider_name: str + """提供商名称""" + + prompt_tokens: int + """提示token数""" + + completion_tokens: int + """完成token数""" + + total_tokens: int + """总token数""" + + +@dataclass +class APIResponse: + """ + API响应类 + """ + + content: str | None = None + """响应内容""" + + reasoning_content: str | None = None + """推理内容""" + + tool_calls: list[ToolCall] | None = None + """工具调用 [(工具名称, 工具参数), ...]""" + + embedding: list[float] | None = None + """嵌入向量""" + + usage: UsageRecord | None = None + """使用情况 (prompt_tokens, completion_tokens, total_tokens)""" + + raw_data: Any = None + """响应原始数据""" + + +class BaseClient(ABC): + """ + 基础客户端 + """ + + api_provider: APIProvider + + def __init__(self, api_provider: APIProvider): + self.api_provider = api_provider + + @abstractmethod + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Optional[ + Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] + ] = None, + async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, + interrupt_flag: asyncio.Event | None = None, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话体 + :param tool_options: 工具选项(可选,默认为None) + :param max_tokens: 最大token数(可选,默认为1024) + :param temperature: 温度(可选,默认为0.7) + :param response_format: 响应格式(可选,默认为 NotGiven ) + :param stream_response_handler: 流式响应处理函数(可选) + :param async_response_parser: 响应解析函数(可选) + :param interrupt_flag: 中断信号量(可选,默认为None) + :return: (响应文本, 推理文本, 工具调用, 其他数据) + """ + raise NotImplementedError("'get_response' method should be overridden in subclasses") + + @abstractmethod + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + raise NotImplementedError("'get_embedding' method should be overridden in subclasses") + + @abstractmethod + async def get_audio_transcriptions( + self, + model_info: ModelInfo, + audio_base64: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取音频转录 + :param model_info: 模型信息 + :param audio_base64: base64编码的音频数据 + :extra_params: 附加的请求参数 + :return: 音频转录响应 + """ + raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses") + + @abstractmethod + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses") + + +class ClientRegistry: + def __init__(self) -> None: + self.client_registry: dict[str, type[BaseClient]] = {} + + def register_client_class(self, client_type: str): + """ + 注册API客户端类 + Args: + client_class: API客户端类 + """ + + def decorator(cls: type[BaseClient]) -> type[BaseClient]: + if not issubclass(cls, BaseClient): + raise TypeError(f"{cls.__name__} is not a subclass of BaseClient") + self.client_registry[client_type] = cls + return cls + + return decorator + + def get_client_class(self, client_type: str) -> type[BaseClient]: + """ + 获取注册的API客户端类 + Args: + client_type: 客户端类型 + Returns: + type[BaseClient]: 注册的API客户端类 + """ + if client_type not in self.client_registry: + raise KeyError(f"'{client_type}' 类型的 Client 未注册") + return self.client_registry[client_type] + + +client_registry = ClientRegistry() diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py new file mode 100644 index 00000000..e4127029 --- /dev/null +++ b/src/llm_models/model_client/gemini_client.py @@ -0,0 +1,496 @@ +import asyncio +import io +import base64 +from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List + +from google import genai +from google.genai.types import ( + Content, + Part, + FunctionDeclaration, + GenerateContentResponse, + ContentListUnion, + ContentUnion, + ThinkingConfig, + Tool, + GenerateContentConfig, + EmbedContentResponse, + EmbedContentConfig, +) +from google.genai.errors import ( + ClientError, + ServerError, + UnknownFunctionCallArgumentError, + UnsupportedFunctionError, + FunctionInvocationError, +) + +from src.config.api_ada_configs import ModelInfo, APIProvider +from src.common.logger import get_logger + +from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat, RespFormatType +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + +logger = get_logger("Gemini客户端") + + +def _convert_messages( + messages: list[Message], +) -> tuple[ContentListUnion, list[str] | None]: + """ + 转换消息格式 - 将消息转换为Gemini API所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表(和可能存在的system消息) + """ + + def _convert_message_item(message: Message) -> Content: + """ + 转换单个消息格式,除了system和tool类型的消息 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + + # 将openai格式的角色重命名为gemini格式的角色 + if message.role == RoleType.Assistant: + role = "model" + elif message.role == RoleType.User: + role = "user" + + # 添加Content + if isinstance(message.content, str): + content = [Part.from_text(text=message.content)] + elif isinstance(message.content, list): + content: List[Part] = [] + for item in message.content: + if isinstance(item, tuple): + content.append( + Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}") + ) + elif isinstance(item, str): + content.append(Part.from_text(text=item)) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + return Content(role=role, parts=content) + + temp_list: list[ContentUnion] = [] + system_instructions: list[str] = [] + for message in messages: + if message.role == RoleType.System: + if isinstance(message.content, str): + system_instructions.append(message.content) + else: + raise ValueError("你tm怎么往system里面塞图片base64?") + elif message.role == RoleType.Tool: + if not message.tool_call_id: + raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") + else: + temp_list.append(_convert_message_item(message)) + if system_instructions: + # 如果有system消息,就把它加上去 + ret: tuple = (temp_list, system_instructions) + else: + # 如果没有system消息,就直接返回 + ret: tuple = (temp_list, None) + + return ret + + +def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]: + """ + 转换工具选项格式 - 将工具选项转换为Gemini API所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具对象列表 + """ + + def _convert_tool_param(tool_option_param: ToolParam) -> dict: + """ + 转换单个工具参数格式 + :param tool_option_param: 工具参数对象 + :return: 转换后的工具参数字典 + """ + return_dict: dict[str, Any] = { + "type": tool_option_param.param_type.value, + "description": tool_option_param.description, + } + if tool_option_param.enum_values: + return_dict["enum"] = tool_option_param.enum_values + return return_dict + + def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration: + """ + 转换单个工具项格式 + :param tool_option: 工具选项对象 + :return: 转换后的Gemini工具选项对象 + """ + ret: dict[str, Any] = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + ret["parameters"] = { + "type": "object", + "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, + "required": [param.name for param in tool_option.params if param.required], + } + ret1 = FunctionDeclaration(**ret) + return ret1 + + return [_convert_tool_option_item(tool_option) for tool_option in tool_options] + + +def _process_delta( + delta: GenerateContentResponse, + fc_delta_buffer: io.StringIO, + tool_calls_buffer: list[tuple[str, str, dict[str, Any]]], +): + if not hasattr(delta, "candidates") or not delta.candidates: + raise RespParseException(delta, "响应解析失败,缺失candidates字段") + + if delta.text: + fc_delta_buffer.write(delta.text) + + if delta.function_calls: # 为什么不用hasattr呢,是因为这个属性一定有,即使是个空的 + for call in delta.function_calls: + try: + if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了 + raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型") + if not call.id or not call.name: + raise RespParseException(delta, "响应解析失败,工具调用缺失id或name字段") + tool_calls_buffer.append( + ( + call.id, + call.name, + call.args or {}, # 如果args是None,则转换为一个空字典 + ) + ) + except Exception as e: + raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e + + +def _build_stream_api_resp( + _fc_delta_buffer: io.StringIO, + _tool_calls_buffer: list[tuple[str, str, dict]], +) -> APIResponse: + # sourcery skip: simplify-len-comparison, use-assigned-variable + resp = APIResponse() + + if _fc_delta_buffer.tell() > 0: + # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + if len(_tool_calls_buffer) > 0: + # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer is not None: + arguments = arguments_buffer + if not isinstance(arguments, dict): + raise RespParseException( + None, + f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}", + ) + else: + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + + return resp + + +async def _default_stream_response_handler( + resp_stream: AsyncIterator[GenerateContentResponse], + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + """ + 流式响应处理函数 - 处理Gemini API的流式响应 + :param resp_stream: 流式响应对象,是一个神秘的iterator,我完全不知道这个玩意能不能跑,不过遍历一遍之后它就空了,如果跑不了一点的话可以考虑改成别的东西 + :return: APIResponse对象 + """ + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 + _tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _usage_record = None # 使用情况记录 + + def _insure_buffer_closed(): + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + + async for chunk in resp_stream: + # 检查是否有中断量 + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量被设置,则抛出ReqAbortException + raise ReqAbortException("请求被外部信号中断") + + _process_delta( + chunk, + _fc_delta_buffer, + _tool_calls_buffer, + ) + + if chunk.usage_metadata: + # 如果有使用情况,则将其存储在APIResponse对象中 + _usage_record = ( + chunk.usage_metadata.prompt_token_count or 0, + (chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0), + chunk.usage_metadata.total_token_count or 0, + ) + try: + return _build_stream_api_resp( + _fc_delta_buffer, + _tool_calls_buffer, + ), _usage_record + except Exception: + # 确保缓冲区被关闭 + _insure_buffer_closed() + raise + + +def _default_normal_response_parser( + resp: GenerateContentResponse, +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + """ + 解析对话补全响应 - 将Gemini API响应解析为APIResponse对象 + :param resp: 响应对象 + :return: APIResponse对象 + """ + api_response = APIResponse() + + if not hasattr(resp, "candidates") or not resp.candidates: + raise RespParseException(resp, "响应解析失败,缺失candidates字段") + try: + if resp.candidates[0].content and resp.candidates[0].content.parts: + for part in resp.candidates[0].content.parts: + if not part.text: + continue + if part.thought: + api_response.reasoning_content = ( + api_response.reasoning_content + part.text if api_response.reasoning_content else part.text + ) + except Exception as e: + logger.warning(f"解析思考内容时发生错误: {e},跳过解析") + + if resp.text: + api_response.content = resp.text + + if resp.function_calls: + api_response.tool_calls = [] + for call in resp.function_calls: + try: + if not isinstance(call.args, dict): + raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") + if not call.name: + raise RespParseException(resp, "响应解析失败,工具调用缺失name字段") + api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {})) + except Exception as e: + raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e + + if resp.usage_metadata: + _usage_record = ( + resp.usage_metadata.prompt_token_count or 0, + (resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0), + resp.usage_metadata.total_token_count or 0, + ) + else: + _usage_record = None + + api_response.raw_data = resp + + return api_response, _usage_record + + +@client_registry.register_client_class("gemini") +class GeminiClient(BaseClient): + client: genai.Client + + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + self.client = genai.Client( + api_key=api_provider.api_key, + ) # 这里和openai不一样,gemini会自己决定自己是否需要retry + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Optional[ + Callable[ + [AsyncIterator[GenerateContentResponse], asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], + ] + ] = None, + async_response_parser: Optional[ + Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]] + ] = None, + interrupt_flag: asyncio.Event | None = None, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取对话响应 + Args: + model_info: 模型信息 + message_list: 对话体 + tool_options: 工具选项(可选,默认为None) + max_tokens: 最大token数(可选,默认为1024) + temperature: 温度(可选,默认为0.7) + response_format: 响应格式(默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的,暂不支持其它相应格式输入) + stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + async_response_parser: 响应解析函数(可选,默认为default_response_parser) + interrupt_flag: 中断信号量(可选,默认为None) + Returns: + APIResponse对象,包含响应内容、推理内容、工具调用等信息 + """ + if stream_response_handler is None: + stream_response_handler = _default_stream_response_handler + + if async_response_parser is None: + async_response_parser = _default_normal_response_parser + + # 将messages构造为Gemini API所需的格式 + messages = _convert_messages(message_list) + # 将tool_options转换为Gemini API所需的格式 + tools = _convert_tool_options(tool_options) if tool_options else None + # 将response_format转换为Gemini API所需的格式 + generation_config_dict = { + "max_output_tokens": max_tokens, + "temperature": temperature, + "response_modalities": ["TEXT"], + "thinking_config": ThinkingConfig( + include_thoughts=True, + thinking_budget=( + extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None + ), + ), + } + if tools: + generation_config_dict["tools"] = Tool(function_declarations=tools) + if messages[1]: + # 如果有system消息,则将其添加到配置中 + generation_config_dict["system_instructions"] = messages[1] + if response_format and response_format.format_type == RespFormatType.TEXT: + generation_config_dict["response_mime_type"] = "text/plain" + elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA): + generation_config_dict["response_mime_type"] = "application/json" + generation_config_dict["response_schema"] = response_format.to_dict() + + generation_config = GenerateContentConfig(**generation_config_dict) + + try: + if model_info.force_stream_mode: + req_task = asyncio.create_task( + self.client.aio.models.generate_content_stream( + model=model_info.model_identifier, + contents=messages[0], + config=generation_config, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 + resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) + else: + req_task = asyncio.create_task( + self.client.aio.models.generate_content( + model=model_info.model_identifier, + contents=messages[0], + config=generation_config, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + + resp, usage_record = async_response_parser(req_task.result()) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.code, e.message) from None + except ( + UnknownFunctionCallArgumentError, + UnsupportedFunctionError, + FunctionInvocationError, + ) as e: + raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None + except Exception as e: + raise NetworkConnectionError() from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + try: + raw_response: EmbedContentResponse = await self.client.aio.models.embed_content( + model=model_info.model_identifier, + contents=embedding_input, + config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"), + ) + except (ClientError, ServerError) as e: + # 重封装ClientError和ServerError为RespNotOkException + raise RespNotOkException(e.code) from None + except Exception as e: + raise NetworkConnectionError() from e + + response = APIResponse() + + # 解析嵌入响应和使用情况 + if hasattr(raw_response, "embeddings") and raw_response.embeddings: + response.embedding = raw_response.embeddings[0].values + else: + raise RespParseException(raw_response, "响应解析失败,缺失embeddings字段") + + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=len(embedding_input), + completion_tokens=0, + total_tokens=len(embedding_input), + ) + + return response + + def get_audio_transcriptions( + self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None + ) -> APIResponse: + raise NotImplementedError("尚未实现音频转录功能") + + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + return ["png", "jpg", "jpeg", "webp", "heic", "heif"] diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py new file mode 100644 index 00000000..ad9cbf17 --- /dev/null +++ b/src/llm_models/model_client/openai_client.py @@ -0,0 +1,580 @@ +import asyncio +import io +import json +import re +import base64 +from collections.abc import Iterable +from typing import Callable, Any, Coroutine, Optional +from json_repair import repair_json + +from openai import ( + AsyncOpenAI, + APIConnectionError, + APIStatusError, + NOT_GIVEN, + AsyncStream, +) +from openai.types.chat import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionToolParam, +) +from openai.types.chat.chat_completion_chunk import ChoiceDelta + +from src.config.api_ada_configs import ModelInfo, APIProvider +from src.common.logger import get_logger +from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + +logger = get_logger("OpenAI客户端") + + +def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]: + """ + 转换消息格式 - 将消息转换为OpenAI API所需的格式 + :param messages: 消息列表 + :return: 转换后的消息列表 + """ + + def _convert_message_item(message: Message) -> ChatCompletionMessageParam: + """ + 转换单个消息格式 + :param message: 消息对象 + :return: 转换后的消息字典 + """ + + # 添加Content + content: str | list[dict[str, Any]] + if isinstance(message.content, str): + content = message.content + elif isinstance(message.content, list): + content = [] + for item in message.content: + if isinstance(item, tuple): + content.append( + { + "type": "image_url", + "image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"}, + } + ) + elif isinstance(item, str): + content.append({"type": "text", "text": item}) + else: + raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象") + + ret = { + "role": message.role.value, + "content": content, + } + + # 添加工具调用ID + if message.role == RoleType.Tool: + if not message.tool_call_id: + raise ValueError("无法触及的代码:请使用MessageBuilder类构建消息对象") + ret["tool_call_id"] = message.tool_call_id + + return ret # type: ignore + + return [_convert_message_item(message) for message in messages] + + +def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]: + """ + 转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式 + :param tool_options: 工具选项列表 + :return: 转换后的工具选项列表 + """ + + def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]: + """ + 转换单个工具参数格式 + :param tool_option_param: 工具参数对象 + :return: 转换后的工具参数字典 + """ + return_dict: dict[str, Any] = { + "type": tool_option_param.param_type.value, + "description": tool_option_param.description, + } + if tool_option_param.enum_values: + return_dict["enum"] = tool_option_param.enum_values + return return_dict + + def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]: + """ + 转换单个工具项格式 + :param tool_option: 工具选项对象 + :return: 转换后的工具选项字典 + """ + ret: dict[str, Any] = { + "name": tool_option.name, + "description": tool_option.description, + } + if tool_option.params: + ret["parameters"] = { + "type": "object", + "properties": {param.name: _convert_tool_param(param) for param in tool_option.params}, + "required": [param.name for param in tool_option.params if param.required], + } + return ret + + return [ + { + "type": "function", + "function": _convert_tool_option_item(tool_option), + } + for tool_option in tool_options + ] + + +def _process_delta( + delta: ChoiceDelta, + has_rc_attr_flag: bool, + in_rc_flag: bool, + rc_delta_buffer: io.StringIO, + fc_delta_buffer: io.StringIO, + tool_calls_buffer: list[tuple[str, str, io.StringIO]], +) -> bool: + # 接收content + if has_rc_attr_flag: + # 有独立的推理内容块,则无需考虑content内容的判读 + if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore + # 如果有推理内容,则将其写入推理内容缓冲区 + assert isinstance(delta.reasoning_content, str) # type: ignore + rc_delta_buffer.write(delta.reasoning_content) # type: ignore + elif delta.content: + # 如果有正式内容,则将其写入正式内容缓冲区 + fc_delta_buffer.write(delta.content) + elif hasattr(delta, "content") and delta.content is not None: + # 没有独立的推理内容块,但有正式内容 + if in_rc_flag: + # 当前在推理内容块中 + if delta.content == "": + # 如果当前内容是,则将其视为推理内容的结束标记,退出推理内容块 + in_rc_flag = False + else: + # 其他情况视为推理内容,加入推理内容缓冲区 + rc_delta_buffer.write(delta.content) + elif delta.content == "" and not fc_delta_buffer.getvalue(): + # 如果当前内容是,且正式内容缓冲区为空,说明为输出的首个token + # 则将其视为推理内容的开始标记,进入推理内容块 + in_rc_flag = True + else: + # 其他情况视为正式内容,加入正式内容缓冲区 + fc_delta_buffer.write(delta.content) + # 接收tool_calls + if hasattr(delta, "tool_calls") and delta.tool_calls: + tool_call_delta = delta.tool_calls[0] + + if tool_call_delta.index >= len(tool_calls_buffer): + # 调用索引号大于等于缓冲区长度,说明是新的工具调用 + if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name: + tool_calls_buffer.append( + ( + tool_call_delta.id, + tool_call_delta.function.name, + io.StringIO(), + ) + ) + else: + logger.warning("工具调用索引号大于等于缓冲区长度,但缺少ID或函数信息。") + + if tool_call_delta.function and tool_call_delta.function.arguments: + # 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中 + tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments) + + return in_rc_flag + + +def _build_stream_api_resp( + _fc_delta_buffer: io.StringIO, + _rc_delta_buffer: io.StringIO, + _tool_calls_buffer: list[tuple[str, str, io.StringIO]], +) -> APIResponse: + resp = APIResponse() + + if _rc_delta_buffer.tell() > 0: + # 如果推理内容缓冲区不为空,则将其写入APIResponse对象 + resp.reasoning_content = _rc_delta_buffer.getvalue() + _rc_delta_buffer.close() + if _fc_delta_buffer.tell() > 0: + # 如果正式内容缓冲区不为空,则将其写入APIResponse对象 + resp.content = _fc_delta_buffer.getvalue() + _fc_delta_buffer.close() + if _tool_calls_buffer: + # 如果工具调用缓冲区不为空,则将其解析为ToolCall对象列表 + resp.tool_calls = [] + for call_id, function_name, arguments_buffer in _tool_calls_buffer: + if arguments_buffer.tell() > 0: + # 如果参数串缓冲区不为空,则解析为JSON对象 + raw_arg_data = arguments_buffer.getvalue() + arguments_buffer.close() + try: + arguments = json.loads(repair_json(raw_arg_data)) + if not isinstance(arguments, dict): + raise RespParseException( + None, + f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}", + ) + except json.JSONDecodeError as e: + raise RespParseException( + None, + f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}", + ) from e + else: + arguments_buffer.close() + arguments = None + + resp.tool_calls.append(ToolCall(call_id, function_name, arguments)) + + return resp + + +async def _default_stream_response_handler( + resp_stream: AsyncStream[ChatCompletionChunk], + interrupt_flag: asyncio.Event | None, +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + """ + 流式响应处理函数 - 处理OpenAI API的流式响应 + :param resp_stream: 流式响应对象 + :return: APIResponse对象 + """ + + _has_rc_attr_flag = False # 标记是否有独立的推理内容块 + _in_rc_flag = False # 标记是否在推理内容块中 + _rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容 + _fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容 + _tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用 + _usage_record = None # 使用情况记录 + + def _insure_buffer_closed(): + # 确保缓冲区被关闭 + if _rc_delta_buffer and not _rc_delta_buffer.closed: + _rc_delta_buffer.close() + if _fc_delta_buffer and not _fc_delta_buffer.closed: + _fc_delta_buffer.close() + for _, _, buffer in _tool_calls_buffer: + if buffer and not buffer.closed: + buffer.close() + + async for event in resp_stream: + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量被设置,则抛出ReqAbortException + _insure_buffer_closed() + raise ReqAbortException("请求被外部信号中断") + + delta = event.choices[0].delta # 获取当前块的delta内容 + + if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore + # 标记:有独立的推理内容块 + _has_rc_attr_flag = True + + _in_rc_flag = _process_delta( + delta, + _has_rc_attr_flag, + _in_rc_flag, + _rc_delta_buffer, + _fc_delta_buffer, + _tool_calls_buffer, + ) + + if event.usage: + # 如果有使用情况,则将其存储在APIResponse对象中 + _usage_record = ( + event.usage.prompt_tokens or 0, + event.usage.completion_tokens or 0, + event.usage.total_tokens or 0, + ) + + try: + return _build_stream_api_resp( + _fc_delta_buffer, + _rc_delta_buffer, + _tool_calls_buffer, + ), _usage_record + except Exception: + # 确保缓冲区被关闭 + _insure_buffer_closed() + raise + + +pattern = re.compile( + r"(?P.*?)(?P.*)|(?P.*)|(?P.+)", + re.DOTALL, +) +"""用于解析推理内容的正则表达式""" + + +def _default_normal_response_parser( + resp: ChatCompletion, +) -> tuple[APIResponse, Optional[tuple[int, int, int]]]: + """ + 解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象 + :param resp: 响应对象 + :return: APIResponse对象 + """ + api_response = APIResponse() + + if not hasattr(resp, "choices") or len(resp.choices) == 0: + raise RespParseException(resp, "响应解析失败,缺失choices字段") + message_part = resp.choices[0].message + + if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore + # 有有效的推理字段 + api_response.content = message_part.content + api_response.reasoning_content = message_part.reasoning_content # type: ignore + elif message_part.content: + # 提取推理和内容 + match = pattern.match(message_part.content) + if not match: + raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容") + if match.group("think") is not None: + result = match.group("think").strip(), match.group("content").strip() + elif match.group("think_unclosed") is not None: + result = match.group("think_unclosed").strip(), None + else: + result = None, match.group("content_only").strip() + api_response.reasoning_content, api_response.content = result + + # 提取工具调用 + if message_part.tool_calls: + api_response.tool_calls = [] + for call in message_part.tool_calls: + try: + arguments = json.loads(repair_json(call.function.arguments)) + if not isinstance(arguments, dict): + raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型") + api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments)) + except json.JSONDecodeError as e: + raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e + + # 提取Usage信息 + if resp.usage: + _usage_record = ( + resp.usage.prompt_tokens or 0, + resp.usage.completion_tokens or 0, + resp.usage.total_tokens or 0, + ) + else: + _usage_record = None + + # 将原始响应存储在原始数据中 + api_response.raw_data = resp + + return api_response, _usage_record + + +@client_registry.register_client_class("openai") +class OpenaiClient(BaseClient): + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + self.client: AsyncOpenAI = AsyncOpenAI( + base_url=api_provider.base_url, + api_key=api_provider.api_key, + max_retries=0, + ) + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Optional[ + Callable[ + [AsyncStream[ChatCompletionChunk], asyncio.Event | None], + Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]], + ] + ] = None, + async_response_parser: Optional[ + Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]] + ] = None, + interrupt_flag: asyncio.Event | None = None, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取对话响应 + Args: + model_info: 模型信息 + message_list: 对话体 + tool_options: 工具选项(可选,默认为None) + max_tokens: 最大token数(可选,默认为1024) + temperature: 温度(可选,默认为0.7) + response_format: 响应格式(可选,默认为 NotGiven ) + stream_response_handler: 流式响应处理函数(可选,默认为default_stream_response_handler) + async_response_parser: 响应解析函数(可选,默认为default_response_parser) + interrupt_flag: 中断信号量(可选,默认为None) + Returns: + (响应文本, 推理文本, 工具调用, 其他数据) + """ + if stream_response_handler is None: + stream_response_handler = _default_stream_response_handler + + if async_response_parser is None: + async_response_parser = _default_normal_response_parser + + # 将messages构造为OpenAI API所需的格式 + messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list) + # 将tool_options转换为OpenAI API所需的格式 + tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore + + try: + if model_info.force_stream_mode: + req_task = asyncio.create_task( + self.client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=True, + response_format=NOT_GIVEN, + extra_body=extra_params, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态 + + resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag) + else: + # 发送请求并获取响应 + req_task = asyncio.create_task( + self.client.chat.completions.create( + model=model_info.model_identifier, + messages=messages, + tools=tools, + temperature=temperature, + max_tokens=max_tokens, + stream=False, + response_format=NOT_GIVEN, + extra_body=extra_params, + ) + ) + while not req_task.done(): + if interrupt_flag and interrupt_flag.is_set(): + # 如果中断量存在且被设置,则取消任务并抛出异常 + req_task.cancel() + raise ReqAbortException("请求被外部信号中断") + await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态 + + resp, usage_record = async_response_parser(req_task.result()) + except APIConnectionError as e: + # 重封装APIConnectionError为NetworkConnectionError + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code, e.message) from e + + if usage_record: + resp.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return resp + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取文本嵌入 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + try: + raw_response = await self.client.embeddings.create( + model=model_info.model_identifier, + input=embedding_input, + extra_body=extra_params, + ) + except APIConnectionError as e: + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code) from e + + response = APIResponse() + + # 解析嵌入响应 + if len(raw_response.data) > 0: + response.embedding = raw_response.data[0].embedding + else: + raise RespParseException( + raw_response, + "响应解析失败,缺失嵌入数据。", + ) + + # 解析使用情况 + if hasattr(raw_response, "usage"): + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=raw_response.usage.prompt_tokens or 0, + completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore + total_tokens=raw_response.usage.total_tokens or 0, + ) + + return response + + async def get_audio_transcriptions( + self, + model_info: ModelInfo, + audio_base64: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取音频转录 + :param model_info: 模型信息 + :param audio_base64: base64编码的音频数据 + :extra_params: 附加的请求参数 + :return: 音频转录响应 + """ + try: + raw_response = await self.client.audio.transcriptions.create( + model=model_info.model_identifier, + file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))), + extra_body=extra_params, + ) + except APIConnectionError as e: + raise NetworkConnectionError() from e + except APIStatusError as e: + # 重封装APIError为RespNotOkException + raise RespNotOkException(e.status_code) from e + response = APIResponse() + # 解析转录响应 + if hasattr(raw_response, "text"): + response.content = raw_response.text + else: + raise RespParseException( + raw_response, + "响应解析失败,缺失转录文本。", + ) + return response + + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + return ["jpg", "jpeg", "png", "webp", "gif"] diff --git a/src/llm_models/payload_content/__init__.py b/src/llm_models/payload_content/__init__.py new file mode 100644 index 00000000..33e43c5e --- /dev/null +++ b/src/llm_models/payload_content/__init__.py @@ -0,0 +1,3 @@ +from .tool_option import ToolCall + +__all__ = ["ToolCall"] \ No newline at end of file diff --git a/src/llm_models/payload_content/message.py b/src/llm_models/payload_content/message.py new file mode 100644 index 00000000..f70c3ded --- /dev/null +++ b/src/llm_models/payload_content/message.py @@ -0,0 +1,107 @@ +from enum import Enum + + +# 设计这系列类的目的是为未来可能的扩展做准备 + + +class RoleType(Enum): + System = "system" + User = "user" + Assistant = "assistant" + Tool = "tool" + + +SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式 + + +class Message: + def __init__( + self, + role: RoleType, + content: str | list[tuple[str, str] | str], + tool_call_id: str | None = None, + ): + """ + 初始化消息对象 + (不应直接修改Message类,而应使用MessageBuilder类来构建对象) + """ + self.role: RoleType = role + self.content: str | list[tuple[str, str] | str] = content + self.tool_call_id: str | None = tool_call_id + + +class MessageBuilder: + def __init__(self): + self.__role: RoleType = RoleType.User + self.__content: list[tuple[str, str] | str] = [] + self.__tool_call_id: str | None = None + + def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder": + """ + 设置角色(默认为User) + :param role: 角色 + :return: MessageBuilder对象 + """ + self.__role = role + return self + + def add_text_content(self, text: str) -> "MessageBuilder": + """ + 添加文本内容 + :param text: 文本内容 + :return: MessageBuilder对象 + """ + self.__content.append(text) + return self + + def add_image_content( + self, + image_format: str, + image_base64: str, + support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式 + ) -> "MessageBuilder": + """ + 添加图片内容 + :param image_format: 图片格式 + :param image_base64: 图片的base64编码 + :return: MessageBuilder对象 + """ + if image_format.lower() not in support_formats: + raise ValueError("不受支持的图片格式") + if not image_base64: + raise ValueError("图片的base64编码不能为空") + self.__content.append((image_format, image_base64)) + return self + + def add_tool_call(self, tool_call_id: str) -> "MessageBuilder": + """ + 添加工具调用指令(调用时请确保已设置为Tool角色) + :param tool_call_id: 工具调用指令的id + :return: MessageBuilder对象 + """ + if self.__role != RoleType.Tool: + raise ValueError("仅当角色为Tool时才能添加工具调用ID") + if not tool_call_id: + raise ValueError("工具调用ID不能为空") + self.__tool_call_id = tool_call_id + return self + + def build(self) -> Message: + """ + 构建消息对象 + :return: Message对象 + """ + if len(self.__content) == 0: + raise ValueError("内容不能为空") + if self.__role == RoleType.Tool and self.__tool_call_id is None: + raise ValueError("Tool角色的工具调用ID不能为空") + + return Message( + role=self.__role, + content=( + self.__content[0] + if (len(self.__content) == 1 and isinstance(self.__content[0], str)) + else self.__content + ), + tool_call_id=self.__tool_call_id, + ) diff --git a/src/llm_models/payload_content/resp_format.py b/src/llm_models/payload_content/resp_format.py new file mode 100644 index 00000000..ab2e2edf --- /dev/null +++ b/src/llm_models/payload_content/resp_format.py @@ -0,0 +1,223 @@ +from enum import Enum +from typing import Optional, Any + +from pydantic import BaseModel +from typing_extensions import TypedDict, Required + + +class RespFormatType(Enum): + TEXT = "text" # 文本 + JSON_OBJ = "json_object" # JSON + JSON_SCHEMA = "json_schema" # JSON Schema + + +class JsonSchema(TypedDict, total=False): + name: Required[str] + """ + The name of the response format. + + Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length + of 64. + """ + + description: Optional[str] + """ + A description of what the response format is for, used by the model to determine + how to respond in the format. + """ + + schema: dict[str, object] + """ + The schema for the response format, described as a JSON Schema object. Learn how + to build JSON schemas [here](https://json-schema.org/). + """ + + strict: Optional[bool] + """ + Whether to enable strict schema adherence when generating the output. If set to + true, the model will always follow the exact schema defined in the `schema` + field. Only a subset of JSON Schema is supported when `strict` is `true`. To + learn more, read the + [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). + """ + + +def _json_schema_type_check(instance) -> str | None: + if "name" not in instance: + return "schema必须包含'name'字段" + elif not isinstance(instance["name"], str) or instance["name"].strip() == "": + return "schema的'name'字段必须是非空字符串" + if "description" in instance and ( + not isinstance(instance["description"], str) + or instance["description"].strip() == "" + ): + return "schema的'description'字段只能填入非空字符串" + if "schema" not in instance: + return "schema必须包含'schema'字段" + elif not isinstance(instance["schema"], dict): + return "schema的'schema'字段必须是字典,详见https://json-schema.org/" + if "strict" in instance and not isinstance(instance["strict"], bool): + return "schema的'strict'字段只能填入布尔值" + + return None + + +def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]: + """ + 递归移除JSON Schema中的title字段 + """ + if isinstance(schema, list): + # 如果当前Schema是列表,则对所有dict/list子元素递归调用 + for idx, item in enumerate(schema): + if isinstance(item, (dict, list)): + schema[idx] = _remove_title(item) + elif isinstance(schema, dict): + # 是字典,移除title字段,并对所有dict/list子元素递归调用 + if "title" in schema: + del schema["title"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) + + return schema + + +def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]: + """ + 链接JSON Schema中的definitions字段 + """ + + def link_definitions_recursive( + path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any] + ) -> dict[str, Any]: + """ + 递归链接JSON Schema中的definitions字段 + :param path: 当前路径 + :param sub_schema: 子Schema + :param defs: Schema定义集 + :return: + """ + if isinstance(sub_schema, list): + # 如果当前Schema是列表,则遍历每个元素 + for i in range(len(sub_schema)): + if isinstance(sub_schema[i], dict): + sub_schema[i] = link_definitions_recursive( + f"{path}/{str(i)}", sub_schema[i], defs + ) + else: + # 否则为字典 + if "$defs" in sub_schema: + # 如果当前Schema有$def字段,则将其添加到defs中 + key_prefix = f"{path}/$defs/" + for key, value in sub_schema["$defs"].items(): + def_key = key_prefix + key + if def_key not in defs: + defs[def_key] = value + del sub_schema["$defs"] + if "$ref" in sub_schema: + # 如果当前Schema有$ref字段,则将其替换为defs中的定义 + def_key = sub_schema["$ref"] + if def_key in defs: + sub_schema = defs[def_key] + else: + raise ValueError(f"Schema中引用的定义'{def_key}'不存在") + # 遍历键值对 + for key, value in sub_schema.items(): + if isinstance(value, (dict, list)): + # 如果当前值是字典或列表,则递归调用 + sub_schema[key] = link_definitions_recursive( + f"{path}/{key}", value, defs + ) + + return sub_schema + + return link_definitions_recursive("#", schema, {}) + + +def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]: + """ + 递归移除JSON Schema中的$defs字段 + """ + if isinstance(schema, list): + # 如果当前Schema是列表,则对所有dict/list子元素递归调用 + for idx, item in enumerate(schema): + if isinstance(item, (dict, list)): + schema[idx] = _remove_title(item) + elif isinstance(schema, dict): + # 是字典,移除title字段,并对所有dict/list子元素递归调用 + if "$defs" in schema: + del schema["$defs"] + for key, value in schema.items(): + if isinstance(value, (dict, list)): + schema[key] = _remove_title(value) + + return schema + + +class RespFormat: + """ + 响应格式 + """ + + @staticmethod + def _generate_schema_from_model(schema): + json_schema = { + "name": schema.__name__, + "schema": _remove_defs( + _link_definitions(_remove_title(schema.model_json_schema())) + ), + "strict": False, + } + if schema.__doc__: + json_schema["description"] = schema.__doc__ + return json_schema + + def __init__( + self, + format_type: RespFormatType = RespFormatType.TEXT, + schema: type | JsonSchema | None = None, + ): + """ + 响应格式 + :param format_type: 响应格式类型(默认为文本) + :param schema: 模板类或JsonSchema(仅当format_type为JSON Schema时有效) + """ + self.format_type: RespFormatType = format_type + + if format_type == RespFormatType.JSON_SCHEMA: + if schema is None: + raise ValueError("当format_type为'JSON_SCHEMA'时,schema不能为空") + if isinstance(schema, dict): + if check_msg := _json_schema_type_check(schema): + raise ValueError(f"schema格式不正确,{check_msg}") + + self.schema = schema + elif issubclass(schema, BaseModel): + try: + json_schema = self._generate_schema_from_model(schema) + + self.schema = json_schema + except Exception as e: + raise ValueError( + f"自动生成JSON Schema时发生异常,请检查模型类{schema.__name__}的定义,详细信息:\n" + f"{schema.__name__}:\n" + ) from e + else: + raise ValueError("schema必须是BaseModel的子类或JsonSchema") + else: + self.schema = None + + def to_dict(self): + """ + 将响应格式转换为字典 + :return: 字典 + """ + if self.schema: + return { + "format_type": self.format_type.value, + "schema": self.schema, + } + else: + return { + "format_type": self.format_type.value, + } diff --git a/src/llm_models/payload_content/tool_option.py b/src/llm_models/payload_content/tool_option.py new file mode 100644 index 00000000..9fedbc86 --- /dev/null +++ b/src/llm_models/payload_content/tool_option.py @@ -0,0 +1,163 @@ +from enum import Enum + + +class ToolParamType(Enum): + """ + 工具调用参数类型 + """ + + STRING = "string" # 字符串 + INTEGER = "integer" # 整型 + FLOAT = "float" # 浮点型 + BOOLEAN = "bool" # 布尔型 + + +class ToolParam: + """ + 工具调用参数 + """ + + def __init__( + self, + name: str, + param_type: ToolParamType, + description: str, + required: bool, + enum_values: list[str] | None = None, + ): + """ + 初始化工具调用参数 + (不应直接修改ToolParam类,而应使用ToolOptionBuilder类来构建对象) + :param name: 参数名称 + :param param_type: 参数类型 + :param description: 参数描述 + :param required: 是否必填 + """ + self.name: str = name + self.param_type: ToolParamType = param_type + self.description: str = description + self.required: bool = required + self.enum_values: list[str] | None = enum_values + + +class ToolOption: + """ + 工具调用项 + """ + + def __init__( + self, + name: str, + description: str, + params: list[ToolParam] | None = None, + ): + """ + 初始化工具调用项 + (不应直接修改ToolOption类,而应使用ToolOptionBuilder类来构建对象) + :param name: 工具名称 + :param description: 工具描述 + :param params: 工具参数列表 + """ + self.name: str = name + self.description: str = description + self.params: list[ToolParam] | None = params + + +class ToolOptionBuilder: + """ + 工具调用项构建器 + """ + + def __init__(self): + self.__name: str = "" + self.__description: str = "" + self.__params: list[ToolParam] = [] + + def set_name(self, name: str) -> "ToolOptionBuilder": + """ + 设置工具名称 + :param name: 工具名称 + :return: ToolBuilder实例 + """ + if not name: + raise ValueError("工具名称不能为空") + self.__name = name + return self + + def set_description(self, description: str) -> "ToolOptionBuilder": + """ + 设置工具描述 + :param description: 工具描述 + :return: ToolBuilder实例 + """ + if not description: + raise ValueError("工具描述不能为空") + self.__description = description + return self + + def add_param( + self, + name: str, + param_type: ToolParamType, + description: str, + required: bool = False, + enum_values: list[str] | None = None, + ) -> "ToolOptionBuilder": + """ + 添加工具参数 + :param name: 参数名称 + :param param_type: 参数类型 + :param description: 参数描述 + :param required: 是否必填(默认为False) + :return: ToolBuilder实例 + """ + if not name or not description: + raise ValueError("参数名称/描述不能为空") + + self.__params.append( + ToolParam( + name=name, + param_type=param_type, + description=description, + required=required, + enum_values=enum_values, + ) + ) + + return self + + def build(self): + """ + 构建工具调用项 + :return: 工具调用项 + """ + if self.__name == "" or self.__description == "": + raise ValueError("工具名称/描述不能为空") + + return ToolOption( + name=self.__name, + description=self.__description, + params=None if len(self.__params) == 0 else self.__params, + ) + + +class ToolCall: + """ + 来自模型反馈的工具调用 + """ + + def __init__( + self, + call_id: str, + func_name: str, + args: dict | None = None, + ): + """ + 初始化工具调用 + :param call_id: 工具调用ID + :param func_name: 要调用的函数名称 + :param args: 工具调用参数 + """ + self.call_id: str = call_id + self.func_name: str = func_name + self.args: dict | None = args diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py new file mode 100644 index 00000000..52a6120c --- /dev/null +++ b/src/llm_models/utils.py @@ -0,0 +1,186 @@ +import base64 +import io + +from PIL import Image +from datetime import datetime + +from src.common.logger import get_logger +from src.common.database.database import db # 确保 db 被导入用于 create_tables +from src.common.database.database_model import LLMUsage +from src.config.api_ada_configs import ModelInfo +from .payload_content.message import Message, MessageBuilder +from .model_client.base_client import UsageRecord + +logger = get_logger("消息压缩工具") + + +def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]: + """ + 压缩消息列表中的图片 + :param messages: 消息列表 + :param img_target_size: 图片目标大小,默认1MB + :return: 压缩后的消息列表 + """ + + def reformat_static_image(image_data: bytes) -> bytes: + """ + 将静态图片转换为JPEG格式 + :param image_data: 图片数据 + :return: 转换后的图片数据 + """ + try: + image = Image.open(image_data) + + if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]): + # 静态图像,转换为JPEG格式 + reformated_image_data = io.BytesIO() + image.save(reformated_image_data, format="JPEG", quality=95, optimize=True) + image_data = reformated_image_data.getvalue() + + return image_data + except Exception as e: + logger.error(f"图片转换格式失败: {str(e)}") + return image_data + + def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]: + """ + 缩放图片 + :param image_data: 图片数据 + :param scale: 缩放比例 + :return: 缩放后的图片数据 + """ + try: + image = Image.open(image_data) + + # 原始尺寸 + original_size = (image.width, image.height) + + # 计算新的尺寸 + new_size = (int(original_size[0] * scale), int(original_size[1] * scale)) + + output_buffer = io.BytesIO() + + if getattr(image, "is_animated", False): + # 动态图片,处理所有帧 + frames = [] + new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折 + for frame_idx in range(getattr(image, "n_frames", 1)): + image.seek(frame_idx) + new_frame = image.copy() + new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS) + frames.append(new_frame) + + # 保存到缓冲区 + frames[0].save( + output_buffer, + format="GIF", + save_all=True, + append_images=frames[1:], + optimize=True, + duration=image.info.get("duration", 100), + loop=image.info.get("loop", 0), + ) + else: + # 静态图片,直接缩放保存 + resized_image = image.resize(new_size, Image.Resampling.LANCZOS) + resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True) + + return output_buffer.getvalue(), original_size, new_size + + except Exception as e: + logger.error(f"图片缩放失败: {str(e)}") + import traceback + + logger.error(traceback.format_exc()) + return image_data, None, None + + def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str: + original_b64_data_size = len(base64_data) # 计算原始数据大小 + + image_data = base64.b64decode(base64_data) + + # 先尝试转换格式为JPEG + image_data = reformat_static_image(image_data) + base64_data = base64.b64encode(image_data).decode("utf-8") + if len(base64_data) <= target_size: + # 如果转换后小于目标大小,直接返回 + logger.info(f"成功将图片转为JPEG格式,编码后大小: {len(base64_data) / 1024:.1f}KB") + return base64_data + + # 如果转换后仍然大于目标大小,进行尺寸压缩 + scale = min(1.0, target_size / len(base64_data)) + image_data, original_size, new_size = rescale_image(image_data, scale) + base64_data = base64.b64encode(image_data).decode("utf-8") + + if original_size and new_size: + logger.info( + f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n" + f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB" + ) + + return base64_data + + compressed_messages = [] + for message in messages: + if isinstance(message.content, list): + # 检查content,如有图片则压缩 + message_builder = MessageBuilder() + for content_item in message.content: + if isinstance(content_item, tuple): + # 图片,进行压缩 + message_builder.add_image_content( + content_item[0], + compress_base64_image(content_item[1], target_size=img_target_size), + ) + else: + message_builder.add_text_content(content_item) + compressed_messages.append(message_builder.build()) + else: + compressed_messages.append(message) + + return compressed_messages + + +class LLMUsageRecorder: + """ + LLM使用情况记录器 + """ + + def __init__(self): + try: + # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 + db.create_tables([LLMUsage], safe=True) + # logger.debug("LLMUsage 表已初始化/确保存在。") + except Exception as e: + logger.error(f"创建 LLMUsage 表失败: {str(e)}") + + def record_usage_to_database( + self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str + ): + input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in + output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out + total_cost = round(input_cost + output_cost, 6) + try: + # 使用 Peewee 模型创建记录 + LLMUsage.create( + model_name=model_info.model_identifier, + user_id=user_id, + request_type=request_type, + endpoint=endpoint, + prompt_tokens=model_usage.prompt_tokens or 0, + completion_tokens=model_usage.completion_tokens or 0, + total_tokens=model_usage.total_tokens or 0, + cost=total_cost or 0.0, + status="success", + timestamp=datetime.now(), # Peewee 会处理 DateTimeField + ) + logger.debug( + f"Token使用情况 - 模型: {model_usage.model_name}, " + f"用户: {user_id}, 类型: {request_type}, " + f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, " + f"总计: {model_usage.total_tokens}" + ) + except Exception as e: + logger.error(f"记录token使用情况失败: {str(e)}") + +llm_usage_recorder = LLMUsageRecorder() \ No newline at end of file diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 9aca329e..b6764064 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1,65 +1,29 @@ -import asyncio -import json import re -from datetime import datetime -from typing import Tuple, Union, Dict, Any, Callable -import aiohttp -from aiohttp.client import ClientResponse -from src.common.logger import get_logger -import base64 -from PIL import Image -import io -import os -import copy # 添加copy模块用于深拷贝 -from src.common.database.database import db # 确保 db 被导入用于 create_tables -from src.common.database.database_model import LLMUsage # 导入 LLMUsage 模型 -from src.config.config import global_config -from src.common.tcp_connector import get_tcp_connector +import copy +import asyncio + +from enum import Enum from rich.traceback import install +from typing import Tuple, List, Dict, Optional, Callable, Any + +from src.common.logger import get_logger +from src.config.config import model_config +from src.config.api_ada_configs import APIProvider, ModelInfo, TaskConfig +from .payload_content.message import MessageBuilder, Message +from .payload_content.resp_format import RespFormat +from .payload_content.tool_option import ToolOption, ToolCall, ToolOptionBuilder, ToolParamType +from .model_client.base_client import BaseClient, APIResponse, client_registry +from .utils import compress_messages, llm_usage_recorder +from .exceptions import NetworkConnectionError, ReqAbortException, RespNotOkException, RespParseException install(extra_lines=3) logger = get_logger("model_utils") - -class PayLoadTooLargeError(Exception): - """自定义异常类,用于处理请求体过大错误""" - - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def __str__(self): - return "请求体过大,请尝试压缩图片或减少输入内容。" - - -class RequestAbortException(Exception): - """自定义异常类,用于处理请求中断异常""" - - def __init__(self, message: str, response: ClientResponse): - super().__init__(message) - self.message = message - self.response = response - - def __str__(self): - return self.message - - -class PermissionDeniedException(Exception): - """自定义异常类,用于处理访问拒绝的异常""" - - def __init__(self, message: str): - super().__init__(message) - self.message = message - - def __str__(self): - return self.message - - # 常见Error Code Mapping error_code_mapping = { 400: "参数不正确", - 401: "API key 错误,认证失败,请检查/config/bot_config.toml和.env中的配置是否正确哦~", + 401: "API key 错误,认证失败,请检查 config/model_config.toml 中的配置是否正确", 402: "账号余额不足", 403: "需要实名,或余额不足", 404: "Not Found", @@ -69,1013 +33,474 @@ error_code_mapping = { } -async def _safely_record(request_content: Dict[str, Any], payload: Dict[str, Any]): - """安全地记录请求体,用于调试日志,不会修改原始payload对象""" - # 创建payload的深拷贝,避免修改原始对象 - safe_payload = copy.deepcopy(payload) - - image_base64: str = request_content.get("image_base64") - image_format: str = request_content.get("image_format") - if ( - image_base64 - and safe_payload - and isinstance(safe_payload, dict) - and "messages" in safe_payload - and len(safe_payload["messages"]) > 0 - ): - if isinstance(safe_payload["messages"][0], dict) and "content" in safe_payload["messages"][0]: - content = safe_payload["messages"][0]["content"] - if isinstance(content, list) and len(content) > 1 and "image_url" in content[1]: - # 只修改拷贝的对象,用于安全的日志记录 - safe_payload["messages"][0]["content"][1]["image_url"]["url"] = ( - f"data:image/{image_format.lower() if image_format else 'jpeg'};base64," - f"{image_base64[:10]}...{image_base64[-10:]}" - ) - return safe_payload +class RequestType(Enum): + """请求类型枚举""" + + RESPONSE = "response" + EMBEDDING = "embedding" + AUDIO = "audio" class LLMRequest: - # 定义需要转换的模型列表,作为类变量避免重复 - MODELS_NEEDING_TRANSFORMATION = [ - "o1", - "o1-2024-12-17", - "o1-mini", - "o1-mini-2024-09-12", - "o1-preview", - "o1-preview-2024-09-12", - "o1-pro", - "o1-pro-2025-03-19", - "o3", - "o3-2025-04-16", - "o3-mini", - "o3-mini-2025-01-31", - "o4-mini", - "o4-mini-2025-04-16", - ] + """LLM请求类""" - def __init__(self, model: dict, **kwargs): - # 将大写的配置键转换为小写并从config中获取实际值 - logger.debug(f"🔍 [模型初始化] 开始初始化模型: {model.get('name', 'Unknown')}") - logger.debug(f"🔍 [模型初始化] 模型配置: {model}") - logger.debug(f"🔍 [模型初始化] 额外参数: {kwargs}") - - try: - # print(f"model['provider']: {model['provider']}") - self.api_key = os.environ[f"{model['provider']}_KEY"] - self.base_url = os.environ[f"{model['provider']}_BASE_URL"] - logger.debug(f"🔍 [模型初始化] 成功获取环境变量: {model['provider']}_KEY 和 {model['provider']}_BASE_URL") - except AttributeError as e: - logger.error(f"原始 model dict 信息:{model}") - logger.error(f"配置错误:找不到对应的配置项 - {str(e)}") - raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e - except KeyError: - logger.warning( - f"找不到{model['provider']}_KEY或{model['provider']}_BASE_URL环境变量,请检查配置文件或环境变量设置。" - ) - self.model_name: str = model["name"] - self.params = kwargs + def __init__(self, model_set: TaskConfig, request_type: str = "") -> None: + self.task_name = request_type + self.model_for_task = model_set + self.request_type = request_type + self.model_usage: Dict[str, Tuple[int, int]] = {model: (0, 0) for model in self.model_for_task.model_list} + """模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty),惩罚值是为了能在某个模型请求不给力的时候进行调整""" - # 记录配置文件中声明了哪些参数(不管值是什么) - self.has_enable_thinking = "enable_thinking" in model - self.has_thinking_budget = "thinking_budget" in model - - self.enable_thinking = model.get("enable_thinking", False) - self.temp = model.get("temp", 0.7) - self.thinking_budget = model.get("thinking_budget", 4096) - self.stream = model.get("stream", False) - self.pri_in = model.get("pri_in", 0) - self.pri_out = model.get("pri_out", 0) - self.max_tokens = model.get("max_tokens", global_config.model.model_max_output_length) - # print(f"max_tokens: {self.max_tokens}") - - logger.debug("🔍 [模型初始化] 模型参数设置完成:") - logger.debug(f" - model_name: {self.model_name}") - logger.debug(f" - has_enable_thinking: {self.has_enable_thinking}") - logger.debug(f" - enable_thinking: {self.enable_thinking}") - logger.debug(f" - has_thinking_budget: {self.has_thinking_budget}") - logger.debug(f" - thinking_budget: {self.thinking_budget}") - logger.debug(f" - temp: {self.temp}") - logger.debug(f" - stream: {self.stream}") - logger.debug(f" - max_tokens: {self.max_tokens}") - logger.debug(f" - base_url: {self.base_url}") + self.pri_in = 0 + self.pri_out = 0 - # 获取数据库实例 - self._init_database() - - # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" - self.request_type = kwargs.pop("request_type", "default") - logger.debug(f"🔍 [模型初始化] 初始化完成,request_type: {self.request_type}") - - @staticmethod - def _init_database(): - """初始化数据库集合""" - try: - # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 - db.create_tables([LLMUsage], safe=True) - # logger.debug("LLMUsage 表已初始化/确保存在。") - except Exception as e: - logger.error(f"创建 LLMUsage 表失败: {str(e)}") - - def _record_usage( + async def generate_response_for_image( self, - prompt_tokens: int, - completion_tokens: int, - total_tokens: int, - user_id: str = "system", - request_type: str = None, - endpoint: str = "/chat/completions", - ): - """记录模型使用情况到数据库 - Args: - prompt_tokens: 输入token数 - completion_tokens: 输出token数 - total_tokens: 总token数 - user_id: 用户ID,默认为system - request_type: 请求类型 - endpoint: API端点 + prompt: str, + image_base64: str, + image_format: str, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: """ - # 如果 request_type 为 None,则使用实例变量中的值 - if request_type is None: - request_type = self.request_type - - try: - # 使用 Peewee 模型创建记录 - LLMUsage.create( - model_name=self.model_name, - user_id=user_id, - request_type=request_type, - endpoint=endpoint, - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - cost=self._calculate_cost(prompt_tokens, completion_tokens), - status="success", - timestamp=datetime.now(), # Peewee 会处理 DateTimeField - ) - logger.debug( - f"Token使用情况 - 模型: {self.model_name}, " - f"用户: {user_id}, 类型: {request_type}, " - f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " - f"总计: {total_tokens}" - ) - except Exception as e: - logger.error(f"记录token使用情况失败: {str(e)}") - - def _calculate_cost(self, prompt_tokens: int, completion_tokens: int) -> float: - """计算API调用成本 - 使用模型的pri_in和pri_out价格计算输入和输出的成本 - + 为图像生成响应 Args: - prompt_tokens: 输入token数量 - completion_tokens: 输出token数量 - + prompt (str): 提示词 + image_base64 (str): 图像的Base64编码字符串 + image_format (str): 图像格式(如 'png', 'jpeg' 等) Returns: - float: 总成本(元) + (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 """ - # 使用模型的pri_in和pri_out计算成本 - input_cost = (prompt_tokens / 1000000) * self.pri_in - output_cost = (completion_tokens / 1000000) * self.pri_out - return round(input_cost + output_cost, 6) + # 模型选择 + model_info, api_provider, client = self._select_model() - async def _prepare_request( - self, - endpoint: str, - prompt: str = None, - image_base64: str = None, - image_format: str = None, - file_bytes: bytes = None, - file_format: str = None, - payload: dict = None, - retry_policy: dict = None, - ) -> Dict[str, Any]: - """配置请求参数 + # 请求体构建 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + message_builder.add_image_content( + image_base64=image_base64, image_format=image_format, support_formats=client.get_support_image_formats() + ) + messages = [message_builder.build()] + + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + ) + content = response.content or "" + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions", + ) + return content, (reasoning_content, model_info.name, tool_calls) + + async def generate_response_for_voice(self, voice_base64: str) -> Optional[str]: + """ + 为语音生成响应 Args: - endpoint: API端点路径 (如 "chat/completions") - prompt: prompt文本 - image_base64: 图片的base64编码 - image_format: 图片格式 - file_bytes: 文件的二进制数据 - file_format: 文件格式 - payload: 请求体数据 - retry_policy: 自定义重试策略 - request_type: 请求类型 + voice_base64 (str): 语音的Base64编码字符串 + Returns: + (Optional[str]): 生成的文本描述或None """ + # 模型选择 + model_info, api_provider, client = self._select_model() - # 合并重试策略 - default_retry = { - "max_retries": 3, - "base_wait": 10, - "retry_codes": [429, 413, 500, 503], - "abort_codes": [400, 401, 402, 403], - } - policy = {**default_retry, **(retry_policy or {})} + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.AUDIO, + model_info=model_info, + audio_base64=voice_base64, + ) + return response.content or None - api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" + async def generate_response_async( + self, + prompt: str, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + tools: Optional[List[Dict[str, Any]]] = None, + ) -> Tuple[str, Tuple[str, str, Optional[List[ToolCall]]]]: + """ + 异步生成响应 + Args: + prompt (str): 提示词 + temperature (float, optional): 温度参数 + max_tokens (int, optional): 最大token数 + Returns: + (Tuple[str, str, str, Optional[List[ToolCall]]]): 响应内容、推理内容、模型名称、工具调用列表 + """ + # 请求体构建 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + messages = [message_builder.build()] + tool_built = self._build_tool_options(tools) + # 模型选择 + model_info, api_provider, client = self._select_model() - stream_mode = self.stream + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + temperature=temperature, + max_tokens=max_tokens, + tool_options=tool_built, + ) + content = response.content + reasoning_content = response.reasoning_content or "" + tool_calls = response.tool_calls + # 从内容中提取标签的推理内容(向后兼容) + if not reasoning_content and content: + content, extracted_reasoning = self._extract_reasoning(content) + reasoning_content = extracted_reasoning + if usage := response.usage: + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/chat/completions", + ) + if not content: + logger.warning("生成的响应为空") + content = "生成的响应为空,请检查模型配置或输入内容是否正确" - # 构建请求体 - if image_base64: - payload = await self._build_payload(prompt, image_base64, image_format) - elif file_bytes: - payload = await self._build_formdata_payload(file_bytes, file_format) - elif payload is None: - payload = await self._build_payload(prompt) + return content, (reasoning_content, model_info.name, tool_calls) - if not file_bytes: - if stream_mode: - payload["stream"] = stream_mode + async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]: + """获取嵌入向量 + Args: + embedding_input (str): 获取嵌入的目标 + Returns: + (Tuple[List[float], str]): (嵌入向量,使用的模型名称) + """ + # 无需构建消息体,直接使用输入文本 + model_info, api_provider, client = self._select_model() - if self.temp != 0.7: - payload["temperature"] = self.temp + # 请求并处理返回值 + response = await self._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.EMBEDDING, + model_info=model_info, + embedding_input=embedding_input, + ) - # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) - if self.has_enable_thinking: - payload["enable_thinking"] = self.enable_thinking + embedding = response.embedding - # 添加thinking_budget参数(只有配置文件中声明了才添加) - if self.has_thinking_budget: - payload["thinking_budget"] = self.thinking_budget + if usage := response.usage: + llm_usage_recorder.record_usage_to_database( + model_info=model_info, + model_usage=usage, + user_id="system", + request_type=self.request_type, + endpoint="/embeddings", + ) - if self.max_tokens: - payload["max_tokens"] = self.max_tokens + if not embedding: + raise RuntimeError("获取embedding失败") - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - payload["max_completion_tokens"] = payload.pop("max_tokens") + return embedding, model_info.name - return { - "policy": policy, - "payload": payload, - "api_url": api_url, - "stream_mode": stream_mode, - "image_base64": image_base64, # 保留必要的exception处理所需的原始数据 - "image_format": image_format, - "file_bytes": file_bytes, - "file_format": file_format, - "prompt": prompt, - } + def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]: + """ + 根据总tokens和惩罚值选择的模型 + """ + least_used_model_name = min( + self.model_usage, key=lambda k: self.model_usage[k][0] + self.model_usage[k][1] * 300 + ) + model_info = model_config.get_model_info(least_used_model_name) + api_provider = model_config.get_provider(model_info.api_provider) + client = client_registry.get_client_class(api_provider.client_type)(copy.deepcopy(api_provider)) + logger.debug(f"选择请求模型: {model_info.name}") + return model_info, api_provider, client async def _execute_request( self, - endpoint: str, - prompt: str = None, - image_base64: str = None, - image_format: str = None, - file_bytes: bytes = None, - file_format: str = None, - payload: dict = None, - retry_policy: dict = None, - response_handler: Callable = None, - user_id: str = "system", - request_type: str = None, - ): - """统一请求执行入口 - Args: - endpoint: API端点路径 (如 "chat/completions") - prompt: prompt文本 - image_base64: 图片的base64编码 - image_format: 图片格式 - file_bytes: 文件的二进制数据 - file_format: 文件格式 - payload: 请求体数据 - retry_policy: 自定义重试策略 - response_handler: 自定义响应处理器 - user_id: 用户ID - request_type: 请求类型 + api_provider: APIProvider, + client: BaseClient, + request_type: RequestType, + model_info: ModelInfo, + message_list: List[Message] | None = None, + tool_options: list[ToolOption] | None = None, + response_format: RespFormat | None = None, + stream_response_handler: Optional[Callable] = None, + async_response_parser: Optional[Callable] = None, + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, + embedding_input: str = "", + audio_base64: str = "", + ) -> APIResponse: """ - # 获取请求配置 - request_content = await self._prepare_request( - endpoint, prompt, image_base64, image_format, file_bytes, file_format, payload, retry_policy - ) - if request_type is None: - request_type = self.request_type - for retry in range(request_content["policy"]["max_retries"]): + 实际执行请求的方法 + + 包含了重试和异常处理逻辑 + """ + retry_remain = api_provider.max_retry + compressed_messages: Optional[List[Message]] = None + while retry_remain > 0: try: - # 使用上下文管理器处理会话 - if file_bytes: - headers = await self._build_headers(is_formdata=True) - else: - headers = await self._build_headers(is_formdata=False) - # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 - if request_content["stream_mode"]: - headers["Accept"] = "text/event-stream" - - # 添加请求发送前的调试信息 - logger.debug(f"🔍 [请求调试] 模型 {self.model_name} 准备发送请求") - logger.debug(f"🔍 [请求调试] API URL: {request_content['api_url']}") - logger.debug(f"🔍 [请求调试] 请求头: {await self._build_headers(no_key=True, is_formdata=file_bytes is not None)}") - - if not file_bytes: - # 安全地记录请求体(隐藏敏感信息) - safe_payload = await _safely_record(request_content, request_content["payload"]) - logger.debug(f"🔍 [请求调试] 请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") - else: - logger.debug(f"🔍 [请求调试] 文件上传请求,文件格式: {request_content['file_format']}") - - async with aiohttp.ClientSession(connector=await get_tcp_connector()) as session: - post_kwargs = {"headers": headers} - # form-data数据上传方式不同 - if file_bytes: - post_kwargs["data"] = request_content["payload"] - else: - post_kwargs["json"] = request_content["payload"] - - async with session.post(request_content["api_url"], **post_kwargs) as response: - handled_result = await self._handle_response( - response, request_content, retry, response_handler, user_id, request_type, endpoint - ) - return handled_result - - except Exception as e: - handled_payload, count_delta = await self._handle_exception(e, retry, request_content) - retry += count_delta # 降级不计入重试次数 - if handled_payload: - # 如果降级成功,重新构建请求体 - request_content["payload"] = handled_payload - continue - - logger.error(f"模型 {self.model_name} 达到最大重试次数,请求仍然失败") - raise RuntimeError(f"模型 {self.model_name} 达到最大重试次数,API请求仍然失败") - - async def _handle_response( - self, - response: ClientResponse, - request_content: Dict[str, Any], - retry_count: int, - response_handler: Callable, - user_id, - request_type, - endpoint, - ): - policy = request_content["policy"] - stream_mode = request_content["stream_mode"] - if response.status in policy["retry_codes"] or response.status in policy["abort_codes"]: - await self._handle_error_response(response, retry_count, policy) - return None - - response.raise_for_status() - result = {} - if stream_mode: - # 将流式输出转化为非流式输出 - result = await self._handle_stream_output(response) - else: - result = await response.json() - return ( - response_handler(result) - if response_handler - else self._default_response_handler(result, user_id, request_type, endpoint) - ) - - async def _handle_stream_output(self, response: ClientResponse) -> Dict[str, Any]: - flag_delta_content_finished = False - accumulated_content = "" - usage = None # 初始化usage变量,避免未定义错误 - reasoning_content = "" - content = "" - tool_calls = None # 初始化工具调用变量 - - async for line_bytes in response.content: - try: - line = line_bytes.decode("utf-8").strip() - if not line: - continue - if line.startswith("data:"): - data_str = line[5:].strip() - if data_str == "[DONE]": - break - try: - chunk = json.loads(data_str) - if flag_delta_content_finished: - chunk_usage = chunk.get("usage", None) - if chunk_usage: - usage = chunk_usage # 获取token用量 - else: - delta = chunk["choices"][0]["delta"] - delta_content = delta.get("content") - if delta_content is None: - delta_content = "" - accumulated_content += delta_content - - # 提取工具调用信息 - if "tool_calls" in delta: - if tool_calls is None: - tool_calls = delta["tool_calls"] - else: - # 合并工具调用信息 - tool_calls.extend(delta["tool_calls"]) - - # 检测流式输出文本是否结束 - finish_reason = chunk["choices"][0].get("finish_reason") - if delta.get("reasoning_content", None): - reasoning_content += delta["reasoning_content"] - if finish_reason == "stop" or finish_reason == "tool_calls": - chunk_usage = chunk.get("usage", None) - if chunk_usage: - usage = chunk_usage - break - # 部分平台在文本输出结束前不会返回token用量,此时需要再获取一次chunk - flag_delta_content_finished = True - except Exception as e: - logger.exception(f"模型 {self.model_name} 解析流式输出错误: {str(e)}") - except Exception as e: - if isinstance(e, GeneratorExit): - log_content = f"模型 {self.model_name} 流式输出被中断,正在清理资源..." - else: - log_content = f"模型 {self.model_name} 处理流式输出时发生错误: {str(e)}" - logger.warning(log_content) - # 确保资源被正确清理 - try: - await response.release() - except Exception as cleanup_error: - logger.error(f"清理资源时发生错误: {cleanup_error}") - # 返回已经累积的内容 - content = accumulated_content - if not content: - content = accumulated_content - think_match = re.search(r"(.*?)", content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() - - # 构建消息对象 - message = { - "content": content, - "reasoning_content": reasoning_content, - } - - # 如果有工具调用,添加到消息中 - if tool_calls: - message["tool_calls"] = tool_calls - - result = { - "choices": [{"message": message}], - "usage": usage, - } - return result - - async def _handle_error_response(self, response: ClientResponse, retry_count: int, policy: Dict[str, Any]): - if response.status in policy["retry_codes"]: - wait_time = policy["base_wait"] * (2**retry_count) - logger.warning(f"模型 {self.model_name} 错误码: {response.status}, 等待 {wait_time}秒后重试") - if response.status == 413: - logger.warning("请求体过大,尝试压缩...") - raise PayLoadTooLargeError("请求体过大") - elif response.status in [500, 503]: - logger.error( - f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" - ) - raise RuntimeError("服务器负载过高,模型回复失败QAQ") - else: - logger.warning(f"模型 {self.model_name} 请求限制(429),等待{wait_time}秒后重试...") - raise RuntimeError("请求限制(429)") - elif response.status in policy["abort_codes"]: - # 特别处理400错误,添加详细调试信息 - if response.status == 400: - logger.error(f"🔍 [调试信息] 模型 {self.model_name} 参数错误 (400) - 开始详细诊断") - logger.error(f"🔍 [调试信息] 模型名称: {self.model_name}") - logger.error(f"🔍 [调试信息] API地址: {self.base_url}") - logger.error("🔍 [调试信息] 模型配置参数:") - logger.error(f" - enable_thinking: {self.enable_thinking}") - logger.error(f" - temp: {self.temp}") - logger.error(f" - thinking_budget: {self.thinking_budget}") - logger.error(f" - stream: {self.stream}") - logger.error(f" - max_tokens: {self.max_tokens}") - logger.error(f" - pri_in: {self.pri_in}") - logger.error(f" - pri_out: {self.pri_out}") - logger.error(f"🔍 [调试信息] 原始params: {self.params}") - - # 尝试获取服务器返回的详细错误信息 - try: - error_text = await response.text() - logger.error(f"🔍 [调试信息] 服务器返回的原始错误内容: {error_text}") - - try: - error_json = json.loads(error_text) - logger.error(f"🔍 [调试信息] 解析后的错误JSON: {json.dumps(error_json, indent=2, ensure_ascii=False)}") - except json.JSONDecodeError: - logger.error("🔍 [调试信息] 错误响应不是有效的JSON格式") - except Exception as e: - logger.error(f"🔍 [调试信息] 无法读取错误响应内容: {str(e)}") - - raise RequestAbortException("参数错误,请检查调试信息", response) - elif response.status != 403: - raise RequestAbortException("请求出现错误,中断处理", response) - else: - raise PermissionDeniedException("模型禁止访问") - - async def _handle_exception( - self, exception, retry_count: int, request_content: Dict[str, Any] - ) -> Union[Tuple[Dict[str, Any], int], Tuple[None, int]]: - policy = request_content["policy"] - payload = request_content["payload"] - wait_time = policy["base_wait"] * (2**retry_count) - keep_request = False - if retry_count < policy["max_retries"] - 1: - keep_request = True - if isinstance(exception, RequestAbortException): - response = exception.response - logger.error( - f"模型 {self.model_name} 错误码: {response.status} - {error_code_mapping.get(response.status)}" - ) - - # 如果是400错误,额外输出请求体信息用于调试 - if response.status == 400: - logger.error("🔍 [异常调试] 400错误 - 请求体调试信息:") - try: - safe_payload = await _safely_record(request_content, payload) - logger.error(f"🔍 [异常调试] 发送的请求体: {json.dumps(safe_payload, indent=2, ensure_ascii=False)}") - except Exception as debug_error: - logger.error(f"🔍 [异常调试] 无法安全记录请求体: {str(debug_error)}") - logger.error(f"🔍 [异常调试] 原始payload类型: {type(payload)}") - if isinstance(payload, dict): - logger.error(f"🔍 [异常调试] 原始payload键: {list(payload.keys())}") - - # print(request_content) - # print(response) - # 尝试获取并记录服务器返回的详细错误信息 - try: - error_json = await response.json() - if error_json and isinstance(error_json, list) and len(error_json) > 0: - # 处理多个错误的情况 - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj: dict = error_item["error"] - error_code = error_obj.get("code") - error_message = error_obj.get("message") - error_status = error_obj.get("status") - logger.error( - f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - # 处理单个错误对象的情况 - error_obj = error_json.get("error", {}) - error_code = error_obj.get("code") - error_message = error_obj.get("message") - error_status = error_obj.get("status") - logger.error(f"服务器错误详情: 代码={error_code}, 状态={error_status}, 消息={error_message}") - else: - # 记录原始错误响应内容 - logger.error(f"服务器错误响应: {error_json}") - except Exception as e: - logger.warning(f"无法解析服务器错误响应: {str(e)}") - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") - - elif isinstance(exception, PermissionDeniedException): - # 只针对硅基流动的V3和R1进行降级处理 - if self.model_name.startswith("Pro/deepseek-ai") and self.base_url == "https://api.siliconflow.cn/v1/": - old_model_name = self.model_name - self.model_name = self.model_name[4:] # 移除"Pro/"前缀 - logger.warning(f"检测到403错误,模型从 {old_model_name} 降级为 {self.model_name}") - - # 对全局配置进行更新 - if global_config.model.replyer_2.get("name") == old_model_name: - global_config.model.replyer_2["name"] = self.model_name - logger.warning(f"将全局配置中的 llm_normal 模型临时降级至{self.model_name}") - if global_config.model.replyer_1.get("name") == old_model_name: - global_config.model.replyer_1["name"] = self.model_name - logger.warning(f"将全局配置中的 llm_reasoning 模型临时降级至{self.model_name}") - - if payload and "model" in payload: - payload["model"] = self.model_name - - await asyncio.sleep(wait_time) - return payload, -1 - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(403)}") - - elif isinstance(exception, PayLoadTooLargeError): - if keep_request: - image_base64 = request_content["image_base64"] - compressed_image_base64 = compress_base64_image_by_scale(image_base64) - new_payload = await self._build_payload( - request_content["prompt"], compressed_image_base64, request_content["image_format"] - ) - return new_payload, 0 - else: - return None, 0 - - elif isinstance(exception, aiohttp.ClientError) or isinstance(exception, asyncio.TimeoutError): - if keep_request: - logger.error(f"模型 {self.model_name} 网络错误,等待{wait_time}秒后重试... 错误: {str(exception)}") - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical(f"模型 {self.model_name} 网络错误达到最大重试次数: {str(exception)}") - raise RuntimeError(f"网络请求失败: {str(exception)}") - - elif isinstance(exception, aiohttp.ClientResponseError): - # 处理aiohttp抛出的,除了policy中的status的响应错误 - if keep_request: - logger.error( - f"模型 {self.model_name} HTTP响应错误,等待{wait_time}秒后重试... 状态码: {exception.status}, 错误: {exception.message}" - ) - try: - error_text = await exception.response.text() - error_json = json.loads(error_text) - if isinstance(error_json, list) and len(error_json) > 0: - # 处理多个错误的情况 - for error_item in error_json: - if "error" in error_item and isinstance(error_item["error"], dict): - error_obj = error_item["error"] - logger.error( - f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, " - f"状态={error_obj.get('status')}, " - f"消息={error_obj.get('message')}" - ) - elif isinstance(error_json, dict) and "error" in error_json: - error_obj = error_json.get("error", {}) - logger.error( - f"模型 {self.model_name} 服务器错误详情: 代码={error_obj.get('code')}, " - f"状态={error_obj.get('status')}, " - f"消息={error_obj.get('message')}" - ) - else: - logger.error(f"模型 {self.model_name} 服务器错误响应: {error_json}") - except (json.JSONDecodeError, TypeError) as json_err: - logger.warning( - f"模型 {self.model_name} 响应不是有效的JSON: {str(json_err)}, 原始内容: {error_text[:200]}" + if request_type == RequestType.RESPONSE: + assert message_list is not None, "message_list cannot be None for response requests" + return await client.get_response( + model_info=model_info, + message_list=(compressed_messages or message_list), + tool_options=tool_options, + max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens, + temperature=self.model_for_task.temperature if temperature is None else temperature, + response_format=response_format, + stream_response_handler=stream_response_handler, + async_response_parser=async_response_parser, + extra_params=model_info.extra_params, ) - except Exception as parse_err: - logger.warning(f"模型 {self.model_name} 无法解析响应错误内容: {str(parse_err)}") + elif request_type == RequestType.EMBEDDING: + assert embedding_input, "embedding_input cannot be empty for embedding requests" + return await client.get_embedding( + model_info=model_info, + embedding_input=embedding_input, + extra_params=model_info.extra_params, + ) + elif request_type == RequestType.AUDIO: + assert audio_base64 is not None, "audio_base64 cannot be None for audio requests" + return await client.get_audio_transcriptions( + model_info=model_info, + audio_base64=audio_base64, + extra_params=model_info.extra_params, + ) + except Exception as e: + logger.debug(f"请求失败: {str(e)}") + # 处理异常 + total_tokens, penalty = self.model_usage[model_info.name] + self.model_usage[model_info.name] = (total_tokens, penalty + 1) - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical( - f"模型 {self.model_name} HTTP响应错误达到最大重试次数: 状态码: {exception.status}, 错误: {exception.message}" - ) - # 安全地检查和记录请求详情 - handled_payload = await _safely_record(request_content, payload) - logger.critical( - f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" - ) - raise RuntimeError( - f"模型 {self.model_name} API请求失败: 状态码 {exception.status}, {exception.message}" + wait_interval, compressed_messages = self._default_exception_handler( + e, + self.task_name, + model_name=model_info.name, + remain_try=retry_remain, + retry_interval=api_provider.retry_interval, + messages=(message_list, compressed_messages is not None) if message_list else None, ) - else: - if keep_request: - logger.error(f"模型 {self.model_name} 请求失败,等待{wait_time}秒后重试... 错误: {str(exception)}") - await asyncio.sleep(wait_time) - return None, 0 - else: - logger.critical(f"模型 {self.model_name} 请求失败: {str(exception)}") - # 安全地检查和记录请求详情 - handled_payload = await _safely_record(request_content, payload) - logger.critical( - f"请求头: {await self._build_headers(no_key=True)} 请求体: {str(handled_payload)[:100]}" - ) - raise RuntimeError(f"模型 {self.model_name} API请求失败: {str(exception)}") + if wait_interval == -1: + retry_remain = 0 # 不再重试 + elif wait_interval > 0: + logger.info(f"等待 {wait_interval} 秒后重试...") + await asyncio.sleep(wait_interval) + finally: + # 放在finally防止死循环 + retry_remain -= 1 + logger.error(f"模型 '{model_info.name}' 请求失败,达到最大重试次数 {api_provider.max_retry} 次") + raise RuntimeError("请求失败,已达到最大重试次数") - async def _transform_parameters(self, params: dict) -> dict: + def _default_exception_handler( + self, + e: Exception, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: Tuple[List[Message], bool] | None = None, + ) -> Tuple[int, List[Message] | None]: """ - 根据模型名称转换参数: - - 对于需要转换的OpenAI CoT系列模型(例如 "o3-mini"),删除 'temperature' 参数, - 并将 'max_tokens' 重命名为 'max_completion_tokens' + 默认异常处理函数 + Args: + e (Exception): 异常对象 + task_name (str): 任务名称 + model_name (str): 模型名称 + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) + Returns: + (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) """ - # 复制一份参数,避免直接修改原始数据 - new_params = dict(params) - - logger.debug(f"🔍 [参数转换] 模型 {self.model_name} 开始参数转换") - logger.debug(f"🔍 [参数转换] 是否为CoT模型: {self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION}") - logger.debug(f"🔍 [参数转换] CoT模型列表: {self.MODELS_NEEDING_TRANSFORMATION}") - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION: - logger.debug("🔍 [参数转换] 检测到CoT模型,开始参数转换") - # 删除 'temperature' 参数(如果存在),但避免删除我们在_build_payload中添加的自定义温度 - if "temperature" in new_params and new_params["temperature"] == 0.7: - removed_temp = new_params.pop("temperature") - logger.debug(f"🔍 [参数转换] 移除默认temperature参数: {removed_temp}") - # 如果存在 'max_tokens',则重命名为 'max_completion_tokens' - if "max_tokens" in new_params: - old_value = new_params["max_tokens"] - new_params["max_completion_tokens"] = new_params.pop("max_tokens") - logger.debug(f"🔍 [参数转换] 参数重命名: max_tokens({old_value}) -> max_completion_tokens({new_params['max_completion_tokens']})") + if isinstance(e, NetworkConnectionError): # 网络连接错误 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 连接异常,超过最大重试次数,请检查网络连接状态或URL是否正确", + ) + elif isinstance(e, ReqAbortException): + logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求被中断,详细信息-{str(e.message)}") + return -1, None # 不再重试请求该模型 + elif isinstance(e, RespNotOkException): + return self._handle_resp_not_ok( + e, + task_name, + model_name, + remain_try, + retry_interval, + messages, + ) + elif isinstance(e, RespParseException): + # 响应解析错误 + logger.error(f"任务-'{task_name}' 模型-'{model_name}': 响应解析错误,错误信息-{e.message}") + logger.debug(f"附加内容: {str(e.ext_info)}") + return -1, None # 不再重试请求该模型 else: - logger.debug("🔍 [参数转换] 非CoT模型,无需参数转换") - - logger.debug(f"🔍 [参数转换] 转换前参数: {params}") - logger.debug(f"🔍 [参数转换] 转换后参数: {new_params}") - return new_params + logger.error(f"任务-'{task_name}' 模型-'{model_name}': 未知异常,错误信息-{str(e)}") + return -1, None # 不再重试请求该模型 - async def _build_formdata_payload(self, file_bytes: bytes, file_format: str) -> aiohttp.FormData: - """构建form-data请求体""" - # 目前只适配了音频文件 - # 如果后续要支持其他类型的文件,可以在这里添加更多的处理逻辑 - data = aiohttp.FormData() - content_type_list = { - "wav": "audio/wav", - "mp3": "audio/mpeg", - "ogg": "audio/ogg", - "flac": "audio/flac", - "aac": "audio/aac", - } + def _check_retry( + self, + remain_try: int, + retry_interval: int, + can_retry_msg: str, + cannot_retry_msg: str, + can_retry_callable: Callable | None = None, + **kwargs, + ) -> Tuple[int, List[Message] | None]: + """辅助函数:检查是否可以重试 + Args: + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + can_retry_msg (str): 可以重试时的提示信息 + cannot_retry_msg (str): 不可以重试时的提示信息 + can_retry_callable (Callable | None): 可以重试时调用的函数(如果有) + **kwargs: 其他参数 - content_type = content_type_list.get(file_format) - if not content_type: - logger.warning(f"暂不支持的文件类型: {file_format}") - - data.add_field( - "file", - io.BytesIO(file_bytes), - filename=f"file.{file_format}", - content_type=f"{content_type}", # 根据实际文件类型设置 - ) - data.add_field("model", self.model_name) - return data - - async def _build_payload(self, prompt: str, image_base64: str = None, image_format: str = None) -> dict: - """构建请求体""" - # 复制一份参数,避免直接修改 self.params - logger.debug(f"🔍 [参数构建] 模型 {self.model_name} 开始构建请求体") - logger.debug(f"🔍 [参数构建] 原始self.params: {self.params}") - - params_copy = await self._transform_parameters(self.params) - logger.debug(f"🔍 [参数构建] 转换后的params_copy: {params_copy}") - - if image_base64: - messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": prompt}, - { - "type": "image_url", - "image_url": {"url": f"data:image/{image_format.lower()};base64,{image_base64}"}, - }, - ], - } - ] - else: - messages = [{"role": "user", "content": prompt}] - - payload = { - "model": self.model_name, - "messages": messages, - **params_copy, - } - - logger.debug(f"🔍 [参数构建] 基础payload构建完成: {list(payload.keys())}") - - # 添加temp参数(如果不是默认值0.7) - if self.temp != 0.7: - payload["temperature"] = self.temp - logger.debug(f"🔍 [参数构建] 添加temperature参数: {self.temp}") - - # 添加enable_thinking参数(只有配置文件中声明了才添加,不管值是true还是false) - if self.has_enable_thinking: - payload["enable_thinking"] = self.enable_thinking - logger.debug(f"🔍 [参数构建] 添加enable_thinking参数: {self.enable_thinking}") - - # 添加thinking_budget参数(只有配置文件中声明了才添加) - if self.has_thinking_budget: - payload["thinking_budget"] = self.thinking_budget - logger.debug(f"🔍 [参数构建] 添加thinking_budget参数: {self.thinking_budget}") - - if self.max_tokens: - payload["max_tokens"] = self.max_tokens - logger.debug(f"🔍 [参数构建] 添加max_tokens参数: {self.max_tokens}") - - # if "max_tokens" not in payload and "max_completion_tokens" not in payload: - # payload["max_tokens"] = global_config.model.model_max_output_length - # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 - if self.model_name.lower() in self.MODELS_NEEDING_TRANSFORMATION and "max_tokens" in payload: - old_value = payload["max_tokens"] - payload["max_completion_tokens"] = payload.pop("max_tokens") - logger.debug(f"🔍 [参数构建] CoT模型参数转换: max_tokens({old_value}) -> max_completion_tokens({payload['max_completion_tokens']})") - - logger.debug(f"🔍 [参数构建] 最终payload键列表: {list(payload.keys())}") - return payload - - def _default_response_handler( - self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions" - ) -> Tuple: - """默认响应解析""" - if "choices" in result and result["choices"]: - message = result["choices"][0]["message"] - content = message.get("content", "") - content, reasoning = self._extract_reasoning(content) - reasoning_content = message.get("model_extra", {}).get("reasoning_content", "") - if not reasoning_content: - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - reasoning_content = reasoning - - # 提取工具调用信息 - tool_calls = message.get("tool_calls", None) - - # 记录token使用情况 - usage = result.get("usage", {}) - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) - self._record_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - user_id=user_id, - request_type=request_type if request_type is not None else self.request_type, - endpoint=endpoint, - ) - - # 只有当tool_calls存在且不为空时才返回 - if tool_calls: - logger.debug(f"检测到工具调用: {tool_calls}") - return content, reasoning_content, tool_calls + Returns: + (Tuple[int, List[Message] | None]): (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + if remain_try > 0: + # 还有重试机会 + logger.warning(f"{can_retry_msg}") + if can_retry_callable is not None: + return retry_interval, can_retry_callable(**kwargs) else: - return content, reasoning_content - elif "text" in result and result["text"]: - return result["text"] - return "没有返回结果", "" + return retry_interval, None + else: + # 达到最大重试次数 + logger.warning(f"{cannot_retry_msg}") + return -1, None # 不再重试请求该模型 + + def _handle_resp_not_ok( + self, + e: RespNotOkException, + task_name: str, + model_name: str, + remain_try: int, + retry_interval: int = 10, + messages: tuple[list[Message], bool] | None = None, + ): + """ + 处理响应错误异常 + Args: + e (RespNotOkException): 响应错误异常对象 + task_name (str): 任务名称 + model_name (str): 模型名称 + remain_try (int): 剩余尝试次数 + retry_interval (int): 重试间隔 + messages (tuple[list[Message], bool] | None): (消息列表, 是否已压缩过) + Returns: + (等待间隔(如果为0则不等待,为-1则不再请求该模型), 新的消息列表(适用于压缩消息)) + """ + # 响应错误 + if e.status_code in [400, 401, 402, 403, 404]: + # 客户端错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None # 不再重试请求该模型 + elif e.status_code == 413: + if messages and not messages[1]: + # 消息列表不为空且未压缩,尝试压缩消息 + return self._check_retry( + remain_try, + 0, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,尝试压缩消息后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,压缩消息后仍然过大,放弃请求", + can_retry_callable=compress_messages, + messages=messages[0], + ) + # 没有消息可压缩 + logger.warning(f"任务-'{task_name}' 模型-'{model_name}': 请求体过大,无法压缩消息,放弃请求。") + return -1, None + elif e.status_code == 429: + # 请求过于频繁 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 请求过于频繁,超过最大重试次数,放弃请求", + ) + elif e.status_code >= 500: + # 服务器错误 + return self._check_retry( + remain_try, + retry_interval, + can_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,将于{retry_interval}秒后重试", + cannot_retry_msg=f"任务-'{task_name}' 模型-'{model_name}': 服务器错误,超过最大重试次数,请稍后再试", + ) + else: + # 未知错误 + logger.warning( + f"任务-'{task_name}' 模型-'{model_name}': 未知错误,错误代码-{e.status_code},错误信息-{e.message}" + ) + return -1, None + + def _build_tool_options(self, tools: Optional[List[Dict[str, Any]]]) -> Optional[List[ToolOption]]: + # sourcery skip: extract-method + """构建工具选项列表""" + if not tools: + return None + tool_options: List[ToolOption] = [] + for tool in tools: + tool_legal = True + tool_options_builder = ToolOptionBuilder() + tool_options_builder.set_name(tool.get("name", "")) + tool_options_builder.set_description(tool.get("description", "")) + parameters: List[Tuple[str, str, str, bool, List[str] | None]] = tool.get("parameters", []) + for param in parameters: + try: + assert isinstance(param, tuple) and len(param) == 5, "参数必须是包含5个元素的元组" + assert isinstance(param[0], str), "参数名称必须是字符串" + assert isinstance(param[1], ToolParamType), "参数类型必须是ToolParamType枚举" + assert isinstance(param[2], str), "参数描述必须是字符串" + assert isinstance(param[3], bool), "参数是否必填必须是布尔值" + assert isinstance(param[4], list) or param[4] is None, "参数枚举值必须是列表或None" + tool_options_builder.add_param( + name=param[0], + param_type=param[1], + description=param[2], + required=param[3], + enum_values=param[4], + ) + except AssertionError as ae: + tool_legal = False + logger.error(f"{param[0]} 参数定义错误: {str(ae)}") + except Exception as e: + tool_legal = False + logger.error(f"构建工具参数失败: {str(e)}") + if tool_legal: + tool_options.append(tool_options_builder.build()) + return tool_options or None @staticmethod def _extract_reasoning(content: str) -> Tuple[str, str]: - """CoT思维链提取""" + """CoT思维链提取,向后兼容""" match = re.search(r"(?:)?(.*?)", content, re.DOTALL) content = re.sub(r"(?:)?.*?", "", content, flags=re.DOTALL, count=1).strip() - if match: - reasoning = match.group(1).strip() - else: - reasoning = "" + reasoning = match[1].strip() if match else "" return content, reasoning - - async def _build_headers(self, no_key: bool = False, is_formdata: bool = False) -> dict: - """构建请求头""" - if no_key: - if is_formdata: - return {"Authorization": "Bearer **********"} - return {"Authorization": "Bearer **********", "Content-Type": "application/json"} - else: - if is_formdata: - return {"Authorization": f"Bearer {self.api_key}"} - return {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"} - # 防止小朋友们截图自己的key - - async def generate_response_for_image(self, prompt: str, image_base64: str, image_format: str) -> Tuple: - """根据输入的提示和图片生成模型的异步响应""" - - response = await self._execute_request( - endpoint="/chat/completions", prompt=prompt, image_base64=image_base64, image_format=image_format - ) - # 根据返回值的长度决定怎么处理 - if len(response) == 3: - content, reasoning_content, tool_calls = response - return content, reasoning_content, tool_calls - else: - content, reasoning_content = response - return content, reasoning_content - - async def generate_response_for_voice(self, voice_bytes: bytes) -> Tuple: - """根据输入的语音文件生成模型的异步响应""" - response = await self._execute_request( - endpoint="/audio/transcriptions", file_bytes=voice_bytes, file_format="wav" - ) - return response - - async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple]: - """异步方式根据输入的提示生成模型的响应""" - # 构建请求体,不硬编码max_tokens - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params, - **kwargs, - } - - response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) - # 原样返回响应,不做处理 - - if len(response) == 3: - content, reasoning_content, tool_calls = response - return content, (reasoning_content, self.model_name, tool_calls) - else: - content, reasoning_content = response - return content, (reasoning_content, self.model_name) - - async def get_embedding(self, text: str) -> Union[list, None]: - """异步方法:获取文本的embedding向量 - - Args: - text: 需要获取embedding的文本 - - Returns: - list: embedding向量,如果失败则返回None - """ - - if len(text) < 1: - logger.debug("该消息没有长度,不再发送获取embedding向量的请求") - return None - - def embedding_handler(result): - """处理响应""" - if "data" in result and len(result["data"]) > 0: - # 提取 token 使用信息 - usage = result.get("usage", {}) - if usage: - prompt_tokens = usage.get("prompt_tokens", 0) - completion_tokens = usage.get("completion_tokens", 0) - total_tokens = usage.get("total_tokens", 0) - # 记录 token 使用情况 - self._record_usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - user_id="system", # 可以根据需要修改 user_id - # request_type="embedding", # 请求类型为 embedding - request_type=self.request_type, # 请求类型为 text - endpoint="/embeddings", # API 端点 - ) - return result["data"][0].get("embedding", None) - return result["data"][0].get("embedding", None) - return None - - embedding = await self._execute_request( - endpoint="/embeddings", - prompt=text, - payload={"model": self.model_name, "input": text, "encoding_format": "float"}, - retry_policy={"max_retries": 2, "base_wait": 6}, - response_handler=embedding_handler, - ) - return embedding - - -def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: - """压缩base64格式的图片到指定大小 - Args: - base64_data: base64编码的图片数据 - target_size: 目标文件大小(字节),默认0.8MB - Returns: - str: 压缩后的base64图片数据 - """ - try: - # 将base64转换为字节数据 - # 确保base64字符串只包含ASCII字符 - if isinstance(base64_data, str): - base64_data = base64_data.encode("ascii", errors="ignore").decode("ascii") - image_data = base64.b64decode(base64_data) - - # 如果已经小于目标大小,直接返回原图 - if len(image_data) <= 2 * 1024 * 1024: - return base64_data - - # 将字节数据转换为图片对象 - img = Image.open(io.BytesIO(image_data)) - - # 获取原始尺寸 - original_width, original_height = img.size - - # 计算缩放比例 - scale = min(1.0, (target_size / len(image_data)) ** 0.5) - - # 计算新的尺寸 - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 创建内存缓冲区 - output_buffer = io.BytesIO() - - # 如果是GIF,处理所有帧 - if getattr(img, "is_animated", False): - frames = [] - for frame_idx in range(img.n_frames): - img.seek(frame_idx) - new_frame = img.copy() - new_frame = new_frame.resize((new_width // 2, new_height // 2), Image.Resampling.LANCZOS) # 动图折上折 - frames.append(new_frame) - - # 保存到缓冲区 - frames[0].save( - output_buffer, - format="GIF", - save_all=True, - append_images=frames[1:], - optimize=True, - duration=img.info.get("duration", 100), - loop=img.info.get("loop", 0), - ) - else: - # 处理静态图片 - resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - # 保存到缓冲区,保持原始格式 - if img.format == "PNG" and img.mode in ("RGBA", "LA"): - resized_img.save(output_buffer, format="PNG", optimize=True) - else: - resized_img.save(output_buffer, format="JPEG", quality=95, optimize=True) - - # 获取压缩后的数据并转换为base64 - compressed_data = output_buffer.getvalue() - logger.info(f"压缩图片: {original_width}x{original_height} -> {new_width}x{new_height}") - logger.info(f"压缩前大小: {len(image_data) / 1024:.1f}KB, 压缩后大小: {len(compressed_data) / 1024:.1f}KB") - - return base64.b64encode(compressed_data).decode("utf-8") - - except Exception as e: - logger.error(f"压缩图片失败: {str(e)}") - import traceback - - logger.error(traceback.format_exc()) - return base64_data diff --git a/src/mais4u/mai_think.py b/src/mais4u/mai_think.py index 867ba8be..5a1f5808 100644 --- a/src/mais4u/mai_think.py +++ b/src/mais4u/mai_think.py @@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager import time from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import model_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.internal_manager import internal_manager from src.common.logger import get_logger + logger = get_logger(__name__) + def init_prompt(): Prompt( """ @@ -32,10 +34,8 @@ def init_prompt(): ) - - class MaiThinking: - def __init__(self,chat_id): + def __init__(self, chat_id): self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.platform = self.chat_stream.platform @@ -44,11 +44,11 @@ class MaiThinking: self.is_group = True else: self.is_group = False - + self.s4u_message_processor = S4UMessageProcessor() - + self.mind = "" - + self.memory_block = "" self.relation_info_block = "" self.time_block = "" @@ -59,17 +59,13 @@ class MaiThinking: self.identity = "" self.sender = "" self.target = "" - - self.thinking_model = LLMRequest( - model=global_config.model.replyer_1, - request_type="thinking", - ) + + self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer_1, request_type="thinking") async def do_think_before_response(self): pass - async def do_think_after_response(self,reponse:str): - + async def do_think_after_response(self, reponse: str): prompt = await global_prompt_manager.format_prompt( "after_response_think_prompt", mind=self.mind, @@ -85,47 +81,44 @@ class MaiThinking: sender=self.sender, target=self.target, ) - + result, _ = await self.thinking_model.generate_response_async(prompt) self.mind = result - + logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}") # logger.info(f"[{self.chat_id}] 思考前prompt:{prompt}") logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}") - - + msg_recv = await self.build_internal_message_recv(self.mind) await self.s4u_message_processor.process_message(msg_recv) internal_manager.set_internal_state(self.mind) - - + async def do_think_when_receive_message(self): pass - - async def build_internal_message_recv(self,message_text:str): - + + async def build_internal_message_recv(self, message_text: str): msg_id = f"internal_{time.time()}" - + message_dict = { "message_info": { "message_id": msg_id, "time": time.time(), "user_info": { - "user_id": "internal", # 内部用户ID - "user_nickname": "内心", # 内部昵称 - "platform": self.platform, # 平台标记为 internal + "user_id": "internal", # 内部用户ID + "user_nickname": "内心", # 内部昵称 + "platform": self.platform, # 平台标记为 internal # 其他 user_info 字段按需补充 }, - "platform": self.platform, # 平台 + "platform": self.platform, # 平台 # 其他 message_info 字段按需补充 }, "message_segment": { - "type": "text", # 消息类型 - "data": message_text, # 消息内容 + "type": "text", # 消息类型 + "data": message_text, # 消息内容 # 其他 segment 字段按需补充 }, - "raw_message": message_text, # 原始消息内容 - "processed_plain_text": message_text, # 处理后的纯文本 + "raw_message": message_text, # 原始消息内容 + "processed_plain_text": message_text, # 处理后的纯文本 # 下面这些字段可选,根据 MessageRecv 需要 "is_emoji": False, "has_emoji": False, @@ -139,45 +132,36 @@ class MaiThinking: "priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级 "interest_value": 1.0, } - + if self.is_group: message_dict["message_info"]["group_info"] = { "platform": self.platform, "group_id": self.chat_stream.group_info.group_id, "group_name": self.chat_stream.group_info.group_name, } - + msg_recv = MessageRecvS4U(message_dict) msg_recv.chat_info = self.chat_info msg_recv.chat_stream = self.chat_stream msg_recv.is_internal = True - + return msg_recv - - - + class MaiThinkingManager: def __init__(self): self.mai_think_list = [] - - def get_mai_think(self,chat_id): + + def get_mai_think(self, chat_id): for mai_think in self.mai_think_list: if mai_think.chat_id == chat_id: return mai_think mai_think = MaiThinking(chat_id) self.mai_think_list.append(mai_think) return mai_think - + + mai_thinking_manager = MaiThinkingManager() - + init_prompt() - - - - - - - - diff --git a/src/mais4u/mais4u_chat/body_emotion_action_manager.py b/src/mais4u/mais4u_chat/body_emotion_action_manager.py index e7380822..8e05a025 100644 --- a/src/mais4u/mais4u_chat/body_emotion_action_manager.py +++ b/src/mais4u/mais4u_chat/body_emotion_action_manager.py @@ -1,14 +1,16 @@ import json import time + +from json_repair import repair_json from src.chat.message_receive.message import MessageRecv from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api -from json_repair import repair_json + from src.mais4u.s4u_config import s4u_config logger = get_logger("action") @@ -32,7 +34,7 @@ BODY_CODE = { "帅气的姿势": "010_0190", "另一个帅气的姿势": "010_0191", "手掌朝前可爱": "010_0210", - "平静,双手后放":"平静,双手后放", + "平静,双手后放": "平静,双手后放", "思考": "思考", "优雅,左手放在腰上": "优雅,左手放在腰上", "一般": "一般", @@ -94,19 +96,15 @@ class ChatAction: self.body_action_cooldown: dict[str, int] = {} print(s4u_config.models.motion) - print(global_config.model.emotion) - - self.action_model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="motion", - ) + print(model_config.model_task_config.emotion) - self.last_change_time = 0 + self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion") + + self.last_change_time: float = 0 async def send_action_update(self): """发送动作更新到前端""" - + body_code = BODY_CODE.get(self.body_action, "") await send_api.custom_to_stream( message_type="body_action", @@ -115,13 +113,11 @@ class ChatAction: storage_message=False, show_log=True, ) - - async def update_action_by_message(self, message: MessageRecv): self.regression_count = 0 - message_time = message.message_info.time + message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, @@ -147,13 +143,13 @@ class ChatAction: prompt_personality = global_config.personality.personality_core indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" - + try: # 冷却池处理:过滤掉冷却中的动作 self._update_body_action_cooldown() available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] all_actions = "\n".join(available_actions) - + prompt = await global_prompt_manager.format_prompt( "change_action_prompt", chat_talking_prompt=chat_talking_prompt, @@ -163,19 +159,18 @@ class ChatAction: ) logger.info(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.action_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"response: {response}") logger.info(f"reasoning_content: {reasoning_content}") - action_data = json.loads(repair_json(response)) - - if action_data: + if action_data := json.loads(repair_json(response)): # 记录原动作,切换后进入冷却 prev_body_action = self.body_action new_body_action = action_data.get("body_action", self.body_action) - if new_body_action != prev_body_action: - if prev_body_action: - self.body_action_cooldown[prev_body_action] = 3 + if new_body_action != prev_body_action and prev_body_action: + self.body_action_cooldown[prev_body_action] = 3 self.body_action = new_body_action self.head_action = action_data.get("head_action", self.head_action) # 发送动作更新 @@ -213,7 +208,6 @@ class ChatAction: prompt_personality = global_config.personality.personality_core indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:" try: - # 冷却池处理:过滤掉冷却中的动作 self._update_body_action_cooldown() available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] @@ -228,17 +222,17 @@ class ChatAction: ) logger.info(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.action_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"response: {response}") logger.info(f"reasoning_content: {reasoning_content}") - action_data = json.loads(repair_json(response)) - if action_data: + if action_data := json.loads(repair_json(response)): prev_body_action = self.body_action new_body_action = action_data.get("body_action", self.body_action) - if new_body_action != prev_body_action: - if prev_body_action: - self.body_action_cooldown[prev_body_action] = 6 + if new_body_action != prev_body_action and prev_body_action: + self.body_action_cooldown[prev_body_action] = 6 self.body_action = new_body_action # 发送动作更新 await self.send_action_update() @@ -306,9 +300,6 @@ class ActionManager: return new_action_state - - - init_prompt() action_manager = ActionManager() diff --git a/src/mais4u/mais4u_chat/s4u_chat.py b/src/mais4u/mais4u_chat/s4u_chat.py index e447ae19..78df5e98 100644 --- a/src/mais4u/mais4u_chat/s4u_chat.py +++ b/src/mais4u/mais4u_chat/s4u_chat.py @@ -137,7 +137,7 @@ class MessageSenderContainer: await self.storage.store_message(bot_message, self.chat_stream) except Exception as e: - logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) + logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True) finally: # CRUCIAL: Always call task_done() for any item that was successfully retrieved. diff --git a/src/mais4u/mais4u_chat/s4u_mood_manager.py b/src/mais4u/mais4u_chat/s4u_mood_manager.py index c936cea1..11d8c7ca 100644 --- a/src/mais4u/mais4u_chat/s4u_mood_manager.py +++ b/src/mais4u/mais4u_chat/s4u_mood_manager.py @@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageRecv from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.manager.async_task_manager import AsyncTask, async_task_manager from src.plugin_system.apis import send_api @@ -114,18 +114,12 @@ class ChatMood: self.regression_count: int = 0 - self.mood_model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="mood_text", - ) + self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text") self.mood_model_numerical = LLMRequest( - model=global_config.model.emotion, - temperature=0.4, - request_type="mood_numerical", + model_set=model_config.model_task_config.emotion, request_type="mood_numerical" ) - self.last_change_time = 0 + self.last_change_time: float = 0 # 发送初始情绪状态到ws端 asyncio.create_task(self.send_emotion_update(self.mood_values)) @@ -164,7 +158,7 @@ class ChatMood: async def update_mood_by_message(self, message: MessageRecv): self.regression_count = 0 - message_time = message.message_info.time + message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( chat_id=self.chat_id, timestamp_start=self.last_change_time, @@ -199,7 +193,9 @@ class ChatMood: mood_state=self.mood_state, ) logger.debug(f"text mood prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"text mood response: {response}") logger.debug(f"text mood reasoning_content: {reasoning_content}") return response @@ -216,8 +212,8 @@ class ChatMood: fear=self.mood_values["fear"], ) logger.debug(f"numerical mood prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( - prompt=prompt + response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async( + prompt=prompt, temperature=0.4 ) logger.info(f"numerical mood response: {response}") logger.debug(f"numerical mood reasoning_content: {reasoning_content}") @@ -276,7 +272,9 @@ class ChatMood: mood_state=self.mood_state, ) logger.debug(f"text regress prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) logger.info(f"text regress response: {response}") logger.debug(f"text regress reasoning_content: {reasoning_content}") return response @@ -293,8 +291,9 @@ class ChatMood: fear=self.mood_values["fear"], ) logger.debug(f"numerical regress prompt: {prompt}") - response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( - prompt=prompt + response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async( + prompt=prompt, + temperature=0.4, ) logger.info(f"numerical regress response: {response}") logger.debug(f"numerical regress reasoning_content: {reasoning_content}") @@ -447,6 +446,7 @@ class MoodManager: # 发送初始情绪状态到ws端 asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values)) + if ENABLE_S4U: init_prompt() mood_manager = MoodManager() diff --git a/src/mais4u/mais4u_chat/s4u_prompt.py b/src/mais4u/mais4u_chat/s4u_prompt.py index d748c25e..72324d74 100644 --- a/src/mais4u/mais4u_chat/s4u_prompt.py +++ b/src/mais4u/mais4u_chat/s4u_prompt.py @@ -150,19 +150,18 @@ class PromptBuilder: relation_prompt = "" if global_config.relationship.enable_relationship and who_chat_in_group: relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id) - + # 将 (platform, user_id, nickname) 转换为 person_id person_ids = [] for person in who_chat_in_group: person_id = PersonInfoManager.get_person_id(person[0], person[1]) person_ids.append(person_id) - + # 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为 relation_info_list = await asyncio.gather( *[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] ) - relation_info = "".join(relation_info_list) - if relation_info: + if relation_info := "".join(relation_info_list): relation_prompt = await global_prompt_manager.format_prompt( "relation_prompt", relation_info=relation_info ) @@ -186,9 +185,9 @@ class PromptBuilder: timestamp=time.time(), limit=300, ) - - talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id) + + talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}" core_dialogue_list = [] background_dialogue_list = [] @@ -258,19 +257,19 @@ class PromptBuilder: all_msg_seg_list.append(msg_seg_str) for msg in all_msg_seg_list: core_msg_str += msg - - + + all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat( chat_id=chat_stream.stream_id, timestamp=time.time(), limit=20, - ) + ) all_dialogue_prompt_str = build_readable_messages( all_dialogue_prompt, timestamp_mode="normal_no_YMD", show_pic=False, ) - + return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index 339b46c3..c0ca2658 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -1,7 +1,7 @@ import os from typing import AsyncGenerator from src.mais4u.openai_client import AsyncOpenAIClient -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder from src.common.logger import get_logger @@ -14,24 +14,27 @@ logger = get_logger("s4u_stream_generator") class S4UStreamGenerator: def __init__(self): - replyer_1_config = global_config.model.replyer_1 - provider = replyer_1_config.get("provider") - if not provider: - logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段") - raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段") + replyer_1_config = model_config.model_task_config.replyer_1 + model_to_use = replyer_1_config.model_list[0] + model_info = model_config.get_model_info(model_to_use) + if not model_info: + logger.error(f"模型 {model_to_use} 在配置中未找到") + raise ValueError(f"模型 {model_to_use} 在配置中未找到") + provider_name = model_info.api_provider + provider_info = model_config.get_provider(provider_name) + if not provider_info: + logger.error("`replyer_1` 找不到对应的Provider") + raise ValueError("`replyer_1` 找不到对应的Provider") - api_key = os.environ.get(f"{provider.upper()}_KEY") - base_url = os.environ.get(f"{provider.upper()}_BASE_URL") + api_key = provider_info.api_key + base_url = provider_info.base_url if not api_key: - logger.error(f"环境变量 {provider.upper()}_KEY 未设置") - raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置") + logger.error(f"{provider_name}没有配置API KEY") + raise ValueError(f"{provider_name}没有配置API KEY") self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url) - self.model_1_name = replyer_1_config.get("name") - if not self.model_1_name: - logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段") - raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段") + self.model_1_name = model_to_use self.replyer_1_config = replyer_1_config self.current_model_name = "unknown model" @@ -44,10 +47,10 @@ class S4UStreamGenerator: r'[^.。!??!\n\r]+(?:[.。!??!\n\r](?![\'"])|$))', # 匹配直到句子结束符 re.UNICODE | re.DOTALL, ) - - self.chat_stream =None - - async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""): + + self.chat_stream = None + + async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""): # person_id = PersonInfoManager.get_person_id( # message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id # ) @@ -71,14 +74,10 @@ class S4UStreamGenerator: [这是用户发来的新消息, 你需要结合上下文,对此进行回复]: {message.processed_plain_text} """ - return True,message_txt + return True, message_txt else: message_txt = message.processed_plain_text - return False,message_txt - - - - + return False, message_txt async def generate_response( self, message: MessageRecvS4U, previous_reply_context: str = "" @@ -88,7 +87,7 @@ class S4UStreamGenerator: self.partial_response = "" message_txt = message.processed_plain_text if not message.is_internal: - interupted,message_txt_added = await self.build_last_internal_message(message,previous_reply_context) + interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context) if interupted: message_txt = message_txt_added @@ -105,7 +104,6 @@ class S4UStreamGenerator: current_client = self.client_1 self.current_model_name = self.model_1_name - extra_kwargs = {} if self.replyer_1_config.get("enable_thinking") is not None: extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking") diff --git a/src/mais4u/mais4u_chat/super_chat_manager.py b/src/mais4u/mais4u_chat/super_chat_manager.py index 528eaecc..a08d18cd 100644 --- a/src/mais4u/mais4u_chat/super_chat_manager.py +++ b/src/mais4u/mais4u_chat/super_chat_manager.py @@ -214,51 +214,49 @@ class SuperChatManager: def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str: """构建SuperChat显示字符串""" superchats = self.get_superchats_by_chat(chat_id) - + if not superchats: return "" - + # 限制显示数量 display_superchats = superchats[:max_count] - - lines = [] - lines.append("📢 当前有效超级弹幕:") - + + lines = ["📢 当前有效超级弹幕:"] for i, sc in enumerate(display_superchats, 1): remaining_minutes = int(sc.remaining_time() / 60) remaining_seconds = int(sc.remaining_time() % 60) - + time_display = f"{remaining_minutes}分{remaining_seconds}秒" if remaining_minutes > 0 else f"{remaining_seconds}秒" - + line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" if len(line) > 100: # 限制单行长度 - line = line[:97] + "..." + line = f"{line[:97]}..." line += f" (剩余{time_display})" lines.append(line) - + if len(superchats) > max_count: lines.append(f"... 还有{len(superchats) - max_count}条SuperChat") - + return "\n".join(lines) def build_superchat_summary_string(self, chat_id: str) -> str: """构建SuperChat摘要字符串""" superchats = self.get_superchats_by_chat(chat_id) - + if not superchats: return "当前没有有效的超级弹幕" lines = [] for sc in superchats: single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}" if len(single_sc_str) > 100: - single_sc_str = single_sc_str[:97] + "..." + single_sc_str = f"{single_sc_str[:97]}..." single_sc_str += f" (剩余{int(sc.remaining_time())}秒)" lines.append(single_sc_str) - + total_amount = sum(sc.price for sc in superchats) count = len(superchats) highest_amount = max(sc.price for sc in superchats) - + final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}元" if lines: final_str += "\n" + "\n".join(lines) @@ -287,7 +285,7 @@ class SuperChatManager: "lowest_amount": min(amounts) } - async def shutdown(self): + async def shutdown(self): # sourcery skip: use-contextlib-suppress """关闭管理器,清理资源""" if self._cleanup_task and not self._cleanup_task.done(): self._cleanup_task.cancel() @@ -300,6 +298,7 @@ class SuperChatManager: +# sourcery skip: assign-if-exp if ENABLE_S4U: super_chat_manager = SuperChatManager() else: diff --git a/src/mais4u/mais4u_chat/yes_or_no.py b/src/mais4u/mais4u_chat/yes_or_no.py index edc200f6..c71c160d 100644 --- a/src/mais4u/mais4u_chat/yes_or_no.py +++ b/src/mais4u/mais4u_chat/yes_or_no.py @@ -1,19 +1,14 @@ from src.llm_models.utils_model import LLMRequest from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import model_config from src.plugin_system.apis import send_api + logger = get_logger(__name__) -head_actions_list = [ - "不做额外动作", - "点头一次", - "点头两次", - "摇头", - "歪脑袋", - "低头望向一边" -] +head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"] -async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat_id: str = ""): + +async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""): prompt = f""" {chat_history} 以上是对方的发言: @@ -30,22 +25,14 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat 低头望向一边 请从上面的动作中选择一个,并输出,请只输出你选择的动作就好,不要输出其他内容。""" - model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="motion", - ) - + model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion") + try: # logger.info(f"prompt: {prompt}") - response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt) + response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7) logger.info(f"response: {response}") - - if response in head_actions_list: - head_action = response - else: - head_action = "不做额外动作" - + + head_action = response if response in head_actions_list else "不做额外动作" await send_api.custom_to_stream( message_type="head_action", content=head_action, @@ -53,11 +40,7 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat storage_message=False, show_log=True, ) - - - + except Exception as e: logger.error(f"yes_or_no_head error: {e}") return "不做额外动作" - - diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index eae0ea71..8daf38e6 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -3,13 +3,14 @@ import random import time from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.message_receive.message import MessageRecv +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.llm_models.utils_model import LLMRequest from src.manager.async_task_manager import AsyncTask, async_task_manager -from src.chat.message_receive.chat_stream import get_chat_manager + logger = get_logger("mood") @@ -49,7 +50,7 @@ class ChatMood: chat_manager = get_chat_manager() self.chat_stream = chat_manager.get_stream(self.chat_id) - + if not self.chat_stream: raise ValueError(f"Chat stream for chat_id {chat_id} not found") @@ -59,11 +60,7 @@ class ChatMood: self.regression_count: int = 0 - self.mood_model = LLMRequest( - model=global_config.model.emotion, - temperature=0.7, - request_type="mood", - ) + self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood") self.last_change_time: float = 0 @@ -83,12 +80,16 @@ class ChatMood: logger.debug( f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}" ) - update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier) + update_probability = global_config.mood.mood_update_threshold * min( + 1.0, base_probability * time_multiplier * interest_multiplier + ) if random.random() > update_probability: return - logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}") + logger.debug( + f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}" + ) message_time: float = message.message_info.time # type: ignore message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( @@ -124,7 +125,9 @@ class ChatMood: mood_state=self.mood_state, ) - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} response: {response}") @@ -171,7 +174,9 @@ class ChatMood: mood_state=self.mood_state, ) - response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) + response, (reasoning_content, _, _) = await self.mood_model.generate_response_async( + prompt=prompt, temperature=0.7 + ) if global_config.debug.show_prompt: logger.info(f"{self.log_prefix} prompt: {prompt}") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index 6be0ad27..4d5fe709 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -11,7 +11,7 @@ from src.common.logger import get_logger from src.common.database.database import db from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config """ @@ -54,11 +54,7 @@ person_info_default = { class PersonInfoManager: def __init__(self): self.person_name_list = {} - # TODO: API-Adapter修改标记 - self.qv_name_llm = LLMRequest( - model=global_config.model.utils, - request_type="relation.qv_name", - ) + self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") try: db.connect(reuse_if_open=True) # 设置连接池参数 @@ -199,7 +195,7 @@ class PersonInfoManager: if existing: logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") return True - + # 尝试创建 PersonInfo.create(**p_data) return True @@ -376,7 +372,7 @@ class PersonInfoManager: "nickname": "昵称", "reason": "理由" }""" - response, (reasoning_content, model_name) = await self.qv_name_llm.generate_response_async(qv_name_prompt) + response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt) # logger.info(f"取名提示词:{qv_name_prompt}\n取名回复:{response}") result = self._extract_json_from_text(response) @@ -592,7 +588,7 @@ class PersonInfoManager: record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) if record: return record, False # 记录存在,未创建 - + # 记录不存在,尝试创建 try: PersonInfo.create(**init_data) @@ -622,7 +618,7 @@ class PersonInfoManager: "points": [], "forgotten_points": [], } - + # 序列化JSON字段 for key in JSON_SERIALIZED_FIELDS: if key in initial_data: @@ -630,12 +626,12 @@ class PersonInfoManager: initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False) elif initial_data[key] is None: initial_data[key] = json.dumps([], ensure_ascii=False) - + model_fields = PersonInfo._meta.fields.keys() # type: ignore filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data) - + if was_created: logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 99f3be30..267ed96f 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -7,7 +7,7 @@ from typing import List, Dict, Any from json_repair import repair_json from src.common.logger import get_logger -from src.config.config import global_config +from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager @@ -73,14 +73,12 @@ class RelationshipFetcher: # LLM模型配置 self.llm_model = LLMRequest( - model=global_config.model.utils_small, - request_type="relation.fetcher", + model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher" ) # 小模型用于即时信息提取 self.instant_llm_model = LLMRequest( - model=global_config.model.utils_small, - request_type="relation.fetch", + model_set=model_config.model_task_config.utils_small, request_type="relation.fetch" ) name = get_chat_manager().get_stream_name(self.chat_id) @@ -96,7 +94,7 @@ class RelationshipFetcher: if not self.info_fetched_cache[person_id]: del self.info_fetched_cache[person_id] - async def build_relation_info(self, person_id, points_num = 3): + async def build_relation_info(self, person_id, points_num=3): # 清理过期的信息缓存 self._cleanup_expired_cache() @@ -361,7 +359,6 @@ class RelationshipFetcher: logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}") logger.error(traceback.format_exc()) - async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): # sourcery skip: use-next """将提取到的信息保存到 person_info 的 info_list 字段中 diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 6c269357..9d7a48b9 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -3,7 +3,7 @@ from .person_info import PersonInfoManager, get_person_info_manager import time import random from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config from src.chat.utils.chat_message_builder import build_readable_messages import json from json_repair import repair_json @@ -20,9 +20,8 @@ logger = get_logger("relation") class RelationshipManager: def __init__(self): self.relationship_llm = LLMRequest( - model=global_config.model.utils, - request_type="relationship", # 用于动作规划 - ) + model_set=model_config.model_task_config.utils, request_type="relationship" + ) # 用于动作规划 @staticmethod async def is_known_some_one(platform, user_id): @@ -181,18 +180,14 @@ class RelationshipManager: try: points = repair_json(points) points_data = json.loads(points) - + # 只处理正确的格式,错误格式直接跳过 if points_data == "none" or not points_data: points_list = [] elif isinstance(points_data, str) and points_data.lower() == "none": points_list = [] elif isinstance(points_data, list): - # 正确格式:数组格式 [{"point": "...", "weight": 10}, ...] - if not points_data: # 空数组 - points_list = [] - else: - points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data] + points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data] else: # 错误格式,直接跳过不解析 logger.warning(f"LLM返回了错误的JSON格式,跳过解析: {type(points_data)}, 内容: {points_data}") diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index eb07dbc9..a102ecd0 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -9,6 +9,7 @@ from .base import ( BasePlugin, BaseAction, BaseCommand, + BaseTool, ConfigField, ComponentType, ActionActivationType, @@ -17,11 +18,13 @@ from .base import ( ActionInfo, CommandInfo, PluginInfo, + ToolInfo, PythonDependency, BaseEventHandler, EventHandlerInfo, EventType, MaiMessages, + ToolParamType, ) # 导入工具模块 @@ -34,6 +37,7 @@ from .utils import ( from .apis import ( chat_api, + tool_api, component_manage_api, config_api, database_api, @@ -44,17 +48,17 @@ from .apis import ( person_api, plugin_manage_api, send_api, - utils_api, register_plugin, get_logger, ) -__version__ = "1.0.0" +__version__ = "2.0.0" __all__ = [ # API 模块 "chat_api", + "tool_api", "component_manage_api", "config_api", "database_api", @@ -65,13 +69,13 @@ __all__ = [ "person_api", "plugin_manage_api", "send_api", - "utils_api", "register_plugin", "get_logger", # 基础类 "BasePlugin", "BaseAction", "BaseCommand", + "BaseTool", "BaseEventHandler", # 类型定义 "ComponentType", @@ -81,9 +85,11 @@ __all__ = [ "ActionInfo", "CommandInfo", "PluginInfo", + "ToolInfo", "PythonDependency", "EventHandlerInfo", "EventType", + "ToolParamType", # 消息 "MaiMessages", # 装饰器 diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 0882fbdc..362c9858 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -17,7 +17,7 @@ from src.plugin_system.apis import ( person_api, plugin_manage_api, send_api, - utils_api, + tool_api, ) from .logging_api import get_logger from .plugin_register_api import register_plugin @@ -35,7 +35,7 @@ __all__ = [ "person_api", "plugin_manage_api", "send_api", - "utils_api", "get_logger", "register_plugin", + "tool_api", ] diff --git a/src/plugin_system/apis/component_manage_api.py b/src/plugin_system/apis/component_manage_api.py index d9ea051d..1ffa0833 100644 --- a/src/plugin_system/apis/component_manage_api.py +++ b/src/plugin_system/apis/component_manage_api.py @@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import ( EventHandlerInfo, PluginInfo, ComponentType, + ToolInfo, ) @@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]: return component_registry.get_registered_command_info(command_name) +def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]: + """ + 获取指定 Tool 的注册信息。 + + Args: + tool_name (str): Tool 名称。 + + Returns: + ToolInfo: Tool 信息对象,如果 Tool 不存在则返回 None。 + """ + from src.plugin_system.core.component_registry import component_registry + + return component_registry.get_registered_tool_info(tool_name) + + # === EventHandler 特定查询方法 === def get_registered_event_handler_info( event_handler_name: str, @@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType, return global_announcement_manager.enable_specific_chat_action(stream_id, component_name) case ComponentType.COMMAND: return global_announcement_manager.enable_specific_chat_command(stream_id, component_name) + case ComponentType.TOOL: + return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name) case ComponentType.EVENT_HANDLER: return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name) case _: @@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType return global_announcement_manager.disable_specific_chat_action(stream_id, component_name) case ComponentType.COMMAND: return global_announcement_manager.disable_specific_chat_command(stream_id, component_name) + case ComponentType.TOOL: + return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name) case ComponentType.EVENT_HANDLER: return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name) case _: raise ValueError(f"未知 component type: {component_type}") + def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]: """ 获取指定消息流中禁用的组件列表。 @@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp return global_announcement_manager.get_disabled_chat_actions(stream_id) case ComponentType.COMMAND: return global_announcement_manager.get_disabled_chat_commands(stream_id) + case ComponentType.TOOL: + return global_announcement_manager.get_disabled_chat_tools(stream_id) case ComponentType.EVENT_HANDLER: return global_announcement_manager.get_disabled_chat_event_handlers(stream_id) case _: - raise ValueError(f"未知 component type: {component_type}") \ No newline at end of file + raise ValueError(f"未知 component type: {component_type}") diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index d46bfba3..8b253806 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -152,10 +152,7 @@ async def db_query( except DoesNotExist: # 记录不存在 - if query_type == "get" and single_result: - return None - return [] - + return None if query_type == "get" and single_result else [] except Exception as e: logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") traceback.print_exc() @@ -170,7 +167,8 @@ async def db_query( async def db_save( model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None -) -> Union[Dict[str, Any], None]: +) -> Optional[Dict[str, Any]]: + # sourcery skip: inline-immediately-returned-variable """保存数据到数据库(创建或更新) 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; @@ -203,10 +201,9 @@ async def db_save( try: # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: - # 查找现有记录 - existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)) - - if existing_records: + if existing_records := list( + model_class.select().where(getattr(model_class, key_field) == key_value).limit(1) + ): # 更新现有记录 existing_record = existing_records[0] for field, value in data.items(): @@ -244,8 +241,8 @@ async def db_get( Args: model_class: Peewee模型类 filters: 过滤条件,字段名和值的字典 - order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序 limit: 结果数量限制 + order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序 single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表 Returns: @@ -310,7 +307,7 @@ async def store_action_info( thinking_id: str = "", action_data: Optional[dict] = None, action_name: str = "", -) -> Union[Dict[str, Any], None]: +) -> Optional[Dict[str, Any]]: """存储动作信息到数据库 将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。 diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index cafb52df..479f3aec 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -65,14 +65,14 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] return None -async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]: +async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: """随机获取指定数量的表情包 Args: count: 要获取的表情包数量,默认为1 Returns: - Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,如果失败则为None + List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表,失败则返回空列表 Raises: TypeError: 如果count不是整数类型 @@ -94,13 +94,13 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, if not all_emojis: logger.warning("[EmojiAPI] 没有可用的表情包") - return None + return [] # 过滤有效表情包 valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted] if not valid_emojis: logger.warning("[EmojiAPI] 没有有效的表情包") - return None + return [] if len(valid_emojis) < count: logger.warning( @@ -127,14 +127,14 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, if not results and count > 0: logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理") - return None + return [] logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包") return results except Exception as e: logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}") - return None + return [] async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: @@ -162,10 +162,11 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: # 筛选匹配情感的表情包 matching_emojis = [] - for emoji_obj in all_emojis: - if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]: - matching_emojis.append(emoji_obj) - + matching_emojis.extend( + emoji_obj + for emoji_obj in all_emojis + if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion] + ) if not matching_emojis: logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包") return None @@ -256,10 +257,11 @@ def get_descriptions() -> List[str]: emoji_manager = get_emoji_manager() descriptions = [] - for emoji_obj in emoji_manager.emoji_objects: - if not emoji_obj.is_deleted and emoji_obj.description: - descriptions.append(emoji_obj.description) - + descriptions.extend( + emoji_obj.description + for emoji_obj in emoji_manager.emoji_objects + if not emoji_obj.is_deleted and emoji_obj.description + ) return descriptions except Exception as e: logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}") diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index f911454c..0e6e6551 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -12,6 +12,7 @@ import traceback from typing import Tuple, Any, Dict, List, Optional from rich.traceback import install from src.common.logger import get_logger +from src.config.api_ada_configs import TaskConfig from src.chat.replyer.default_generator import DefaultReplyer from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.utils import process_llm_response @@ -31,7 +32,7 @@ logger = get_logger("generator_api") def get_replyer( chat_stream: Optional[ChatStream] = None, chat_id: Optional[str] = None, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, request_type: str = "replyer", ) -> Optional[DefaultReplyer]: """获取回复器对象 @@ -42,7 +43,7 @@ def get_replyer( Args: chat_stream: 聊天流对象(优先) chat_id: 聊天ID(实际上就是stream_id) - model_configs: 模型配置列表 + model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组 request_type: 请求类型 Returns: @@ -58,7 +59,7 @@ def get_replyer( return replyer_manager.get_replyer( chat_stream=chat_stream, chat_id=chat_id, - model_configs=model_configs, + model_set_with_weight=model_set_with_weight, request_type=request_type, ) except Exception as e: @@ -83,31 +84,36 @@ async def generate_reply( enable_splitter: bool = True, enable_chinese_typo: bool = True, return_prompt: bool = False, - model_configs: Optional[List[Dict[str, Any]]] = None, - request_type: str = "", - enable_timeout: bool = False, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, + request_type: str = "generator_api", ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """生成回复 Args: chat_stream: 聊天流对象(优先) chat_id: 聊天ID(备用) - action_data: 动作数据 + action_data: 动作数据(向下兼容,包含reply_to和extra_info) + reply_to: 回复对象,格式为 "发送者:消息内容" + extra_info: 额外信息,用于补充上下文 + available_actions: 可用动作 + enable_tool: 是否启用工具调用 enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 return_prompt: 是否返回提示词 + model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组 + request_type: 请求类型(可选,记录LLM使用) Returns: Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词) """ try: # 获取回复器 - replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type) + replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") return False, [], None logger.debug("[GeneratorAPI] 开始生成回复") - + if not reply_to and action_data: reply_to = action_data.get("reply_to", "") if not extra_info and action_data: @@ -118,7 +124,6 @@ async def generate_reply( reply_to=reply_to, extra_info=extra_info, available_actions=available_actions, - enable_timeout=enable_timeout, enable_tool=enable_tool, ) reply_set = [] @@ -150,33 +155,35 @@ async def rewrite_reply( chat_id: Optional[str] = None, enable_splitter: bool = True, enable_chinese_typo: bool = True, - model_configs: Optional[List[Dict[str, Any]]] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, raw_reply: str = "", reason: str = "", reply_to: str = "", -) -> Tuple[bool, List[Tuple[str, Any]]]: + return_prompt: bool = False, +) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: """重写回复 Args: chat_stream: 聊天流对象(优先) - reply_data: 回复数据字典(备用,当其他参数缺失时从此获取) + reply_data: 回复数据字典(向下兼容备用,当其他参数缺失时从此获取) chat_id: 聊天ID(备用) enable_splitter: 是否启用消息分割器 enable_chinese_typo: 是否启用错字生成器 - model_configs: 模型配置列表 + model_set_with_weight: 模型配置列表,每个元素为 (TaskConfig, weight) 元组 raw_reply: 原始回复内容 reason: 回复原因 reply_to: 回复对象 + return_prompt: 是否返回提示词 Returns: Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) """ try: # 获取回复器 - replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs) + replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight) if not replyer: logger.error("[GeneratorAPI] 无法获取回复器") - return False, [] + return False, [], None logger.info("[GeneratorAPI] 开始重写回复") @@ -187,10 +194,11 @@ async def rewrite_reply( reply_to = reply_to or reply_data.get("reply_to", "") # 调用回复器重写回复 - success, content = await replyer.rewrite_reply_with_context( + success, content, prompt = await replyer.rewrite_reply_with_context( raw_reply=raw_reply, reason=reason, reply_to=reply_to, + return_prompt=return_prompt, ) reply_set = [] if content: @@ -201,14 +209,14 @@ async def rewrite_reply( else: logger.warning("[GeneratorAPI] 重写回复失败") - return success, reply_set + return success, reply_set, prompt if return_prompt else None except ValueError as ve: raise ve except Exception as e: logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") - return False, [] + return False, [], None async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]: @@ -234,3 +242,27 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese except Exception as e: logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}") return [] + +async def generate_response_custom( + chat_stream: Optional[ChatStream] = None, + chat_id: Optional[str] = None, + model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None, + prompt: str = "", +) -> Optional[str]: + replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight) + if not replyer: + logger.error("[GeneratorAPI] 无法获取回复器") + return None + + try: + logger.debug("[GeneratorAPI] 开始生成自定义回复") + response, _, _, _ = await replyer.llm_generate_content(prompt) + if response: + logger.debug("[GeneratorAPI] 自定义回复生成成功") + return response + else: + logger.warning("[GeneratorAPI] 自定义回复生成失败") + return None + except Exception as e: + logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}") + return None \ No newline at end of file diff --git a/src/plugin_system/apis/llm_api.py b/src/plugin_system/apis/llm_api.py index 72b865b8..9d37a8e3 100644 --- a/src/plugin_system/apis/llm_api.py +++ b/src/plugin_system/apis/llm_api.py @@ -7,10 +7,12 @@ success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) """ -from typing import Tuple, Dict, Any +from typing import Tuple, Dict, List, Any, Optional from src.common.logger import get_logger +from src.llm_models.payload_content.tool_option import ToolCall from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config +from src.config.config import global_config, model_config +from src.config.api_ada_configs import TaskConfig logger = get_logger("llm_api") @@ -19,9 +21,7 @@ logger = get_logger("llm_api") # ============================================================================= - - -def get_available_models() -> Dict[str, Any]: +def get_available_models() -> Dict[str, TaskConfig]: """获取所有可用的模型配置 Returns: @@ -33,14 +33,14 @@ def get_available_models() -> Dict[str, Any]: return {} # 自动获取所有属性并转换为字典形式 - rets = {} - models = global_config.model + models = model_config.model_task_config attrs = dir(models) + rets: Dict[str, TaskConfig] = {} for attr in attrs: if not attr.startswith("__"): try: value = getattr(models, attr) - if not callable(value): # 排除方法 + if not callable(value) and isinstance(value, TaskConfig): rets[attr] = value except Exception as e: logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}") @@ -53,7 +53,11 @@ def get_available_models() -> Dict[str, Any]: async def generate_with_model( - prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs + prompt: str, + model_config: TaskConfig, + request_type: str = "plugin.generate", + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, ) -> Tuple[bool, str, str, str]: """使用指定模型生成内容 @@ -61,22 +65,62 @@ async def generate_with_model( prompt: 提示词 model_config: 模型配置(从 get_available_models 获取的模型配置) request_type: 请求类型标识 - **kwargs: 其他模型特定参数,如temperature、max_tokens等 Returns: Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) """ try: - model_name = model_config.get("name") - logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容") + model_name_list = model_config.model_list + logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") logger.debug(f"[LLMAPI] 完整提示词: {prompt}") - llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs) + llm_request = LLMRequest(model_set=model_config, request_type=request_type) - response, (reasoning, model_name) = await llm_request.generate_response_async(prompt) - return True, response, reasoning, model_name + response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens) + return True, response, reasoning_content, model_name except Exception as e: error_msg = f"生成内容时出错: {str(e)}" logger.error(f"[LLMAPI] {error_msg}") return False, error_msg, "", "" + +async def generate_with_model_with_tools( + prompt: str, + model_config: TaskConfig, + tool_options: List[Dict[str, Any]] | None = None, + request_type: str = "plugin.generate", + temperature: Optional[float] = None, + max_tokens: Optional[int] = None, +) -> Tuple[bool, str, str, str, List[ToolCall] | None]: + """使用指定模型和工具生成内容 + + Args: + prompt: 提示词 + model_config: 模型配置(从 get_available_models 获取的模型配置) + tool_options: 工具选项列表 + request_type: 请求类型标识 + temperature: 温度参数 + max_tokens: 最大token数 + + Returns: + Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) + """ + try: + model_name_list = model_config.model_list + logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容") + logger.debug(f"[LLMAPI] 完整提示词: {prompt}") + + llm_request = LLMRequest(model_set=model_config, request_type=request_type) + + response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async( + prompt, + tools=tool_options, + temperature=temperature, + max_tokens=max_tokens + ) + return True, response, reasoning_content, model_name, tool_call + + except Exception as e: + error_msg = f"生成内容时出错: {str(e)}" + logger.error(f"[LLMAPI] {error_msg}") + return False, error_msg, "", "", None diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 7794ee81..7cf9dc04 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -207,7 +207,7 @@ def get_random_chat_messages( def get_messages_by_time_for_users( - start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" + start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """ 获取指定用户在所有聊天中指定时间范围内的消息 @@ -287,7 +287,7 @@ def get_messages_before_time_in_chat( return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) -def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: +def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]: """ 获取指定用户在指定时间戳之前的消息 @@ -372,7 +372,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional return num_new_messages_since(chat_id, start_time, end_time) -def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list) -> int: +def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int: """ 计算指定聊天中指定用户从开始时间到结束时间的新消息数量 diff --git a/src/plugin_system/apis/plugin_manage_api.py b/src/plugin_system/apis/plugin_manage_api.py index 1c01119b..693e42b4 100644 --- a/src/plugin_system/apis/plugin_manage_api.py +++ b/src/plugin_system/apis/plugin_manage_api.py @@ -1,10 +1,12 @@ from typing import Tuple, List + + def list_loaded_plugins() -> List[str]: """ 列出所有当前加载的插件。 Returns: - list: 当前加载的插件名称列表。 + List[str]: 当前加载的插件名称列表。 """ from src.plugin_system.core.plugin_manager import plugin_manager @@ -16,17 +18,38 @@ def list_registered_plugins() -> List[str]: 列出所有已注册的插件。 Returns: - list: 已注册的插件名称列表。 + List[str]: 已注册的插件名称列表。 """ from src.plugin_system.core.plugin_manager import plugin_manager return plugin_manager.list_registered_plugins() +def get_plugin_path(plugin_name: str) -> str: + """ + 获取指定插件的路径。 + + Args: + plugin_name (str): 插件名称。 + + Returns: + str: 插件目录的绝对路径。 + + Raises: + ValueError: 如果插件不存在。 + """ + from src.plugin_system.core.plugin_manager import plugin_manager + + if plugin_path := plugin_manager.get_plugin_path(plugin_name): + return plugin_path + else: + raise ValueError(f"插件 '{plugin_name}' 不存在。") + + async def remove_plugin(plugin_name: str) -> bool: """ 卸载指定的插件。 - + **此函数是异步的,确保在异步环境中调用。** Args: @@ -43,7 +66,7 @@ async def remove_plugin(plugin_name: str) -> bool: async def reload_plugin(plugin_name: str) -> bool: """ 重新加载指定的插件。 - + **此函数是异步的,确保在异步环境中调用。** Args: @@ -71,6 +94,7 @@ def load_plugin(plugin_name: str) -> Tuple[bool, int]: return plugin_manager.load_registered_plugin_classes(plugin_name) + def add_plugin_directory(plugin_directory: str) -> bool: """ 添加插件目录。 @@ -84,6 +108,7 @@ def add_plugin_directory(plugin_directory: str) -> bool: return plugin_manager.add_plugin_directory(plugin_directory) + def rescan_plugin_directory() -> Tuple[int, int]: """ 重新扫描插件目录,加载新插件。 @@ -92,4 +117,4 @@ def rescan_plugin_directory() -> Tuple[int, int]: """ from src.plugin_system.core.plugin_manager import plugin_manager - return plugin_manager.rescan_plugin_directory() \ No newline at end of file + return plugin_manager.rescan_plugin_directory() diff --git a/src/plugin_system/apis/send_api.py b/src/plugin_system/apis/send_api.py index f7af0259..10fbd804 100644 --- a/src/plugin_system/apis/send_api.py +++ b/src/plugin_system/apis/send_api.py @@ -49,7 +49,7 @@ async def _send_to_target( display_message: str = "", typing: bool = False, reply_to: str = "", - reply_to_platform_id: str = "", + reply_to_platform_id: Optional[str] = None, storage_message: bool = True, show_log: bool = True, ) -> bool: @@ -60,8 +60,11 @@ async def _send_to_target( content: 消息内容 stream_id: 目标流ID display_message: 显示消息 - typing: 是否显示正在输入 - reply_to: 回复消息的格式,如"发送者:消息内容" + typing: 是否模拟打字等待。 + reply_to: 回复消息,格式为"发送者:消息内容" + reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!) + storage_message: 是否存储消息到数据库 + show_log: 发送是否显示日志 Returns: bool: 是否发送成功 @@ -97,6 +100,10 @@ async def _send_to_target( anchor_message = None if reply_to: anchor_message = await _find_reply_message(target_stream, reply_to) + if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id: + reply_to_platform_id = ( + f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}" + ) # 构建发送消息对象 bot_message = MessageSending( @@ -262,12 +269,22 @@ async def text_to_stream( stream_id: 聊天流ID typing: 是否显示正在输入 reply_to: 回复消息,格式为"发送者:消息内容" + reply_to_platform_id: 回复消息,格式为"平台:用户ID",如果不提供则自动查找(插件开发者禁用!) storage_message: 是否存储消息到数据库 Returns: bool: 是否发送成功 """ - return await _send_to_target("text", text, stream_id, "", typing, reply_to, reply_to_platform_id, storage_message) + return await _send_to_target( + "text", + text, + stream_id, + "", + typing, + reply_to, + reply_to_platform_id=reply_to_platform_id, + storage_message=storage_message, + ) async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool: @@ -318,7 +335,7 @@ async def command_to_stream( async def custom_to_stream( message_type: str, - content: str, + content: str | dict, stream_id: str, display_message: str = "", typing: bool = False, @@ -350,249 +367,3 @@ async def custom_to_stream( storage_message=storage_message, show_log=show_log, ) - - -async def text_to_group( - text: str, - group_id: str, - platform: str = "qq", - typing: bool = False, - reply_to: str = "", - storage_message: bool = True, -) -> bool: - """向群聊发送文本消息 - - Args: - text: 要发送的文本内容 - group_id: 群聊ID - platform: 平台,默认为"qq" - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - - return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message) - - -async def text_to_user( - text: str, - user_id: str, - platform: str = "qq", - typing: bool = False, - reply_to: str = "", - storage_message: bool = True, -) -> bool: - """向用户发送私聊文本消息 - - Args: - text: 要发送的文本内容 - user_id: 用户ID - platform: 平台,默认为"qq" - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message) - - -async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: - """向群聊发送表情包 - - Args: - emoji_base64: 表情包的base64编码 - group_id: 群聊ID - platform: 平台,默认为"qq" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message) - - -async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool: - """向用户发送表情包 - - Args: - emoji_base64: 表情包的base64编码 - user_id: 用户ID - platform: 平台,默认为"qq" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message) - - -async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: - """向群聊发送图片 - - Args: - image_base64: 图片的base64编码 - group_id: 群聊ID - platform: 平台,默认为"qq" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message) - - -async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool: - """向用户发送图片 - - Args: - image_base64: 图片的base64编码 - user_id: 用户ID - platform: 平台,默认为"qq" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target("image", image_base64, stream_id, "", typing=False) - - -async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool: - """向群聊发送命令 - - Args: - command: 命令 - group_id: 群聊ID - platform: 平台,默认为"qq" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message) - - -async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool: - """向用户发送命令 - - Args: - command: 命令 - user_id: 用户ID - platform: 平台,默认为"qq" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message) - - -# ============================================================================= -# 通用发送函数 - 支持任意消息类型 -# ============================================================================= - - -async def custom_to_group( - message_type: str, - content: str, - group_id: str, - platform: str = "qq", - display_message: str = "", - typing: bool = False, - reply_to: str = "", - storage_message: bool = True, -) -> bool: - """向群聊发送自定义类型消息 - - Args: - message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等 - content: 消息内容(通常是base64编码或文本) - group_id: 群聊ID - platform: 平台,默认为"qq" - display_message: 显示消息 - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, group_id, True) - return await _send_to_target( - message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message - ) - - -async def custom_to_user( - message_type: str, - content: str, - user_id: str, - platform: str = "qq", - display_message: str = "", - typing: bool = False, - reply_to: str = "", - storage_message: bool = True, -) -> bool: - """向用户发送自定义类型消息 - - Args: - message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"等 - content: 消息内容(通常是base64编码或文本) - user_id: 用户ID - platform: 平台,默认为"qq" - display_message: 显示消息 - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - - Returns: - bool: 是否发送成功 - """ - stream_id = get_chat_manager().get_stream_id(platform, user_id, False) - return await _send_to_target( - message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message - ) - - -async def custom_message( - message_type: str, - content: str, - target_id: str, - is_group: bool = True, - platform: str = "qq", - display_message: str = "", - typing: bool = False, - reply_to: str = "", - storage_message: bool = True, -) -> bool: - """发送自定义消息的通用接口 - - Args: - message_type: 消息类型,如"text"、"image"、"emoji"、"video"、"file"、"audio"等 - content: 消息内容 - target_id: 目标ID(群ID或用户ID) - is_group: 是否为群聊,True为群聊,False为私聊 - platform: 平台,默认为"qq" - display_message: 显示消息 - typing: 是否显示正在输入 - reply_to: 回复消息,格式为"发送者:消息内容" - - Returns: - bool: 是否发送成功 - - 示例: - # 发送视频到群聊 - await send_api.custom_message("video", video_base64, "123456", True) - - # 发送文件到用户 - await send_api.custom_message("file", file_base64, "987654", False) - - # 发送音频到群聊并回复特定消息 - await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好") - """ - stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group) - return await _send_to_target( - message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message - ) diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py new file mode 100644 index 00000000..a6704126 --- /dev/null +++ b/src/plugin_system/apis/tool_api.py @@ -0,0 +1,27 @@ +from typing import Optional, Type +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.base.component_types import ComponentType + +from src.common.logger import get_logger + +logger = get_logger("tool_api") + + +def get_tool_instance(tool_name: str) -> Optional[BaseTool]: + """获取公开工具实例""" + from src.plugin_system.core import component_registry + + tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore + return tool_class() if tool_class else None + + +def get_llm_available_tool_definitions(): + """获取LLM可用的工具定义列表 + + Returns: + List[Tuple[str, Dict[str, Any]]]: 工具定义列表,为[("tool_name", 定义)] + """ + from src.plugin_system.core import component_registry + + llm_available_tools = component_registry.get_llm_available_tools() + return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()] diff --git a/src/plugin_system/apis/utils_api.py b/src/plugin_system/apis/utils_api.py deleted file mode 100644 index 45996df5..00000000 --- a/src/plugin_system/apis/utils_api.py +++ /dev/null @@ -1,168 +0,0 @@ -"""工具类API模块 - -提供了各种辅助功能 -使用方式: - from src.plugin_system.apis import utils_api - plugin_path = utils_api.get_plugin_path() - data = utils_api.read_json_file("data.json") - timestamp = utils_api.get_timestamp() -""" - -import os -import json -import time -import inspect -import datetime -import uuid -from typing import Any, Optional -from src.common.logger import get_logger - -logger = get_logger("utils_api") - - -# ============================================================================= -# 文件操作API函数 -# ============================================================================= - - -def get_plugin_path(caller_frame=None) -> str: - """获取调用者插件的路径 - - Args: - caller_frame: 调用者的栈帧,默认为None(自动获取) - - Returns: - str: 插件目录的绝对路径 - """ - try: - if caller_frame is None: - caller_frame = inspect.currentframe().f_back # type: ignore - - plugin_module_path = inspect.getfile(caller_frame) # type: ignore - plugin_dir = os.path.dirname(plugin_module_path) - return plugin_dir - except Exception as e: - logger.error(f"[UtilsAPI] 获取插件路径失败: {e}") - return "" - - -def read_json_file(file_path: str, default: Any = None) -> Any: - """读取JSON文件 - - Args: - file_path: 文件路径,可以是相对于插件目录的路径 - default: 如果文件不存在或读取失败时返回的默认值 - - Returns: - Any: JSON数据或默认值 - """ - try: - # 如果是相对路径,则相对于调用者的插件目录 - if not os.path.isabs(file_path): - caller_frame = inspect.currentframe().f_back # type: ignore - plugin_dir = get_plugin_path(caller_frame) - file_path = os.path.join(plugin_dir, file_path) - - if not os.path.exists(file_path): - logger.warning(f"[UtilsAPI] 文件不存在: {file_path}") - return default - - with open(file_path, "r", encoding="utf-8") as f: - return json.load(f) - except Exception as e: - logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}") - return default - - -def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool: - """写入JSON文件 - - Args: - file_path: 文件路径,可以是相对于插件目录的路径 - data: 要写入的数据 - indent: JSON缩进 - - Returns: - bool: 是否写入成功 - """ - try: - # 如果是相对路径,则相对于调用者的插件目录 - if not os.path.isabs(file_path): - caller_frame = inspect.currentframe().f_back # type: ignore - plugin_dir = get_plugin_path(caller_frame) - file_path = os.path.join(plugin_dir, file_path) - - # 确保目录存在 - os.makedirs(os.path.dirname(file_path), exist_ok=True) - - with open(file_path, "w", encoding="utf-8") as f: - json.dump(data, f, ensure_ascii=False, indent=indent) - return True - except Exception as e: - logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}") - return False - - -# ============================================================================= -# 时间相关API函数 -# ============================================================================= - - -def get_timestamp() -> int: - """获取当前时间戳 - - Returns: - int: 当前时间戳(秒) - """ - return int(time.time()) - - -def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: - """格式化时间 - - Args: - timestamp: 时间戳,如果为None则使用当前时间 - format_str: 时间格式字符串 - - Returns: - str: 格式化后的时间字符串 - """ - try: - if timestamp is None: - timestamp = time.time() - return datetime.datetime.fromtimestamp(timestamp).strftime(format_str) - except Exception as e: - logger.error(f"[UtilsAPI] 格式化时间失败: {e}") - return "" - - -def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int: - """解析时间字符串为时间戳 - - Args: - time_str: 时间字符串 - format_str: 时间格式字符串 - - Returns: - int: 时间戳(秒) - """ - try: - dt = datetime.datetime.strptime(time_str, format_str) - return int(dt.timestamp()) - except Exception as e: - logger.error(f"[UtilsAPI] 解析时间失败: {e}") - return 0 - - -# ============================================================================= -# 其他工具函数 -# ============================================================================= - - -def generate_unique_id() -> str: - """生成唯一ID - - Returns: - str: 唯一ID - """ - return str(uuid.uuid4()) diff --git a/src/plugin_system/base/__init__.py b/src/plugin_system/base/__init__.py index a95e05ae..bc63d35d 100644 --- a/src/plugin_system/base/__init__.py +++ b/src/plugin_system/base/__init__.py @@ -6,6 +6,7 @@ from .base_plugin import BasePlugin from .base_action import BaseAction +from .base_tool import BaseTool from .base_command import BaseCommand from .base_events_handler import BaseEventHandler from .component_types import ( @@ -15,11 +16,13 @@ from .component_types import ( ComponentInfo, ActionInfo, CommandInfo, + ToolInfo, PluginInfo, PythonDependency, EventHandlerInfo, EventType, MaiMessages, + ToolParamType, ) from .config_types import ConfigField @@ -27,12 +30,14 @@ __all__ = [ "BasePlugin", "BaseAction", "BaseCommand", + "BaseTool", "ComponentType", "ActionActivationType", "ChatMode", "ComponentInfo", "ActionInfo", "CommandInfo", + "ToolInfo", "PluginInfo", "PythonDependency", "ConfigField", @@ -40,4 +45,5 @@ __all__ = [ "EventType", "BaseEventHandler", "MaiMessages", + "ToolParamType", ] diff --git a/src/plugin_system/base/base_action.py b/src/plugin_system/base/base_action.py index 7acd14a4..66d723f5 100644 --- a/src/plugin_system/base/base_action.py +++ b/src/plugin_system/base/base_action.py @@ -208,7 +208,7 @@ class BaseAction(ABC): return False, f"等待新消息失败: {str(e)}" async def send_text( - self, content: str, reply_to: str = "", reply_to_platform_id: str = "", typing: bool = False + self, content: str, reply_to: str = "", typing: bool = False ) -> bool: """发送文本消息 @@ -227,7 +227,6 @@ class BaseAction(ABC): text=content, stream_id=self.chat_id, reply_to=reply_to, - reply_to_platform_id=reply_to_platform_id, typing=typing, ) diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 3cf82390..ea28c514 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -3,10 +3,11 @@ from typing import List, Type, Tuple, Union from .plugin_base import PluginBase from src.common.logger import get_logger -from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo +from src.plugin_system.base.component_types import ActionInfo, CommandInfo, EventHandlerInfo, ToolInfo from .base_action import BaseAction from .base_command import BaseCommand from .base_events_handler import BaseEventHandler +from .base_tool import BaseTool logger = get_logger("base_plugin") @@ -31,6 +32,7 @@ class BasePlugin(PluginBase): Tuple[ActionInfo, Type[BaseAction]], Tuple[CommandInfo, Type[BaseCommand]], Tuple[EventHandlerInfo, Type[BaseEventHandler]], + Tuple[ToolInfo, Type[BaseTool]], ] ]: """获取插件包含的组件列表 diff --git a/src/plugin_system/base/base_tool.py b/src/plugin_system/base/base_tool.py new file mode 100644 index 00000000..1d589eca --- /dev/null +++ b/src/plugin_system/base/base_tool.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from typing import Any, List, Tuple +from rich.traceback import install + +from src.common.logger import get_logger +from src.plugin_system.base.component_types import ComponentType, ToolInfo, ToolParamType + +install(extra_lines=3) + +logger = get_logger("base_tool") + + +class BaseTool(ABC): + """所有工具的基类""" + + name: str = "" + """工具的名称""" + description: str = "" + """工具的描述""" + parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = [] + """工具的参数定义,为[("param_name", param_type, "description", required, enum_values)]格式 + param_name: 参数名称 + param_type: 参数类型 + description: 参数描述 + required: 是否必填 + enum_values: 枚举值列表 + 例如: [("arg1", ToolParamType.STRING, "参数1描述", True, None), ("arg2", ToolParamType.INTEGER, "参数2描述", False, ["1", "2", "3"])] + """ + available_for_llm: bool = False + """是否可供LLM使用""" + + @classmethod + def get_tool_definition(cls) -> dict[str, Any]: + """获取工具定义,用于LLM工具调用 + + Returns: + dict: 工具定义字典 + """ + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + + return {"name": cls.name, "description": cls.description, "parameters": cls.parameters} + + @classmethod + def get_tool_info(cls) -> ToolInfo: + """获取工具信息""" + if not cls.name or not cls.description or not cls.parameters: + raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") + + return ToolInfo( + name=cls.name, + tool_description=cls.description, + enabled=cls.available_for_llm, + tool_parameters=cls.parameters, + component_type=ComponentType.TOOL, + ) + + @abstractmethod + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + """执行工具函数(供llm调用) + 通过该方法,maicore会通过llm的tool call来调用工具 + 传入的是json格式的参数,符合parameters定义的格式 + + Args: + function_args: 工具调用参数 + + Returns: + dict: 工具执行结果 + """ + raise NotImplementedError("子类必须实现execute方法") + + async def direct_execute(self, **function_args: dict[str, Any]) -> dict[str, Any]: + """直接执行工具函数(供插件调用) + 通过该方法,插件可以直接调用工具,而不需要传入字典格式的参数 + 插件可以直接调用此方法,用更加明了的方式传入参数 + 示例: result = await tool.direct_execute(arg1="参数",arg2="参数2") + + 工具开发者可以重写此方法以实现与llm调用差异化的执行逻辑 + + Args: + **function_args: 工具调用参数 + + Returns: + dict: 工具执行结果 + """ + parameter_required = [param[0] for param in self.parameters if param[3]] # 获取所有必填参数名 + for param_name in parameter_required: + if param_name not in function_args: + raise ValueError(f"工具类 {self.__class__.__name__} 缺少必要参数: {param_name}") + + return await self.execute(function_args) diff --git a/src/plugin_system/base/component_types.py b/src/plugin_system/base/component_types.py index eeb2a5a0..7775f5fb 100644 --- a/src/plugin_system/base/component_types.py +++ b/src/plugin_system/base/component_types.py @@ -1,8 +1,9 @@ from enum import Enum -from typing import Dict, Any, List, Optional +from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass, field from maim_message import Seg +from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType # 组件类型枚举 class ComponentType(Enum): @@ -10,6 +11,7 @@ class ComponentType(Enum): ACTION = "action" # 动作组件 COMMAND = "command" # 命令组件 + TOOL = "tool" # 服务组件(预留) SCHEDULER = "scheduler" # 定时任务组件(预留) EVENT_HANDLER = "event_handler" # 事件处理组件(预留) @@ -146,6 +148,18 @@ class CommandInfo(ComponentInfo): self.component_type = ComponentType.COMMAND +@dataclass +class ToolInfo(ComponentInfo): + """工具组件信息""" + + tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义 + tool_description: str = "" # 工具描述 + + def __post_init__(self): + super().__post_init__() + self.component_type = ComponentType.TOOL + + @dataclass class EventHandlerInfo(ComponentInfo): """事件处理器组件信息""" diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 2ea89b88..59a03b73 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -6,6 +6,7 @@ from src.common.logger import get_logger from src.plugin_system.base.component_types import ( ComponentInfo, ActionInfo, + ToolInfo, CommandInfo, EventHandlerInfo, PluginInfo, @@ -13,6 +14,7 @@ from src.plugin_system.base.component_types import ( ) from src.plugin_system.base.base_command import BaseCommand from src.plugin_system.base.base_action import BaseAction +from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_events_handler import BaseEventHandler logger = get_logger("component_registry") @@ -30,7 +32,7 @@ class ComponentRegistry: """组件注册表 命名空间式组件名 -> 组件信息""" self._components_by_type: Dict[ComponentType, Dict[str, ComponentInfo]] = {types: {} for types in ComponentType} """类型 -> 组件原名称 -> 组件信息""" - self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseEventHandler]]] = {} + self._components_classes: Dict[str, Type[Union[BaseCommand, BaseAction, BaseTool, BaseEventHandler]]] = {} """命名空间式组件名 -> 组件类""" # 插件注册表 @@ -49,6 +51,10 @@ class ComponentRegistry: self._command_patterns: Dict[Pattern, str] = {} """编译后的正则 -> command名""" + # 工具特定注册表 + self._tool_registry: Dict[str, Type[BaseTool]] = {} # 工具名 -> 工具类 + self._llm_available_tools: Dict[str, Type[BaseTool]] = {} # llm可用的工具名 -> 工具类 + # EventHandler特定注册表 self._event_handler_registry: Dict[str, Type[BaseEventHandler]] = {} """event_handler名 -> event_handler类""" @@ -79,7 +85,9 @@ class ComponentRegistry: return True def register_component( - self, component_info: ComponentInfo, component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler]] + self, + component_info: ComponentInfo, + component_class: Type[Union[BaseCommand, BaseAction, BaseEventHandler, BaseTool]], ) -> bool: """注册组件 @@ -125,6 +133,10 @@ class ComponentRegistry: assert isinstance(component_info, CommandInfo) assert issubclass(component_class, BaseCommand) ret = self._register_command_component(component_info, component_class) + case ComponentType.TOOL: + assert isinstance(component_info, ToolInfo) + assert issubclass(component_class, BaseTool) + ret = self._register_tool_component(component_info, component_class) case ComponentType.EVENT_HANDLER: assert isinstance(component_info, EventHandlerInfo) assert issubclass(component_class, BaseEventHandler) @@ -180,6 +192,18 @@ class ComponentRegistry: return True + def _register_tool_component(self, tool_info: ToolInfo, tool_class: Type[BaseTool]) -> bool: + """注册Tool组件到Tool特定注册表""" + tool_name = tool_info.name + + self._tool_registry[tool_name] = tool_class + + # 如果是llm可用的且启用的工具,添加到 llm可用工具列表 + if tool_info.enabled: + self._llm_available_tools[tool_name] = tool_class + + return True + def _register_event_handler_component( self, handler_info: EventHandlerInfo, handler_class: Type[BaseEventHandler] ) -> bool: @@ -222,6 +246,9 @@ class ComponentRegistry: keys_to_remove = [k for k, v in self._command_patterns.items() if v == component_name] for key in keys_to_remove: self._command_patterns.pop(key) + case ComponentType.TOOL: + self._tool_registry.pop(component_name) + self._llm_available_tools.pop(component_name) case ComponentType.EVENT_HANDLER: from .events_manager import events_manager # 延迟导入防止循环导入问题 @@ -234,13 +261,13 @@ class ComponentRegistry: self._components_classes.pop(namespaced_name) logger.info(f"组件 {component_name} 已移除") return True - except KeyError: - logger.warning(f"移除组件时未找到组件: {component_name}") + except KeyError as e: + logger.warning(f"移除组件时未找到组件: {component_name}, 发生错误: {e}") return False except Exception as e: logger.error(f"移除组件 {component_name} 时发生错误: {e}") return False - + def remove_plugin_registry(self, plugin_name: str) -> bool: """移除插件注册信息 @@ -281,6 +308,10 @@ class ComponentRegistry: assert isinstance(target_component_info, CommandInfo) pattern = target_component_info.command_pattern self._command_patterns[re.compile(pattern)] = component_name + case ComponentType.TOOL: + assert isinstance(target_component_info, ToolInfo) + assert issubclass(target_component_class, BaseTool) + self._llm_available_tools[component_name] = target_component_class case ComponentType.EVENT_HANDLER: assert isinstance(target_component_info, EventHandlerInfo) assert issubclass(target_component_class, BaseEventHandler) @@ -308,20 +339,29 @@ class ComponentRegistry: logger.warning(f"组件 {component_name} 未注册,无法禁用") return False target_component_info.enabled = False - match component_type: - case ComponentType.ACTION: - self._default_actions.pop(component_name, None) - case ComponentType.COMMAND: - self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name} - case ComponentType.EVENT_HANDLER: - self._enabled_event_handlers.pop(component_name, None) - from .events_manager import events_manager # 延迟导入防止循环导入问题 + try: + match component_type: + case ComponentType.ACTION: + self._default_actions.pop(component_name) + case ComponentType.COMMAND: + self._command_patterns = {k: v for k, v in self._command_patterns.items() if v != component_name} + case ComponentType.TOOL: + self._llm_available_tools.pop(component_name) + case ComponentType.EVENT_HANDLER: + self._enabled_event_handlers.pop(component_name) + from .events_manager import events_manager # 延迟导入防止循环导入问题 - await events_manager.unregister_event_subscriber(component_name) - self._components[component_name].enabled = False - self._components_by_type[component_type][component_name].enabled = False - logger.info(f"组件 {component_name} 已禁用") - return True + await events_manager.unregister_event_subscriber(component_name) + self._components[component_name].enabled = False + self._components_by_type[component_type][component_name].enabled = False + logger.info(f"组件 {component_name} 已禁用") + return True + except KeyError as e: + logger.warning(f"禁用组件时未找到组件或已禁用: {component_name}, 发生错误: {e}") + return False + except Exception as e: + logger.error(f"禁用组件 {component_name} 时发生错误: {e}") + return False # === 组件查询方法 === def get_component_info( @@ -371,7 +411,7 @@ class ComponentRegistry: self, component_name: str, component_type: Optional[ComponentType] = None, - ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler]]]: + ) -> Optional[Union[Type[BaseCommand], Type[BaseAction], Type[BaseEventHandler], Type[BaseTool]]]: """获取组件类,支持自动命名空间解析 Args: @@ -476,6 +516,27 @@ class ComponentRegistry: command_info, ) + # === Tool 特定查询方法 === + def get_tool_registry(self) -> Dict[str, Type[BaseTool]]: + """获取Tool注册表""" + return self._tool_registry.copy() + + def get_llm_available_tools(self) -> Dict[str, Type[BaseTool]]: + """获取LLM可用的Tool列表""" + return self._llm_available_tools.copy() + + def get_registered_tool_info(self, tool_name: str) -> Optional[ToolInfo]: + """获取Tool信息 + + Args: + tool_name: 工具名称 + + Returns: + ToolInfo: 工具信息对象,如果工具不存在则返回 None + """ + info = self.get_component_info(tool_name, ComponentType.TOOL) + return info if isinstance(info, ToolInfo) else None + # === EventHandler 特定查询方法 === def get_event_handler_registry(self) -> Dict[str, Type[BaseEventHandler]]: @@ -529,17 +590,21 @@ class ComponentRegistry: """获取注册中心统计信息""" action_components: int = 0 command_components: int = 0 + tool_components: int = 0 events_handlers: int = 0 for component in self._components.values(): if component.component_type == ComponentType.ACTION: action_components += 1 elif component.component_type == ComponentType.COMMAND: command_components += 1 + elif component.component_type == ComponentType.TOOL: + tool_components += 1 elif component.component_type == ComponentType.EVENT_HANDLER: events_handlers += 1 return { "action_components": action_components, "command_components": command_components, + "tool_components": tool_components, "event_handlers": events_handlers, "total_components": len(self._components), "total_plugins": len(self._plugins), diff --git a/src/plugin_system/core/global_announcement_manager.py b/src/plugin_system/core/global_announcement_manager.py index 9f7052f5..bb6f06b4 100644 --- a/src/plugin_system/core/global_announcement_manager.py +++ b/src/plugin_system/core/global_announcement_manager.py @@ -13,6 +13,8 @@ class GlobalAnnouncementManager: self._user_disabled_commands: Dict[str, List[str]] = {} # 用户禁用的事件处理器,chat_id -> [handler_name] self._user_disabled_event_handlers: Dict[str, List[str]] = {} + # 用户禁用的工具,chat_id -> [tool_name] + self._user_disabled_tools: Dict[str, List[str]] = {} def disable_specific_chat_action(self, chat_id: str, action_name: str) -> bool: """禁用特定聊天的某个动作""" @@ -77,6 +79,27 @@ class GlobalAnnouncementManager: return False return False + def disable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool: + """禁用特定聊天的某个工具""" + if chat_id not in self._user_disabled_tools: + self._user_disabled_tools[chat_id] = [] + if tool_name in self._user_disabled_tools[chat_id]: + logger.warning(f"工具 {tool_name} 已经被禁用") + return False + self._user_disabled_tools[chat_id].append(tool_name) + return True + + def enable_specific_chat_tool(self, chat_id: str, tool_name: str) -> bool: + """启用特定聊天的某个工具""" + if chat_id in self._user_disabled_tools: + try: + self._user_disabled_tools[chat_id].remove(tool_name) + return True + except ValueError: + logger.warning(f"工具 {tool_name} 不在禁用列表中") + return False + return False + def get_disabled_chat_actions(self, chat_id: str) -> List[str]: """获取特定聊天禁用的所有动作""" return self._user_disabled_actions.get(chat_id, []).copy() @@ -88,6 +111,10 @@ class GlobalAnnouncementManager: def get_disabled_chat_event_handlers(self, chat_id: str) -> List[str]: """获取特定聊天禁用的所有事件处理器""" return self._user_disabled_event_handlers.get(chat_id, []).copy() + + def get_disabled_chat_tools(self, chat_id: str) -> List[str]: + """获取特定聊天禁用的所有工具""" + return self._user_disabled_tools.get(chat_id, []).copy() global_announcement_manager = GlobalAnnouncementManager() diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index dfafda18..014b7a0c 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -224,6 +224,18 @@ class PluginManager: list: 已注册的插件类名称列表。 """ return list(self.plugin_classes.keys()) + + def get_plugin_path(self, plugin_name: str) -> Optional[str]: + """ + 获取指定插件的路径。 + + Args: + plugin_name: 插件名称 + + Returns: + Optional[str]: 插件目录的绝对路径,如果插件不存在则返回None。 + """ + return self.plugin_paths.get(plugin_name) # === 私有方法 === # == 目录管理 == @@ -346,6 +358,7 @@ class PluginManager: stats = component_registry.get_registry_stats() action_count = stats.get("action_components", 0) command_count = stats.get("command_components", 0) + tool_count = stats.get("tool_components", 0) event_handler_count = stats.get("event_handlers", 0) total_components = stats.get("total_components", 0) @@ -353,7 +366,7 @@ class PluginManager: if total_registered > 0: logger.info("🎉 插件系统加载完成!") logger.info( - f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, EventHandler: {event_handler_count})" + f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, EventHandler: {event_handler_count})" ) # 显示详细的插件列表 @@ -388,6 +401,9 @@ class PluginManager: command_components = [ c for c in plugin_info.components if c.component_type == ComponentType.COMMAND ] + tool_components = [ + c for c in plugin_info.components if c.component_type == ComponentType.TOOL + ] event_handler_components = [ c for c in plugin_info.components if c.component_type == ComponentType.EVENT_HANDLER ] @@ -399,7 +415,9 @@ class PluginManager: if command_components: command_names = [c.name for c in command_components] logger.info(f" ⚡ Command组件: {', '.join(command_names)}") - + if tool_components: + tool_names = [c.name for c in tool_components] + logger.info(f" 🛠️ Tool组件: {', '.join(tool_names)}") if event_handler_components: event_handler_names = [c.name for c in event_handler_components] logger.info(f" 📢 EventHandler组件: {', '.join(event_handler_names)}") diff --git a/src/tools/tool_executor.py b/src/plugin_system/core/tool_use.py similarity index 76% rename from src/tools/tool_executor.py rename to src/plugin_system/core/tool_use.py index 0f50ca2a..7a5eee31 100644 --- a/src/tools/tool_executor.py +++ b/src/plugin_system/core/tool_use.py @@ -1,14 +1,16 @@ -from src.llm_models.utils_model import LLMRequest -from src.config.config import global_config import time -from src.common.logger import get_logger +from typing import List, Dict, Tuple, Optional, Any +from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.core.global_announcement_manager import global_announcement_manager +from src.llm_models.utils_model import LLMRequest +from src.llm_models.payload_content import ToolCall +from src.config.config import global_config, model_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.tools.tool_use import ToolUser -from src.chat.utils.json_utils import process_llm_tool_calls -from typing import List, Dict, Tuple, Optional from src.chat.message_receive.chat_stream import get_chat_manager +from src.common.logger import get_logger -logger = get_logger("tool_executor") +logger = get_logger("tool_use") def init_tool_executor_prompt(): @@ -28,6 +30,10 @@ If you need to use a tool, please directly call the corresponding tool function. Prompt(tool_executor_prompt, "tool_executor_prompt") +# 初始化提示词 +init_tool_executor_prompt() + + class ToolExecutor: """独立的工具执行器组件 @@ -46,13 +52,7 @@ class ToolExecutor: self.chat_stream = get_chat_manager().get_stream(self.chat_id) self.log_prefix = f"[{get_chat_manager().get_stream_name(self.chat_id) or self.chat_id}]" - self.llm_model = LLMRequest( - model=global_config.model.tool_use, - request_type="tool_executor", - ) - - # 初始化工具实例 - self.tool_instance = ToolUser() + self.llm_model = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="tool_executor") # 缓存配置 self.enable_cache = enable_cache @@ -63,7 +63,7 @@ class ToolExecutor: async def execute_from_chat_message( self, target_message: str, chat_history: str, sender: str, return_details: bool = False - ) -> Tuple[List[Dict], List[str], str]: + ) -> Tuple[List[Dict[str, Any]], List[str], str]: """从聊天消息执行工具 Args: @@ -73,7 +73,7 @@ class ToolExecutor: return_details: 是否返回详细信息(使用的工具列表和提示词) Returns: - 如果return_details为False: List[Dict] - 工具执行结果列表 + 如果return_details为False: Tuple[List[Dict], List[str], str] - (工具执行结果列表, 空, 空) 如果return_details为True: Tuple[List[Dict], List[str], str] - (结果列表, 使用的工具, 提示词) """ @@ -82,15 +82,15 @@ class ToolExecutor: if cached_result := self._get_from_cache(cache_key): logger.info(f"{self.log_prefix}使用缓存结果,跳过工具执行") if not return_details: - return cached_result, [], "使用缓存结果" + return cached_result, [], "" # 从缓存结果中提取工具名称 used_tools = [result.get("tool_name", "unknown") for result in cached_result] - return cached_result, used_tools, "使用缓存结果" + return cached_result, used_tools, "" # 缓存未命中,执行工具调用 # 获取可用工具 - tools = self.tool_instance._define_tools() + tools = self._get_tool_definitions() # 获取当前时间 time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) @@ -110,17 +110,12 @@ class ToolExecutor: logger.debug(f"{self.log_prefix}开始LLM工具调用分析") # 调用LLM进行工具决策 - response, other_info = await self.llm_model.generate_response_async(prompt=prompt, tools=tools) - - # 解析LLM响应 - if len(other_info) == 3: - reasoning_content, model_name, tool_calls = other_info - else: - reasoning_content, model_name = other_info - tool_calls = None + response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async( + prompt=prompt, tools=tools + ) # 执行工具调用 - tool_results, used_tools = await self._execute_tool_calls(tool_calls) + tool_results, used_tools = await self.execute_tool_calls(tool_calls) # 缓存结果 if tool_results: @@ -134,7 +129,12 @@ class ToolExecutor: else: return tool_results, [], "" - async def _execute_tool_calls(self, tool_calls) -> Tuple[List[Dict], List[str]]: + def _get_tool_definitions(self) -> List[Dict[str, Any]]: + all_tools = get_llm_available_tool_definitions() + user_disabled_tools = global_announcement_manager.get_disabled_chat_tools(self.chat_id) + return [definition for name, definition in all_tools if name not in user_disabled_tools] + + async def execute_tool_calls(self, tool_calls: Optional[List[ToolCall]]) -> Tuple[List[Dict[str, Any]], List[str]]: """执行工具调用 Args: @@ -143,36 +143,23 @@ class ToolExecutor: Returns: Tuple[List[Dict], List[str]]: (工具执行结果列表, 使用的工具名称列表) """ - tool_results = [] + tool_results: List[Dict[str, Any]] = [] used_tools = [] if not tool_calls: logger.debug(f"{self.log_prefix}无需执行工具") - return tool_results, used_tools + return [], [] logger.info(f"{self.log_prefix}开始执行工具调用: {tool_calls}") - # 处理工具调用 - success, valid_tool_calls, error_msg = process_llm_tool_calls(tool_calls) - - if not success: - logger.error(f"{self.log_prefix}工具调用解析失败: {error_msg}") - return tool_results, used_tools - - if not valid_tool_calls: - logger.debug(f"{self.log_prefix}无有效工具调用") - return tool_results, used_tools - # 执行每个工具调用 - for tool_call in valid_tool_calls: + for tool_call in tool_calls: try: - tool_name = tool_call.get("name", "unknown_tool") - used_tools.append(tool_name) - + tool_name = tool_call.func_name logger.debug(f"{self.log_prefix}执行工具: {tool_name}") # 执行工具 - result = await self.tool_instance.execute_tool_call(tool_call) + result = await self.execute_tool_call(tool_call) if result: tool_info = { @@ -182,15 +169,15 @@ class ToolExecutor: "tool_name": tool_name, "timestamp": time.time(), } - tool_results.append(tool_info) - - logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}") content = tool_info["content"] if not isinstance(content, (str, list, tuple)): - content = str(content) + tool_info["content"] = str(content) + + tool_results.append(tool_info) + used_tools.append(tool_name) + logger.info(f"{self.log_prefix}工具{tool_name}执行成功,类型: {tool_info['type']}") preview = content[:200] logger.debug(f"{self.log_prefix}工具{tool_name}结果内容: {preview}...") - except Exception as e: logger.error(f"{self.log_prefix}工具{tool_name}执行失败: {e}") # 添加错误信息到结果中 @@ -205,6 +192,42 @@ class ToolExecutor: return tool_results, used_tools + async def execute_tool_call(self, tool_call: ToolCall, tool_instance: Optional[BaseTool] = None) -> Optional[Dict[str, Any]]: + # sourcery skip: use-assigned-variable + """执行单个工具调用 + + Args: + tool_call: 工具调用对象 + + Returns: + Optional[Dict]: 工具调用结果,如果失败则返回None + """ + try: + function_name = tool_call.func_name + function_args = tool_call.args or {} + function_args["llm_called"] = True # 标记为LLM调用 + + # 获取对应工具实例 + tool_instance = tool_instance or get_tool_instance(function_name) + if not tool_instance: + logger.warning(f"未知工具名称: {function_name}") + return None + + # 执行工具 + result = await tool_instance.execute(function_args) + if result: + return { + "tool_call_id": tool_call.call_id, + "role": "tool", + "name": function_name, + "type": "function", + "content": result["content"], + } + return None + except Exception as e: + logger.error(f"执行工具调用时发生错误: {str(e)}") + raise e + def _generate_cache_key(self, target_message: str, chat_history: str, sender: str) -> str: """生成缓存键 @@ -272,18 +295,7 @@ class ToolExecutor: if expired_keys: logger.debug(f"{self.log_prefix}清理了{len(expired_keys)}个过期缓存") - def get_available_tools(self) -> List[str]: - """获取可用工具列表 - - Returns: - List[str]: 可用工具名称列表 - """ - tools = self.tool_instance._define_tools() - return [tool.get("function", {}).get("name", "unknown") for tool in tools] - - async def execute_specific_tool( - self, tool_name: str, tool_args: Dict, validate_args: bool = True - ) -> Optional[Dict]: + async def execute_specific_tool_simple(self, tool_name: str, tool_args: Dict) -> Optional[Dict]: """直接执行指定工具 Args: @@ -295,11 +307,15 @@ class ToolExecutor: Optional[Dict]: 工具执行结果,失败时返回None """ try: - tool_call = {"name": tool_name, "arguments": tool_args} + tool_call = ToolCall( + call_id=f"direct_tool_{time.time()}", + func_name=tool_name, + args=tool_args, + ) logger.info(f"{self.log_prefix}直接执行工具: {tool_name}") - result = await self.tool_instance.execute_tool_call(tool_call) + result = await self.execute_tool_call(tool_call) if result: tool_info = { @@ -366,12 +382,8 @@ class ToolExecutor: logger.info(f"{self.log_prefix}缓存TTL修改为: {cache_ttl}") -# 初始化提示词 -init_tool_executor_prompt() - - """ -使用示例: +ToolExecutor使用示例: # 1. 基础使用 - 从聊天消息执行工具(启用缓存,默认TTL=3) executor = ToolExecutor(executor_id="my_executor") @@ -394,13 +406,12 @@ results, used_tools, prompt = await executor.execute_from_chat_message( ) # 5. 直接执行特定工具 -result = await executor.execute_specific_tool( +result = await executor.execute_specific_tool_simple( tool_name="get_knowledge", tool_args={"query": "机器学习"} ) # 6. 缓存管理 -available_tools = executor.get_available_tools() cache_status = executor.get_cache_status() # 查看缓存状态 executor.clear_cache() # 清空缓存 executor.set_cache_config(cache_ttl=5) # 动态修改缓存配置 diff --git a/src/plugins/built_in/core_actions/emoji.py b/src/plugins/built_in/core_actions/emoji.py index fa922dc1..790f2096 100644 --- a/src/plugins/built_in/core_actions/emoji.py +++ b/src/plugins/built_in/core_actions/emoji.py @@ -58,6 +58,7 @@ class EmojiAction(BaseAction): associated_types = ["emoji"] async def execute(self) -> Tuple[bool, str]: + # sourcery skip: assign-if-exp, introduce-default-else, swap-if-else-branches, use-named-expression """执行表情动作""" logger.info(f"{self.log_prefix} 决定发送表情") diff --git a/src/tools/not_using/lpmm_get_knowledge.py b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py similarity index 83% rename from src/tools/not_using/lpmm_get_knowledge.py rename to src/plugins/built_in/knowledge/lpmm_get_knowledge.py index 467db6ed..fd3d811b 100644 --- a/src/tools/not_using/lpmm_get_knowledge.py +++ b/src/plugins/built_in/knowledge/lpmm_get_knowledge.py @@ -1,10 +1,9 @@ -from src.tools.tool_can_use.base_tool import BaseTool - -# from src.common.database import db -from src.common.logger import get_logger from typing import Dict, Any -from src.chat.knowledge.knowledge_lib import qa_manager +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.knowledge.knowledge_lib import qa_manager +from src.plugin_system import BaseTool, ToolParamType logger = get_logger("lpmm_get_knowledge_tool") @@ -14,14 +13,11 @@ class SearchKnowledgeFromLPMMTool(BaseTool): name = "lpmm_search_knowledge" description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具" - parameters = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索查询关键词"}, - "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, - }, - "required": ["query"], - } + parameters = [ + ("query", ToolParamType.STRING, "搜索查询关键词", True, None), + ("threshold", ToolParamType.FLOAT, "相似度阈值,0.0到1.0之间", False, None), + ] + available_for_llm = global_config.lpmm_knowledge.enable async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]: """执行知识库搜索 diff --git a/src/plugins/built_in/plugin_management/plugin.py b/src/plugins/built_in/plugin_management/plugin.py index de846dd5..c2489a38 100644 --- a/src/plugins/built_in/plugin_management/plugin.py +++ b/src/plugins/built_in/plugin_management/plugin.py @@ -11,6 +11,7 @@ from src.plugin_system import ( component_manage_api, ComponentInfo, ComponentType, + send_api, ) @@ -27,8 +28,15 @@ class ManagementCommand(BaseCommand): or not self.message.message_info.user_info or str(self.message.message_info.user_info.user_id) not in self.get_config("plugin.permission", []) # type: ignore ): - await self.send_text("你没有权限使用插件管理命令") + await self._send_message("你没有权限使用插件管理命令") return False, "没有权限", True + if not self.message.chat_stream: + await self._send_message("无法获取聊天流信息") + return False, "无法获取聊天流信息", True + self.stream_id = self.message.chat_stream.stream_id + if not self.stream_id: + await self._send_message("无法获取聊天流信息") + return False, "无法获取聊天流信息", True command_list = self.matched_groups["manage_command"].strip().split(" ") if len(command_list) == 1: await self.show_help("all") @@ -42,7 +50,7 @@ class ManagementCommand(BaseCommand): case "help": await self.show_help("all") case _: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if len(command_list) == 3: if command_list[1] == "plugin": @@ -56,7 +64,7 @@ class ManagementCommand(BaseCommand): case "rescan": await self._rescan_plugin_dirs() case _: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True elif command_list[1] == "component": if command_list[2] == "list": @@ -64,10 +72,10 @@ class ManagementCommand(BaseCommand): elif command_list[2] == "help": await self.show_help("component") else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if len(command_list) == 4: if command_list[1] == "plugin": @@ -81,28 +89,28 @@ class ManagementCommand(BaseCommand): case "add_dir": await self._add_dir(command_list[3]) case _: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True elif command_list[1] == "component": if command_list[2] != "list": - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if command_list[3] == "enabled": await self._list_enabled_components() elif command_list[3] == "disabled": await self._list_disabled_components() else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if len(command_list) == 5: if command_list[1] != "component": - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if command_list[2] != "list": - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if command_list[3] == "enabled": await self._list_enabled_components(target_type=command_list[4]) @@ -111,11 +119,11 @@ class ManagementCommand(BaseCommand): elif command_list[3] == "type": await self._list_registered_components_by_type(command_list[4]) else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if len(command_list) == 6: if command_list[1] != "component": - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True if command_list[2] == "enable": if command_list[3] == "global": @@ -123,7 +131,7 @@ class ManagementCommand(BaseCommand): elif command_list[3] == "local": await self._locally_enable_component(command_list[4], command_list[5]) else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True elif command_list[2] == "disable": if command_list[3] == "global": @@ -131,10 +139,10 @@ class ManagementCommand(BaseCommand): elif command_list[3] == "local": await self._locally_disable_component(command_list[4], command_list[5]) else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True else: - await self.send_text("插件管理命令不合法") + await self._send_message("插件管理命令不合法") return False, "命令不合法", True return True, "命令执行完成", True @@ -180,51 +188,51 @@ class ManagementCommand(BaseCommand): ) case _: return - await self.send_text(help_msg) + await self._send_message(help_msg) async def _list_loaded_plugins(self): plugins = plugin_manage_api.list_loaded_plugins() - await self.send_text(f"已加载的插件: {', '.join(plugins)}") + await self._send_message(f"已加载的插件: {', '.join(plugins)}") async def _list_registered_plugins(self): plugins = plugin_manage_api.list_registered_plugins() - await self.send_text(f"已注册的插件: {', '.join(plugins)}") + await self._send_message(f"已注册的插件: {', '.join(plugins)}") async def _rescan_plugin_dirs(self): plugin_manage_api.rescan_plugin_directory() - await self.send_text("插件目录重新扫描执行中") + await self._send_message("插件目录重新扫描执行中") async def _load_plugin(self, plugin_name: str): success, count = plugin_manage_api.load_plugin(plugin_name) if success: - await self.send_text(f"插件加载成功: {plugin_name}") + await self._send_message(f"插件加载成功: {plugin_name}") else: if count == 0: - await self.send_text(f"插件{plugin_name}为禁用状态") - await self.send_text(f"插件加载失败: {plugin_name}") + await self._send_message(f"插件{plugin_name}为禁用状态") + await self._send_message(f"插件加载失败: {plugin_name}") async def _unload_plugin(self, plugin_name: str): success = await plugin_manage_api.remove_plugin(plugin_name) if success: - await self.send_text(f"插件卸载成功: {plugin_name}") + await self._send_message(f"插件卸载成功: {plugin_name}") else: - await self.send_text(f"插件卸载失败: {plugin_name}") + await self._send_message(f"插件卸载失败: {plugin_name}") async def _reload_plugin(self, plugin_name: str): success = await plugin_manage_api.reload_plugin(plugin_name) if success: - await self.send_text(f"插件重新加载成功: {plugin_name}") + await self._send_message(f"插件重新加载成功: {plugin_name}") else: - await self.send_text(f"插件重新加载失败: {plugin_name}") + await self._send_message(f"插件重新加载失败: {plugin_name}") async def _add_dir(self, dir_path: str): - await self.send_text(f"正在添加插件目录: {dir_path}") + await self._send_message(f"正在添加插件目录: {dir_path}") success = plugin_manage_api.add_plugin_directory(dir_path) await asyncio.sleep(0.5) # 防止乱序发送 if success: - await self.send_text(f"插件目录添加成功: {dir_path}") + await self._send_message(f"插件目录添加成功: {dir_path}") else: - await self.send_text(f"插件目录添加失败: {dir_path}") + await self._send_message(f"插件目录添加失败: {dir_path}") def _fetch_all_registered_components(self) -> List[ComponentInfo]: all_plugin_info = component_manage_api.get_all_plugin_info() @@ -255,29 +263,29 @@ class ManagementCommand(BaseCommand): async def _list_all_registered_components(self): components_info = self._fetch_all_registered_components() if not components_info: - await self.send_text("没有注册的组件") + await self._send_message("没有注册的组件") return all_components_str = ", ".join( f"{component.name} ({component.component_type})" for component in components_info ) - await self.send_text(f"已注册的组件: {all_components_str}") + await self._send_message(f"已注册的组件: {all_components_str}") async def _list_enabled_components(self, target_type: str = "global"): components_info = self._fetch_all_registered_components() if not components_info: - await self.send_text("没有注册的组件") + await self._send_message("没有注册的组件") return if target_type == "global": enabled_components = [component for component in components_info if component.enabled] if not enabled_components: - await self.send_text("没有满足条件的已启用全局组件") + await self._send_message("没有满足条件的已启用全局组件") return enabled_components_str = ", ".join( f"{component.name} ({component.component_type})" for component in enabled_components ) - await self.send_text(f"满足条件的已启用全局组件: {enabled_components_str}") + await self._send_message(f"满足条件的已启用全局组件: {enabled_components_str}") elif target_type == "local": locally_disabled_components = self._fetch_locally_disabled_components() enabled_components = [ @@ -286,28 +294,28 @@ class ManagementCommand(BaseCommand): if (component.name not in locally_disabled_components and component.enabled) ] if not enabled_components: - await self.send_text("本聊天没有满足条件的已启用组件") + await self._send_message("本聊天没有满足条件的已启用组件") return enabled_components_str = ", ".join( f"{component.name} ({component.component_type})" for component in enabled_components ) - await self.send_text(f"本聊天满足条件的已启用组件: {enabled_components_str}") + await self._send_message(f"本聊天满足条件的已启用组件: {enabled_components_str}") async def _list_disabled_components(self, target_type: str = "global"): components_info = self._fetch_all_registered_components() if not components_info: - await self.send_text("没有注册的组件") + await self._send_message("没有注册的组件") return if target_type == "global": disabled_components = [component for component in components_info if not component.enabled] if not disabled_components: - await self.send_text("没有满足条件的已禁用全局组件") + await self._send_message("没有满足条件的已禁用全局组件") return disabled_components_str = ", ".join( f"{component.name} ({component.component_type})" for component in disabled_components ) - await self.send_text(f"满足条件的已禁用全局组件: {disabled_components_str}") + await self._send_message(f"满足条件的已禁用全局组件: {disabled_components_str}") elif target_type == "local": locally_disabled_components = self._fetch_locally_disabled_components() disabled_components = [ @@ -316,12 +324,12 @@ class ManagementCommand(BaseCommand): if (component.name in locally_disabled_components or not component.enabled) ] if not disabled_components: - await self.send_text("本聊天没有满足条件的已禁用组件") + await self._send_message("本聊天没有满足条件的已禁用组件") return disabled_components_str = ", ".join( f"{component.name} ({component.component_type})" for component in disabled_components ) - await self.send_text(f"本聊天满足条件的已禁用组件: {disabled_components_str}") + await self._send_message(f"本聊天满足条件的已禁用组件: {disabled_components_str}") async def _list_registered_components_by_type(self, target_type: str): match target_type: @@ -332,18 +340,18 @@ class ManagementCommand(BaseCommand): case "event_handler": component_type = ComponentType.EVENT_HANDLER case _: - await self.send_text(f"未知组件类型: {target_type}") + await self._send_message(f"未知组件类型: {target_type}") return components_info = component_manage_api.get_components_info_by_type(component_type) if not components_info: - await self.send_text(f"没有注册的 {target_type} 组件") + await self._send_message(f"没有注册的 {target_type} 组件") return components_str = ", ".join( f"{name} ({component.component_type})" for name, component in components_info.items() ) - await self.send_text(f"注册的 {target_type} 组件: {components_str}") + await self._send_message(f"注册的 {target_type} 组件: {components_str}") async def _globally_enable_component(self, component_name: str, component_type: str): match component_type: @@ -354,12 +362,12 @@ class ManagementCommand(BaseCommand): case "event_handler": target_component_type = ComponentType.EVENT_HANDLER case _: - await self.send_text(f"未知组件类型: {component_type}") + await self._send_message(f"未知组件类型: {component_type}") return if component_manage_api.globally_enable_component(component_name, target_component_type): - await self.send_text(f"全局启用组件成功: {component_name}") + await self._send_message(f"全局启用组件成功: {component_name}") else: - await self.send_text(f"全局启用组件失败: {component_name}") + await self._send_message(f"全局启用组件失败: {component_name}") async def _globally_disable_component(self, component_name: str, component_type: str): match component_type: @@ -370,13 +378,13 @@ class ManagementCommand(BaseCommand): case "event_handler": target_component_type = ComponentType.EVENT_HANDLER case _: - await self.send_text(f"未知组件类型: {component_type}") + await self._send_message(f"未知组件类型: {component_type}") return success = await component_manage_api.globally_disable_component(component_name, target_component_type) if success: - await self.send_text(f"全局禁用组件成功: {component_name}") + await self._send_message(f"全局禁用组件成功: {component_name}") else: - await self.send_text(f"全局禁用组件失败: {component_name}") + await self._send_message(f"全局禁用组件失败: {component_name}") async def _locally_enable_component(self, component_name: str, component_type: str): match component_type: @@ -387,16 +395,16 @@ class ManagementCommand(BaseCommand): case "event_handler": target_component_type = ComponentType.EVENT_HANDLER case _: - await self.send_text(f"未知组件类型: {component_type}") + await self._send_message(f"未知组件类型: {component_type}") return if component_manage_api.locally_enable_component( component_name, target_component_type, self.message.chat_stream.stream_id, ): - await self.send_text(f"本地启用组件成功: {component_name}") + await self._send_message(f"本地启用组件成功: {component_name}") else: - await self.send_text(f"本地启用组件失败: {component_name}") + await self._send_message(f"本地启用组件失败: {component_name}") async def _locally_disable_component(self, component_name: str, component_type: str): match component_type: @@ -407,16 +415,19 @@ class ManagementCommand(BaseCommand): case "event_handler": target_component_type = ComponentType.EVENT_HANDLER case _: - await self.send_text(f"未知组件类型: {component_type}") + await self._send_message(f"未知组件类型: {component_type}") return if component_manage_api.locally_disable_component( component_name, target_component_type, self.message.chat_stream.stream_id, ): - await self.send_text(f"本地禁用组件成功: {component_name}") + await self._send_message(f"本地禁用组件成功: {component_name}") else: - await self.send_text(f"本地禁用组件失败: {component_name}") + await self._send_message(f"本地禁用组件失败: {component_name}") + + async def _send_message(self, message: str): + await send_api.text_to_stream(message, self.stream_id, typing=False, storage_message=False) @register_plugin @@ -430,7 +441,9 @@ class PluginManagementPlugin(BasePlugin): "plugin": { "enabled": ConfigField(bool, default=False, description="是否启用插件"), "config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本"), - "permission": ConfigField(list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID"), + "permission": ConfigField( + list, default=[], description="有权限使用插件管理命令的用户列表,请填写字符串形式的用户ID" + ), }, } diff --git a/src/tools/not_using/get_knowledge.py b/src/tools/not_using/get_knowledge.py deleted file mode 100644 index c436d774..00000000 --- a/src/tools/not_using/get_knowledge.py +++ /dev/null @@ -1,133 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool -from src.chat.utils.utils import get_embedding -from src.common.database.database_model import Knowledges # Updated import -from src.common.logger import get_logger -from typing import Any, Union, List # Added List -import json # Added for parsing embedding -import math # Added for cosine similarity - -logger = get_logger("get_knowledge_tool") - - -class SearchKnowledgeTool(BaseTool): - """从知识库中搜索相关信息的工具""" - - name = "search_knowledge" - description = "使用工具从知识库中搜索相关信息" - parameters = { - "type": "object", - "properties": { - "query": {"type": "string", "description": "搜索查询关键词"}, - "threshold": {"type": "number", "description": "相似度阈值,0.0到1.0之间"}, - }, - "required": ["query"], - } - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行知识库搜索 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - query = "" # Initialize query to ensure it's defined in except block - try: - query = function_args.get("query") - threshold = function_args.get("threshold", 0.4) - - # 调用知识库搜索 - embedding = await get_embedding(query, request_type="info_retrieval") - if embedding: - knowledge_info = self.get_info_from_db(embedding, limit=3, threshold=threshold) - if knowledge_info: - content = f"你知道这些知识: {knowledge_info}" - else: - content = f"你不太了解有关{query}的知识" - return {"type": "knowledge", "id": query, "content": content} - return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你知识库炸了"} - except Exception as e: - logger.error(f"知识库搜索工具执行失败: {str(e)}") - return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"} - - @staticmethod - def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float: - """计算两个向量之间的余弦相似度""" - dot_product = sum(p * q for p, q in zip(vec1, vec2, strict=False)) - magnitude1 = math.sqrt(sum(p * p for p in vec1)) - magnitude2 = math.sqrt(sum(q * q for q in vec2)) - if magnitude1 == 0 or magnitude2 == 0: - return 0.0 - return dot_product / (magnitude1 * magnitude2) - - @staticmethod - def get_info_from_db( - query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False - ) -> Union[str, list]: - """从数据库中获取相关信息 - - Args: - query_embedding: 查询的嵌入向量 - limit: 最大返回结果数 - threshold: 相似度阈值 - return_raw: 是否返回原始结果 - - Returns: - Union[str, list]: 格式化的信息字符串或原始结果列表 - """ - if not query_embedding: - return "" if not return_raw else [] - - similar_items = [] - try: - all_knowledges = Knowledges.select() - for item in all_knowledges: - try: - item_embedding_str = item.embedding - if not item_embedding_str: - logger.warning(f"Knowledge item ID {item.id} has empty embedding string.") - continue - item_embedding = json.loads(item_embedding_str) - if not isinstance(item_embedding, list) or not all( - isinstance(x, (int, float)) for x in item_embedding - ): - logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.") - continue - except json.JSONDecodeError: - logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}") - continue - except AttributeError: - logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.") - continue - - similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding) - - if similarity >= threshold: - similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item}) - - # 按相似度降序排序 - similar_items.sort(key=lambda x: x["similarity"], reverse=True) - - # 应用限制 - results = similar_items[:limit] - logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}") - - except Exception as e: - logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}") - return "" if not return_raw else [] - - if not results: - return "" if not return_raw else [] - - if return_raw: - # Peewee 模型实例不能直接序列化为 JSON,如果需要原始模型,调用者需要处理 - # 这里返回包含内容和相似度的字典列表 - return [{"content": r["content"], "similarity": r["similarity"]} for r in results] - else: - # 返回所有找到的内容,用换行分隔 - return "\n".join(str(result["content"]) for result in results) - - -# 注册工具 -# register_tool(SearchKnowledgeTool) diff --git a/src/tools/tool_can_use/__init__.py b/src/tools/tool_can_use/__init__.py deleted file mode 100644 index 14bae04c..00000000 --- a/src/tools/tool_can_use/__init__.py +++ /dev/null @@ -1,20 +0,0 @@ -from src.tools.tool_can_use.base_tool import ( - BaseTool, - register_tool, - discover_tools, - get_all_tool_definitions, - get_tool_instance, - TOOL_REGISTRY, -) - -__all__ = [ - "BaseTool", - "register_tool", - "discover_tools", - "get_all_tool_definitions", - "get_tool_instance", - "TOOL_REGISTRY", -] - -# 自动发现并注册工具 -discover_tools() diff --git a/src/tools/tool_can_use/base_tool.py b/src/tools/tool_can_use/base_tool.py deleted file mode 100644 index 89d051dc..00000000 --- a/src/tools/tool_can_use/base_tool.py +++ /dev/null @@ -1,115 +0,0 @@ -from typing import List, Any, Optional, Type -import inspect -import importlib -import pkgutil -import os -from src.common.logger import get_logger -from rich.traceback import install - -install(extra_lines=3) - -logger = get_logger("base_tool") - -# 工具注册表 -TOOL_REGISTRY = {} - - -class BaseTool: - """所有工具的基类""" - - # 工具名称,子类必须重写 - name = None - # 工具描述,子类必须重写 - description = None - # 工具参数定义,子类必须重写 - parameters = None - - @classmethod - def get_tool_definition(cls) -> dict[str, Any]: - """获取工具定义,用于LLM工具调用 - - Returns: - dict: 工具定义字典 - """ - if not cls.name or not cls.description or not cls.parameters: - raise NotImplementedError(f"工具类 {cls.__name__} 必须定义 name, description 和 parameters 属性") - - return { - "type": "function", - "function": {"name": cls.name, "description": cls.description, "parameters": cls.parameters}, - } - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行工具函数 - - Args: - function_args: 工具调用参数 - - Returns: - dict: 工具执行结果 - """ - raise NotImplementedError("子类必须实现execute方法") - - -def register_tool(tool_class: Type[BaseTool]): - """注册工具到全局注册表 - - Args: - tool_class: 工具类 - """ - if not issubclass(tool_class, BaseTool): - raise TypeError(f"{tool_class.__name__} 不是 BaseTool 的子类") - - tool_name = tool_class.name - if not tool_name: - raise ValueError(f"工具类 {tool_class.__name__} 没有定义 name 属性") - - TOOL_REGISTRY[tool_name] = tool_class - logger.info(f"已注册: {tool_name}") - - -def discover_tools(): - """自动发现并注册tool_can_use目录下的所有工具""" - # 获取当前目录路径 - current_dir = os.path.dirname(os.path.abspath(__file__)) - package_name = os.path.basename(current_dir) - - # 遍历包中的所有模块 - for _, module_name, _ in pkgutil.iter_modules([current_dir]): - # 跳过当前模块和__pycache__ - if module_name == "base_tool" or module_name.startswith("__"): - continue - - # 导入模块 - module = importlib.import_module(f"src.tools.{package_name}.{module_name}") - - # 查找模块中的工具类 - for _, obj in inspect.getmembers(module): - if inspect.isclass(obj) and issubclass(obj, BaseTool) and obj != BaseTool: - register_tool(obj) - - logger.info(f"工具发现完成,共注册 {len(TOOL_REGISTRY)} 个工具") - - -def get_all_tool_definitions() -> List[dict[str, Any]]: - """获取所有已注册工具的定义 - - Returns: - List[dict]: 工具定义列表 - """ - return [tool_class().get_tool_definition() for tool_class in TOOL_REGISTRY.values()] - - -def get_tool_instance(tool_name: str) -> Optional[BaseTool]: - """获取指定名称的工具实例 - - Args: - tool_name: 工具名称 - - Returns: - Optional[BaseTool]: 工具实例,如果找不到则返回None - """ - tool_class = TOOL_REGISTRY.get(tool_name) - if not tool_class: - return None - return tool_class() diff --git a/src/tools/tool_can_use/compare_numbers_tool.py b/src/tools/tool_can_use/compare_numbers_tool.py deleted file mode 100644 index 236a4587..00000000 --- a/src/tools/tool_can_use/compare_numbers_tool.py +++ /dev/null @@ -1,45 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool -from src.common.logger import get_logger -from typing import Any - -logger = get_logger("compare_numbers_tool") - - -class CompareNumbersTool(BaseTool): - """比较两个数大小的工具""" - - name = "compare_numbers" - description = "使用工具 比较两个数的大小,返回较大的数" - parameters = { - "type": "object", - "properties": { - "num1": {"type": "number", "description": "第一个数字"}, - "num2": {"type": "number", "description": "第二个数字"}, - }, - "required": ["num1", "num2"], - } - - async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: - """执行比较两个数的大小 - - Args: - function_args: 工具参数 - - Returns: - dict: 工具执行结果 - """ - num1: int | float = function_args.get("num1") # type: ignore - num2: int | float = function_args.get("num2") # type: ignore - - try: - if num1 > num2: - result = f"{num1} 大于 {num2}" - elif num1 < num2: - result = f"{num1} 小于 {num2}" - else: - result = f"{num1} 等于 {num2}" - - return {"name": self.name, "content": result} - except Exception as e: - logger.error(f"比较数字失败: {str(e)}") - return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"} diff --git a/src/tools/tool_can_use/rename_person_tool.py b/src/tools/tool_can_use/rename_person_tool.py deleted file mode 100644 index 17e62468..00000000 --- a/src/tools/tool_can_use/rename_person_tool.py +++ /dev/null @@ -1,103 +0,0 @@ -from src.tools.tool_can_use.base_tool import BaseTool -from src.person_info.person_info import get_person_info_manager -from src.common.logger import get_logger - - -logger = get_logger("rename_person_tool") - - -class RenamePersonTool(BaseTool): - name = "rename_person" - description = ( - "这个工具可以改变用户的昵称。你可以选择改变对他人的称呼。你想给人改名,叫别人别的称呼,需要调用这个工具。" - ) - parameters = { - "type": "object", - "properties": { - "person_name": {"type": "string", "description": "需要重新取名的用户的当前昵称"}, - "message_content": { - "type": "string", - "description": "当前的聊天内容或特定要求,用于提供取名建议的上下文,尽可能详细。", - }, - }, - "required": ["person_name"], - } - - async def execute(self, function_args: dict): - """ - 执行取名工具逻辑 - - Args: - function_args (dict): 包含 'person_name' 和可选 'message_content' 的字典 - message_txt (str): 原始消息文本 (这里未使用,因为 message_content 更明确) - - Returns: - dict: 包含执行结果的字典 - """ - person_name_to_find = function_args.get("person_name") - request_context = function_args.get("message_content", "") # 如果没有提供,则为空字符串 - - if not person_name_to_find: - return {"name": self.name, "content": "错误:必须提供需要重命名的用户昵称 (person_name)。"} - person_info_manager = get_person_info_manager() - try: - # 1. 根据昵称查找用户信息 - logger.debug(f"尝试根据昵称 '{person_name_to_find}' 查找用户...") - person_info = await person_info_manager.get_person_info_by_name(person_name_to_find) - - if not person_info: - logger.info(f"未找到昵称为 '{person_name_to_find}' 的用户。") - return { - "name": self.name, - "content": f"找不到昵称为 '{person_name_to_find}' 的用户。请确保输入的是我之前为该用户取的昵称。", - } - - person_id = person_info.get("person_id") - user_nickname = person_info.get("nickname") # 这是用户原始昵称 - user_cardname = person_info.get("user_cardname") - user_avatar = person_info.get("user_avatar") - - if not person_id: - logger.error(f"找到了用户 '{person_name_to_find}' 但无法获取 person_id") - return {"name": self.name, "content": f"找到了用户 '{person_name_to_find}' 但获取内部ID时出错。"} - - # 2. 调用 qv_person_name 进行取名 - logger.debug( - f"为用户 {person_id} (原昵称: {person_name_to_find}) 调用 qv_person_name,请求上下文: '{request_context}'" - ) - result = await person_info_manager.qv_person_name( - person_id=person_id, - user_nickname=user_nickname, # type: ignore - user_cardname=user_cardname, # type: ignore - user_avatar=user_avatar, # type: ignore - request=request_context, - ) - - # 3. 处理结果 - if result and result.get("nickname"): - new_name = result["nickname"] - # reason = result.get("reason", "未提供理由") - logger.info(f"成功为用户 {person_id} 取了新昵称: {new_name}") - - content = f"已成功将用户 {person_name_to_find} 的备注名更新为 {new_name}" - logger.info(content) - return {"name": self.name, "content": content} - else: - logger.warning(f"为用户 {person_id} 调用 qv_person_name 后未能成功获取新昵称。") - # 尝试从内存中获取可能已经更新的名字 - current_name = await person_info_manager.get_value(person_id, "person_name") - if current_name and current_name != person_name_to_find: - return { - "name": self.name, - "content": f"尝试取新昵称时遇到一点小问题,但我已经将 '{person_name_to_find}' 的昵称更新为 '{current_name}' 了。", - } - else: - return { - "name": self.name, - "content": f"尝试为 '{person_name_to_find}' 取新昵称时遇到了问题,未能成功生成。可能需要稍后再试。", - } - - except Exception as e: - error_msg = f"重命名失败: {str(e)}" - logger.error(error_msg, exc_info=True) - return {"name": self.name, "content": error_msg} diff --git a/src/tools/tool_use.py b/src/tools/tool_use.py deleted file mode 100644 index 6a8cd48a..00000000 --- a/src/tools/tool_use.py +++ /dev/null @@ -1,56 +0,0 @@ -import json -from src.common.logger import get_logger -from src.tools.tool_can_use import get_all_tool_definitions, get_tool_instance - -logger = get_logger("tool_use") - - -class ToolUser: - @staticmethod - def _define_tools(): - """获取所有已注册工具的定义 - - Returns: - list: 工具定义列表 - """ - return get_all_tool_definitions() - - @staticmethod - async def execute_tool_call(tool_call): - # sourcery skip: use-assigned-variable - """执行特定的工具调用 - - Args: - tool_call: 工具调用对象 - message_txt: 原始消息文本 - - Returns: - dict: 工具调用结果 - """ - try: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - - # 获取对应工具实例 - tool_instance = get_tool_instance(function_name) - if not tool_instance: - logger.warning(f"未知工具名称: {function_name}") - return None - - # 执行工具 - result = await tool_instance.execute(function_args) - if result: - # 直接使用 function_name 作为 tool_type - tool_type = function_name - - return { - "tool_call_id": tool_call["id"], - "role": "tool", - "name": function_name, - "type": tool_type, - "content": result["content"], - } - return None - except Exception as e: - logger.error(f"执行工具调用时发生错误: {str(e)}") - return None diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 39857d66..fae41f82 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "4.5.0" +version = "6.0.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请在修改后将version的值进行变更 @@ -213,134 +213,9 @@ file_log_level = "DEBUG" # 文件日志级别,可选: DEBUG, INFO, WARNING, ER suppress_libraries = ["faiss","httpx", "urllib3", "asyncio", "websockets", "httpcore", "requests", "peewee", "openai","uvicorn","jieba"] # 完全屏蔽的库 library_log_levels = { "aiohttp" = "WARNING"} # 设置特定库的日志级别 -#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写 - -# stream = : 用于指定模型是否是使用流式输出 -# pri_in = : 用于指定模型输入价格 -# pri_out = : 用于指定模型输出价格 -# temp = : 用于指定模型温度 -# enable_thinking = : 用于指定模型是否启用思考 -# thinking_budget = : 用于指定模型思考最长长度 - [debug] show_prompt = false # 是否显示prompt - -[model] -model_max_output_length = 1024 # 模型单次返回的最大token数 - -#------------必填:组件模型------------ - -[model.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 #模型的输入价格(非必填,可以记录消耗) -pri_out = 8 #模型的输出价格(非必填,可以记录消耗) -#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 -temp = 0.2 #模型的温度,新V3建议0.1-0.3 - -[model.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型 -# 强烈建议使用免费的小模型 -name = "Qwen/Qwen3-8B" -provider = "SILICONFLOW" -pri_in = 0 -pri_out = 0 -temp = 0.7 -enable_thinking = false # 是否启用思考 - -[model.replyer_1] # 首要回复模型,还用于表达器和表达方式学习 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 #模型的输入价格(非必填,可以记录消耗) -pri_out = 8 #模型的输出价格(非必填,可以记录消耗) -#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 -temp = 0.2 #模型的温度,新V3建议0.1-0.3 - -[model.replyer_2] # 次要回复模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 #模型的输入价格(非必填,可以记录消耗) -pri_out = 8 #模型的输出价格(非必填,可以记录消耗) -#默认temp 0.2 如果你使用的是老V3或者其他模型,请自己修改temp参数 -temp = 0.2 #模型的温度,新V3建议0.1-0.3 - -[model.planner] #决策:负责决定麦麦该做什么的模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.3 - -[model.emotion] #负责麦麦的情绪变化 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.3 - - -[model.memory] # 记忆模型 -name = "Qwen/Qwen3-30B-A3B" -provider = "SILICONFLOW" -pri_in = 0.7 -pri_out = 2.8 -temp = 0.7 -enable_thinking = false # 是否启用思考 - -[model.vlm] # 图像识别模型 -name = "Pro/Qwen/Qwen2.5-VL-7B-Instruct" -provider = "SILICONFLOW" -pri_in = 0.35 -pri_out = 0.35 - -[model.voice] # 语音识别模型 -name = "FunAudioLLM/SenseVoiceSmall" -provider = "SILICONFLOW" -pri_in = 0 -pri_out = 0 - -[model.tool_use] #工具调用模型,需要使用支持工具调用的模型 -name = "Qwen/Qwen3-14B" -provider = "SILICONFLOW" -pri_in = 0.5 -pri_out = 2 -temp = 0.7 -enable_thinking = false # 是否启用思考(qwen3 only) - -#嵌入模型 -[model.embedding] -name = "BAAI/bge-m3" -provider = "SILICONFLOW" -pri_in = 0 -pri_out = 0 - - -#------------LPMM知识库模型------------ - -[model.lpmm_entity_extract] # 实体提取模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.2 - - -[model.lpmm_rdf_build] # RDF构建模型 -name = "Pro/deepseek-ai/DeepSeek-V3" -provider = "SILICONFLOW" -pri_in = 2 -pri_out = 8 -temp = 0.2 - - -[model.lpmm_qa] # 问答模型 -name = "Qwen/Qwen3-30B-A3B" -provider = "SILICONFLOW" -pri_in = 0.7 -pri_out = 2.8 -temp = 0.7 -enable_thinking = false # 是否启用思考 - [maim_message] auth_token = [] # 认证令牌,用于API验证,为空则不启用验证 # 以下项目若要使用需要打开use_custom,并单独配置maim_message的服务器 @@ -356,8 +231,4 @@ key_file = "" # SSL密钥文件路径,仅在use_wss=true时有效 enable = true [experimental] #实验性功能 -enable_friend_chat = false # 是否启用好友聊天 - - - - +enable_friend_chat = false # 是否启用好友聊天 \ No newline at end of file diff --git a/template/lpmm_config_template.toml b/template/lpmm_config_template.toml deleted file mode 100644 index 5bf24732..00000000 --- a/template/lpmm_config_template.toml +++ /dev/null @@ -1,60 +0,0 @@ -[lpmm] -version = "0.1.0" - -# LLM API 服务提供商,可配置多个 -[[llm_providers]] -name = "localhost" -base_url = "http://127.0.0.1:8888/v1/" -api_key = "lm_studio" - -[[llm_providers]] -name = "siliconflow" -base_url = "https://api.siliconflow.cn/v1/" -api_key = "" - -[entity_extract.llm] -# 设置用于实体提取的LLM模型 -provider = "siliconflow" # 服务提供商 -model = "deepseek-ai/DeepSeek-V3" # 模型名称 - -[rdf_build.llm] -# 设置用于RDF构建的LLM模型 -provider = "siliconflow" # 服务提供商 -model = "deepseek-ai/DeepSeek-V3" # 模型名称 - -[embedding] -# 设置用于文本嵌入的Embedding模型 -provider = "siliconflow" # 服务提供商 -model = "Pro/BAAI/bge-m3" # 模型名称 -dimension = 1024 # 嵌入维度 - -[rag.params] -# RAG参数配置 -synonym_search_top_k = 10 # 同义词搜索TopK -synonym_threshold = 0.8 # 同义词阈值(相似度高于此阈值的词语会被认为是同义词) - -[qa.llm] -# 设置用于QA的LLM模型 -provider = "siliconflow" # 服务提供商 -model = "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" # 模型名称 - -[info_extraction] -workers = 3 # 实体提取同时执行线程数,非Pro模型不要设置超过5 - -[qa.params] -# QA参数配置 -relation_search_top_k = 10 # 关系搜索TopK -relation_threshold = 0.5 # 关系阈值(相似度高于此阈值的关系会被认为是相关的关系) -paragraph_search_top_k = 1000 # 段落搜索TopK(不能过小,可能影响搜索结果) -paragraph_node_weight = 0.05 # 段落节点权重(在图搜索&PPR计算中的权重,当搜索仅使用DPR时,此参数不起作用) -ent_filter_top_k = 10 # 实体过滤TopK -ppr_damping = 0.8 # PPR阻尼系数 -res_top_k = 3 # 最终提供的文段TopK - -[persistence] -# 持久化配置(存储中间数据,防止重复计算) -data_root_path = "data" # 数据根目录 -imported_data_path = "data/imported_lpmm_data" # 转换为json的raw文件数据路径 -openie_data_path = "data/openie" # OpenIE数据路径 -embedding_data_dir = "data/embedding" # 嵌入数据目录 -rag_data_dir = "data/rag" # RAG数据目录 diff --git a/template/model_config_template.toml b/template/model_config_template.toml new file mode 100644 index 00000000..3dcff6f8 --- /dev/null +++ b/template/model_config_template.toml @@ -0,0 +1,171 @@ +[inner] +version = "1.1.1" + +# 配置文件版本号迭代规则同bot_config.toml + +[[api_providers]] # API服务提供商(可以配置多个) +name = "DeepSeek" # API服务商名称(可随意命名,在models的api-provider中需使用这个命名) +base_url = "https://api.deepseek.cn/v1" # API服务商的BaseURL +api_key = "your-api-key-here" # API密钥(请替换为实际的API密钥) +client_type = "openai" # 请求客户端(可选,默认值为"openai",使用gimini等Google系模型时请配置为"gemini") +max_retry = 2 # 最大重试次数(单个模型API调用失败,最多重试的次数) +timeout = 30 # API请求超时时间(单位:秒) +retry_interval = 10 # 重试间隔时间(单位:秒) + +[[api_providers]] # SiliconFlow的API服务商配置 +name = "SiliconFlow" +base_url = "https://api.siliconflow.cn/v1" +api_key = "your-siliconflow-api-key" +client_type = "openai" +max_retry = 2 +timeout = 30 +retry_interval = 10 + +[[api_providers]] # 特殊:Google的Gimini使用特殊API,与OpenAI格式不兼容,需要配置client为"gemini" +name = "Google" +base_url = "https://api.google.com/v1" +api_key = "your-google-api-key-1" +client_type = "gemini" +max_retry = 2 +timeout = 30 +retry_interval = 10 + + +[[models]] # 模型(可以配置多个) +model_identifier = "deepseek-chat" # 模型标识符(API服务商提供的模型标识符) +name = "deepseek-v3" # 模型名称(可随意命名,在后面中需使用这个命名) +api_provider = "DeepSeek" # API服务商名称(对应在api_providers中配置的服务商名称) +price_in = 2.0 # 输入价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0) +price_out = 8.0 # 输出价格(用于API调用统计,单位:元/ M token)(可选,若无该字段,默认值为0) +#force_stream_mode = true # 强制流式输出模式(若模型不支持非流式输出,请取消该注释,启用强制流式输出,若无该字段,默认值为false) + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-V3" +name = "siliconflow-deepseek-v3" +api_provider = "SiliconFlow" +price_in = 2.0 +price_out = 8.0 + +[[models]] +model_identifier = "Pro/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B" +name = "deepseek-r1-distill-qwen-32b" +api_provider = "SiliconFlow" +price_in = 4.0 +price_out = 16.0 + +[[models]] +model_identifier = "Qwen/Qwen3-8B" +name = "qwen3-8b" +api_provider = "SiliconFlow" +price_in = 0 +price_out = 0 +[models.extra_params] # 可选的额外参数配置 +enable_thinking = false # 不启用思考 + +[[models]] +model_identifier = "Qwen/Qwen3-14B" +name = "qwen3-14b" +api_provider = "SiliconFlow" +price_in = 0.5 +price_out = 2.0 +[models.extra_params] # 可选的额外参数配置 +enable_thinking = false # 不启用思考 + +[[models]] +model_identifier = "Qwen/Qwen3-30B-A3B" +name = "qwen3-30b" +api_provider = "SiliconFlow" +price_in = 0.7 +price_out = 2.8 +[models.extra_params] # 可选的额外参数配置 +enable_thinking = false # 不启用思考 + +[[models]] +model_identifier = "Qwen/Qwen2.5-VL-72B-Instruct" +name = "qwen2.5-vl-72b" +api_provider = "SiliconFlow" +price_in = 4.13 +price_out = 4.13 + +[[models]] +model_identifier = "FunAudioLLM/SenseVoiceSmall" +name = "sensevoice-small" +api_provider = "SiliconFlow" +price_in = 0 +price_out = 0 + +[[models]] +model_identifier = "BAAI/bge-m3" +name = "bge-m3" +api_provider = "SiliconFlow" +price_in = 0 +price_out = 0 + + +[model_task_config.utils] # 在麦麦的一些组件中使用的模型,例如表情包模块,取名模块,关系模块,是麦麦必须的模型 +model_list = ["siliconflow-deepseek-v3"] # 使用的模型列表,每个子项对应上面的模型名称(name) +temperature = 0.2 # 模型温度,新V3建议0.1-0.3 +max_tokens = 800 # 最大输出token数 + +[model_task_config.utils_small] # 在麦麦的一些组件中使用的小模型,消耗量较大,建议使用速度较快的小模型 +model_list = ["qwen3-8b"] +temperature = 0.7 +max_tokens = 800 + +[model_task_config.replyer_1] # 首要回复模型,还用于表达器和表达方式学习 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 # 模型温度,新V3建议0.1-0.3 +max_tokens = 800 + +[model_task_config.replyer_2] # 次要回复模型 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.7 +max_tokens = 800 + +[model_task_config.planner] #决策:负责决定麦麦该做什么的模型 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.3 +max_tokens = 800 + +[model_task_config.emotion] #负责麦麦的情绪变化 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.3 +max_tokens = 800 + +[model_task_config.memory] # 记忆模型 +model_list = ["qwen3-30b"] +temperature = 0.7 +max_tokens = 800 + +[model_task_config.vlm] # 图像识别模型 +model_list = ["qwen2.5-vl-72b"] +max_tokens = 800 + +[model_task_config.voice] # 语音识别模型 +model_list = ["sensevoice-small"] + +[model_task_config.tool_use] #工具调用模型,需要使用支持工具调用的模型 +model_list = ["qwen3-14b"] +temperature = 0.7 +max_tokens = 800 + +#嵌入模型 +[model_task_config.embedding] +model_list = ["bge-m3"] + +#------------LPMM知识库模型------------ + +[model_task_config.lpmm_entity_extract] # 实体提取模型 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 +max_tokens = 800 + +[model_task_config.lpmm_rdf_build] # RDF构建模型 +model_list = ["siliconflow-deepseek-v3"] +temperature = 0.2 +max_tokens = 800 + +[model_task_config.lpmm_qa] # 问答模型 +model_list = ["deepseek-r1-distill-qwen-32b"] +temperature = 0.7 +max_tokens = 800 diff --git a/template/template.env b/template/template.env index 4718203d..d9b6e2bd 100644 --- a/template/template.env +++ b/template/template.env @@ -1,23 +1,2 @@ HOST=127.0.0.1 -PORT=8000 - -# 密钥和url - -# 硅基流动 -SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ -# DeepSeek官方 -DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 -# 阿里百炼 -BAILIAN_BASE_URL = https://dashscope.aliyuncs.com/compatible-mode/v1 -# 火山引擎 -HUOSHAN_BASE_URL = -# xxxxx平台 -xxxxxxx_BASE_URL=https://xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx - -# 定义你要用的api的key(需要去对应网站申请哦) -DEEP_SEEK_KEY= -CHAT_ANY_WHERE_KEY= -SILICONFLOW_KEY= -BAILIAN_KEY = -HUOSHAN_KEY = -xxxxxxx_KEY= +PORT=8000 \ No newline at end of file