Merge pull request #1 from MaiM-with-u/dev

Dev
pull/1160/head
小飞 2025-08-04 13:49:56 +08:00 committed by GitHub
commit 422fe87c7d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
121 changed files with 6248 additions and 6394 deletions

4
.gitignore vendored
View File

@ -41,6 +41,7 @@ config/bot_config.toml.bak
config/lpmm_config.toml config/lpmm_config.toml
config/lpmm_config.toml.bak config/lpmm_config.toml.bak
template/compare/bot_config_template.toml template/compare/bot_config_template.toml
template/compare/model_config_template.toml
(测试版)麦麦生成人格.bat (测试版)麦麦生成人格.bat
(临时版)麦麦开始学习.bat (临时版)麦麦开始学习.bat
src/plugins/utils/statistic.py src/plugins/utils/statistic.py
@ -321,4 +322,5 @@ run_pet.bat
config.toml config.toml
interested_rates.txt interested_rates.txt
MaiBot.code-workspace

33
bot.py
View File

@ -74,36 +74,6 @@ def easter_egg():
print(rainbow_text) print(rainbow_text)
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
async def graceful_shutdown(): async def graceful_shutdown():
try: try:
@ -229,9 +199,6 @@ def raw_main():
easter_egg() easter_egg()
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
# 返回MainSystem实例 # 返回MainSystem实例
return MainSystem() return MainSystem()

View File

@ -1,5 +1,17 @@
# Changelog # Changelog
## [0.10.0] - 2025-7-1
### 主要功能更改
- 工具系统重构,现在合并到了插件系统中
- 彻底重构了整个LLM Request了现在支持模型轮询和更多灵活的参数
- 同时重构了整个模型配置系统升级需要重新配置llm配置文件
- 随着LLM Request的重构插件系统彻底重构完成。插件系统进入稳定状态仅增加新的API
- 具体相比于之前的更改可以查看[changes.md](./changes.md)
### 细节优化
- 修复了lint爆炸的问题代码更加规范了
- 修改了log的颜色更加护眼
## [0.9.1] - 2025-7-26 ## [0.9.1] - 2025-7-26
### 主要修复和优化 ### 主要修复和优化

View File

@ -25,6 +25,7 @@
- 这意味着你终于可以动态控制是否继续后续消息的处理了。 - 这意味着你终于可以动态控制是否继续后续消息的处理了。
8. 移除了dependency_manager但是依然保留了`python_dependencies`属性,等待后续重构。 8. 移除了dependency_manager但是依然保留了`python_dependencies`属性,等待后续重构。
- 一并移除了文档有关manager的内容。 - 一并移除了文档有关manager的内容。
9. 增加了工具的有关api
# 插件系统修改 # 插件系统修改
1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)** 1. 现在所有的匹配模式不再是关键字了,而是枚举类。**(可能有遗漏)**
@ -57,30 +58,12 @@
15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。 15. 实现了组件的局部禁用,也就是针对某一个聊天禁用的功能。
- 通过`disable_specific_chat_action``enable_specific_chat_action``disable_specific_chat_command``enable_specific_chat_command``disable_specific_chat_event_handler``enable_specific_chat_event_handler`来操作 - 通过`disable_specific_chat_action``enable_specific_chat_action``disable_specific_chat_command``enable_specific_chat_command``disable_specific_chat_event_handler``enable_specific_chat_event_handler`来操作
- 同样不保存到配置文件~ - 同样不保存到配置文件~
16. 把`BaseTool`一并合并进入了插件系统
# 官方插件修改 # 官方插件修改
1. `HelloWorld`插件现在有一个样例的`EventHandler`。 1. `HelloWorld`插件现在有一个样例的`EventHandler`。
2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。 2. 内置插件增加了一个通过`Command`来管理插件的功能。具体是使用`/pm`命令唤起。(需要自行启用)
3. `HelloWorld`插件现在有一个样例的`CompareNumbersTool`。
### TODO
把这个看起来就很别扭的config获取方式改一下
# 吐槽
```python
plugin_path = Path(plugin_file)
if plugin_path.parent.name != "plugins":
# 插件包格式parent_dir.plugin
module_name = f"plugins.{plugin_path.parent.name}.plugin"
else:
# 单文件格式plugins.filename
module_name = f"plugins.{plugin_path.stem}"
```
```python
plugin_path = Path(plugin_file)
module_name = ".".join(plugin_path.parent.parts)
```
这两个区别很大的。
### 执笔BGM ### 执笔BGM
塞壬唱片! 塞壬唱片!

BIN
docs/image-1.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

BIN
docs/image.png 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.9 KiB

View File

@ -0,0 +1,331 @@
# 模型配置指南
本文档将指导您如何配置 `model_config.toml` 文件,该文件用于配置 MaiBot 的各种AI模型和API服务提供商。
## 配置文件结构
配置文件主要包含以下几个部分:
- 版本信息
- API服务提供商配置
- 模型配置
- 模型任务配置
## 1. 版本信息
```toml
[inner]
version = "1.1.1"
```
用于标识配置文件的版本,遵循语义化版本规则。
## 2. API服务提供商配置
### 2.1 基本配置
使用 `[[api_providers]]` 数组配置多个API服务提供商
```toml
[[api_providers]]
name = "DeepSeek" # 服务商名称(自定义)
base_url = "https://api.deepseek.cn/v1" # API服务的基础URL
api_key = "your-api-key-here" # API密钥
client_type = "openai" # 客户端类型
max_retry = 2 # 最大重试次数
timeout = 30 # 超时时间(秒)
retry_interval = 10 # 重试间隔(秒)
```
### 2.2 配置参数说明
| 参数 | 必填 | 说明 | 默认值 |
|------|------|------|--------|
| `name` | ✅ | 服务商名称,需要在模型配置中引用 | - |
| `base_url` | ✅ | API服务的基础URL | - |
| `api_key` | ✅ | API密钥请替换为实际密钥 | - |
| `client_type` | ❌ | 客户端类型:`openai`OpenAI格式`gemini`Gemini格式现在支持不良好 | `openai` |
| `max_retry` | ❌ | API调用失败时的最大重试次数 | 2 |
| `timeout` | ❌ | API请求超时时间 | 30 |
| `retry_interval` | ❌ | 重试间隔时间(秒) | 10 |
### 2.3 支持的服务商示例
#### DeepSeek
```toml
[[api_providers]]
name = "DeepSeek"
base_url = "https://api.deepseek.cn/v1"
api_key = "your-deepseek-api-key"
client_type = "openai"
```
#### SiliconFlow
```toml
[[api_providers]]
name = "SiliconFlow"
base_url = "https://api.siliconflow.cn/v1"
api_key = "your-siliconflow-api-key"
client_type = "openai"
```
#### Google Gemini
```toml
[[api_providers]]
name = "Google"
base_url = "https://api.google.com/v1"
api_key = "your-google-api-key"
client_type = "gemini" # 注意Gemini需要使用特殊客户端
```
## 3. 模型配置
### 3.1 基本模型配置
使用 `[[models]]` 数组配置多个模型:
```toml
[[models]]
model_identifier = "deepseek-chat" # 模型在API服务商中的标识符
name = "deepseek-v3" # 自定义模型名称
api_provider = "DeepSeek" # 引用的API服务商名称
price_in = 2.0 # 输入价格(元/M token
price_out = 8.0 # 输出价格(元/M token
```
### 3.2 高级模型配置
#### 强制流式输出
对于不支持非流式输出的模型:
```toml
[[models]]
model_identifier = "some-model"
name = "custom-name"
api_provider = "Provider"
force_stream_mode = true # 启用强制流式输出
```
#### 额外参数配置`extra_params`
```toml
[[models]]
model_identifier = "Qwen/Qwen3-8B"
name = "qwen3-8b"
api_provider = "SiliconFlow"
[models.extra_params]
enable_thinking = false # 禁用思考
```
这里的 `extra_params` 可以包含任何API服务商支持的额外参数配置**配置时应参考相应的API文档**。
比如上面就是参考SiliconFlow的文档配置配置的`Qwen3`禁用思考参数。
![SiliconFlow文档截图](image-1.png)
以豆包文档为另一个例子
![豆包文档截图](image.png)
得到豆包`"doubao-seed-1-6-250615"`的禁用思考配置方法为
```toml
[[models]]
# 你的模型
[models.extra_params]
thinking = {type = "disabled"} # 禁用思考
```
请注意,`extra_params` 的配置应该构成一个合法的TOML字典结构具体内容取决于API服务商的要求。
### 3.3 配置参数说明
| 参数 | 必填 | 说明 |
|------|------|------|
| `model_identifier` | ✅ | API服务商提供的模型标识符 |
| `name` | ✅ | 自定义模型名称,用于在任务配置中引用 |
| `api_provider` | ✅ | 对应的API服务商名称 |
| `price_in` | ❌ | 输入价格(元/M token用于成本统计 |
| `price_out` | ❌ | 输出价格(元/M token用于成本统计 |
| `force_stream_mode` | ❌ | 是否强制使用流式输出 |
| `extra_params` | ❌ | 额外的模型参数配置 |
## 4. 模型任务配置
### utils - 工具模型
用于表情包模块、取名模块、关系模块等核心功能:
```toml
[model_task_config.utils]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### utils_small - 小型工具模型
用于高频率调用的场景,建议使用速度快的小模型:
```toml
[model_task_config.utils_small]
model_list = ["qwen3-8b"]
temperature = 0.7
max_tokens = 800
```
### replyer_1 - 主要回复模型
首要回复模型,也用于表达器和表达方式学习:
```toml
[model_task_config.replyer_1]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### replyer_2 - 次要回复模型
```toml
[model_task_config.replyer_2]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.7
max_tokens = 800
```
### planner - 决策模型
负责决定MaiBot该做什么
```toml
[model_task_config.planner]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### emotion - 情绪模型
负责MaiBot的情绪变化
```toml
[model_task_config.emotion]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.3
max_tokens = 800
```
### memory - 记忆模型
```toml
[model_task_config.memory]
model_list = ["qwen3-30b"]
temperature = 0.7
max_tokens = 800
```
### vlm - 视觉语言模型
用于图像识别:
```toml
[model_task_config.vlm]
model_list = ["qwen2.5-vl-72b"]
max_tokens = 800
```
### voice - 语音识别模型
```toml
[model_task_config.voice]
model_list = ["sensevoice-small"]
```
### embedding - 嵌入模型
```toml
[model_task_config.embedding]
model_list = ["bge-m3"]
```
### tool_use - 工具调用模型
需要使用支持工具调用的模型:
```toml
[model_task_config.tool_use]
model_list = ["qwen3-14b"]
temperature = 0.7
max_tokens = 800
```
### lpmm_entity_extract - 实体提取模型
```toml
[model_task_config.lpmm_entity_extract]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_rdf_build - RDF构建模型
```toml
[model_task_config.lpmm_rdf_build]
model_list = ["siliconflow-deepseek-v3"]
temperature = 0.2
max_tokens = 800
```
### lpmm_qa - 问答模型
```toml
[model_task_config.lpmm_qa]
model_list = ["deepseek-r1-distill-qwen-32b"]
temperature = 0.7
max_tokens = 800
```
## 5. 配置建议
### 5.1 Temperature 参数选择
| 任务类型 | 推荐温度 | 说明 |
|----------|----------|------|
| 精确任务(工具调用、实体提取) | 0.1-0.3 | 需要准确性和一致性 |
| 创意任务(对话、记忆) | 0.5-0.8 | 需要多样性和创造性 |
| 平衡任务(决策、情绪) | 0.3-0.5 | 平衡准确性和灵活性 |
### 5.2 模型选择建议
| 任务类型 | 推荐模型类型 | 示例 |
|----------|--------------|------|
| 高精度任务 | 大模型 | DeepSeek-V3, GPT-4 |
| 高频率任务 | 小模型 | Qwen3-8B |
| 多模态任务 | 专用模型 | Qwen2.5-VL, SenseVoice |
| 工具调用 | 支持Function Call的模型 | Qwen3-14B |
### 5.3 成本优化
1. **分层使用**:核心功能使用高质量模型,辅助功能使用经济模型
2. **合理配置max_tokens**:根据实际需求设置,避免浪费
3. **选择免费模型**对于测试环境优先使用price为0的模型
## 6. 配置验证
### 6.1 必要检查项
1. ✅ API密钥是否正确配置
2. ✅ 模型标识符是否与API服务商提供的一致
3. ✅ 任务配置中引用的模型名称是否在models中定义
4. ✅ 多模态任务是否配置了对应的专用模型
### 6.2 测试配置
建议在正式使用前:
1. 使用少量测试数据验证配置
2. 检查API调用是否正常
3. 确认成本统计功能正常工作
## 7. 故障排除
### 7.1 常见问题
**问题1**: API调用失败
- 检查API密钥是否正确
- 确认base_url是否可访问
- 检查模型标识符是否正确
**问题2**: 模型未找到
- 确认模型名称在任务配置和模型定义中一致
- 检查api_provider名称是否匹配
**问题3**: 响应异常
- 检查温度参数是否合理0-1之间
- 确认max_tokens设置是否合适
- 验证模型是否支持所需功能
### 7.2 日志查看
查看 `logs/` 目录下的日志文件,寻找相关错误信息。
## 8. 更新和维护
1. **定期更新**: 关注API服务商的模型更新及时调整配置
2. **性能监控**: 监控模型调用的成本和性能
3. **备份配置**: 在修改前备份当前配置文件

View File

@ -0,0 +1,194 @@
# 组件管理API
组件管理API模块提供了对插件组件的查询和管理功能使得插件能够获取和使用组件相关的信息。
## 导入方式
```python
from src.plugin_system.apis import component_manage_api
# 或者
from src.plugin_system import component_manage_api
```
## 功能概述
组件管理API主要提供以下功能
- **插件信息查询** - 获取所有插件或指定插件的信息。
- **组件查询** - 按名称或类型查询组件信息。
- **组件管理** - 启用或禁用组件,支持全局和局部操作。
## 主要功能
### 1. 获取所有插件信息
```python
def get_all_plugin_info() -> Dict[str, PluginInfo]:
```
获取所有插件的信息。
**Returns:**
- `Dict[str, PluginInfo]` - 包含所有插件信息的字典,键为插件名称,值为 `PluginInfo` 对象。
### 2. 获取指定插件信息
```python
def get_plugin_info(plugin_name: str) -> Optional[PluginInfo]:
```
获取指定插件的信息。
**Args:**
- `plugin_name` (str): 插件名称。
**Returns:**
- `Optional[PluginInfo]`: 插件信息对象,如果插件不存在则返回 `None`
### 3. 获取指定组件信息
```python
def get_component_info(component_name: str, component_type: ComponentType) -> Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
```
获取指定组件的信息。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `Optional[Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 组件信息对象,如果组件不存在则返回 `None`
### 4. 获取指定类型的所有组件信息
```python
def get_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
```
获取指定类型的所有组件信息。
**Args:**
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型组件信息的字典,键为组件名称,值为对应的组件信息对象。
### 5. 获取指定类型的所有启用的组件信息
```python
def get_enabled_components_info_by_type(component_type: ComponentType) -> Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]:
```
获取指定类型的所有启用的组件信息。
**Args:**
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `Dict[str, Union[CommandInfo, ActionInfo, EventHandlerInfo]]`: 包含指定类型启用组件信息的字典,键为组件名称,值为对应的组件信息对象。
### 6. 获取指定 Action 的注册信息
```python
def get_registered_action_info(action_name: str) -> Optional[ActionInfo]:
```
获取指定 Action 的注册信息。
**Args:**
- `action_name` (str): Action 名称。
**Returns:**
- `Optional[ActionInfo]` - Action 信息对象,如果 Action 不存在则返回 `None`
### 7. 获取指定 Command 的注册信息
```python
def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
```
获取指定 Command 的注册信息。
**Args:**
- `command_name` (str): Command 名称。
**Returns:**
- `Optional[CommandInfo]` - Command 信息对象,如果 Command 不存在则返回 `None`
### 8. 获取指定 Tool 的注册信息
```python
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
```
获取指定 Tool 的注册信息。
**Args:**
- `tool_name` (str): Tool 名称。
**Returns:**
- `Optional[ToolInfo]` - Tool 信息对象,如果 Tool 不存在则返回 `None`
### 9. 获取指定 EventHandler 的注册信息
```python
def get_registered_event_handler_info(event_handler_name: str) -> Optional[EventHandlerInfo]:
```
获取指定 EventHandler 的注册信息。
**Args:**
- `event_handler_name` (str): EventHandler 名称。
**Returns:**
- `Optional[EventHandlerInfo]` - EventHandler 信息对象,如果 EventHandler 不存在则返回 `None`
### 10. 全局启用指定组件
```python
def globally_enable_component(component_name: str, component_type: ComponentType) -> bool:
```
全局启用指定组件。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `bool` - 启用成功返回 `True`,否则返回 `False`
### 11. 全局禁用指定组件
```python
async def globally_disable_component(component_name: str, component_type: ComponentType) -> bool:
```
全局禁用指定组件。
**此函数是异步的,确保在异步环境中调用。**
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `bool` - 禁用成功返回 `True`,否则返回 `False`
### 12. 局部启用指定组件
```python
def locally_enable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
```
局部启用指定组件。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
- `stream_id` (str): 消息流 ID。
**Returns:**
- `bool` - 启用成功返回 `True`,否则返回 `False`
### 13. 局部禁用指定组件
```python
def locally_disable_component(component_name: str, component_type: ComponentType, stream_id: str) -> bool:
```
局部禁用指定组件。
**Args:**
- `component_name` (str): 组件名称。
- `component_type` (ComponentType): 组件类型。
- `stream_id` (str): 消息流 ID。
**Returns:**
- `bool` - 禁用成功返回 `True`,否则返回 `False`
### 14. 获取指定消息流中禁用的组件列表
```python
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
```
获取指定消息流中禁用的组件列表。
**Args:**
- `stream_id` (str): 消息流 ID。
- `component_type` (ComponentType): 组件类型。
**Returns:**
- `list[str]` - 禁用的组件名称列表。

View File

@ -1,6 +1,6 @@
# 配置API # 配置API
配置API模块提供了配置读取和用户信息获取等功能,让插件能够安全地访问全局配置和用户信息 配置API模块提供了配置读取功能让插件能够安全地访问全局配置和插件配置
## 导入方式 ## 导入方式

View File

@ -6,72 +6,51 @@
```python ```python
from src.plugin_system.apis import database_api from src.plugin_system.apis import database_api
# 或者
from src.plugin_system import database_api
``` ```
## 主要功能 ## 主要功能
### 1. 通用数据库查询 ### 1. 通用数据库操作
#### `db_query(model_class, query_type="get", filters=None, data=None, limit=None, order_by=None, single_result=False)`
执行数据库查询操作的通用接口
**参数:**
- `model_class`Peewee模型类如ActionRecords、Messages等
- `query_type`:查询类型,可选值: "get", "create", "update", "delete", "count"
- `filters`:过滤条件字典,键为字段名,值为要匹配的值
- `data`:用于创建或更新的数据字典
- `limit`:限制结果数量
- `order_by`:排序字段列表,使用字段名,前缀'-'表示降序
- `single_result`:是否只返回单个结果
**返回:**
根据查询类型返回不同的结果:
- "get":返回查询结果列表或单个结果
- "create":返回创建的记录
- "update":返回受影响的行数
- "delete":返回受影响的行数
- "count":返回记录数量
### 2. 便捷查询函数
#### `db_save(model_class, data, key_field=None, key_value=None)`
保存数据到数据库(创建或更新)
**参数:**
- `model_class`Peewee模型类
- `data`:要保存的数据字典
- `key_field`:用于查找现有记录的字段名
- `key_value`:用于查找现有记录的字段值
**返回:**
- `Dict[str, Any]`保存后的记录数据失败时返回None
#### `db_get(model_class, filters=None, order_by=None, limit=None)`
简化的查询函数
**参数:**
- `model_class`Peewee模型类
- `filters`:过滤条件字典
- `order_by`:排序字段
- `limit`:限制结果数量
**返回:**
- `Union[List[Dict], Dict, None]`:查询结果
### 3. 专用函数
#### `store_action_info(...)`
存储动作信息的专用函数
## 使用示例
### 1. 基本查询操作
```python ```python
from src.plugin_system.apis import database_api async def db_query(
from src.common.database.database_model import Messages, ActionRecords model_class: Type[Model],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
```
执行数据库查询操作的通用接口。
# 查询最近10条消息 **Args:**
- `model_class`: Peewee模型类。
- Peewee模型类可以在`src.common.database.database_model`模块中找到,如`ActionRecords`、`Messages`等。
- `data`: 用于创建或更新的数据
- `query_type`: 查询类型
- 可选值: `get`, `create`, `update`, `delete`, `count`
- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
- `limit`: 限制结果数量。
- `order_by`: 排序字段列表,使用字段名,前缀'-'表示降序。
- 排序字段,前缀`-`表示降序,例如`-time`表示按时间字段(即`time`字段)降序
- `single_result`: 是否只返回单个结果。
**Returns:**
- 根据查询类型返回不同的结果:
- `get`: 返回查询结果列表或单个结果。(如果 `single_result=True`
- `create`: 返回创建的记录。
- `update`: 返回受影响的行数。
- `delete`: 返回受影响的行数。
- `count`: 返回记录数量。
#### 示例
1. 查询最近10条消息
```python
messages = await database_api.db_query( messages = await database_api.db_query(
Messages, Messages,
query_type="get", query_type="get",
@ -79,180 +58,159 @@ messages = await database_api.db_query(
limit=10, limit=10,
order_by=["-time"] order_by=["-time"]
) )
# 查询单条记录
message = await database_api.db_query(
Messages,
query_type="get",
filters={"message_id": "msg_123"},
single_result=True
)
``` ```
2. 创建一条记录
### 2. 创建记录
```python ```python
# 创建新的动作记录
new_record = await database_api.db_query( new_record = await database_api.db_query(
ActionRecords, ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create", query_type="create",
data={
"action_id": "action_123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
}
) )
print(f"创建了记录: {new_record['id']}")
``` ```
3. 更新记录
### 3. 更新记录
```python ```python
# 更新动作状态
updated_count = await database_api.db_query( updated_count = await database_api.db_query(
ActionRecords, ActionRecords,
data={"action_done": True},
query_type="update", query_type="update",
filters={"action_id": "action_123"}, filters={"action_id": "123"},
data={"action_done": True, "completion_time": time.time()}
) )
print(f"更新了 {updated_count} 条记录")
``` ```
4. 删除记录
### 4. 删除记录
```python ```python
# 删除过期记录
deleted_count = await database_api.db_query( deleted_count = await database_api.db_query(
ActionRecords, ActionRecords,
query_type="delete", query_type="delete",
filters={"time__lt": time.time() - 86400} # 删除24小时前的记录 filters={"action_id": "123"}
) )
print(f"删除了 {deleted_count} 条过期记录")
``` ```
5. 计数
### 5. 统计查询
```python ```python
# 统计消息数量 count = await database_api.db_query(
message_count = await database_api.db_query(
Messages, Messages,
query_type="count", query_type="count",
filters={"chat_id": chat_stream.stream_id} filters={"chat_id": chat_stream.stream_id}
) )
print(f"该聊天有 {message_count} 条消息")
``` ```
### 6. 使用便捷函数 ### 2. 数据库保存
```python
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
```
保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录。
**Args:**
- `model_class`: Peewee模型类。
- `data`: 要保存的数据字典。
- `key_field`: 用于查找现有记录的字段名,例如"action_id"。
- `key_value`: 用于查找现有记录的字段值。
**Returns:**
- `Optional[Dict[str, Any]]`: 保存后的记录数据失败时返回None。
#### 示例
创建或更新一条记录
```python ```python
# 使用db_save进行创建或更新
record = await database_api.db_save( record = await database_api.db_save(
ActionRecords, ActionRecords,
{ {
"action_id": "action_123", "action_id": "123",
"time": time.time(), "time": time.time(),
"action_name": "TestAction", "action_name": "TestAction",
"action_done": True "action_done": True
}, },
key_field="action_id", key_field="action_id",
key_value="action_123" key_value="123"
) )
```
# 使用db_get进行简单查询 ### 3. 数据库获取
recent_messages = await database_api.db_get( ```python
async def db_get(
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
```
从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作。
**Args:**
- `model_class`: Peewee模型类。
- `filters`: 过滤条件字典,键为字段名,值为要匹配的值。
- `limit`: 限制结果数量。
- `order_by`: 排序字段,使用字段名,前缀'-'表示降序。
- `single_result`: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
**Returns:**
- `Union[List[Dict], Dict, None]`: 查询结果列表或单个结果(如果`single_result=True`失败时返回None。
#### 示例
1. 获取单个记录
```python
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
```
2. 获取最近10条记录
```python
records = await database_api.db_get(
Messages, Messages,
filters={"chat_id": chat_stream.stream_id}, filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time", order_by="-time",
limit=5
) )
``` ```
## 高级用法 ### 4. 动作信息存储
### 复杂查询示例
```python ```python
# 查询特定用户在特定时间段的消息 async def store_action_info(
user_messages = await database_api.db_query( chat_stream=None,
Messages, action_build_into_prompt: bool = False,
query_type="get", action_prompt_display: str = "",
filters={ action_done: bool = True,
"user_id": "123456", thinking_id: str = "",
"time__gte": start_time, # 大于等于开始时间 action_data: Optional[dict] = None,
"time__lt": end_time # 小于结束时间 action_name: str = "",
}, ) -> Optional[Dict[str, Any]]:
order_by=["-time"], ```
limit=50 存储动作信息到数据库,是一种针对 Action 的 `db_save()` 的封装函数。
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建。
**Args:**
- `chat_stream`: 聊天流对象包含聊天ID等信息。
- `action_build_into_prompt`: 是否将动作信息构建到提示中。
- `action_prompt_display`: 动作提示的显示文本。
- `action_done`: 动作是否完成。
- `thinking_id`: 思考过程的ID。
- `action_data`: 动作的数据字典。
- `action_name`: 动作的名称。
**Returns:**
- `Optional[Dict[str, Any]]`: 存储后的记录数据失败时返回None。
#### 示例
```python
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
) )
```
# 批量处理
for message in user_messages:
print(f"消息内容: {message['plain_text']}")
print(f"发送时间: {message['time']}")
```
### 插件中的数据持久化
```python
from src.plugin_system.base import BasePlugin
from src.plugin_system.apis import database_api
class DataPlugin(BasePlugin):
async def handle_action(self, action_data, chat_stream):
# 保存插件数据
plugin_data = {
"plugin_name": self.plugin_name,
"chat_id": chat_stream.stream_id,
"data": json.dumps(action_data),
"created_time": time.time()
}
# 使用自定义表模型(需要先定义)
record = await database_api.db_save(
PluginData, # 假设的插件数据模型
plugin_data,
key_field="plugin_name",
key_value=self.plugin_name
)
return {"success": True, "record_id": record["id"]}
```
## 数据模型
### 常用模型类
系统提供了以下常用的数据模型:
- `Messages`:消息记录
- `ActionRecords`:动作记录
- `UserInfo`:用户信息
- `GroupInfo`:群组信息
### 字段说明
#### Messages模型主要字段
- `message_id`消息ID
- `chat_id`聊天ID
- `user_id`用户ID
- `plain_text`:纯文本内容
- `time`:时间戳
#### ActionRecords模型主要字段
- `action_id`动作ID
- `action_name`:动作名称
- `action_done`:是否完成
- `time`:创建时间
## 注意事项
1. **异步操作**所有数据库API都是异步的必须使用`await`
2. **错误处理**函数内置错误处理失败时返回None或空列表
3. **数据类型**:返回的都是字典格式的数据,不是模型对象
4. **性能考虑**:使用`limit`参数避免查询大量数据
5. **过滤条件**支持简单的等值过滤复杂查询需要使用原生Peewee语法
6. **事务**如需事务支持建议直接使用Peewee的事务功能

View File

@ -6,11 +6,13 @@
```python ```python
from src.plugin_system.apis import emoji_api from src.plugin_system.apis import emoji_api
# 或者
from src.plugin_system import emoji_api
``` ```
## 🆕 **二步走识别优化** ## 二步走识别优化
新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案: 从新版本开始,表情包识别系统采用了**二步走识别 + 智能缓存**的优化方案:
### **收到表情包时的识别流程** ### **收到表情包时的识别流程**
1. **第一步**VLM视觉分析 - 生成详细描述 1. **第一步**VLM视觉分析 - 生成详细描述
@ -30,217 +32,84 @@ from src.plugin_system.apis import emoji_api
## 主要功能 ## 主要功能
### 1. 表情包获取 ### 1. 表情包获取
```python
#### `get_by_description(description: str) -> Optional[Tuple[str, str, str]]` async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]:
```
根据场景描述选择表情包 根据场景描述选择表情包
**参数** **Args**
- `description`场景描述文本,例如"开心的大笑"、"轻微的讽刺"、"表示无奈和沮丧"等 - `description`表情包的描述文本,例如"开心"、"难过"、"愤怒"等
**返回** **Returns**
- `Optional[Tuple[str, str, str]]`(base64编码, 表情包描述, 匹配的场景) 或 None - `Optional[Tuple[str, str, str]]`一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到匹配的表情包则返回None
**示例:** #### 示例
```python ```python
emoji_result = await emoji_api.get_by_description("开心的大笑") emoji_result = await emoji_api.get_by_description("大笑")
if emoji_result: if emoji_result:
emoji_base64, description, matched_scene = emoji_result emoji_base64, description, matched_scene = emoji_result
print(f"获取到表情包: {description}, 场景: {matched_scene}") print(f"获取到表情包: {description}, 场景: {matched_scene}")
# 可以将emoji_base64用于发送表情包 # 可以将emoji_base64用于发送表情包
``` ```
#### `get_random() -> Optional[Tuple[str, str, str]]` ### 2. 随机获取表情包
随机获取表情包
**返回:**
- `Optional[Tuple[str, str, str]]`(base64编码, 表情包描述, 随机场景) 或 None
**示例:**
```python ```python
random_emoji = await emoji_api.get_random() async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
if random_emoji:
emoji_base64, description, scene = random_emoji
print(f"随机表情包: {description}")
``` ```
随机获取指定数量的表情包
#### `get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]` **Args**
根据场景关键词获取表情包 - `count`要获取的表情包数量默认为1
**参数** **Returns**
- `emotion`:场景关键词,如"大笑"、"讽刺"、"无奈"等 - `List[Tuple[str, str, str]]`:一个包含多个表情包的列表,每个元素是一个元组: (表情包的base64编码, 描述, 情感标签),如果未找到或出错则返回空列表
**返回:** ### 3. 根据情感获取表情包
- `Optional[Tuple[str, str, str]]`(base64编码, 表情包描述, 匹配的场景) 或 None
**示例:**
```python ```python
emoji_result = await emoji_api.get_by_emotion("讽刺") async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
if emoji_result:
emoji_base64, description, scene = emoji_result
# 发送讽刺表情包
``` ```
根据情感标签获取表情包
### 2. 表情包信息查询 **Args**
- `emotion`:情感标签,例如"开心"、"悲伤"、"愤怒"等
#### `get_count() -> int` **Returns**
获取表情包数量 - `Optional[Tuple[str, str, str]]`:一个元组: (表情包的base64编码, 描述, 情感标签)如果未找到则返回None
**返回:** ### 4. 获取表情包数量
- `int`:当前可用的表情包数量 ```python
def get_count() -> int:
```
获取当前可用表情包的数量
#### `get_info() -> dict` ### 5. 获取表情包系统信息
获取表情包系统信息 ```python
def get_info() -> Dict[str, Any]:
```
获取表情包系统的基本信息
**返回:** **Returns**
- `dict`:包含表情包数量、最大数量等信息 - `Dict[str, Any]`:包含表情包数量、描述等信息的字典,包含以下键:
- `current_count`:当前表情包数量
- `max_count`:最大表情包数量
- `available_emojis`:当前可用的表情包数量
**返回字典包含:** ### 6. 获取所有可用的情感标签
- `current_count`:当前表情包数量 ```python
- `max_count`:最大表情包数量 def get_emotions() -> List[str]:
- `available_emojis`:可用表情包数量 ```
获取所有可用的情感标签 **(已经去重)**
#### `get_emotions() -> list` ### 7. 获取所有表情包描述
获取所有可用的场景关键词 ```python
def get_descriptions() -> List[str]:
**返回:** ```
- `list`:所有表情包的场景关键词列表(去重)
#### `get_descriptions() -> list`
获取所有表情包的描述列表 获取所有表情包的描述列表
**返回:**
- `list`:所有表情包的描述文本列表
## 使用示例
### 1. 智能表情包选择
```python
from src.plugin_system.apis import emoji_api
async def send_emotion_response(message_text: str, chat_stream):
"""根据消息内容智能选择表情包回复"""
# 分析消息场景
if "哈哈" in message_text or "好笑" in message_text:
emoji_result = await emoji_api.get_by_description("开心的大笑")
elif "无语" in message_text or "算了" in message_text:
emoji_result = await emoji_api.get_by_description("表示无奈和沮丧")
elif "呵呵" in message_text or "是吗" in message_text:
emoji_result = await emoji_api.get_by_description("轻微的讽刺")
elif "生气" in message_text or "愤怒" in message_text:
emoji_result = await emoji_api.get_by_description("愤怒和不满")
else:
# 随机选择一个表情包
emoji_result = await emoji_api.get_random()
if emoji_result:
emoji_base64, description, scene = emoji_result
# 使用send_api发送表情包
from src.plugin_system.apis import send_api
success = await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
return success
return False
```
### 2. 表情包管理功能
```python
async def show_emoji_stats():
"""显示表情包统计信息"""
# 获取基本信息
count = emoji_api.get_count()
info = emoji_api.get_info()
scenes = emoji_api.get_emotions() # 实际返回的是场景关键词
stats = f"""
📊 表情包统计信息:
- 总数量: {count}
- 可用数量: {info['available_emojis']}
- 最大容量: {info['max_count']}
- 支持场景: {len(scenes)}种
🎭 支持的场景关键词: {', '.join(scenes[:10])}{'...' if len(scenes) > 10 else ''}
"""
return stats
```
### 3. 表情包测试功能
```python
async def test_emoji_system():
"""测试表情包系统的各种功能"""
print("=== 表情包系统测试 ===")
# 测试场景描述查找
test_descriptions = ["开心的大笑", "轻微的讽刺", "表示无奈和沮丧", "愤怒和不满"]
for desc in test_descriptions:
result = await emoji_api.get_by_description(desc)
if result:
_, description, scene = result
print(f"✅ 场景'{desc}' -> {description} ({scene})")
else:
print(f"❌ 场景'{desc}' -> 未找到")
# 测试关键词查找
scenes = emoji_api.get_emotions()
if scenes:
test_scene = scenes[0]
result = await emoji_api.get_by_emotion(test_scene)
if result:
print(f"✅ 关键词'{test_scene}' -> 找到匹配表情包")
# 测试随机获取
random_result = await emoji_api.get_random()
if random_result:
print("✅ 随机获取 -> 成功")
print(f"📊 系统信息: {emoji_api.get_info()}")
```
### 4. 在Action中使用表情包
```python
from src.plugin_system.base import BaseAction
class EmojiAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 从action_data获取场景描述或关键词
scene_keyword = action_data.get("scene", "")
scene_description = action_data.get("description", "")
emoji_result = None
# 优先使用具体的场景描述
if scene_description:
emoji_result = await emoji_api.get_by_description(scene_description)
# 其次使用场景关键词
elif scene_keyword:
emoji_result = await emoji_api.get_by_emotion(scene_keyword)
# 最后随机选择
else:
emoji_result = await emoji_api.get_random()
if emoji_result:
emoji_base64, description, scene = emoji_result
return {
"success": True,
"emoji_base64": emoji_base64,
"description": description,
"scene": scene
}
return {"success": False, "message": "未找到合适的表情包"}
```
## 场景描述说明 ## 场景描述说明
### 常用场景描述 ### 常用场景描述
表情包系统支持多种具体的场景描述,常见的包括 表情包系统支持多种具体的场景描述,举例如下:
- **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈 - **开心类场景**:开心的大笑、满意的微笑、兴奋的手舞足蹈
- **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头 - **无奈类场景**:表示无奈和沮丧、轻微的讽刺、无语的摇头
@ -248,8 +117,8 @@ class EmojiAction(BaseAction):
- **惊讶类场景**:震惊的表情、意外的发现、困惑的思考 - **惊讶类场景**:震惊的表情、意外的发现、困惑的思考
- **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子 - **可爱类场景**:卖萌的表情、撒娇的动作、害羞的样子
### 场景关键词示例 ### 情感关键词示例
系统支持的场景关键词包括 系统支持的情感关键词举例如下
- 大笑、微笑、兴奋、手舞足蹈 - 大笑、微笑、兴奋、手舞足蹈
- 无奈、沮丧、讽刺、无语、摇头 - 无奈、沮丧、讽刺、无语、摇头
- 愤怒、不满、生气、瞪视、抓狂 - 愤怒、不满、生气、瞪视、抓狂
@ -263,9 +132,9 @@ class EmojiAction(BaseAction):
## 注意事项 ## 注意事项
1. **异步函数**获取表情包的函数都是异步的,需要使用 `await` 1. **异步函数**部分函数是异步的,需要使用 `await`
2. **返回格式**表情包以base64编码返回可直接用于发送 2. **返回格式**表情包以base64编码返回可直接用于发送
3. **错误处理**所有函数都有错误处理失败时返回None或默认值 3. **错误处理**所有函数都有错误处理失败时返回None,空列表或默认值
4. **使用统计**:系统会记录表情包的使用次数 4. **使用统计**:系统会记录表情包的使用次数
5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在 5. **文件依赖**:表情包依赖于本地文件,确保表情包文件存在
6. **编码格式**返回的是base64编码的图片数据可直接用于网络传输 6. **编码格式**返回的是base64编码的图片数据可直接用于网络传输

View File

@ -6,241 +6,151 @@
```python ```python
from src.plugin_system.apis import generator_api from src.plugin_system.apis import generator_api
# 或者
from src.plugin_system import generator_api
``` ```
## 主要功能 ## 主要功能
### 1. 回复器获取 ### 1. 回复器获取
```python
#### `get_replyer(chat_stream=None, platform=None, chat_id=None, is_group=True)` def get_replyer(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer",
) -> Optional[DefaultReplyer]:
```
获取回复器对象 获取回复器对象
**参数:** 优先使用chat_stream如果没有则使用chat_id直接查找。
- `chat_stream`:聊天流对象(优先)
- `platform`:平台名称,如"qq"
- `chat_id`聊天ID群ID或用户ID
- `is_group`:是否为群聊
**返回:** 使用 ReplyerManager 来管理实例,避免重复创建。
- `DefaultReplyer`回复器对象如果获取失败则返回None
**示例:** **Args:**
- `chat_stream`: 聊天流对象
- `chat_id`: 聊天ID实际上就是`stream_id`
- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
- `request_type`: 请求类型用于记录LLM使用情况可以不写
**Returns:**
- `DefaultReplyer`: 回复器对象如果获取失败则返回None
#### 示例
```python ```python
# 使用聊天流获取回复器 # 使用聊天流获取回复器
replyer = generator_api.get_replyer(chat_stream=chat_stream) replyer = generator_api.get_replyer(chat_stream=chat_stream)
# 使用平台和ID获取回复器 # 使用平台和ID获取回复器
replyer = generator_api.get_replyer( replyer = generator_api.get_replyer(chat_id="123456789")
platform="qq",
chat_id="123456789",
is_group=True
)
``` ```
### 2. 回复生成 ### 2. 回复生成
```python
#### `generate_reply(chat_stream=None, action_data=None, platform=None, chat_id=None, is_group=True)` async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_to: str = "",
extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
return_prompt: bool = False,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "generator_api",
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
```
生成回复 生成回复
**参数:** 优先使用chat_stream如果没有则使用chat_id直接查找。
- `chat_stream`:聊天流对象(优先)
- `action_data`:动作数据
- `platform`:平台名称(备用)
- `chat_id`聊天ID备用
- `is_group`:是否为群聊(备用)
**返回:** **Args:**
- `Tuple[bool, List[Tuple[str, Any]]]`(是否成功, 回复集合) - `chat_stream`: 聊天流对象
- `chat_id`: 聊天ID实际上就是`stream_id`
- `action_data`: 动作数据(向下兼容,包含`reply_to`和`extra_info`
- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
- `extra_info`: 附加信息
- `available_actions`: 可用动作字典,格式为 `{"action_name": ActionInfo}`
- `enable_tool`: 是否启用工具
- `enable_splitter`: 是否启用分割器
- `enable_chinese_typo`: 是否启用中文错别字
- `return_prompt`: 是否返回提示词
- `model_set_with_weight`: 模型配置列表,每个元素为 `(TaskConfig, weight)` 元组
- `request_type`: 请求类型可选记录LLM使用
- `request_type`: 请求类型用于记录LLM使用情况
**示例:** **Returns:**
- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
#### 示例
```python ```python
success, reply_set = await generator_api.generate_reply( success, reply_set, prompt = await generator_api.generate_reply(
chat_stream=chat_stream, chat_stream=chat_stream,
action_data={"message": "你好", "intent": "greeting"} action_data=action_data,
reply_to="麦麦:你好",
available_actions=action_info,
enable_tool=True,
return_prompt=True
) )
if success: if success:
for reply_type, reply_content in reply_set: for reply_type, reply_content in reply_set:
print(f"回复类型: {reply_type}, 内容: {reply_content}") print(f"回复类型: {reply_type}, 内容: {reply_content}")
if prompt:
print(f"使用的提示词: {prompt}")
``` ```
#### `rewrite_reply(chat_stream=None, reply_data=None, platform=None, chat_id=None, is_group=True)` ### 3. 回复重写
重写回复
**参数:**
- `chat_stream`:聊天流对象(优先)
- `reply_data`:回复数据
- `platform`:平台名称(备用)
- `chat_id`聊天ID备用
- `is_group`:是否为群聊(备用)
**返回:**
- `Tuple[bool, List[Tuple[str, Any]]]`(是否成功, 回复集合)
**示例:**
```python ```python
success, reply_set = await generator_api.rewrite_reply( async def rewrite_reply(
chat_stream: Optional[ChatStream] = None,
reply_data: Optional[Dict[str, Any]] = None,
chat_id: Optional[str] = None,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "",
reason: str = "",
reply_to: str = "",
return_prompt: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
```
重写回复,使用新的内容替换旧的回复内容。
优先使用chat_stream如果没有则使用chat_id直接查找。
**Args:**
- `chat_stream`: 聊天流对象
- `reply_data`: 回复数据,包含`raw_reply`, `reason`和`reply_to`**(向下兼容备用,当其他参数缺失时从此获取)**
- `chat_id`: 聊天ID实际上就是`stream_id`
- `enable_splitter`: 是否启用分割器
- `enable_chinese_typo`: 是否启用中文错别字
- `model_set_with_weight`: 模型配置列表,每个元素为 (TaskConfig, weight) 元组
- `raw_reply`: 原始回复内容
- `reason`: 重写原因
- `reply_to`: 回复目标,格式为 `{发送者的person_name:消息内容}`
**Returns:**
- `Tuple[bool, List[Tuple[str, Any]], Optional[str]]`: (是否成功, 回复集合, 提示词)
#### 示例
```python
success, reply_set, prompt = await generator_api.rewrite_reply(
chat_stream=chat_stream, chat_stream=chat_stream,
reply_data={"original_text": "原始回复", "style": "more_friendly"} raw_reply="原始回复内容",
reason="重写原因",
reply_to="麦麦:你好",
return_prompt=True
) )
if success:
for reply_type, reply_content in reply_set:
print(f"回复类型: {reply_type}, 内容: {reply_content}")
if prompt:
print(f"使用的提示词: {prompt}")
``` ```
## 使用示例 ## 回复集合`reply_set`格式
### 1. 基础回复生成
```python
from src.plugin_system.apis import generator_api
async def generate_greeting_reply(chat_stream, user_name):
"""生成问候回复"""
action_data = {
"intent": "greeting",
"user_name": user_name,
"context": "morning_greeting"
}
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=action_data
)
if success and reply_set:
# 获取第一个回复
reply_type, reply_content = reply_set[0]
return reply_content
return "你好!" # 默认回复
```
### 2. 在Action中使用回复生成器
```python
from src.plugin_system.base import BaseAction
class ChatAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 准备回复数据
reply_context = {
"message_type": "response",
"user_input": action_data.get("user_message", ""),
"intent": action_data.get("intent", ""),
"entities": action_data.get("entities", {}),
"context": self.get_conversation_context(chat_stream)
}
# 生成回复
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=reply_context
)
if success:
return {
"success": True,
"replies": reply_set,
"generated_count": len(reply_set)
}
return {
"success": False,
"error": "回复生成失败",
"fallback_reply": "抱歉,我现在无法理解您的消息。"
}
```
### 3. 多样化回复生成
```python
async def generate_diverse_replies(chat_stream, topic, count=3):
"""生成多个不同风格的回复"""
styles = ["formal", "casual", "humorous"]
all_replies = []
for i, style in enumerate(styles[:count]):
action_data = {
"topic": topic,
"style": style,
"variation": i
}
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=action_data
)
if success and reply_set:
all_replies.extend(reply_set)
return all_replies
```
### 4. 回复重写功能
```python
async def improve_reply(chat_stream, original_reply, improvement_type="more_friendly"):
"""改进原始回复"""
reply_data = {
"original_text": original_reply,
"improvement_type": improvement_type,
"target_audience": "young_users",
"tone": "positive"
}
success, improved_replies = await generator_api.rewrite_reply(
chat_stream=chat_stream,
reply_data=reply_data
)
if success and improved_replies:
# 返回改进后的第一个回复
_, improved_content = improved_replies[0]
return improved_content
return original_reply # 如果改进失败,返回原始回复
```
### 5. 条件回复生成
```python
async def conditional_reply_generation(chat_stream, user_message, user_emotion):
"""根据用户情感生成条件回复"""
# 根据情感调整回复策略
if user_emotion == "sad":
action_data = {
"intent": "comfort",
"tone": "empathetic",
"style": "supportive"
}
elif user_emotion == "angry":
action_data = {
"intent": "calm",
"tone": "peaceful",
"style": "understanding"
}
else:
action_data = {
"intent": "respond",
"tone": "neutral",
"style": "helpful"
}
action_data["user_message"] = user_message
action_data["user_emotion"] = user_emotion
success, reply_set = await generator_api.generate_reply(
chat_stream=chat_stream,
action_data=action_data
)
return reply_set if success else []
```
## 回复集合格式
### 回复类型 ### 回复类型
生成的回复集合包含多种类型的回复: 生成的回复集合包含多种类型的回复:
@ -260,82 +170,32 @@ reply_set = [
] ]
``` ```
## 高级用法 ### 4. 自定义提示词回复
### 1. 自定义回复器配置
```python ```python
async def generate_with_custom_config(chat_stream, action_data): async def generate_response_custom(
"""使用自定义配置生成回复""" chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
# 获取回复器 model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
replyer = generator_api.get_replyer(chat_stream=chat_stream) prompt: str = "",
) -> Optional[str]:
if replyer:
# 可以访问回复器的内部方法
success, reply_set = await replyer.generate_reply_with_context(
reply_data=action_data,
# 可以传递额外的配置参数
)
return success, reply_set
return False, []
``` ```
生成自定义提示词回复
### 2. 回复质量评估 优先使用chat_stream如果没有则使用chat_id直接查找。
```python **Args:**
async def generate_and_evaluate_replies(chat_stream, action_data): - `chat_stream`: 聊天流对象
"""生成回复并评估质量""" - `chat_id`: 聊天ID备用
- `model_set_with_weight`: 模型集合配置列表
success, reply_set = await generator_api.generate_reply( - `prompt`: 自定义提示词
chat_stream=chat_stream,
action_data=action_data
)
if success:
evaluated_replies = []
for reply_type, reply_content in reply_set:
# 简单的质量评估
quality_score = evaluate_reply_quality(reply_content)
evaluated_replies.append({
"type": reply_type,
"content": reply_content,
"quality": quality_score
})
# 按质量排序
evaluated_replies.sort(key=lambda x: x["quality"], reverse=True)
return evaluated_replies
return []
def evaluate_reply_quality(reply_content): **Returns:**
"""简单的回复质量评估""" - `Optional[str]`: 生成的自定义回复内容如果生成失败则返回None
if not reply_content:
return 0
score = 50 # 基础分
# 长度适中加分
if 5 <= len(reply_content) <= 100:
score += 20
# 包含积极词汇加分
positive_words = ["好", "棒", "不错", "感谢", "开心"]
for word in positive_words:
if word in reply_content:
score += 10
break
return min(score, 100)
```
## 注意事项 ## 注意事项
1. **异步操作**:所有生成函数都是异步的,必须使用`await` 1. **异步操作**:部分函数是异步的,须使用`await`
2. **错误处理**函数内置错误处理失败时返回False和空列表 2. **聊天流依赖**:需要有效的聊天流对象才能正常工作
3. **聊天流依赖**:需要有效的聊天流对象才能正常工作 3. **性能考虑**回复生成可能需要一些时间特别是使用LLM时
4. **性能考虑**回复生成可能需要一些时间特别是使用LLM时 4. **回复格式**:返回的回复集合是元组列表,包含类型和内容
5. **回复格式**:返回的回复集合是元组列表,包含类型和内容 5. **上下文感知**:生成器会考虑聊天上下文和历史消息,除非你用的是自定义提示词。
6. **上下文感知**:生成器会考虑聊天上下文和历史消息

View File

@ -6,239 +6,34 @@ LLM API模块提供与大语言模型交互的功能让插件能够使用系
```python ```python
from src.plugin_system.apis import llm_api from src.plugin_system.apis import llm_api
# 或者
from src.plugin_system import llm_api
``` ```
## 主要功能 ## 主要功能
### 1. 模型管理 ### 1. 查询可用模型
#### `get_available_models() -> Dict[str, Any]`
获取所有可用的模型配置
**返回:**
- `Dict[str, Any]`模型配置字典key为模型名称value为模型配置
**示例:**
```python ```python
models = llm_api.get_available_models() def get_available_models() -> Dict[str, TaskConfig]:
for model_name, model_config in models.items():
print(f"模型: {model_name}")
print(f"配置: {model_config}")
``` ```
获取所有可用的模型配置。
### 2. 内容生成 **Return**
- `Dict[str, TaskConfig]`模型配置字典key为模型名称value为模型配置对象。
#### `generate_with_model(prompt, model_config, request_type="plugin.generate", **kwargs)` ### 2. 使用模型生成内容
使用指定模型生成内容
**参数:**
- `prompt`:提示词
- `model_config`:模型配置(从 get_available_models 获取)
- `request_type`:请求类型标识
- `**kwargs`其他模型特定参数如temperature、max_tokens等
**返回:**
- `Tuple[bool, str, str, str]`(是否成功, 生成的内容, 推理过程, 模型名称)
**示例:**
```python ```python
models = llm_api.get_available_models() async def generate_with_model(
default_model = models.get("default") prompt: str, model_config: TaskConfig, request_type: str = "plugin.generate", **kwargs
) -> Tuple[bool, str, str, str]:
if default_model:
success, response, reasoning, model_name = await llm_api.generate_with_model(
prompt="请写一首关于春天的诗",
model_config=default_model,
temperature=0.7,
max_tokens=200
)
if success:
print(f"生成内容: {response}")
print(f"使用模型: {model_name}")
``` ```
使用指定模型生成内容。
## 使用示例 **Args:**
- `prompt`:提示词。
- `model_config`:模型配置对象(从 `get_available_models` 获取)。
- `request_type`:请求类型标识,默认为 `"plugin.generate"`
- `**kwargs`:其他模型特定参数,如 `temperature`、`max_tokens` 等。
### 1. 基础文本生成 **Return**
- `Tuple[bool, str, str, str]`:返回一个元组,包含(是否成功, 生成的内容, 推理过程, 模型名称)。
```python
from src.plugin_system.apis import llm_api
async def generate_story(topic: str):
"""生成故事"""
models = llm_api.get_available_models()
model = models.get("default")
if not model:
return "未找到可用模型"
prompt = f"请写一个关于{topic}的短故事大约100字左右。"
success, story, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=model,
request_type="story.generate",
temperature=0.8,
max_tokens=150
)
return story if success else "故事生成失败"
```
### 2. 在Action中使用LLM
```python
from src.plugin_system.base import BaseAction
class LLMAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 获取用户输入
user_input = action_data.get("user_message", "")
intent = action_data.get("intent", "chat")
# 获取模型配置
models = llm_api.get_available_models()
model = models.get("default")
if not model:
return {"success": False, "error": "未配置LLM模型"}
# 构建提示词
prompt = self.build_prompt(user_input, intent)
# 生成回复
success, response, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=model,
request_type=f"plugin.{self.plugin_name}",
temperature=0.7
)
if success:
return {
"success": True,
"response": response,
"model_used": model_name,
"reasoning": reasoning
}
return {"success": False, "error": response}
def build_prompt(self, user_input: str, intent: str) -> str:
"""构建提示词"""
base_prompt = "你是一个友善的AI助手。"
if intent == "question":
return f"{base_prompt}\n\n用户问题{user_input}\n\n请提供准确、有用的回答"
elif intent == "chat":
return f"{base_prompt}\n\n用户说{user_input}\n\n请进行自然的对话"
else:
return f"{base_prompt}\n\n用户输入{user_input}\n\n请回复"
```
### 3. 多模型对比
```python
async def compare_models(prompt: str):
"""使用多个模型生成内容并对比"""
models = llm_api.get_available_models()
results = {}
for model_name, model_config in models.items():
success, response, reasoning, actual_model = await llm_api.generate_with_model(
prompt=prompt,
model_config=model_config,
request_type="comparison.test"
)
results[model_name] = {
"success": success,
"response": response,
"model": actual_model,
"reasoning": reasoning
}
return results
```
### 4. 智能对话插件
```python
class ChatbotPlugin(BasePlugin):
async def handle_action(self, action_data, chat_stream):
user_message = action_data.get("message", "")
# 获取历史对话上下文
context = self.get_conversation_context(chat_stream)
# 构建对话提示词
prompt = self.build_conversation_prompt(user_message, context)
# 获取模型配置
models = llm_api.get_available_models()
chat_model = models.get("chat", models.get("default"))
if not chat_model:
return {"success": False, "message": "聊天模型未配置"}
# 生成回复
success, response, reasoning, model_name = await llm_api.generate_with_model(
prompt=prompt,
model_config=chat_model,
request_type="chat.conversation",
temperature=0.8,
max_tokens=500
)
if success:
# 保存对话历史
self.save_conversation(chat_stream, user_message, response)
return {
"success": True,
"reply": response,
"model": model_name
}
return {"success": False, "message": "回复生成失败"}
def build_conversation_prompt(self, user_message: str, context: list) -> str:
"""构建对话提示词"""
prompt = "你是一个有趣、友善的聊天机器人。请自然地回复用户的消息。\n\n"
# 添加历史对话
if context:
prompt += "对话历史:\n"
for msg in context[-5:]: # 只保留最近5条
prompt += f"用户: {msg['user']}\n机器人: {msg['bot']}\n"
prompt += "\n"
prompt += f"用户: {user_message}\n机器人: "
return prompt
```
## 模型配置说明
### 常用模型类型
- `default`:默认模型
- `chat`:聊天专用模型
- `creative`:创意生成模型
- `code`:代码生成模型
### 配置参数
LLM模型支持的常用参数
- `temperature`控制输出随机性0.0-1.0
- `max_tokens`:最大生成长度
- `top_p`:核采样参数
- `frequency_penalty`:频率惩罚
- `presence_penalty`:存在惩罚
## 注意事项
1. **异步操作**LLM生成是异步的必须使用`await`
2. **错误处理**生成失败时返回False和错误信息
3. **配置依赖**:需要正确配置模型才能使用
4. **请求类型**建议为不同用途设置不同的request_type
5. **性能考虑**LLM调用可能较慢考虑超时和缓存
6. **成本控制**注意控制max_tokens以控制成本

View File

@ -0,0 +1,29 @@
# Logging API
Logging API模块提供了获取本体logger的功能允许插件记录日志信息。
## 导入方式
```python
from src.plugin_system.apis import get_logger
# 或者
from src.plugin_system import get_logger
```
## 主要功能
### 1. 获取本体logger
```python
def get_logger(name: str) -> structlog.stdlib.BoundLogger:
```
获取本体logger实例。
**Args:**
- `name` (str): 日志记录器的名称。
**Returns:**
- 一个logger实例有以下方法:
- `debug`
- `info`
- `warning`
- `error`
- `critical`

View File

@ -1,11 +1,13 @@
# 消息API # 消息API
> 消息API提供了强大的消息查询、计数和格式化功能让你轻松处理聊天消息数据。 消息API提供了强大的消息查询、计数和格式化功能让你轻松处理聊天消息数据。
## 导入方式 ## 导入方式
```python ```python
from src.plugin_system.apis import message_api from src.plugin_system.apis import message_api
# 或者
from src.plugin_system import message_api
``` ```
## 功能概述 ## 功能概述
@ -15,297 +17,356 @@ from src.plugin_system.apis import message_api
- **消息计数** - 统计新消息数量 - **消息计数** - 统计新消息数量
- **消息格式化** - 将消息转换为可读格式 - **消息格式化** - 将消息转换为可读格式
--- ## 主要功能
## 消息查询API ### 1. 按照事件查询消息
```python
def get_messages_by_time(
start_time: float, end_time: float, limit: int = 0, limit_mode: str = "latest", filter_mai: bool = False
) -> List[Dict[str, Any]]:
```
获取指定时间范围内的消息。
### 按时间查询消息 **Args:**
#### `get_messages_by_time(start_time, end_time, limit=0, limit_mode="latest")`
获取指定时间范围内的消息
**参数:**
- `start_time` (float): 开始时间戳 - `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳 - `end_time` (float): 结束时间戳
- `limit` (int): 限制返回消息数量0为不限制 - `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录 - `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**返回:** `List[Dict[str, Any]]` - 消息列表 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
**示例:** 消息列表中包含的键与`Messages`类的属性一致。(位于`src.common.database.database_model`
### 2. 获取指定聊天中指定时间范围内的信息
```python ```python
import time def get_messages_by_time_in_chat(
chat_id: str,
# 获取最近24小时的消息 start_time: float,
now = time.time() end_time: float,
yesterday = now - 24 * 3600 limit: int = 0,
messages = message_api.get_messages_by_time(yesterday, now, limit=50) limit_mode: str = "latest",
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
``` ```
获取指定聊天中指定时间范围内的消息。
### 按聊天查询消息 **Args:**
#### `get_messages_by_time_in_chat(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
获取指定聊天中指定时间范围内的消息
**参数:**
- `chat_id` (str): 聊天ID
- 其他参数同上
**示例:**
```python
# 获取某个群聊最近的100条消息
messages = message_api.get_messages_by_time_in_chat(
chat_id="123456789",
start_time=yesterday,
end_time=now,
limit=100
)
```
#### `get_messages_by_time_in_chat_inclusive(chat_id, start_time, end_time, limit=0, limit_mode="latest")`
获取指定聊天中指定时间范围内的消息(包含边界时间点)
`get_messages_by_time_in_chat` 类似,但包含边界时间戳的消息。
#### `get_recent_messages(chat_id, hours=24.0, limit=100, limit_mode="latest")`
获取指定聊天中最近一段时间的消息(便捷方法)
**参数:**
- `chat_id` (str): 聊天ID
- `hours` (float): 最近多少小时默认24小时
- `limit` (int): 限制返回消息数量默认100条
- `limit_mode` (str): 限制模式
**示例:**
```python
# 获取最近6小时的消息
recent_messages = message_api.get_recent_messages(
chat_id="123456789",
hours=6.0,
limit=50
)
```
### 按用户查询消息
#### `get_messages_by_time_in_chat_for_users(chat_id, start_time, end_time, person_ids, limit=0, limit_mode="latest")`
获取指定聊天中指定用户在指定时间范围内的消息
**参数:**
- `chat_id` (str): 聊天ID - `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳 - `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳 - `end_time` (float): 结束时间戳
- `person_ids` (list): 用户ID列表 - `limit` (int): 限制返回消息数量0为不限制
- `limit` (int): 限制返回消息数量 - `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `limit_mode` (str): 限制模式 - `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**示例:** **Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 3. 获取指定聊天中指定时间范围内的信息(包含边界)
```python ```python
# 获取特定用户的消息 def get_messages_by_time_in_chat_inclusive(
user_messages = message_api.get_messages_by_time_in_chat_for_users( chat_id: str,
chat_id="123456789", start_time: float,
start_time=yesterday, end_time: float,
end_time=now, limit: int = 0,
person_ids=["user1", "user2"] limit_mode: str = "latest",
) filter_mai: bool = False,
filter_command: bool = False,
) -> List[Dict[str, Any]]:
``` ```
获取指定聊天中指定时间范围内的消息(包含边界)。
#### `get_messages_by_time_for_users(start_time, end_time, person_ids, limit=0, limit_mode="latest")` **Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳(包含)
- `end_time` (float): 结束时间戳(包含)
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
- `filter_command` (bool): 是否过滤命令消息默认False
获取指定用户在所有聊天中指定时间范围内的消息 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 其他查询方法
#### `get_random_chat_messages(start_time, end_time, limit=0, limit_mode="latest")` ### 4. 获取指定聊天中指定用户在指定时间范围内的消息
```python
def get_messages_by_time_in_chat_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
```
获取指定聊天中指定用户在指定时间范围内的消息。
随机选择一个聊天,返回该聊天在指定时间范围内的消息 **Args:**
#### `get_messages_before_time(timestamp, limit=0)`
获取指定时间戳之前的消息
#### `get_messages_before_time_in_chat(chat_id, timestamp, limit=0)`
获取指定聊天中指定时间戳之前的消息
#### `get_messages_before_time_for_users(timestamp, person_ids, limit=0)`
获取指定用户在指定时间戳之前的消息
---
## 消息计数API
### `count_new_messages(chat_id, start_time=0.0, end_time=None)`
计算指定聊天中从开始时间到结束时间的新消息数量
**参数:**
- `chat_id` (str): 聊天ID - `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳 - `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳如果为None则使用当前时间 - `end_time` (float): 结束时间戳
- `person_ids` (List[str]): 用户ID列表
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
**返回:** `int` - 新消息数量 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
**示例:**
### 5. 随机选择一个聊天,返回该聊天在指定时间范围内的消息
```python ```python
# 计算最近1小时的新消息数 def get_random_chat_messages(
import time start_time: float,
now = time.time() end_time: float,
hour_ago = now - 3600 limit: int = 0,
new_count = message_api.count_new_messages("123456789", hour_ago, now) limit_mode: str = "latest",
print(f"最近1小时有{new_count}条新消息") filter_mai: bool = False,
) -> List[Dict[str, Any]]:
``` ```
随机选择一个聊天,返回该聊天在指定时间范围内的消息。
### `count_new_messages_for_users(chat_id, start_time, end_time, person_ids)` **Args:**
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
计算指定聊天中指定用户从开始时间到结束时间的新消息数量 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
---
## 消息格式化API ### 6. 获取指定用户在所有聊天中指定时间范围内的消息
```python
def get_messages_by_time_for_users(
start_time: float,
end_time: float,
person_ids: List[str],
limit: int = 0,
limit_mode: str = "latest",
) -> List[Dict[str, Any]]:
```
获取指定用户在所有聊天中指定时间范围内的消息。
### `build_readable_messages_to_str(messages, **options)` **Args:**
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
- `person_ids` (List[str]): 用户ID列表
- `limit` (int): 限制返回消息数量0为不限制
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
将消息列表构建成可读的字符串 **Returns:**
- `List[Dict[str, Any]]` - 消息列表
**参数:**
### 7. 获取指定时间戳之前的消息
```python
def get_messages_before_time(
timestamp: float,
limit: int = 0,
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
```
获取指定时间戳之前的消息。
**Args:**
- `timestamp` (float): 时间戳
- `limit` (int): 限制返回消息数量0为不限制
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 8. 获取指定聊天中指定时间戳之前的消息
```python
def get_messages_before_time_in_chat(
chat_id: str,
timestamp: float,
limit: int = 0,
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
```
获取指定聊天中指定时间戳之前的消息。
**Args:**
- `chat_id` (str): 聊天ID
- `timestamp` (float): 时间戳
- `limit` (int): 限制返回消息数量0为不限制
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 9. 获取指定用户在指定时间戳之前的消息
```python
def get_messages_before_time_for_users(
timestamp: float,
person_ids: List[str],
limit: int = 0,
) -> List[Dict[str, Any]]:
```
获取指定用户在指定时间戳之前的消息。
**Args:**
- `timestamp` (float): 时间戳
- `person_ids` (List[str]): 用户ID列表
- `limit` (int): 限制返回消息数量0为不限制
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 10. 获取指定聊天中最近一段时间的消息
```python
def get_recent_messages(
chat_id: str,
hours: float = 24.0,
limit: int = 100,
limit_mode: str = "latest",
filter_mai: bool = False,
) -> List[Dict[str, Any]]:
```
获取指定聊天中最近一段时间的消息。
**Args:**
- `chat_id` (str): 聊天ID
- `hours` (float): 最近多少小时默认24小时
- `limit` (int): 限制返回消息数量默认100条
- `limit_mode` (str): 限制模式,`"earliest"`获取最早记录,`"latest"`获取最新记录
- `filter_mai` (bool): 是否过滤掉机器人的消息默认False
**Returns:**
- `List[Dict[str, Any]]` - 消息列表
### 11. 计算指定聊天中从开始时间到结束时间的新消息数量
```python
def count_new_messages(
chat_id: str,
start_time: float = 0.0,
end_time: Optional[float] = None,
) -> int:
```
计算指定聊天中从开始时间到结束时间的新消息数量。
**Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳
- `end_time` (Optional[float]): 结束时间戳如果为None则使用当前时间
**Returns:**
- `int` - 新消息数量
### 12. 计算指定聊天中指定用户从开始时间到结束时间的新消息数量
```python
def count_new_messages_for_users(
chat_id: str,
start_time: float,
end_time: float,
person_ids: List[str],
) -> int:
```
计算指定聊天中指定用户从开始时间到结束时间的新消息数量。
**Args:**
- `chat_id` (str): 聊天ID
- `start_time` (float): 开始时间戳
- `end_time` (float): 结束时间戳
- `person_ids` (List[str]): 用户ID列表
**Returns:**
- `int` - 新消息数量
### 13. 将消息列表构建成可读的字符串
```python
def build_readable_messages_to_str(
messages: List[Dict[str, Any]],
replace_bot_name: bool = True,
merge_messages: bool = False,
timestamp_mode: str = "relative",
read_mark: float = 0.0,
truncate: bool = False,
show_actions: bool = False,
) -> str:
```
将消息列表构建成可读的字符串。
**Args:**
- `messages` (List[Dict[str, Any]]): 消息列表 - `messages` (List[Dict[str, Any]]): 消息列表
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"默认True - `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
- `merge_messages` (bool): 是否合并连续消息默认False - `merge_messages` (bool): 是否合并连续消息
- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`,默认`"relative"` - `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`
- `read_mark` (float): 已读标记时间戳用于分割已读和未读消息默认0.0 - `read_mark` (float): 已读标记时间戳,用于分割已读和未读消息
- `truncate` (bool): 是否截断长消息默认False - `truncate` (bool): 是否截断长消息
- `show_actions` (bool): 是否显示动作记录默认False - `show_actions` (bool): 是否显示动作记录
**返回:** `str` - 格式化后的可读字符串 **Returns:**
- `str` - 格式化后的可读字符串
**示例:**
### 14. 将消息列表构建成可读的字符串,并返回详细信息
```python ```python
# 获取消息并格式化为可读文本 async def build_readable_messages_with_details(
messages = message_api.get_recent_messages("123456789", hours=2) messages: List[Dict[str, Any]],
readable_text = message_api.build_readable_messages_to_str( replace_bot_name: bool = True,
messages, merge_messages: bool = False,
replace_bot_name=True, timestamp_mode: str = "relative",
merge_messages=True, truncate: bool = False,
timestamp_mode="relative" ) -> Tuple[str, List[Tuple[float, str, str]]]:
)
print(readable_text)
``` ```
将消息列表构建成可读的字符串,并返回详细信息。
### `build_readable_messages_with_details(messages, **options)` 异步 **Args:**
- `messages` (List[Dict[str, Any]]): 消息列表
- `replace_bot_name` (bool): 是否将机器人的名称替换为"你"
- `merge_messages` (bool): 是否合并连续消息
- `timestamp_mode` (str): 时间戳显示模式,`"relative"`或`"absolute"`
- `truncate` (bool): 是否截断长消息
将消息列表构建成可读的字符串,并返回详细信息 **Returns:**
- `Tuple[str, List[Tuple[float, str, str]]]` - 格式化后的可读字符串和详细信息元组列表(时间戳, 昵称, 内容)
**参数:** 与 `build_readable_messages_to_str` 类似,但不包含 `read_mark``show_actions`
**返回:** `Tuple[str, List[Tuple[float, str, str]]]` - 格式化字符串和详细信息元组列表(时间戳, 昵称, 内容) ### 15. 从消息列表中提取不重复的用户ID列表
**示例:**
```python ```python
# 异步获取详细格式化信息 async def get_person_ids_from_messages(
readable_text, details = await message_api.build_readable_messages_with_details( messages: List[Dict[str, Any]],
messages, ) -> List[str]:
timestamp_mode="absolute"
)
for timestamp, nickname, content in details:
print(f"{timestamp}: {nickname} 说: {content}")
``` ```
从消息列表中提取不重复的用户ID列表。
### `get_person_ids_from_messages(messages)` 异步 **Args:**
从消息列表中提取不重复的用户ID列表
**参数:**
- `messages` (List[Dict[str, Any]]): 消息列表 - `messages` (List[Dict[str, Any]]): 消息列表
**返回:** `List[str]` - 用户ID列表 **Returns:**
- `List[str]` - 用户ID列表
**示例:**
### 16. 从消息列表中移除机器人的消息
```python ```python
# 获取参与对话的所有用户ID def filter_mai_messages(
messages = message_api.get_recent_messages("123456789") messages: List[Dict[str, Any]],
person_ids = await message_api.get_person_ids_from_messages(messages) ) -> List[Dict[str, Any]]:
print(f"参与对话的用户: {person_ids}")
``` ```
从消息列表中移除机器人的消息。
--- **Args:**
- `messages` (List[Dict[str, Any]]): 消息列表,每个元素是消息字典
## 完整使用示例 **Returns:**
- `List[Dict[str, Any]]` - 过滤后的消息列表
### 场景1统计活跃度
```python
import time
from src.plugin_system.apis import message_api
async def analyze_chat_activity(chat_id: str):
"""分析聊天活跃度"""
now = time.time()
day_ago = now - 24 * 3600
# 获取最近24小时的消息
messages = message_api.get_recent_messages(chat_id, hours=24)
# 统计消息数量
total_count = len(messages)
# 获取参与用户
person_ids = await message_api.get_person_ids_from_messages(messages)
# 格式化消息内容
readable_text = message_api.build_readable_messages_to_str(
messages[-10:], # 最后10条消息
merge_messages=True,
timestamp_mode="relative"
)
return {
"total_messages": total_count,
"active_users": len(person_ids),
"recent_chat": readable_text
}
```
### 场景2查看特定用户的历史消息
```python
def get_user_history(chat_id: str, user_id: str, days: int = 7):
"""获取用户最近N天的消息历史"""
now = time.time()
start_time = now - days * 24 * 3600
# 获取特定用户的消息
user_messages = message_api.get_messages_by_time_in_chat_for_users(
chat_id=chat_id,
start_time=start_time,
end_time=now,
person_ids=[user_id],
limit=100
)
# 格式化为可读文本
readable_history = message_api.build_readable_messages_to_str(
user_messages,
replace_bot_name=False,
timestamp_mode="absolute"
)
return readable_history
```
---
## 注意事项 ## 注意事项
1. **时间戳格式**所有时间参数都使用Unix时间戳float类型 1. **时间戳格式**所有时间参数都使用Unix时间戳float类型
2. **异步函数**`build_readable_messages_with_details` 和 `get_person_ids_from_messages` 是异步函数,需要使用 `await` 2. **异步函数**:部分函数是异步函数,需要使用 `await`
3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数 3. **性能考虑**:查询大量消息时建议设置合理的 `limit` 参数
4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息 4. **消息格式**:返回的消息是字典格式,包含时间戳、发送者、内容等信息
5. **用户ID**`person_ids` 参数接受字符串列表,用于筛选特定用户的消息 5. **用户ID**`person_ids` 参数接受字符串列表,用于筛选特定用户的消息

View File

@ -6,59 +6,65 @@
```python ```python
from src.plugin_system.apis import person_api from src.plugin_system.apis import person_api
# 或者
from src.plugin_system import person_api
``` ```
## 主要功能 ## 主要功能
### 1. Person ID管理 ### 1. Person ID 获取
```python
#### `get_person_id(platform: str, user_id: int) -> str` def get_person_id(platform: str, user_id: int) -> str:
```
根据平台和用户ID获取person_id 根据平台和用户ID获取person_id
**参数:** **Args:**
- `platform`:平台名称,如 "qq", "telegram" 等 - `platform`:平台名称,如 "qq", "telegram" 等
- `user_id`用户ID - `user_id`用户ID
**返回:** **Returns:**
- `str`唯一的person_idMD5哈希值 - `str`唯一的person_idMD5哈希值
**示例:** #### 示例
```python ```python
person_id = person_api.get_person_id("qq", 123456) person_id = person_api.get_person_id("qq", 123456)
print(f"Person ID: {person_id}")
``` ```
### 2. 用户信息查询 ### 2. 用户信息查询
```python
async def get_person_value(person_id: str, field_name: str, default: Any = None) -> Any:
```
查询单个用户信息字段值
#### `get_person_value(person_id: str, field_name: str, default: Any = None) -> Any` **Args:**
根据person_id和字段名获取某个值
**参数:**
- `person_id`用户的唯一标识ID - `person_id`用户的唯一标识ID
- `field_name`:要获取的字段名,如 "nickname", "impression" 等 - `field_name`:要获取的字段名
- `default`字段不存在或获取失败返回的默认值 - `default`:字段值不存在时的默认值
**返回:** **Returns:**
- `Any`:字段值或默认值 - `Any`:字段值或默认值
**示例:** #### 示例
```python ```python
nickname = await person_api.get_person_value(person_id, "nickname", "未知用户") nickname = await person_api.get_person_value(person_id, "nickname", "未知用户")
impression = await person_api.get_person_value(person_id, "impression") impression = await person_api.get_person_value(person_id, "impression")
``` ```
#### `get_person_values(person_id: str, field_names: list, default_dict: dict = None) -> dict` ### 3. 批量用户信息查询
```python
async def get_person_values(person_id: str, field_names: list, default_dict: Optional[dict] = None) -> dict:
```
批量获取用户信息字段值 批量获取用户信息字段值
**参数:** **Args:**
- `person_id`用户的唯一标识ID - `person_id`用户的唯一标识ID
- `field_names`:要获取的字段名列表 - `field_names`:要获取的字段名列表
- `default_dict`:默认值字典,键为字段名,值为默认值 - `default_dict`:默认值字典,键为字段名,值为默认值
**返回:** **Returns:**
- `dict`:字段名到值的映射字典 - `dict`:字段名到值的映射字典
**示例:** #### 示例
```python ```python
values = await person_api.get_person_values( values = await person_api.get_person_values(
person_id, person_id,
@ -67,204 +73,31 @@ values = await person_api.get_person_values(
) )
``` ```
### 3. 用户状态查询 ### 4. 判断用户是否已知
```python
#### `is_person_known(platform: str, user_id: int) -> bool` async def is_person_known(platform: str, user_id: int) -> bool:
```
判断是否认识某个用户 判断是否认识某个用户
**参数:** **Args:**
- `platform`:平台名称 - `platform`:平台名称
- `user_id`用户ID - `user_id`用户ID
**返回:** **Returns:**
- `bool`:是否认识该用户 - `bool`:是否认识该用户
**示例:** ### 5. 根据用户名获取Person ID
```python ```python
known = await person_api.is_person_known("qq", 123456) def get_person_id_by_name(person_name: str) -> str:
if known:
print("这个用户我认识")
``` ```
### 4. 用户名查询
#### `get_person_id_by_name(person_name: str) -> str`
根据用户名获取person_id 根据用户名获取person_id
**参数:** **Args:**
- `person_name`:用户名 - `person_name`:用户名
**返回:** **Returns:**
- `str`person_id如果未找到返回空字符串 - `str`person_id如果未找到返回空字符串
**示例:**
```python
person_id = person_api.get_person_id_by_name("张三")
if person_id:
print(f"找到用户: {person_id}")
```
## 使用示例
### 1. 基础用户信息获取
```python
from src.plugin_system.apis import person_api
async def get_user_info(platform: str, user_id: int):
"""获取用户基本信息"""
# 获取person_id
person_id = person_api.get_person_id(platform, user_id)
# 获取用户信息
user_info = await person_api.get_person_values(
person_id,
["nickname", "impression", "know_times", "last_seen"],
{
"nickname": "未知用户",
"impression": "",
"know_times": 0,
"last_seen": 0
}
)
return {
"person_id": person_id,
"nickname": user_info["nickname"],
"impression": user_info["impression"],
"know_times": user_info["know_times"],
"last_seen": user_info["last_seen"]
}
```
### 2. 在Action中使用用户信息
```python
from src.plugin_system.base import BaseAction
class PersonalizedAction(BaseAction):
async def execute(self, action_data, chat_stream):
# 获取发送者信息
user_id = chat_stream.user_info.user_id
platform = chat_stream.platform
# 获取person_id
person_id = person_api.get_person_id(platform, user_id)
# 获取用户昵称和印象
nickname = await person_api.get_person_value(person_id, "nickname", "朋友")
impression = await person_api.get_person_value(person_id, "impression", "")
# 根据用户信息个性化回复
if impression:
response = f"你好 {nickname}!根据我对你的了解:{impression}"
else:
response = f"你好 {nickname}!很高兴见到你。"
return {
"success": True,
"response": response,
"user_info": {
"nickname": nickname,
"impression": impression
}
}
```
### 3. 用户识别和欢迎
```python
async def welcome_user(chat_stream):
"""欢迎用户,区分新老用户"""
user_id = chat_stream.user_info.user_id
platform = chat_stream.platform
# 检查是否认识这个用户
is_known = await person_api.is_person_known(platform, user_id)
if is_known:
# 老用户,获取详细信息
person_id = person_api.get_person_id(platform, user_id)
nickname = await person_api.get_person_value(person_id, "nickname", "老朋友")
know_times = await person_api.get_person_value(person_id, "know_times", 0)
welcome_msg = f"欢迎回来,{nickname}!我们已经聊过 {know_times} 次了。"
else:
# 新用户
welcome_msg = "你好很高兴认识你我是MaiBot。"
return welcome_msg
```
### 4. 用户搜索功能
```python
async def find_user_by_name(name: str):
"""根据名字查找用户"""
person_id = person_api.get_person_id_by_name(name)
if not person_id:
return {"found": False, "message": f"未找到名为 '{name}' 的用户"}
# 获取用户详细信息
user_info = await person_api.get_person_values(
person_id,
["nickname", "platform", "user_id", "impression", "know_times"],
{}
)
return {
"found": True,
"person_id": person_id,
"info": user_info
}
```
### 5. 用户印象分析
```python
async def analyze_user_relationship(chat_stream):
"""分析用户关系"""
user_id = chat_stream.user_info.user_id
platform = chat_stream.platform
person_id = person_api.get_person_id(platform, user_id)
# 获取关系相关信息
relationship_info = await person_api.get_person_values(
person_id,
["nickname", "impression", "know_times", "relationship_level", "last_interaction"],
{
"nickname": "未知",
"impression": "",
"know_times": 0,
"relationship_level": "stranger",
"last_interaction": 0
}
)
# 分析关系程度
know_times = relationship_info["know_times"]
if know_times == 0:
relationship = "陌生人"
elif know_times < 5:
relationship = "新朋友"
elif know_times < 20:
relationship = "熟人"
else:
relationship = "老朋友"
return {
"nickname": relationship_info["nickname"],
"relationship": relationship,
"impression": relationship_info["impression"],
"interaction_count": know_times
}
```
## 常用字段说明 ## 常用字段说明
### 基础信息字段 ### 基础信息字段
@ -274,69 +107,13 @@ async def analyze_user_relationship(chat_stream):
### 关系信息字段 ### 关系信息字段
- `impression`:对用户的印象 - `impression`:对用户的印象
- `know_times`:交互次数 - `points`: 用户特征点
- `relationship_level`:关系等级
- `last_seen`:最后见面时间
- `last_interaction`:最后交互时间
### 个性化字段 其他字段可以参考`PersonInfo`类的属性(位于`src.common.database.database_model`
- `preferences`:用户偏好
- `interests`:兴趣爱好
- `mood_history`:情绪历史
- `topic_interests`:话题兴趣
## 最佳实践
### 1. 错误处理
```python
async def safe_get_user_info(person_id: str, field: str):
"""安全获取用户信息"""
try:
value = await person_api.get_person_value(person_id, field)
return value if value is not None else "未设置"
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return "获取失败"
```
### 2. 批量操作
```python
async def get_complete_user_profile(person_id: str):
"""获取完整用户档案"""
# 一次性获取所有需要的字段
fields = [
"nickname", "impression", "know_times",
"preferences", "interests", "relationship_level"
]
defaults = {
"nickname": "用户",
"impression": "",
"know_times": 0,
"preferences": "{}",
"interests": "[]",
"relationship_level": "stranger"
}
profile = await person_api.get_person_values(person_id, fields, defaults)
# 处理JSON字段
try:
profile["preferences"] = json.loads(profile["preferences"])
profile["interests"] = json.loads(profile["interests"])
except:
profile["preferences"] = {}
profile["interests"] = []
return profile
```
## 注意事项 ## 注意事项
1. **异步操作**:大部分查询函数都是异步的,需要使用`await` 1. **异步操作**:部分查询函数都是异步的,需要使用`await`
2. **错误处理**:所有函数都有错误处理,失败时记录日志并返回默认值 2. **性能考虑**:批量查询优于单个查询
3. **数据类型**返回的数据可能是字符串、数字或JSON需要适当处理 3. **隐私保护**:确保用户信息的使用符合隐私政策
4. **性能考虑**:批量查询优于单个查询 4. **数据一致性**person_id是用户的唯一标识应妥善保存和使用
5. **隐私保护**:确保用户信息的使用符合隐私政策
6. **数据一致性**person_id是用户的唯一标识应妥善保存和使用

View File

@ -0,0 +1,105 @@
# 插件管理API
插件管理API模块提供了对插件的加载、卸载、重新加载以及目录管理功能。
## 导入方式
```python
from src.plugin_system.apis import plugin_manage_api
# 或者
from src.plugin_system import plugin_manage_api
```
## 功能概述
插件管理API主要提供以下功能
- **插件查询** - 列出当前加载的插件或已注册的插件。
- **插件管理** - 加载、卸载、重新加载插件。
- **插件目录管理** - 添加插件目录并重新扫描。
## 主要功能
### 1. 列出当前加载的插件
```python
def list_loaded_plugins() -> List[str]:
```
列出所有当前加载的插件。
**Returns:**
- `List[str]` - 当前加载的插件名称列表。
### 2. 列出所有已注册的插件
```python
def list_registered_plugins() -> List[str]:
```
列出所有已注册的插件。
**Returns:**
- `List[str]` - 已注册的插件名称列表。
### 3. 获取插件路径
```python
def get_plugin_path(plugin_name: str) -> str:
```
获取指定插件的路径。
**Args:**
- `plugin_name` (str): 要查询的插件名称。
**Returns:**
- `str` - 插件的路径,如果插件不存在则 raise ValueError。
### 4. 卸载指定的插件
```python
async def remove_plugin(plugin_name: str) -> bool:
```
卸载指定的插件。
**Args:**
- `plugin_name` (str): 要卸载的插件名称。
**Returns:**
- `bool` - 卸载是否成功。
### 5. 重新加载指定的插件
```python
async def reload_plugin(plugin_name: str) -> bool:
```
重新加载指定的插件。
**Args:**
- `plugin_name` (str): 要重新加载的插件名称。
**Returns:**
- `bool` - 重新加载是否成功。
### 6. 加载指定的插件
```python
def load_plugin(plugin_name: str) -> Tuple[bool, int]:
```
加载指定的插件。
**Args:**
- `plugin_name` (str): 要加载的插件名称。
**Returns:**
- `Tuple[bool, int]` - 加载是否成功,成功或失败的个数。
### 7. 添加插件目录
```python
def add_plugin_directory(plugin_directory: str) -> bool:
```
添加插件目录。
**Args:**
- `plugin_directory` (str): 要添加的插件目录路径。
**Returns:**
- `bool` - 添加是否成功。
### 8. 重新扫描插件目录
```python
def rescan_plugin_directory() -> Tuple[int, int]:
```
重新扫描插件目录,加载新插件。
**Returns:**
- `Tuple[int, int]` - 成功加载的插件数量和失败的插件数量。

View File

@ -6,86 +6,108 @@
```python ```python
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
# 或者
from src.plugin_system import send_api
``` ```
## 主要功能 ## 主要功能
### 1. 文本消息发送 ### 1. 发送文本消息
```python
async def text_to_stream(
text: str,
stream_id: str,
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
```
发送文本消息到指定的流
#### `text_to_group(text, group_id, platform="qq", typing=False, reply_to="", storage_message=True)` **Args:**
向群聊发送文本消息 - `text` (str): 要发送的文本内容
- `stream_id` (str): 聊天流ID
- `typing` (bool): 是否显示正在输入
- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
- `storage_message` (bool): 是否存储消息到数据库
**参数:** **Returns:**
- `text`:要发送的文本内容 - `bool` - 是否发送成功
- `group_id`群聊ID
- `platform`:平台,默认为"qq"
- `typing`:是否显示正在输入
- `reply_to`:回复消息的格式,如"发送者:消息内容"
- `storage_message`:是否存储到数据库
**返回:** ### 2. 发送表情包
- `bool`:是否发送成功 ```python
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
```
向指定流发送表情包。
#### `text_to_user(text, user_id, platform="qq", typing=False, reply_to="", storage_message=True)` **Args:**
向用户发送私聊文本消息 - `emoji_base64` (str): 表情包的base64编码
- `stream_id` (str): 聊天流ID
- `storage_message` (bool): 是否存储消息到数据库
**参数与返回值同上** **Returns:**
- `bool` - 是否发送成功
### 2. 表情包发送 ### 3. 发送图片
```python
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True) -> bool:
```
向指定流发送图片。
#### `emoji_to_group(emoji_base64, group_id, platform="qq", storage_message=True)` **Args:**
向群聊发送表情包 - `image_base64` (str): 图片的base64编码
- `stream_id` (str): 聊天流ID
- `storage_message` (bool): 是否存储消息到数据库
**参数:** **Returns:**
- `emoji_base64`表情包的base64编码 - `bool` - 是否发送成功
- `group_id`群聊ID
- `platform`:平台,默认为"qq"
- `storage_message`:是否存储到数据库
#### `emoji_to_user(emoji_base64, user_id, platform="qq", storage_message=True)` ### 4. 发送命令
向用户发送表情包 ```python
async def command_to_stream(command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "") -> bool:
```
向指定流发送命令。
### 3. 图片发送 **Args:**
- `command` (Union[str, dict]): 命令内容
- `stream_id` (str): 聊天流ID
- `storage_message` (bool): 是否存储消息到数据库
- `display_message` (str): 显示消息
#### `image_to_group(image_base64, group_id, platform="qq", storage_message=True)` **Returns:**
向群聊发送图片 - `bool` - 是否发送成功
#### `image_to_user(image_base64, user_id, platform="qq", storage_message=True)` ### 5. 发送自定义类型消息
向用户发送图片 ```python
async def custom_to_stream(
message_type: str,
content: str,
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
show_log: bool = True,
) -> bool:
```
向指定流发送自定义类型消息。
### 4. 命令发送 **Args:**
- `message_type` (str): 消息类型,如"text"、"image"、"emoji"、"video"、"file"等
- `content` (str): 消息内容通常是base64编码或文本
- `stream_id` (str): 聊天流ID
- `display_message` (str): 显示消息
- `typing` (bool): 是否显示正在输入
- `reply_to` (str): 回复消息,格式为"发送者:消息内容"
- `storage_message` (bool): 是否存储消息到数据库
- `show_log` (bool): 是否显示日志
#### `command_to_group(command, group_id, platform="qq", storage_message=True)` **Returns:**
向群聊发送命令 - `bool` - 是否发送成功
#### `command_to_user(command, user_id, platform="qq", storage_message=True)`
向用户发送命令
### 5. 自定义消息发送
#### `custom_to_group(message_type, content, group_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
向群聊发送自定义类型消息
#### `custom_to_user(message_type, content, user_id, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
向用户发送自定义类型消息
#### `custom_message(message_type, content, target_id, is_group=True, platform="qq", display_message="", typing=False, reply_to="", storage_message=True)`
通用的自定义消息发送
**参数:**
- `message_type`:消息类型,如"text"、"image"、"emoji"等
- `content`:消息内容
- `target_id`目标ID群ID或用户ID
- `is_group`:是否为群聊
- `platform`:平台
- `display_message`:显示消息
- `typing`:是否显示正在输入
- `reply_to`:回复消息
- `storage_message`:是否存储
## 使用示例 ## 使用示例
### 1. 基础文本发送 ### 1. 基础文本发送,并回复消息
```python ```python
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
@ -93,57 +115,23 @@ from src.plugin_system.apis import send_api
async def send_hello(chat_stream): async def send_hello(chat_stream):
"""发送问候消息""" """发送问候消息"""
if chat_stream.group_info: success = await send_api.text_to_stream(
# 群聊 text="Hello, world!",
success = await send_api.text_to_group( stream_id=chat_stream.stream_id,
text="大家好!", typing=True,
group_id=chat_stream.group_info.group_id, reply_to="User:How are you?",
typing=True storage_message=True
) )
else:
# 私聊
success = await send_api.text_to_user(
text="你好!",
user_id=chat_stream.user_info.user_id,
typing=True
)
return success return success
``` ```
### 2. 回复特定消息 ### 2. 发送表情包
```python
async def reply_to_message(chat_stream, reply_text, original_sender, original_message):
"""回复特定消息"""
# 构建回复格式
reply_to = f"{original_sender}:{original_message}"
if chat_stream.group_info:
success = await send_api.text_to_group(
text=reply_text,
group_id=chat_stream.group_info.group_id,
reply_to=reply_to
)
else:
success = await send_api.text_to_user(
text=reply_text,
user_id=chat_stream.user_info.user_id,
reply_to=reply_to
)
return success
```
### 3. 发送表情包
```python ```python
from src.plugin_system.apis import emoji_api
async def send_emoji_reaction(chat_stream, emotion): async def send_emoji_reaction(chat_stream, emotion):
"""根据情感发送表情包""" """根据情感发送表情包"""
from src.plugin_system.apis import emoji_api
# 获取表情包 # 获取表情包
emoji_result = await emoji_api.get_by_emotion(emotion) emoji_result = await emoji_api.get_by_emotion(emotion)
if not emoji_result: if not emoji_result:
@ -152,107 +140,10 @@ async def send_emoji_reaction(chat_stream, emotion):
emoji_base64, description, matched_emotion = emoji_result emoji_base64, description, matched_emotion = emoji_result
# 发送表情包 # 发送表情包
if chat_stream.group_info: success = await send_api.emoji_to_stream(
success = await send_api.emoji_to_group( emoji_base64=emoji_base64,
emoji_base64=emoji_base64, stream_id=chat_stream.stream_id,
group_id=chat_stream.group_info.group_id storage_message=False # 不存储到数据库
)
else:
success = await send_api.emoji_to_user(
emoji_base64=emoji_base64,
user_id=chat_stream.user_info.user_id
)
return success
```
### 4. 在Action中发送消息
```python
from src.plugin_system.base import BaseAction
class MessageAction(BaseAction):
async def execute(self, action_data, chat_stream):
message_type = action_data.get("type", "text")
content = action_data.get("content", "")
if message_type == "text":
success = await self.send_text(chat_stream, content)
elif message_type == "emoji":
success = await self.send_emoji(chat_stream, content)
elif message_type == "image":
success = await self.send_image(chat_stream, content)
else:
success = False
return {"success": success}
async def send_text(self, chat_stream, text):
if chat_stream.group_info:
return await send_api.text_to_group(text, chat_stream.group_info.group_id)
else:
return await send_api.text_to_user(text, chat_stream.user_info.user_id)
async def send_emoji(self, chat_stream, emoji_base64):
if chat_stream.group_info:
return await send_api.emoji_to_group(emoji_base64, chat_stream.group_info.group_id)
else:
return await send_api.emoji_to_user(emoji_base64, chat_stream.user_info.user_id)
async def send_image(self, chat_stream, image_base64):
if chat_stream.group_info:
return await send_api.image_to_group(image_base64, chat_stream.group_info.group_id)
else:
return await send_api.image_to_user(image_base64, chat_stream.user_info.user_id)
```
### 5. 批量发送消息
```python
async def broadcast_message(message: str, target_groups: list):
"""向多个群组广播消息"""
results = {}
for group_id in target_groups:
try:
success = await send_api.text_to_group(
text=message,
group_id=group_id,
typing=True
)
results[group_id] = success
except Exception as e:
results[group_id] = False
print(f"发送到群 {group_id} 失败: {e}")
return results
```
### 6. 智能消息发送
```python
async def smart_send(chat_stream, message_data):
"""智能发送不同类型的消息"""
message_type = message_data.get("type", "text")
content = message_data.get("content", "")
options = message_data.get("options", {})
# 根据聊天流类型选择发送方法
target_id = (chat_stream.group_info.group_id if chat_stream.group_info
else chat_stream.user_info.user_id)
is_group = chat_stream.group_info is not None
# 使用通用发送方法
success = await send_api.custom_message(
message_type=message_type,
content=content,
target_id=target_id,
is_group=is_group,
typing=options.get("typing", False),
reply_to=options.get("reply_to", ""),
display_message=options.get("display_message", "")
) )
return success return success
@ -273,90 +164,6 @@ async def smart_send(chat_stream, message_data):
系统会自动查找匹配的原始消息并进行回复。 系统会自动查找匹配的原始消息并进行回复。
## 高级用法
### 1. 消息发送队列
```python
import asyncio
class MessageQueue:
def __init__(self):
self.queue = asyncio.Queue()
self.running = False
async def add_message(self, chat_stream, message_type, content, options=None):
"""添加消息到队列"""
message_item = {
"chat_stream": chat_stream,
"type": message_type,
"content": content,
"options": options or {}
}
await self.queue.put(message_item)
async def process_queue(self):
"""处理消息队列"""
self.running = True
while self.running:
try:
message_item = await asyncio.wait_for(self.queue.get(), timeout=1.0)
# 发送消息
success = await smart_send(
message_item["chat_stream"],
{
"type": message_item["type"],
"content": message_item["content"],
"options": message_item["options"]
}
)
# 标记任务完成
self.queue.task_done()
# 发送间隔
await asyncio.sleep(0.5)
except asyncio.TimeoutError:
continue
except Exception as e:
print(f"处理消息队列出错: {e}")
```
### 2. 消息模板系统
```python
class MessageTemplate:
def __init__(self):
self.templates = {
"welcome": "欢迎 {nickname} 加入群聊!",
"goodbye": "{nickname} 离开了群聊。",
"notification": "🔔 通知:{message}",
"error": "❌ 错误:{error_message}",
"success": "✅ 成功:{message}"
}
def format_message(self, template_name: str, **kwargs) -> str:
"""格式化消息模板"""
template = self.templates.get(template_name, "{message}")
return template.format(**kwargs)
async def send_template(self, chat_stream, template_name: str, **kwargs):
"""发送模板消息"""
message = self.format_message(template_name, **kwargs)
if chat_stream.group_info:
return await send_api.text_to_group(message, chat_stream.group_info.group_id)
else:
return await send_api.text_to_user(message, chat_stream.user_info.user_id)
# 使用示例
template_system = MessageTemplate()
await template_system.send_template(chat_stream, "welcome", nickname="张三")
```
## 注意事项 ## 注意事项
1. **异步操作**:所有发送函数都是异步的,必须使用`await` 1. **异步操作**:所有发送函数都是异步的,必须使用`await`

View File

@ -0,0 +1,55 @@
# 工具API
工具API模块提供了获取和管理工具实例的功能让插件能够访问系统中注册的工具。
## 导入方式
```python
from src.plugin_system.apis import tool_api
# 或者
from src.plugin_system import tool_api
```
## 主要功能
### 1. 获取工具实例
```python
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
```
获取指定名称的工具实例。
**Args**:
- `tool_name`: 工具名称字符串
**Returns**:
- `Optional[BaseTool]`: 工具实例,如果工具不存在则返回 None
### 2. 获取LLM可用的工具定义
```python
def get_llm_available_tool_definitions():
```
获取所有LLM可用的工具定义列表。
**Returns**:
- `List[Tuple[str, Dict[str, Any]]]`: 工具定义列表,每个元素为 `(工具名称, 工具定义字典)` 的元组
- 其具体定义请参照[tool-components.md](../tool-components.md)中的工具定义格式。
#### 示例:
```python
# 获取所有LLM可用的工具定义
tools = tool_api.get_llm_available_tool_definitions()
for tool_name, tool_definition in tools:
print(f"工具: {tool_name}")
print(f"定义: {tool_definition}")
```
## 注意事项
1. **工具存在性检查**:使用前请检查工具实例是否为 None
2. **权限控制**:某些工具可能有使用权限限制
3. **异步调用**:大多数工具方法是异步的,需要使用 await
4. **错误处理**:调用工具时请做好异常处理

View File

@ -1,435 +0,0 @@
# 工具API
工具API模块提供了各种辅助功能包括文件操作、时间处理、唯一ID生成等常用工具函数。
## 导入方式
```python
from src.plugin_system.apis import utils_api
```
## 主要功能
### 1. 文件操作
#### `get_plugin_path(caller_frame=None) -> str`
获取调用者插件的路径
**参数:**
- `caller_frame`调用者的栈帧默认为None自动获取
**返回:**
- `str`:插件目录的绝对路径
**示例:**
```python
plugin_path = utils_api.get_plugin_path()
print(f"插件路径: {plugin_path}")
```
#### `read_json_file(file_path: str, default: Any = None) -> Any`
读取JSON文件
**参数:**
- `file_path`:文件路径,可以是相对于插件目录的路径
- `default`:如果文件不存在或读取失败时返回的默认值
**返回:**
- `Any`JSON数据或默认值
**示例:**
```python
# 读取插件配置文件
config = utils_api.read_json_file("config.json", {})
settings = utils_api.read_json_file("data/settings.json", {"enabled": True})
```
#### `write_json_file(file_path: str, data: Any, indent: int = 2) -> bool`
写入JSON文件
**参数:**
- `file_path`:文件路径,可以是相对于插件目录的路径
- `data`:要写入的数据
- `indent`JSON缩进
**返回:**
- `bool`:是否写入成功
**示例:**
```python
data = {"name": "test", "value": 123}
success = utils_api.write_json_file("output.json", data)
```
### 2. 时间相关
#### `get_timestamp() -> int`
获取当前时间戳
**返回:**
- `int`:当前时间戳(秒)
#### `format_time(timestamp: Optional[int] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str`
格式化时间
**参数:**
- `timestamp`时间戳如果为None则使用当前时间
- `format_str`:时间格式字符串
**返回:**
- `str`:格式化后的时间字符串
#### `parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int`
解析时间字符串为时间戳
**参数:**
- `time_str`:时间字符串
- `format_str`:时间格式字符串
**返回:**
- `int`:时间戳(秒)
### 3. 其他工具
#### `generate_unique_id() -> str`
生成唯一ID
**返回:**
- `str`唯一ID
## 使用示例
### 1. 插件数据管理
```python
from src.plugin_system.apis import utils_api
class DataPlugin(BasePlugin):
def __init__(self):
self.plugin_path = utils_api.get_plugin_path()
self.data_file = "plugin_data.json"
self.load_data()
def load_data(self):
"""加载插件数据"""
default_data = {
"users": {},
"settings": {"enabled": True},
"stats": {"message_count": 0}
}
self.data = utils_api.read_json_file(self.data_file, default_data)
def save_data(self):
"""保存插件数据"""
return utils_api.write_json_file(self.data_file, self.data)
async def handle_action(self, action_data, chat_stream):
# 更新统计信息
self.data["stats"]["message_count"] += 1
self.data["stats"]["last_update"] = utils_api.get_timestamp()
# 保存数据
if self.save_data():
return {"success": True, "message": "数据已保存"}
else:
return {"success": False, "message": "数据保存失败"}
```
### 2. 日志记录系统
```python
class PluginLogger:
def __init__(self, plugin_name: str):
self.plugin_name = plugin_name
self.log_file = f"{plugin_name}_log.json"
self.logs = utils_api.read_json_file(self.log_file, [])
def log_event(self, event_type: str, message: str, data: dict = None):
"""记录事件"""
log_entry = {
"id": utils_api.generate_unique_id(),
"timestamp": utils_api.get_timestamp(),
"formatted_time": utils_api.format_time(),
"event_type": event_type,
"message": message,
"data": data or {}
}
self.logs.append(log_entry)
# 保持最新的100条记录
if len(self.logs) > 100:
self.logs = self.logs[-100:]
# 保存到文件
utils_api.write_json_file(self.log_file, self.logs)
def get_logs_by_type(self, event_type: str) -> list:
"""获取指定类型的日志"""
return [log for log in self.logs if log["event_type"] == event_type]
def get_recent_logs(self, count: int = 10) -> list:
"""获取最近的日志"""
return self.logs[-count:]
# 使用示例
logger = PluginLogger("my_plugin")
logger.log_event("user_action", "用户发送了消息", {"user_id": "123", "message": "hello"})
```
### 3. 配置管理系统
```python
class ConfigManager:
def __init__(self, config_file: str = "plugin_config.json"):
self.config_file = config_file
self.default_config = {
"enabled": True,
"debug": False,
"max_users": 100,
"response_delay": 1.0,
"features": {
"auto_reply": True,
"logging": True
}
}
self.config = self.load_config()
def load_config(self) -> dict:
"""加载配置"""
return utils_api.read_json_file(self.config_file, self.default_config)
def save_config(self) -> bool:
"""保存配置"""
return utils_api.write_json_file(self.config_file, self.config, indent=4)
def get(self, key: str, default=None):
"""获取配置值,支持嵌套访问"""
keys = key.split('.')
value = self.config
for k in keys:
if isinstance(value, dict) and k in value:
value = value[k]
else:
return default
return value
def set(self, key: str, value):
"""设置配置值,支持嵌套设置"""
keys = key.split('.')
config = self.config
for k in keys[:-1]:
if k not in config:
config[k] = {}
config = config[k]
config[keys[-1]] = value
def update_config(self, updates: dict):
"""批量更新配置"""
def deep_update(base, updates):
for key, value in updates.items():
if isinstance(value, dict) and key in base and isinstance(base[key], dict):
deep_update(base[key], value)
else:
base[key] = value
deep_update(self.config, updates)
# 使用示例
config = ConfigManager()
print(f"调试模式: {config.get('debug', False)}")
print(f"自动回复: {config.get('features.auto_reply', True)}")
config.set('features.new_feature', True)
config.save_config()
```
### 4. 缓存系统
```python
class PluginCache:
def __init__(self, cache_file: str = "plugin_cache.json", ttl: int = 3600):
self.cache_file = cache_file
self.ttl = ttl # 缓存过期时间(秒)
self.cache = self.load_cache()
def load_cache(self) -> dict:
"""加载缓存"""
return utils_api.read_json_file(self.cache_file, {})
def save_cache(self):
"""保存缓存"""
return utils_api.write_json_file(self.cache_file, self.cache)
def get(self, key: str):
"""获取缓存值"""
if key not in self.cache:
return None
item = self.cache[key]
current_time = utils_api.get_timestamp()
# 检查是否过期
if current_time - item["timestamp"] > self.ttl:
del self.cache[key]
return None
return item["value"]
def set(self, key: str, value):
"""设置缓存值"""
self.cache[key] = {
"value": value,
"timestamp": utils_api.get_timestamp()
}
self.save_cache()
def clear_expired(self):
"""清理过期缓存"""
current_time = utils_api.get_timestamp()
expired_keys = []
for key, item in self.cache.items():
if current_time - item["timestamp"] > self.ttl:
expired_keys.append(key)
for key in expired_keys:
del self.cache[key]
if expired_keys:
self.save_cache()
return len(expired_keys)
# 使用示例
cache = PluginCache(ttl=1800) # 30分钟过期
cache.set("user_data_123", {"name": "张三", "score": 100})
user_data = cache.get("user_data_123")
```
### 5. 时间处理工具
```python
class TimeHelper:
@staticmethod
def get_time_info():
"""获取当前时间的详细信息"""
timestamp = utils_api.get_timestamp()
return {
"timestamp": timestamp,
"datetime": utils_api.format_time(timestamp),
"date": utils_api.format_time(timestamp, "%Y-%m-%d"),
"time": utils_api.format_time(timestamp, "%H:%M:%S"),
"year": utils_api.format_time(timestamp, "%Y"),
"month": utils_api.format_time(timestamp, "%m"),
"day": utils_api.format_time(timestamp, "%d"),
"weekday": utils_api.format_time(timestamp, "%A")
}
@staticmethod
def time_ago(timestamp: int) -> str:
"""计算时间差"""
current = utils_api.get_timestamp()
diff = current - timestamp
if diff < 60:
return f"{diff}秒前"
elif diff < 3600:
return f"{diff // 60}分钟前"
elif diff < 86400:
return f"{diff // 3600}小时前"
else:
return f"{diff // 86400}天前"
@staticmethod
def parse_duration(duration_str: str) -> int:
"""解析时间段字符串,返回秒数"""
import re
pattern = r'(\d+)([smhd])'
matches = re.findall(pattern, duration_str.lower())
total_seconds = 0
for value, unit in matches:
value = int(value)
if unit == 's':
total_seconds += value
elif unit == 'm':
total_seconds += value * 60
elif unit == 'h':
total_seconds += value * 3600
elif unit == 'd':
total_seconds += value * 86400
return total_seconds
# 使用示例
time_info = TimeHelper.get_time_info()
print(f"当前时间: {time_info['datetime']}")
last_seen = 1699000000
print(f"最后见面: {TimeHelper.time_ago(last_seen)}")
duration = TimeHelper.parse_duration("1h30m") # 1小时30分钟 = 5400秒
```
## 最佳实践
### 1. 错误处理
```python
def safe_file_operation(file_path: str, data: dict):
"""安全的文件操作"""
try:
success = utils_api.write_json_file(file_path, data)
if not success:
logger.warning(f"文件写入失败: {file_path}")
return success
except Exception as e:
logger.error(f"文件操作出错: {e}")
return False
```
### 2. 路径处理
```python
import os
def get_data_path(filename: str) -> str:
"""获取数据文件的完整路径"""
plugin_path = utils_api.get_plugin_path()
data_dir = os.path.join(plugin_path, "data")
# 确保数据目录存在
os.makedirs(data_dir, exist_ok=True)
return os.path.join(data_dir, filename)
```
### 3. 定期清理
```python
async def cleanup_old_files():
"""清理旧文件"""
plugin_path = utils_api.get_plugin_path()
current_time = utils_api.get_timestamp()
for filename in os.listdir(plugin_path):
if filename.endswith('.tmp'):
file_path = os.path.join(plugin_path, filename)
file_time = os.path.getmtime(file_path)
# 删除超过24小时的临时文件
if current_time - file_time > 86400:
os.remove(file_path)
```
## 注意事项
1. **相对路径**:文件路径支持相对于插件目录的路径
2. **自动创建目录**:写入文件时会自动创建必要的目录
3. **错误处理**:所有函数都有错误处理,失败时返回默认值
4. **编码格式**文件读写使用UTF-8编码
5. **时间格式**:时间戳使用秒为单位
6. **JSON格式**JSON文件使用可读性好的缩进格式

View File

@ -10,6 +10,7 @@
- [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件 - [🧱 Action组件详解](action-components.md) - 掌握最核心的Action组件
- [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件 - [💻 Command组件详解](command-components.md) - 学习直接响应命令的组件
- [🔧 Tool组件详解](tool-components.md) - 了解如何扩展信息获取能力
- [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件 - [⚙️ 配置文件系统指南](configuration-guide.md) - 学会使用自动生成的插件配置文件
- [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构 - [📄 Manifest系统指南](manifest-guide.md) - 了解插件元数据管理和配置架构
@ -43,24 +44,24 @@ Command vs Action 选择指南
- [LLM API](api/llm-api.md) - 大语言模型交互接口可以使用内置LLM生成内容 - [LLM API](api/llm-api.md) - 大语言模型交互接口可以使用内置LLM生成内容
- [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器 - [✨ 回复生成器API](api/generator-api.md) - 智能回复生成接口,可以使用内置风格化生成器
### 表情包api ### 表情包API
- [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口 - [😊 表情包API](api/emoji-api.md) - 表情包选择和管理接口
### 关系系统api ### 关系系统API
- [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口 - [人物信息API](api/person-api.md) - 用户信息,处理麦麦认识的人和关系的接口
### 数据与配置API ### 数据与配置API
- [🗄️ 数据库API](api/database-api.md) - 数据库操作接口 - [🗄️ 数据库API](api/database-api.md) - 数据库操作接口
- [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口 - [⚙️ 配置API](api/config-api.md) - 配置读取和用户信息接口
### 插件和组件管理API
- [🔌 插件API](api/plugin-manage-api.md) - 插件加载和管理接口
- [🧩 组件API](api/component-manage-api.md) - 组件注册和管理接口
### 日志API
- [📜 日志API](api/logging-api.md) - logger实例获取接口
### 工具API ### 工具API
- [工具API](api/utils-api.md) - 文件操作、时间处理等工具函数 - [🔧 工具API](api/tool-api.md) - tool获取接口
## 实验性
这些功能将在未来重构或移除
- [🔧 工具系统详解](tool-system.md) - 工具系统的使用和开发

View File

@ -1,10 +1,10 @@
# 🔧 工具系统详解 # 🔧 工具组件详解
## 📖 什么是工具系统 ## 📖 什么是工具
工具系统是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门可以拓展麦麦能做的事情那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。 工具是MaiBot的信息获取能力扩展组件。如果说Action组件功能五花八门可以拓展麦麦能做的事情那么Tool就是在某个过程中拓宽了麦麦能够获得的信息量。
### 🎯 工具系统的特点 ### 🎯 工具的特点
- 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力 - 🔍 **信息获取增强**:扩展麦麦获取外部信息的能力
- 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据 - 📊 **数据丰富**:帮助麦麦获得更多背景信息和实时数据
@ -20,14 +20,11 @@
| **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 | | **目标** | 让麦麦做更多事情 | 提供具体功能 | 让麦麦知道更多信息 |
| **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 | | **使用场景** | 增强交互体验 | 功能服务 | 信息查询和分析 |
## 🏗️ 工具基本结构 ## 🏗️ Tool组件的基本结构
### 必要组件
每个工具必须继承 `BaseTool` 基类并实现以下属性和方法: 每个工具必须继承 `BaseTool` 基类并实现以下属性和方法:
```python ```python
from src.tools.tool_can_use.base_tool import BaseTool, register_tool from src.plugin_system import BaseTool, ToolParamType
class MyTool(BaseTool): class MyTool(BaseTool):
# 工具名称,必须唯一 # 工具名称,必须唯一
@ -36,21 +33,29 @@ class MyTool(BaseTool):
# 工具描述告诉LLM这个工具的用途 # 工具描述告诉LLM这个工具的用途
description = "这个工具用于获取特定类型的信息" description = "这个工具用于获取特定类型的信息"
# 参数定义遵循JSONSchema格式 # 参数定义,仅定义参数
parameters = { # 比如想要定义一个类似下面的openai格式的参数表则可以这么定义:
"type": "object", # {
"properties": { # "type": "object",
"query": { # "properties": {
"type": "string", # "query": {
"description": "查询参数" # "type": "string",
}, # "description": "查询参数"
"limit": { # },
"type": "integer", # "limit": {
"description": "结果数量限制" # "type": "integer",
} # "description": "结果数量限制"
}, # "enum": [10, 20, 50] # 可选值
"required": ["query"] # }
} # },
# "required": ["query"]
# }
parameters = [
("query", ToolParamType.STRING, "查询参数", True, None), # 必填参数
("limit", ToolParamType.INTEGER, "结果数量限制", False, ["10", "20", "50"]) # 可选参数
]
available_for_llm = True # 是否对LLM可用
async def execute(self, function_args: Dict[str, Any]): async def execute(self, function_args: Dict[str, Any]):
"""执行工具逻辑""" """执行工具逻辑"""
@ -69,7 +74,12 @@ class MyTool(BaseTool):
|-----|------|------| |-----|------|------|
| `name` | str | 工具的唯一标识名称 | | `name` | str | 工具的唯一标识名称 |
| `description` | str | 工具功能描述帮助LLM理解用途 | | `description` | str | 工具功能描述帮助LLM理解用途 |
| `parameters` | dict | JSONSchema格式的参数定义 | | `parameters` | list[tuple] | 参数定义 |
其构造而成的工具定义为:
```python
{"name": cls.name, "description": cls.description, "parameters": cls.parameters}
```
### 方法说明 ### 方法说明
@ -77,15 +87,6 @@ class MyTool(BaseTool):
|-----|------|--------|------| |-----|------|--------|------|
| `execute` | `function_args` | `dict` | 执行工具核心逻辑 | | `execute` | `function_args` | `dict` | 执行工具核心逻辑 |
## 🔄 自动注册机制
工具系统采用自动发现和注册机制:
1. **文件扫描**:系统自动遍历 `tool_can_use` 目录中的所有Python文件
2. **类识别**:寻找继承自 `BaseTool` 的工具类
3. **自动注册**:只需要实现对应的类并把文件放在正确文件夹中就可自动注册
4. **即用即加载**:工具在需要时被实例化和调用
--- ---
## 🎨 完整工具示例 ## 🎨 完整工具示例
@ -93,7 +94,7 @@ class MyTool(BaseTool):
完成一个天气查询工具 完成一个天气查询工具
```python ```python
from src.tools.tool_can_use.base_tool import BaseTool, register_tool from src.plugin_system import BaseTool
import aiohttp import aiohttp
import json import json
@ -102,23 +103,13 @@ class WeatherTool(BaseTool):
name = "weather_query" name = "weather_query"
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等" description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况等"
available_for_llm = True # 允许LLM调用此工具
parameters = [
("city", ToolParamType.STRING, "要查询天气的城市名称,如:北京、上海、纽约", True, None),
("country", ToolParamType.STRING, "国家代码CN、US可选参数", False, None)
]
parameters = { async def execute(self, function_args: dict):
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "要查询天气的城市名称,如:北京、上海、纽约"
},
"country": {
"type": "string",
"description": "国家代码CN、US可选参数"
}
},
"required": ["city"]
}
async def execute(self, function_args, message_txt=""):
"""执行天气查询""" """执行天气查询"""
try: try:
city = function_args.get("city") city = function_args.get("city")
@ -177,55 +168,12 @@ class WeatherTool(BaseTool):
--- ---
## 📊 工具开发步骤
### 1. 创建工具文件
`src/tools/tool_can_use/` 目录下创建新的Python文件
```bash
# 例如创建 my_new_tool.py
touch src/tools/tool_can_use/my_new_tool.py
```
### 2. 实现工具类
```python
from src.tools.tool_can_use.base_tool import BaseTool, register_tool
class MyNewTool(BaseTool):
name = "my_new_tool"
description = "新工具的功能描述"
parameters = {
"type": "object",
"properties": {
# 定义参数
},
"required": []
}
async def execute(self, function_args, message_txt=""):
# 实现工具逻辑
return {
"name": self.name,
"content": "执行结果"
}
```
### 3. 系统集成
工具创建完成后,系统会自动发现和注册,无需额外配置。
---
## 🚨 注意事项和限制 ## 🚨 注意事项和限制
### 当前限制 ### 当前限制
1. **独立开发**:需要单独编写,暂未完全融入插件系统 1. **适用范围**:主要适用于信息获取场景
2. **适用范围**:主要适用于信息获取场景 2. **配置要求**:必须开启工具处理器
3. **配置要求**:必须开启工具处理器
### 开发建议 ### 开发建议
@ -238,66 +186,49 @@ class MyNewTool(BaseTool):
## 🎯 最佳实践 ## 🎯 最佳实践
### 1. 工具命名规范 ### 1. 工具命名规范
#### ✅ 好的命名
```python ```python
# ✅ 好的命名
name = "weather_query" # 清晰表达功能 name = "weather_query" # 清晰表达功能
name = "knowledge_search" # 描述性强 name = "knowledge_search" # 描述性强
name = "stock_price_check" # 功能明确 name = "stock_price_check" # 功能明确
```
# ❌ 避免的命名 #### ❌ 避免的命名
```python
name = "tool1" # 无意义 name = "tool1" # 无意义
name = "wq" # 过于简短 name = "wq" # 过于简短
name = "weather_and_news" # 功能过于复杂 name = "weather_and_news" # 功能过于复杂
``` ```
### 2. 描述规范 ### 2. 描述规范
#### ✅ 良好的描述
```python ```python
# ✅ 好的描述
description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况" description = "查询指定城市的实时天气信息,包括温度、湿度、天气状况"
```
# ❌ 避免的描述 #### ❌ 避免的描述
```python
description = "天气" # 过于简单 description = "天气" # 过于简单
description = "获取信息" # 不够具体 description = "获取信息" # 不够具体
``` ```
### 3. 参数设计 ### 3. 参数设计
#### ✅ 合理的参数设计
```python ```python
# ✅ 合理的参数设计 parameters = [
parameters = { ("city", ToolParamType.STRING, "城市名称,如:北京、上海", True, None),
"type": "object", ("unit", ToolParamType.STRING, "温度单位celsius 或 fahrenheit", False, ["celsius", "fahrenheit"])
"properties": { ]
"city": { ```
"type": "string", #### ❌ 避免的参数设计
"description": "城市名称,如:北京、上海" ```python
}, parameters = [
"unit": { ("data", "string", "数据", True) # 参数过于模糊
"type": "string", ]
"description": "温度单位celsius(摄氏度) 或 fahrenheit(华氏度)",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["city"]
}
# ❌ 避免的参数设计
parameters = {
"type": "object",
"properties": {
"data": {
"type": "string",
"description": "数据" # 描述不清晰
}
}
}
``` ```
### 4. 结果格式化 ### 4. 结果格式化
#### ✅ 良好的结果格式
```python ```python
# ✅ 良好的结果格式
def _format_result(self, data): def _format_result(self, data):
return f""" return f"""
🔍 查询结果 🔍 查询结果
@ -307,12 +238,9 @@ def _format_result(self, data):
📝 说明: {data['description']} 📝 说明: {data['description']}
━━━━━━━━━━━━ ━━━━━━━━━━━━
""".strip() """.strip()
```
# ❌ 避免的结果格式 #### ❌ 避免的结果格式
```python
def _format_result(self, data): def _format_result(self, data):
return str(data) # 直接返回原始数据 return str(data) # 直接返回原始数据
``` ```
---
🎉 **工具系统为麦麦提供了强大的信息获取能力!合理使用工具可以让麦麦变得更加智能和博学。**

View File

@ -1,18 +1,55 @@
from typing import List, Tuple, Type from typing import List, Tuple, Type, Any
from src.plugin_system import ( from src.plugin_system import (
BasePlugin, BasePlugin,
register_plugin, register_plugin,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BaseTool,
ComponentInfo, ComponentInfo,
ActionActivationType, ActionActivationType,
ConfigField, ConfigField,
BaseEventHandler, BaseEventHandler,
EventType, EventType,
MaiMessages, MaiMessages,
ToolParamType
) )
class CompareNumbersTool(BaseTool):
"""比较两个数大小的工具"""
name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数"
parameters = [
("num1", ToolParamType.FLOAT, "第一个数字", True, None),
("num2", ToolParamType.FLOAT, "第二个数字", True, None),
]
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小
Args:
function_args: 工具参数
Returns:
dict: 工具执行结果
"""
num1: int | float = function_args.get("num1") # type: ignore
num2: int | float = function_args.get("num2") # type: ignore
try:
if num1 > num2:
result = f"{num1} 大于 {num2}"
elif num1 < num2:
result = f"{num1} 小于 {num2}"
else:
result = f"{num1} 等于 {num2}"
return {"name": self.name, "content": result}
except Exception as e:
return {"name": self.name, "content": f"比较数字失败,炸了: {str(e)}"}
# ===== Action组件 ===== # ===== Action组件 =====
class HelloAction(BaseAction): class HelloAction(BaseAction):
"""问候Action - 简单的问候动作""" """问候Action - 简单的问候动作"""
@ -132,7 +169,9 @@ class HelloWorldPlugin(BasePlugin):
"enabled": ConfigField(type=bool, default=False, description="是否启用插件"), "enabled": ConfigField(type=bool, default=False, description="是否启用插件"),
}, },
"greeting": { "greeting": {
"message": ConfigField(type=str, default="嗨!很开心见到你!😊", description="默认问候消息"), "message": ConfigField(
type=list, default=["嗨!很开心见到你!😊", "Ciallo(∠・ω< )⌒★"], description="默认问候消息"
),
"enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"), "enable_emoji": ConfigField(type=bool, default=True, description="是否启用表情符号"),
}, },
"time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")}, "time": {"format": ConfigField(type=str, default="%Y-%m-%d %H:%M:%S", description="时间显示格式")},
@ -142,6 +181,7 @@ class HelloWorldPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]: def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
return [ return [
(HelloAction.get_action_info(), HelloAction), (HelloAction.get_action_info(), HelloAction),
(CompareNumbersTool.get_tool_info(), CompareNumbersTool), # 添加比较数字工具
(ByeAction.get_action_info(), ByeAction), # 添加告别Action (ByeAction.get_action_info(), ByeAction), # 添加告别Action
(TimeCommand.get_command_info(), TimeCommand), (TimeCommand.get_command_info(), TimeCommand),
(PrintMessage.get_handler_info(), PrintMessage), (PrintMessage.get_handler_info(), PrintMessage),

View File

@ -15,6 +15,7 @@ matplotlib
networkx networkx
numpy numpy
openai openai
google-genai
pandas pandas
peewee peewee
pyarrow pyarrow

View File

@ -24,46 +24,6 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入") logger = get_logger("OpenIE导入")
ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"):
load_dotenv(".env", override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_openie_dir(): def ensure_openie_dir():
"""确保OpenIE数据目录存在""" """确保OpenIE数据目录存在"""
if not os.path.exists(OPENIE_DIR): if not os.path.exists(OPENIE_DIR):
@ -214,8 +174,6 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
def main(): # sourcery skip: dict-comprehension def main(): # sourcery skip: dict-comprehension
# 新增确认提示 # 新增确认提示
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
print("=== 重要操作确认 ===") print("=== 重要操作确认 ===")
print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型") print("OpenIE导入时会大量发送请求可能会撞到请求速度上限请注意选用的模型")
print("同之前样例在本地模型下在70分钟内我们发送了约8万条请求在网络允许下速度会更快") print("同之前样例在本地模型下在70分钟内我们发送了约8万条请求在网络允许下速度会更快")

View File

@ -25,9 +25,8 @@ from rich.progress import (
TextColumn, TextColumn,
) )
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from dotenv import load_dotenv
logger = get_logger("LPMM知识库-信息提取") logger = get_logger("LPMM知识库-信息提取")
@ -36,45 +35,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp") TEMP_DIR = os.path.join(ROOT_PATH, "temp")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
ENV_FILE = os.path.join(ROOT_PATH, ".env")
if os.path.exists(".env"):
load_dotenv(".env", override=True)
print("成功加载环境变量配置")
else:
print("未找到.env文件请确保程序所需的环境变量被正确设置")
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
env_mask = {key: os.getenv(key) for key in os.environ}
def scan_provider(env_config: dict):
provider = {}
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
# 避免 GPG_KEY 这样的变量干扰检查
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
# 遍历 env_config 的所有键
for key in env_config:
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
# 提取 provider 名称
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
# 初始化 provider 的字典(如果尚未初始化)
if provider_name not in provider:
provider[provider_name] = {"url": None, "key": None}
# 根据键的类型填充 url 或 key
if key.endswith("_BASE_URL"):
provider[provider_name]["url"] = env_config[key]
elif key.endswith("_KEY"):
provider[provider_name]["key"] = env_config[key]
# 检查每个 provider 是否同时存在 url 和 key
for provider_name, config in provider.items():
if config["url"] is None or config["key"] is None:
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
def ensure_dirs(): def ensure_dirs():
"""确保临时目录和输出目录存在""" """确保临时目录和输出目录存在"""
@ -96,11 +56,11 @@ open_ie_doc_lock = Lock()
shutdown_event = Event() shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest( lpmm_entity_extract_llm = LLMRequest(
model=global_config.model.lpmm_entity_extract, model_set=model_config.model_task_config.lpmm_entity_extract,
request_type="lpmm.entity_extract" request_type="lpmm.entity_extract"
) )
lpmm_rdf_build_llm = LLMRequest( lpmm_rdf_build_llm = LLMRequest(
model=global_config.model.lpmm_rdf_build, model_set=model_config.model_task_config.lpmm_rdf_build,
request_type="lpmm.rdf_build" request_type="lpmm.rdf_build"
) )
def process_single_text(pg_hash, raw_data): def process_single_text(pg_hash, raw_data):
@ -158,8 +118,6 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器 # 设置信号处理器
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在 ensure_dirs() # 确保目录存在
env_config = {key: os.getenv(key) for key in os.environ}
scan_provider(env_config)
# 新增用户确认提示 # 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。") print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")

View File

@ -414,7 +414,7 @@ class HeartFChatting:
else: else:
logger.warning(f"{self.log_prefix} 预生成的回复任务未生成有效内容") logger.warning(f"{self.log_prefix} 预生成的回复任务未生成有效内容")
action_message: Dict[str, Any] = message_data or target_message # type: ignore action_message = message_data or target_message
if action_type == "reply": if action_type == "reply":
# 等待回复生成完毕 # 等待回复生成完毕
if self.loop_mode == ChatMode.NORMAL: if self.loop_mode == ChatMode.NORMAL:

View File

@ -8,15 +8,15 @@ import traceback
import io import io
import re import re
import binascii import binascii
from typing import Optional, Tuple, List, Any from typing import Optional, Tuple, List, Any
from PIL import Image from PIL import Image
from rich.traceback import install from rich.traceback import install
from src.common.database.database_model import Emoji from src.common.database.database_model import Emoji
from src.common.database.database import db as peewee_db from src.common.database.database import db as peewee_db
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.utils.utils_image import image_path_to_base64, get_image_manager from src.chat.utils.utils_image import image_path_to_base64, get_image_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@ -379,9 +379,9 @@ class EmojiManager:
self._scan_task = None self._scan_task = None
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.3, max_tokens=1000, request_type="emoji") self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="emoji")
self.llm_emotion_judge = LLMRequest( self.llm_emotion_judge = LLMRequest(
model=global_config.model.utils, max_tokens=600, request_type="emoji" model_set=model_config.model_task_config.utils, request_type="emoji"
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
self.emoji_num = 0 self.emoji_num = 0
@ -492,6 +492,7 @@ class EmojiManager:
return None return None
def _levenshtein_distance(self, s1: str, s2: str) -> int: def _levenshtein_distance(self, s1: str, s2: str) -> int:
# sourcery skip: simplify-empty-collection-comparison, simplify-len-comparison, simplify-str-len-comparison
"""计算两个字符串的编辑距离 """计算两个字符串的编辑距离
Args: Args:
@ -629,11 +630,11 @@ class EmojiManager:
if success: if success:
# 注册成功则跳出循环 # 注册成功则跳出循环
break break
else:
# 注册失败则删除对应文件 # 注册失败则删除对应文件
file_path = os.path.join(EMOJI_DIR, filename) file_path = os.path.join(EMOJI_DIR, filename)
os.remove(file_path) os.remove(file_path)
logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}") logger.warning(f"[清理] 删除注册失败的表情包文件: {filename}")
except Exception as e: except Exception as e:
logger.error(f"[错误] 扫描表情包目录失败: {str(e)}") logger.error(f"[错误] 扫描表情包目录失败: {str(e)}")
@ -694,6 +695,7 @@ class EmojiManager:
return [] return []
async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]: async def get_emoji_from_manager(self, emoji_hash: str) -> Optional["MaiEmoji"]:
# sourcery skip: use-next
"""从内存中的 emoji_objects 列表获取表情包 """从内存中的 emoji_objects 列表获取表情包
参数: 参数:
@ -709,10 +711,10 @@ class EmojiManager:
async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]: async def get_emoji_description_by_hash(self, emoji_hash: str) -> Optional[str]:
"""根据哈希值获取已注册表情包的描述 """根据哈希值获取已注册表情包的描述
Args: Args:
emoji_hash: 表情包的哈希值 emoji_hash: 表情包的哈希值
Returns: Returns:
Optional[str]: 表情包描述如果未找到则返回None Optional[str]: 表情包描述如果未找到则返回None
""" """
@ -722,7 +724,7 @@ class EmojiManager:
if emoji and emoji.description: if emoji and emoji.description:
logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...") logger.info(f"[缓存命中] 从内存获取表情包描述: {emoji.description[:50]}...")
return emoji.description return emoji.description
# 如果内存中没有,从数据库查找 # 如果内存中没有,从数据库查找
self._ensure_db() self._ensure_db()
try: try:
@ -732,9 +734,9 @@ class EmojiManager:
return emoji_record.description return emoji_record.description
except Exception as e: except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}") logger.error(f"从数据库查询表情包描述时出错: {e}")
return None return None
except Exception as e: except Exception as e:
logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}") logger.error(f"获取表情包描述失败 (Hash: {emoji_hash}): {str(e)}")
return None return None
@ -779,6 +781,7 @@ class EmojiManager:
return False return False
async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool: async def replace_a_emoji(self, new_emoji: "MaiEmoji") -> bool:
# sourcery skip: use-getitem-for-re-match-groups
"""替换一个表情包 """替换一个表情包
Args: Args:
@ -820,7 +823,7 @@ class EmojiManager:
) )
# 调用大模型进行决策 # 调用大模型进行决策
decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8) decision, _ = await self.llm_emotion_judge.generate_response_async(prompt, temperature=0.8, max_tokens=600)
logger.info(f"[决策] 结果: {decision}") logger.info(f"[决策] 结果: {decision}")
# 解析决策结果 # 解析决策结果
@ -828,9 +831,7 @@ class EmojiManager:
logger.info("[决策] 不删除任何表情包") logger.info("[决策] 不删除任何表情包")
return False return False
# 尝试从决策中提取表情包编号 if match := re.search(r"删除编号(\d+)", decision):
match = re.search(r"删除编号(\d+)", decision)
if match:
emoji_index = int(match.group(1)) - 1 # 转换为0-based索引 emoji_index = int(match.group(1)) - 1 # 转换为0-based索引
# 检查索引是否有效 # 检查索引是否有效
@ -889,6 +890,7 @@ class EmojiManager:
existing_description = None existing_description = None
try: try:
from src.common.database.database_model import Images from src.common.database.database_model import Images
existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji")) existing_image = Images.get_or_none((Images.emoji_hash == image_hash) & (Images.type == "emoji"))
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description
@ -902,15 +904,21 @@ class EmojiManager:
logger.info("[优化] 复用已有的详细描述跳过VLM调用") logger.info("[优化] 复用已有的详细描述跳过VLM调用")
else: else:
logger.info("[VLM分析] 生成新的详细描述") logger.info("[VLM分析] 生成新的详细描述")
if image_format == "gif" or image_format == "GIF": if image_format in ["gif", "GIF"]:
image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore image_base64 = get_image_manager().transform_gif(image_base64) # type: ignore
if not image_base64: if not image_base64:
raise RuntimeError("GIF表情包转换失败") raise RuntimeError("GIF表情包转换失败")
prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, "jpg") description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, "jpg", temperature=0.3, max_tokens=1000
)
else: else:
prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" prompt = (
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
)
# 审核表情包 # 审核表情包
if global_config.emoji.content_filtration: if global_config.emoji.content_filtration:
@ -922,7 +930,9 @@ class EmojiManager:
4. 不要出现5个以上文字 4. 不要出现5个以上文字
请回答这个表情包是否满足上述要求是则回答是否则回答否不要出现任何其他内容 请回答这个表情包是否满足上述要求是则回答是否则回答否不要出现任何其他内容
''' '''
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) content, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.3, max_tokens=1000
)
if content == "": if content == "":
return "", [] return "", []
@ -933,7 +943,9 @@ class EmojiManager:
你可以关注其幽默和讽刺意味动用贴吧微博小红书的知识必须从互联网梗,meme的角度去分析 你可以关注其幽默和讽刺意味动用贴吧微博小红书的知识必须从互联网梗,meme的角度去分析
请直接输出描述不要出现任何其他内容如果有多个描述可以用逗号分隔 请直接输出描述不要出现任何其他内容如果有多个描述可以用逗号分隔
""" """
emotions_text, _ = await self.llm_emotion_judge.generate_response_async(emotion_prompt, temperature=0.7) emotions_text, _ = await self.llm_emotion_judge.generate_response_async(
emotion_prompt, temperature=0.7, max_tokens=600
)
# 处理情感列表 # 处理情感列表
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()] emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]

View File

@ -7,12 +7,12 @@ from datetime import datetime
from typing import List, Dict, Optional, Any, Tuple from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import model_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_random, build_anonymous_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.database.database_model import Expression
MAX_EXPRESSION_COUNT = 300 MAX_EXPRESSION_COUNT = 300
@ -80,11 +80,8 @@ def init_prompt() -> None:
class ExpressionLearner: class ExpressionLearner:
def __init__(self) -> None: def __init__(self) -> None:
# TODO: API-Adapter修改标记
self.express_learn_model: LLMRequest = LLMRequest( self.express_learn_model: LLMRequest = LLMRequest(
model=global_config.model.replyer_1, model_set=model_config.model_task_config.replyer_1, request_type="expressor.learner"
temperature=0.3,
request_type="expressor.learner",
) )
self.llm_model = None self.llm_model = None
self._ensure_expression_directories() self._ensure_expression_directories()
@ -101,7 +98,7 @@ class ExpressionLearner:
os.path.join(base_dir, "learnt_style"), os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"), os.path.join(base_dir, "learnt_grammar"),
] ]
for directory in directories_to_create: for directory in directories_to_create:
try: try:
os.makedirs(directory, exist_ok=True) os.makedirs(directory, exist_ok=True)
@ -116,7 +113,7 @@ class ExpressionLearner:
""" """
base_dir = os.path.join("data", "expression") base_dir = os.path.join("data", "expression")
done_flag = os.path.join(base_dir, "done.done") done_flag = os.path.join(base_dir, "done.done")
# 确保基础目录存在 # 确保基础目录存在
try: try:
os.makedirs(base_dir, exist_ok=True) os.makedirs(base_dir, exist_ok=True)
@ -124,28 +121,28 @@ class ExpressionLearner:
except Exception as e: except Exception as e:
logger.error(f"创建表达方式目录失败: {e}") logger.error(f"创建表达方式目录失败: {e}")
return return
if os.path.exists(done_flag): if os.path.exists(done_flag):
logger.info("表达方式JSON已迁移无需重复迁移。") logger.info("表达方式JSON已迁移无需重复迁移。")
return return
logger.info("开始迁移表达方式JSON到数据库...") logger.info("开始迁移表达方式JSON到数据库...")
migrated_count = 0 migrated_count = 0
for type in ["learnt_style", "learnt_grammar"]: for type in ["learnt_style", "learnt_grammar"]:
type_str = "style" if type == "learnt_style" else "grammar" type_str = "style" if type == "learnt_style" else "grammar"
type_dir = os.path.join(base_dir, type) type_dir = os.path.join(base_dir, type)
if not os.path.exists(type_dir): if not os.path.exists(type_dir):
logger.debug(f"目录不存在,跳过: {type_dir}") logger.debug(f"目录不存在,跳过: {type_dir}")
continue continue
try: try:
chat_ids = os.listdir(type_dir) chat_ids = os.listdir(type_dir)
logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录") logger.debug(f"{type_dir} 中找到 {len(chat_ids)} 个聊天ID目录")
except Exception as e: except Exception as e:
logger.error(f"读取目录失败 {type_dir}: {e}") logger.error(f"读取目录失败 {type_dir}: {e}")
continue continue
for chat_id in chat_ids: for chat_id in chat_ids:
expr_file = os.path.join(type_dir, chat_id, "expressions.json") expr_file = os.path.join(type_dir, chat_id, "expressions.json")
if not os.path.exists(expr_file): if not os.path.exists(expr_file):
@ -153,24 +150,24 @@ class ExpressionLearner:
try: try:
with open(expr_file, "r", encoding="utf-8") as f: with open(expr_file, "r", encoding="utf-8") as f:
expressions = json.load(f) expressions = json.load(f)
if not isinstance(expressions, list): if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}") logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
continue continue
for expr in expressions: for expr in expressions:
if not isinstance(expr, dict): if not isinstance(expr, dict):
continue continue
situation = expr.get("situation") situation = expr.get("situation")
style_val = expr.get("style") style_val = expr.get("style")
count = expr.get("count", 1) count = expr.get("count", 1)
last_active_time = expr.get("last_active_time", time.time()) last_active_time = expr.get("last_active_time", time.time())
if not situation or not style_val: if not situation or not style_val:
logger.warning(f"表达方式缺少必要字段,跳过: {expr}") logger.warning(f"表达方式缺少必要字段,跳过: {expr}")
continue continue
# 查重同chat_id+type+situation+style # 查重同chat_id+type+situation+style
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
@ -201,7 +198,7 @@ class ExpressionLearner:
logger.error(f"JSON解析失败 {expr_file}: {e}") logger.error(f"JSON解析失败 {expr_file}: {e}")
except Exception as e: except Exception as e:
logger.error(f"迁移表达方式 {expr_file} 失败: {e}") logger.error(f"迁移表达方式 {expr_file} 失败: {e}")
# 标记迁移完成 # 标记迁移完成
try: try:
# 确保done.done文件的父目录存在 # 确保done.done文件的父目录存在
@ -209,7 +206,7 @@ class ExpressionLearner:
if not os.path.exists(done_parent_dir): if not os.path.exists(done_parent_dir):
os.makedirs(done_parent_dir, exist_ok=True) os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}") logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f: with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n") f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件") logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
@ -229,13 +226,13 @@ class ExpressionLearner:
# 查找所有create_date为空的表达方式 # 查找所有create_date为空的表达方式
old_expressions = Expression.select().where(Expression.create_date.is_null()) old_expressions = Expression.select().where(Expression.create_date.is_null())
updated_count = 0 updated_count = 0
for expr in old_expressions: for expr in old_expressions:
# 使用last_active_time作为create_date # 使用last_active_time作为create_date
expr.create_date = expr.last_active_time expr.create_date = expr.last_active_time
expr.save() expr.save()
updated_count += 1 updated_count += 1
if updated_count > 0: if updated_count > 0:
logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期") logger.info(f"已为 {updated_count} 个老的表达方式设置创建日期")
except Exception as e: except Exception as e:
@ -287,25 +284,29 @@ class ExpressionLearner:
获取指定chat_id的表达方式创建信息按创建日期排序 获取指定chat_id的表达方式创建信息按创建日期排序
""" """
try: try:
expressions = (Expression.select() expressions = (
.where(Expression.chat_id == chat_id) Expression.select()
.order_by(Expression.create_date.desc()) .where(Expression.chat_id == chat_id)
.limit(limit)) .order_by(Expression.create_date.desc())
.limit(limit)
)
result = [] result = []
for expr in expressions: for expr in expressions:
create_date = expr.create_date if expr.create_date is not None else expr.last_active_time create_date = expr.create_date if expr.create_date is not None else expr.last_active_time
result.append({ result.append(
"situation": expr.situation, {
"style": expr.style, "situation": expr.situation,
"type": expr.type, "style": expr.style,
"count": expr.count, "type": expr.type,
"create_date": create_date, "count": expr.count,
"create_date_formatted": format_create_date(create_date), "create_date": create_date,
"last_active_time": expr.last_active_time, "create_date_formatted": format_create_date(create_date),
"last_active_formatted": format_create_date(expr.last_active_time), "last_active_time": expr.last_active_time,
}) "last_active_formatted": format_create_date(expr.last_active_time),
}
)
return result return result
except Exception as e: except Exception as e:
logger.error(f"获取表达方式创建信息失败: {e}") logger.error(f"获取表达方式创建信息失败: {e}")
@ -355,19 +356,19 @@ class ExpressionLearner:
try: try:
# 获取所有表达方式 # 获取所有表达方式
all_expressions = Expression.select() all_expressions = Expression.select()
updated_count = 0 updated_count = 0
deleted_count = 0 deleted_count = 0
for expr in all_expressions: for expr in all_expressions:
# 计算时间差 # 计算时间差
last_active = expr.last_active_time last_active = expr.last_active_time
time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天 time_diff_days = (current_time - last_active) / (24 * 3600) # 转换为天
# 计算衰减值 # 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days) decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value) new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01: if new_count <= 0.01:
# 如果count太小删除这个表达方式 # 如果count太小删除这个表达方式
expr.delete_instance() expr.delete_instance()
@ -377,10 +378,10 @@ class ExpressionLearner:
expr.count = new_count expr.count = new_count
expr.save() expr.save()
updated_count += 1 updated_count += 1
if updated_count > 0 or deleted_count > 0: if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
except Exception as e: except Exception as e:
logger.error(f"数据库全局衰减失败: {e}") logger.error(f"数据库全局衰减失败: {e}")
@ -527,7 +528,7 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的prompt: {prompt}") logger.debug(f"学习{type_str}的prompt: {prompt}")
try: try:
response, _ = await self.express_learn_model.generate_response_async(prompt) response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e: except Exception as e:
logger.error(f"学习{type_str}失败: {e}") logger.error(f"学习{type_str}失败: {e}")
return None return None

View File

@ -1,16 +1,17 @@
import json import json
import time import time
import random import random
import hashlib
from typing import List, Dict, Tuple, Optional, Any from typing import List, Dict, Tuple, Optional, Any
from json_repair import repair_json from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .expression_learner import get_expression_learner from .expression_learner import get_expression_learner
from src.common.database.database_model import Expression
logger = get_logger("expression_selector") logger = get_logger("expression_selector")
@ -75,10 +76,8 @@ def weighted_sample(population: List[Dict], weights: List[float], k: int) -> Lis
class ExpressionSelector: class ExpressionSelector:
def __init__(self): def __init__(self):
self.expression_learner = get_expression_learner() self.expression_learner = get_expression_learner()
# TODO: API-Adapter修改标记
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.model.utils_small, model_set=model_config.model_task_config.utils_small, request_type="expression.selector"
request_type="expression.selector",
) )
@staticmethod @staticmethod
@ -92,7 +91,6 @@ class ExpressionSelector:
id_str = parts[1] id_str = parts[1]
stream_type = parts[2] stream_type = parts[2]
is_group = stream_type == "group" is_group = stream_type == "group"
import hashlib
if is_group: if is_group:
components = [platform, str(id_str)] components = [platform, str(id_str)]
else: else:
@ -108,8 +106,7 @@ class ExpressionSelector:
for group in groups: for group in groups:
group_chat_ids = [] group_chat_ids = []
for stream_config_str in group: for stream_config_str in group:
chat_id_candidate = self._parse_stream_config_to_chat_id(stream_config_str) if chat_id_candidate := self._parse_stream_config_to_chat_id(stream_config_str):
if chat_id_candidate:
group_chat_ids.append(chat_id_candidate) group_chat_ids.append(chat_id_candidate)
if chat_id in group_chat_ids: if chat_id in group_chat_ids:
return group_chat_ids return group_chat_ids
@ -118,9 +115,10 @@ class ExpressionSelector:
def get_random_expressions( def get_random_expressions(
self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float self, chat_id: str, total_num: int, style_percentage: float, grammar_percentage: float
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
# sourcery skip: extract-duplicate-method, move-assign
# 支持多chat_id合并抽选 # 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id) related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式 # 优化一次性查询所有相关chat_id的表达方式
style_query = Expression.select().where( style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style") (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "style")
@ -128,7 +126,7 @@ class ExpressionSelector:
grammar_query = Expression.select().where( grammar_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar") (Expression.chat_id.in_(related_chat_ids)) & (Expression.type == "grammar")
) )
style_exprs = [ style_exprs = [
{ {
"situation": expr.situation, "situation": expr.situation,
@ -138,9 +136,10 @@ class ExpressionSelector:
"source_id": expr.chat_id, "source_id": expr.chat_id,
"type": "style", "type": "style",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in style_query }
for expr in style_query
] ]
grammar_exprs = [ grammar_exprs = [
{ {
"situation": expr.situation, "situation": expr.situation,
@ -150,9 +149,10 @@ class ExpressionSelector:
"source_id": expr.chat_id, "source_id": expr.chat_id,
"type": "grammar", "type": "grammar",
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} for expr in grammar_query }
for expr in grammar_query
] ]
style_num = int(total_num * style_percentage) style_num = int(total_num * style_percentage)
grammar_num = int(total_num * grammar_percentage) grammar_num = int(total_num * grammar_percentage)
# 按权重抽样使用count作为权重 # 按权重抽样使用count作为权重
@ -174,22 +174,22 @@ class ExpressionSelector:
return return
updates_by_key = {} updates_by_key = {}
for expr in expressions_to_update: for expr in expressions_to_update:
source_id = expr.get("source_id") source_id: str = expr.get("source_id") # type: ignore
expr_type = expr.get("type", "style") expr_type: str = expr.get("type", "style")
situation = expr.get("situation") situation: str = expr.get("situation") # type: ignore
style = expr.get("style") style: str = expr.get("style") # type: ignore
if not source_id or not situation or not style: if not source_id or not situation or not style:
logger.warning(f"表达方式缺少必要字段,无法更新: {expr}") logger.warning(f"表达方式缺少必要字段,无法更新: {expr}")
continue continue
key = (source_id, expr_type, situation, style) key = (source_id, expr_type, situation, style)
if key not in updates_by_key: if key not in updates_by_key:
updates_by_key[key] = expr updates_by_key[key] = expr
for (chat_id, expr_type, situation, style), _expr in updates_by_key.items(): for chat_id, expr_type, situation, style in updates_by_key:
query = Expression.select().where( query = Expression.select().where(
(Expression.chat_id == chat_id) & (Expression.chat_id == chat_id)
(Expression.type == expr_type) & & (Expression.type == expr_type)
(Expression.situation == situation) & & (Expression.situation == situation)
(Expression.style == style) & (Expression.style == style)
) )
if query.exists(): if query.exists():
expr_obj = query.get() expr_obj = query.get()
@ -264,7 +264,7 @@ class ExpressionSelector:
# 4. 调用LLM # 4. 调用LLM
try: try:
content, (_, _) = await self.llm_model.generate_response_async(prompt=prompt) content, _ = await self.llm_model.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix} LLM返回结果: {content}") # logger.info(f"{self.log_prefix} LLM返回结果: {content}")

View File

@ -3,6 +3,7 @@ import json
import os import os
import math import math
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np import numpy as np
@ -11,8 +12,6 @@ import pandas as pd
# import tqdm # import tqdm
import faiss import faiss
# from .llm_client import LLMClient
# from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install from rich.traceback import install
@ -26,12 +25,20 @@ from rich.progress import (
SpinnerColumn, SpinnerColumn,
TextColumn, TextColumn,
) )
from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config
install(extra_lines=3) install(extra_lines=3)
# 多线程embedding配置常量
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
MIN_CHUNK_SIZE = 1 # 最小分块大小
MAX_CHUNK_SIZE = 50 # 最大分块大小
MIN_WORKERS = 1 # 最小线程数
MAX_WORKERS = 20 # 最大线程数
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding") EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/") EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
@ -87,13 +94,23 @@ class EmbeddingStoreItem:
class EmbeddingStore: class EmbeddingStore:
def __init__(self, namespace: str, dir_path: str): def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
self.namespace = namespace self.namespace = namespace
self.dir = dir_path self.dir = dir_path
self.embedding_file_path = f"{dir_path}/{namespace}.parquet" self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index" self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json" self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
# 多线程配置参数验证和设置
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
# 如果配置值被调整,记录日志
if self.max_workers != max_workers:
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
if self.chunk_size != chunk_size:
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
self.store = {} self.store = {}
self.faiss_index = None self.faiss_index = None
@ -125,16 +142,134 @@ class EmbeddingStore:
return [] return []
return result return result
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
"""使用多线程批量获取嵌入向量
Args:
strs: 要获取嵌入的字符串列表
chunk_size: 每个线程处理的数据块大小
max_workers: 最大线程数
progress_callback: 进度回调函数接收一个参数表示完成的数量
Returns:
包含(原始字符串, 嵌入向量)的元组列表保持与输入顺序一致
"""
if not strs:
return []
# 分块
chunks = []
for i in range(0, len(strs), chunk_size):
chunk = strs[i:i + chunk_size]
chunks.append((i, chunk)) # 保存起始索引以维持顺序
# 结果存储,使用字典按索引存储以保证顺序
results = {}
def process_chunk(chunk_data):
"""处理单个数据块的函数"""
start_idx, chunk_strs = chunk_data
chunk_results = []
# 为每个线程创建独立的LLMRequest实例
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config
try:
# 创建线程专用的LLM实例
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
for i, s in enumerate(chunk_strs):
try:
# 直接使用异步函数
embedding = asyncio.run(llm.get_embedding(s))
if embedding and len(embedding) > 0:
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
else:
logger.error(f"获取嵌入失败: {s}")
chunk_results.append((start_idx + i, s, []))
# 每完成一个嵌入立即更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
except Exception as e:
logger.error(f"创建LLM实例失败: {e}")
# 如果创建LLM实例失败返回空结果
for i, s in enumerate(chunk_strs):
chunk_results.append((start_idx + i, s, []))
# 即使失败也要更新进度
if progress_callback:
progress_callback(1)
return chunk_results
# 使用线程池处理
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
# 收集结果进度已在process_chunk中实时更新
for future in as_completed(future_to_chunk):
try:
chunk_results = future.result()
for idx, s, embedding in chunk_results:
results[idx] = (s, embedding)
except Exception as e:
chunk = future_to_chunk[future]
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
# 为失败的块添加空结果
start_idx, chunk_strs = chunk
for i, s in enumerate(chunk_strs):
results[start_idx + i] = (s, [])
# 按原始顺序返回结果
ordered_results = []
for i in range(len(strs)):
if i in results:
ordered_results.append(results[i])
else:
# 防止遗漏
ordered_results.append((strs[i], []))
return ordered_results
def get_test_file_path(self): def get_test_file_path(self):
return EMBEDDING_TEST_FILE return EMBEDDING_TEST_FILE
def save_embedding_test_vectors(self): def save_embedding_test_vectors(self):
"""保存测试字符串的嵌入到本地""" """保存测试字符串的嵌入到本地(使用多线程优化)"""
logger.info("开始保存测试字符串的嵌入向量...")
# 使用多线程批量获取测试字符串的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
)
# 构建测试向量字典
test_vectors = {} test_vectors = {}
for idx, s in enumerate(EMBEDDING_TEST_STRINGS): for idx, (s, embedding) in enumerate(embedding_results):
test_vectors[str(idx)] = self._get_embedding(s) if embedding:
test_vectors[str(idx)] = embedding
else:
logger.error(f"获取测试字符串嵌入失败: {s}")
# 使用原始单线程方法作为后备
test_vectors[str(idx)] = self._get_embedding(s)
with open(self.get_test_file_path(), "w", encoding="utf-8") as f: with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
json.dump(test_vectors, f, ensure_ascii=False, indent=2) json.dump(test_vectors, f, ensure_ascii=False, indent=2)
logger.info("测试字符串嵌入向量保存完成")
def load_embedding_test_vectors(self): def load_embedding_test_vectors(self):
"""加载本地保存的测试字符串嵌入""" """加载本地保存的测试字符串嵌入"""
@ -145,29 +280,64 @@ class EmbeddingStore:
return json.load(f) return json.load(f)
def check_embedding_model_consistency(self): def check_embedding_model_consistency(self):
"""校验当前模型与本地嵌入模型是否一致""" """校验当前模型与本地嵌入模型是否一致(使用多线程优化)"""
local_vectors = self.load_embedding_test_vectors() local_vectors = self.load_embedding_test_vectors()
if local_vectors is None: if local_vectors is None:
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。") logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
self.save_embedding_test_vectors() self.save_embedding_test_vectors()
return True return True
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
local_emb = local_vectors.get(str(idx)) # 检查本地向量完整性
if local_emb is None: for idx in range(len(EMBEDDING_TEST_STRINGS)):
if local_vectors.get(str(idx)) is None:
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。") logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
self.save_embedding_test_vectors() self.save_embedding_test_vectors()
return True return True
new_emb = self._get_embedding(s)
logger.info("开始检验嵌入模型一致性...")
# 使用多线程批量获取当前模型的嵌入
embedding_results = self._get_embeddings_batch_threaded(
EMBEDDING_TEST_STRINGS,
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
)
# 检查一致性
for idx, (s, new_emb) in enumerate(embedding_results):
local_emb = local_vectors.get(str(idx))
if not new_emb:
logger.error(f"获取测试字符串嵌入失败: {s}")
return False
sim = cosine_similarity(local_emb, new_emb) sim = cosine_similarity(local_emb, new_emb)
if sim < EMBEDDING_SIM_THRESHOLD: if sim < EMBEDDING_SIM_THRESHOLD:
logger.error("嵌入模型一致性校验失败") logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
return False return False
logger.info("嵌入模型一致性校验通过。") logger.info("嵌入模型一致性校验通过。")
return True return True
def batch_insert_strs(self, strs: List[str], times: int) -> None: def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串""" """向库中存入字符串(使用多线程优化)"""
if not strs:
return
total = len(strs) total = len(strs)
# 过滤已存在的字符串
new_strs = []
for s in strs:
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash not in self.store:
new_strs.append(s)
if not new_strs:
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
return
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
with Progress( with Progress(
SpinnerColumn(), SpinnerColumn(),
TextColumn("[progress.description]{task.description}"), TextColumn("[progress.description]{task.description}"),
@ -181,19 +351,38 @@ class EmbeddingStore:
transient=False, transient=False,
) as progress: ) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total) task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重 # 首先更新已存在项的进度
item_hash = self.namespace + "-" + get_sha256(s) already_processed = total - len(new_strs)
if item_hash in self.store: if already_processed > 0:
progress.update(task, advance=1) progress.update(task, advance=already_processed)
continue
if new_strs:
# 获取embedding # 使用实例配置的参数,智能调整分块和线程数
embedding = self._get_embedding(s) optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
progress.update(task, advance=1)
# 定义进度更新回调函数
def update_progress(count):
progress.update(task, advance=count)
# 批量获取嵌入,并实时更新进度
embedding_results = self._get_embeddings_batch_threaded(
new_strs,
chunk_size=optimal_chunk_size,
max_workers=optimal_max_workers,
progress_callback=update_progress
)
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
for s, embedding in embedding_results:
item_hash = self.namespace + "-" + get_sha256(s)
if embedding: # 只有成功获取到嵌入才存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
else:
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
def save_to_file(self) -> None: def save_to_file(self) -> None:
"""保存到文件""" """保存到文件"""
@ -316,31 +505,37 @@ class EmbeddingStore:
class EmbeddingManager: class EmbeddingManager:
def __init__(self): def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
"""
初始化EmbeddingManager
Args:
max_workers: 最大线程数
chunk_size: 每个线程处理的数据块大小
"""
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "paragraph", # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
max_workers=max_workers,
chunk_size=chunk_size,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "entity", # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
max_workers=max_workers,
chunk_size=chunk_size,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
local_storage["pg_namespace"], # type: ignore "relation", # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
max_workers=max_workers,
chunk_size=chunk_size,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()
def check_all_embedding_model_consistency(self): def check_all_embedding_model_consistency(self):
"""对所有嵌入库做模型一致性校验""" """对所有嵌入库做模型一致性校验"""
for store in [ return self.paragraphs_embedding_store.check_embedding_model_consistency()
self.paragraphs_embedding_store,
self.entities_embedding_store,
self.relation_embedding_store,
]:
if not store.check_embedding_model_consistency():
return False
return True
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库""" """将段落编码存入Embedding库"""

View File

@ -8,12 +8,15 @@ from . import prompt_template
from .knowledge_lib import INVALID_ENTITY from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json from json_repair import repair_json
def _extract_json_from_text(text: str): def _extract_json_from_text(text: str):
# sourcery skip: assign-if-exp, extract-method
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
if text is None: if text is None:
logger.error("输入文本为None") logger.error("输入文本为None")
return [] return []
try: try:
fixed_json = repair_json(text) fixed_json = repair_json(text)
if isinstance(fixed_json, str): if isinstance(fixed_json, str):
@ -24,7 +27,7 @@ def _extract_json_from_text(text: str):
# 如果是列表,直接返回 # 如果是列表,直接返回
if isinstance(parsed_json, list): if isinstance(parsed_json, list):
return parsed_json return parsed_json
# 如果是字典且只有一个项目,可能包装了列表 # 如果是字典且只有一个项目,可能包装了列表
if isinstance(parsed_json, dict): if isinstance(parsed_json, dict):
# 如果字典只有一个键,并且值是列表,返回那个列表 # 如果字典只有一个键,并且值是列表,返回那个列表
@ -33,7 +36,7 @@ def _extract_json_from_text(text: str):
if isinstance(value, list): if isinstance(value, list):
return value return value
return parsed_json return parsed_json
# 其他情况,尝试转换为列表 # 其他情况,尝试转换为列表
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}") logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
return [] return []
@ -42,44 +45,40 @@ def _extract_json_from_text(text: str):
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...") logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
return [] return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""对段落进行实体提取返回提取出的实体列表JSON格式""" """对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph) entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
# 使用 asyncio.run 来运行异步方法 # 使用 asyncio.run 来运行异步方法
try: try:
# 如果当前已有事件循环在运行,使用它 # 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
llm_req.generate_response_async(entity_extract_context), loop response, _ = future.result()
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError: except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run # 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run( response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
llm_req.generate_response_async(entity_extract_context)
)
# 添加调试日志 # 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}") logger.debug(f"LLM返回的原始响应: {response}")
entity_extract_result = _extract_json_from_text(response) entity_extract_result = _extract_json_from_text(response)
# 检查返回的是否为有效的实体列表 # 检查返回的是否为有效的实体列表
if not isinstance(entity_extract_result, list): if not isinstance(entity_extract_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表 if not isinstance(entity_extract_result, dict):
if isinstance(entity_extract_result, dict): raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 尝试常见的键名
for key in ['entities', 'result', 'data', 'items']: # 尝试常见的键名
if key in entity_extract_result and isinstance(entity_extract_result[key], list): for key in ["entities", "result", "data", "items"]:
entity_extract_result = entity_extract_result[key] if key in entity_extract_result and isinstance(entity_extract_result[key], list):
break entity_extract_result = entity_extract_result[key]
else: break
# 如果找不到合适的列表,抛出异常
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
else: else:
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") # 如果找不到合适的列表,抛出异常
raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 过滤无效实体 # 过滤无效实体
entity_extract_result = [ entity_extract_result = [
entity entity
@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY) if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
] ]
if len(entity_extract_result) == 0: if not entity_extract_result:
raise Exception("实体提取结果为空") raise ValueError("实体提取结果为空")
return entity_extract_result return entity_extract_result
@ -98,45 +97,44 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
rdf_extract_context = prompt_template.build_rdf_triple_extract_context( rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False) paragraph, entities=json.dumps(entities, ensure_ascii=False)
) )
# 使用 asyncio.run 来运行异步方法 # 使用 asyncio.run 来运行异步方法
try: try:
# 如果当前已有事件循环在运行,使用它 # 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
llm_req.generate_response_async(rdf_extract_context), loop response, _ = future.result()
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError: except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run # 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run( response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
llm_req.generate_response_async(rdf_extract_context)
)
# 添加调试日志 # 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}") logger.debug(f"RDF LLM返回的原始响应: {response}")
rdf_triple_result = _extract_json_from_text(response) rdf_triple_result = _extract_json_from_text(response)
# 检查返回的是否为有效的三元组列表 # 检查返回的是否为有效的三元组列表
if not isinstance(rdf_triple_result, list): if not isinstance(rdf_triple_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表 if not isinstance(rdf_triple_result, dict):
if isinstance(rdf_triple_result, dict): raise ValueError(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 尝试常见的键名
for key in ['triples', 'result', 'data', 'items']: # 尝试常见的键名
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): for key in ["triples", "result", "data", "items"]:
rdf_triple_result = rdf_triple_result[key] if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
break rdf_triple_result = rdf_triple_result[key]
else: break
# 如果找不到合适的列表,抛出异常
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
else: else:
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}") # 如果找不到合适的列表,抛出异常
raise ValueError(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 验证三元组格式 # 验证三元组格式
for triple in rdf_triple_result: for triple in rdf_triple_result:
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: if (
raise Exception("RDF提取结果格式错误") not isinstance(triple, list)
or len(triple) != 3
or (triple[0] is None or triple[1] is None or triple[2] is None)
or "" in triple
):
raise ValueError("RDF提取结果格式错误")
return rdf_triple_result return rdf_triple_result

View File

@ -20,8 +20,7 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .lpmmconfig import global_config from src.config.config import global_config
from src.manager.local_store_manager import local_storage
from .global_logger import logger from .global_logger import logger
@ -30,19 +29,9 @@ def _get_kg_dir():
""" """
安全地获取KG数据目录路径 安全地获取KG数据目录路径
""" """
root_path: str = local_storage["root_path"] current_dir = os.path.dirname(os.path.abspath(__file__))
if root_path is None: root_path: str = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用 kg_dir = os.path.join(root_path, "data/rag")
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag")
else:
kg_dir = os.path.join(root_path, rag_data_dir)
return str(kg_dir).replace("\\", "/") return str(kg_dir).replace("\\", "/")
@ -65,9 +54,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径 # 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str() self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" self.graph_data_path = self.dir_path + "/" + "rag-graph" + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" self.ent_cnt_data_path = self.dir_path + "/" + "rag-ent-cnt" + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" self.pg_hash_file_path = self.dir_path + "/" + "rag-pg-hash" + ".json"
def save_to_file(self): def save_to_file(self):
"""将KG数据保存到文件""" """将KG数据保存到文件"""
@ -122,8 +111,8 @@ class KGManager:
# 避免自连接 # 避免自连接
continue continue
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) hash_key1 = "entity" + "-" + get_sha256(triple[0])
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) hash_key2 = "entity" + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1) entity_set.add(hash_key1)
@ -141,8 +130,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系""" """构建实体节点与文段节点之间的关系"""
for idx in triple_list_data: for idx in triple_list_data:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) ent_hash_key = "entity" + "-" + get_sha256(triple[0])
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) pg_hash_key = "paragraph" + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod @staticmethod
@ -157,12 +146,12 @@ class KGManager:
ent_hash_list = set() ent_hash_list = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) ent_hash_list.add("entity" + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) ent_hash_list.add("entity" + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
synonym_result = dict() synonym_result = {}
# rich 进度条 # rich 进度条
total = len(ent_hash_list) total = len(ent_hash_list)
@ -190,14 +179,14 @@ class KGManager:
assert isinstance(ent, EmbeddingStoreItem) assert isinstance(ent, EmbeddingStoreItem)
# 查询相似实体 # 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k( similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] ent.embedding, global_config.lpmm_knowledge.rag_synonym_search_top_k
) )
res_ent = [] # Debug res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents: for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash: if res_ent_hash == ent_hash:
# 避免自连接 # 避免自连接
continue continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]: if similarity < global_config.lpmm_knowledge.rag_synonym_threshold:
# 相似度阈值 # 相似度阈值
continue continue
node_to_node[(res_ent_hash, ent_hash)] = similarity node_to_node[(res_ent_hash, ent_hash)] = similarity
@ -263,7 +252,7 @@ class KGManager:
for src_tgt in node_to_node.keys(): for src_tgt in node_to_node.keys():
for node_hash in src_tgt: for node_hash in src_tgt:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(local_storage["ent_namespace"]): if node_hash.startswith("entity"):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash) node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None: if node is None:
@ -275,7 +264,7 @@ class KGManager:
node_item["type"] = "ent" node_item["type"] = "ent"
node_item["create_time"] = now_time node_item["create_time"] = now_time
self.graph.update_node(node_item) self.graph.update_node(node_item)
elif node_hash.startswith(local_storage["pg_namespace"]): elif node_hash.startswith("paragraph"):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None: if node is None:
@ -359,7 +348,7 @@ class KGManager:
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]: for ent in [(triple[0]), (triple[2])]:
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) ent_hash = "entity" + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体 if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash] = []
@ -380,7 +369,7 @@ class KGManager:
for ent_hash in ent_weights.keys(): for ent_hash in ent_weights.keys():
ent_weights[ent_hash] = 1.0 ent_weights[ent_hash] = 1.0
else: else:
down_edge = global_config["qa"]["params"]["paragraph_node_weight"] down_edge = global_config.lpmm_knowledge.qa_paragraph_node_weight
# 缩放取值区间至[down_edge, 1] # 缩放取值区间至[down_edge, 1]
for ent_hash, score in ent_weights.items(): for ent_hash, score in ent_weights.items():
# 缩放相似度 # 缩放相似度
@ -389,7 +378,7 @@ class KGManager:
) + down_edge ) + down_edge
# 取平均相似度的top_k实体 # 取平均相似度的top_k实体
top_k = global_config["qa"]["params"]["ent_filter_top_k"] top_k = global_config.lpmm_knowledge.qa_ent_filter_top_k
if len(ent_mean_scores) > top_k: if len(ent_mean_scores) > top_k:
# 从大到小排序取后len - k个 # 从大到小排序取后len - k个
ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)} ent_mean_scores = {k: v for k, v in sorted(ent_mean_scores.items(), key=lambda item: item[1], reverse=True)}
@ -418,7 +407,7 @@ class KGManager:
for pg_hash, score in pg_sim_scores.items(): for pg_hash, score in pg_sim_scores.items():
pg_weights[pg_hash] = ( pg_weights[pg_hash] = (
score * global_config["qa"]["params"]["paragraph_node_weight"] score * global_config.lpmm_knowledge.qa_paragraph_node_weight
) # 文段权重 = 归一化相似度 * 文段节点权重参数 ) # 文段权重 = 归一化相似度 * 文段节点权重参数
del pg_sim_scores del pg_sim_scores
@ -431,7 +420,7 @@ class KGManager:
self.graph, self.graph,
personalization=ppr_node_weights, personalization=ppr_node_weights,
max_iter=100, max_iter=100,
alpha=global_config["qa"]["params"]["ppr_damping"], alpha=global_config.lpmm_knowledge.qa_ppr_damping,
) )
# 获取最终结果 # 获取最终结果
@ -439,7 +428,7 @@ class KGManager:
passage_node_res = [ passage_node_res = [
(node_key, score) (node_key, score)
for node_key, score in ppr_res.items() for node_key, score in ppr_res.items()
if node_key.startswith(local_storage["pg_namespace"]) if node_key.startswith("paragraph")
] ]
del ppr_res del ppr_res

View File

@ -1,12 +1,8 @@
from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
from src.chat.knowledge.qa_manager import QAManager from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger from src.chat.knowledge.global_logger import logger
from src.config.config import global_config as bot_global_config from src.config.config import global_config
from src.manager.local_store_manager import local_storage
import os import os
INVALID_ENTITY = [ INVALID_ENTITY = [
@ -21,9 +17,6 @@ INVALID_ENTITY = [
"她们", "她们",
"它们", "它们",
] ]
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph" RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt" RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
@ -34,67 +27,13 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..",
DATA_PATH = os.path.join(ROOT_PATH, "data") DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage():
"""
初始化知识库相关的本地存储配置
使用字典批量设置避免重复的if判断
"""
# 定义所有需要初始化的配置项
default_configs = {
# 路径配置
"root_path": ROOT_PATH,
"data_path": f"{ROOT_PATH}/data",
# 实体和命名空间配置
"lpmm_invalid_entity": INVALID_ENTITY,
"pg_namespace": PG_NAMESPACE,
"ent_namespace": ENT_NAMESPACE,
"rel_namespace": REL_NAMESPACE,
# RAG相关命名空间配置
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
}
# 日志级别映射重要配置用info其他用debug
important_configs = {"root_path", "data_path"}
# 批量设置配置项
initialized_count = 0
for key, default_value in default_configs.items():
if local_storage[key] is None:
local_storage[key] = default_value
# 根据重要性选择日志级别
if key in important_configs:
logger.info(f"设置{key}: {default_value}")
else:
logger.debug(f"设置{key}: {default_value}")
initialized_count += 1
if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else:
logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径
# sourcery skip: dict-comprehension
_initialize_knowledge_local_storage()
qa_manager = None qa_manager = None
inspire_manager = None inspire_manager = None
# 检查LPMM知识库是否启用 # 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable: if global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM") logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端") logger.info("创建LLM客户端")
llm_client_list = {}
for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"], # type: ignore
global_config["llm_providers"][key]["api_key"], # type: ignore
)
# 初始化Embedding库 # 初始化Embedding库
embed_manager = EmbeddingManager() embed_manager = EmbeddingManager()
@ -120,7 +59,7 @@ if bot_global_config.lpmm_knowledge.enable:
# 数据比对Embedding库与KG的段落hash集合 # 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes: for pg_hash in kg_manager.stored_paragraph_hashes:
key = f"{PG_NAMESPACE}-{pg_hash}" key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes: if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}") logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
@ -130,11 +69,11 @@ if bot_global_config.lpmm_knowledge.enable:
kg_manager, kg_manager,
) )
# 记忆激活(用于记忆库) # # 记忆激活(用于记忆库)
inspire_manager = MemoryActiveManager( # inspire_manager = MemoryActiveManager(
embed_manager, # embed_manager,
llm_client_list[global_config["embedding"]["provider"]], # llm_client_list[global_config["embedding"]["provider"]],
) # )
else: else:
logger.info("LPMM知识库已禁用跳过初始化") logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误 # 创建空的占位符对象,避免导入错误

View File

@ -1,45 +0,0 @@
from openai import OpenAI
class LLMMessage:
def __init__(self, role, content):
self.role = role
self.content = content
def to_dict(self):
return {"role": self.role, "content": self.content}
class LLMClient:
"""LLM客户端对应一个API服务商"""
def __init__(self, url, api_key):
self.client = OpenAI(
base_url=url,
api_key=api_key,
)
def send_chat_request(self, model, messages):
"""发送对话请求,等待返回结果"""
response = self.client.chat.completions.create(model=model, messages=messages, stream=False)
if hasattr(response.choices[0].message, "reasoning_content"):
# 有单独的推理内容块
reasoning_content = response.choices[0].message.reasoning_content
content = response.choices[0].message.content
else:
# 无单独的推理内容块
response = response.choices[0].message.content.split("<think>")[-1].split("</think>")
# 如果有推理内容,则分割推理内容和内容
if len(response) == 2:
reasoning_content = response[0]
content = response[1]
else:
reasoning_content = None
content = response[0]
return reasoning_content, content
def send_embedding_request(self, model, text):
"""发送嵌入请求,等待返回结果"""
text = text.replace("\n", " ")
return self.client.embeddings.create(input=[text], model=model).data[0].embedding

View File

@ -1,137 +0,0 @@
import os
import toml
import sys
# import argparse
from .global_logger import logger
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
# 无效实体
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
def _load_config(config, config_file_path):
"""读取TOML格式的配置文件"""
if not os.path.exists(config_file_path):
return
with open(config_file_path, "r", encoding="utf-8") as f:
file_config = toml.load(f)
# Check if all top-level keys from default config exist in the file config
for key in config.keys():
if key not in file_config:
logger.critical(f"警告: 配置文件 '{config_file_path}' 缺少必需的顶级键: '{key}'。请检查配置文件。")
logger.critical("请通过template/lpmm_config_template.toml文件进行更新")
sys.exit(1)
if "llm_providers" in file_config:
for provider in file_config["llm_providers"]:
if provider["name"] not in config["llm_providers"]:
config["llm_providers"][provider["name"]] = {}
config["llm_providers"][provider["name"]]["base_url"] = provider["base_url"]
config["llm_providers"][provider["name"]]["api_key"] = provider["api_key"]
if "entity_extract" in file_config:
config["entity_extract"] = file_config["entity_extract"]
if "rdf_build" in file_config:
config["rdf_build"] = file_config["rdf_build"]
if "embedding" in file_config:
config["embedding"] = file_config["embedding"]
if "rag" in file_config:
config["rag"] = file_config["rag"]
if "qa" in file_config:
config["qa"] = file_config["qa"]
if "persistence" in file_config:
config["persistence"] = file_config["persistence"]
# print(config)
logger.info(f"从文件中读取配置: {config_file_path}")
global_config = dict(
{
"lpmm": {
"version": "0.1.0",
},
"llm_providers": {
"localhost": {
"base_url": "https://api.siliconflow.cn/v1",
"api_key": "sk-ospynxadyorf",
}
},
"entity_extract": {
"llm": {
"provider": "localhost",
"model": "Pro/deepseek-ai/DeepSeek-V3",
}
},
"rdf_build": {
"llm": {
"provider": "localhost",
"model": "Pro/deepseek-ai/DeepSeek-V3",
}
},
"embedding": {
"provider": "localhost",
"model": "Pro/BAAI/bge-m3",
"dimension": 1024,
},
"rag": {
"params": {
"synonym_search_top_k": 10,
"synonym_threshold": 0.75,
}
},
"qa": {
"params": {
"relation_search_top_k": 10,
"relation_threshold": 0.75,
"paragraph_search_top_k": 10,
"paragraph_node_weight": 0.05,
"ent_filter_top_k": 10,
"ppr_damping": 0.8,
"res_top_k": 10,
},
"llm": {
"provider": "localhost",
"model": "qa",
},
},
"persistence": {
"data_root_path": "data",
"raw_data_path": "data/raw.json",
"openie_data_path": "data/openie.json",
"embedding_data_dir": "data/embedding",
"rag_data_dir": "data/rag",
},
"info_extraction": {
"workers": 10,
},
}
)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
config_path = os.path.join(ROOT_PATH, "config", "lpmm_config.toml")
_load_config(global_config, config_path)

View File

@ -1,3 +1,4 @@
raise DeprecationWarning("MemoryActiveManager is not used yet, please do not import it")
from .lpmmconfig import global_config from .lpmmconfig import global_config
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
from .llm_client import LLMClient from .llm_client import LLMClient

View File

@ -2,16 +2,14 @@ import time
from typing import Tuple, List, Dict, Optional from typing import Tuple, List, Dict, Optional
from .global_logger import logger from .global_logger import logger
# from . import prompt_template
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
# from .llm_client import LLMClient
from .kg_manager import KGManager from .kg_manager import KGManager
# from .lpmmconfig import global_config # from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config, model_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@ -21,17 +19,12 @@ class QAManager:
self, self,
embed_manager: EmbeddingManager, embed_manager: EmbeddingManager,
kg_manager: KGManager, kg_manager: KGManager,
): ):
self.embed_manager = embed_manager self.embed_manager = embed_manager
self.kg_manager = kg_manager self.kg_manager = kg_manager
# TODO: API-Adapter修改标记 self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
self.qa_model = LLMRequest(
model=global_config.model.lpmm_qa,
request_type="lpmm.qa"
)
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: async def process_query(self, question: str) -> Optional[Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]]:
"""处理查询""" """处理查询"""
# 生成问题的Embedding # 生成问题的Embedding
@ -49,66 +42,70 @@ class QAManager:
question_embedding, question_embedding,
global_config.lpmm_knowledge.qa_relation_search_top_k, global_config.lpmm_knowledge.qa_relation_search_top_k,
) )
if relation_search_res is not None: if relation_search_res is None:
# 过滤阈值 return None
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果 # 过滤阈值
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0) # 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold: relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
# 未找到相关关系 if not relation_search_res or relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
logger.debug("未找到相关关系,跳过关系检索") # 未找到相关关系
relation_search_res = [] logger.debug("未找到相关关系,跳过关系检索")
relation_search_res = []
part_end_time = time.perf_counter() part_end_time = time.perf_counter()
logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s") logger.debug(f"关系检索用时:{part_end_time - part_start_time:.5f}s")
for res in relation_search_res: for res in relation_search_res:
rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str rel_str = self.embed_manager.relation_embedding_store.store.get(res[0]).str
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}") print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
# TODO: 使用LLM过滤三元组结果 # TODO: 使用LLM过滤三元组结果
# logger.info(f"LLM过滤三元组用时{time.time() - part_start_time:.2f}s") # logger.info(f"LLM过滤三元组用时{time.time() - part_start_time:.2f}s")
# part_start_time = time.time() # part_start_time = time.time()
# 根据问题Embedding查询Paragraph Embedding库 # 根据问题Embedding查询Paragraph Embedding库
part_start_time = time.perf_counter()
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
question_embedding,
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
)
part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
if len(relation_search_res) != 0:
logger.info("找到相关关系将使用RAG进行检索")
# 使用KG检索
part_start_time = time.perf_counter() part_start_time = time.perf_counter()
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k( result, ppr_node_weights = self.kg_manager.kg_search(
question_embedding, relation_search_res, paragraph_search_res, self.embed_manager
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
) )
part_end_time = time.perf_counter() part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s") logger.info(f"RAG检索用时{part_end_time - part_start_time:.5f}s")
if len(relation_search_res) != 0:
logger.info("找到相关关系将使用RAG进行检索")
# 使用KG检索
part_start_time = time.perf_counter()
result, ppr_node_weights = self.kg_manager.kg_search(
relation_search_res, paragraph_search_res, self.embed_manager
)
part_end_time = time.perf_counter()
logger.info(f"RAG检索用时{part_end_time - part_start_time:.5f}s")
else:
logger.info("未找到相关关系,将使用文段检索结果")
result = paragraph_search_res
ppr_node_weights = None
# 过滤阈值
result = dyn_select_top_k(result, 0.5, 1.0)
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
else: else:
return None logger.info("未找到相关关系,将使用文段检索结果")
result = paragraph_search_res
ppr_node_weights = None
async def get_knowledge(self, question: str) -> str: # 过滤阈值
result = dyn_select_top_k(result, 0.5, 1.0)
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights
async def get_knowledge(self, question: str) -> Optional[str]:
"""获取知识""" """获取知识"""
# 处理查询 # 处理查询
processed_result = await self.process_query(question) processed_result = await self.process_query(question)
if processed_result is not None: if processed_result is not None:
query_res = processed_result[0] query_res = processed_result[0]
# 检查查询结果是否为空
if not query_res:
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
return None
knowledge = [ knowledge = [
( (
self.embed_manager.paragraphs_embedding_store.store[res[0]].str, self.embed_manager.paragraphs_embedding_store.store[res[0]].str,

View File

@ -1,48 +0,0 @@
import json
import os
from .global_logger import logger
from .lpmmconfig import global_config
from src.chat.knowledge.utils.hash import get_sha256
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
"""加载原始数据文件
读取原始数据文件将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns:
- raw_data: 原始数据列表
- sha256_list: 原始数据的SHA256集合
"""
# 读取指定路径或默认路径的json文件
json_path = path if path else global_config["persistence"]["raw_data_path"]
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
import_json = json.loads(f.read())
else:
raise Exception(f"原始数据文件读取失败: {json_path}")
"""
import_json 内容示例
import_json = ["The capital of China is Beijing. The capital of France is Paris.",]
"""
raw_data = []
sha256_list = []
sha256_set = set()
for item in import_json:
if not isinstance(item, str):
logger.warning("数据类型错误:{}".format(item))
continue
pg_hash = get_sha256(item)
if pg_hash in sha256_set:
logger.warning("重复数据:{}".format(item))
continue
sha256_set.add(pg_hash)
sha256_list.append(pg_hash)
raw_data.append(item)
logger.info("共读取到{}条数据".format(len(raw_data)))
return sha256_list, raw_data

View File

@ -5,6 +5,10 @@ def dyn_select_top_k(
score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float score: List[Tuple[Any, float]], jmp_factor: float, var_factor: float
) -> List[Tuple[Any, float, float]]: ) -> List[Tuple[Any, float, float]]:
"""动态TopK选择""" """动态TopK选择"""
# 检查输入列表是否为空
if not score:
return []
# 按照分数排序(降序) # 按照分数排序(降序)
sorted_score = sorted(score, key=lambda x: x[1], reverse=True) sorted_score = sorted(score, key=lambda x: x[1], reverse=True)

View File

@ -1,17 +0,0 @@
import networkx as nx
from matplotlib import pyplot as plt
def draw_graph_and_show(graph):
"""绘制图并显示画布大小1280*1280"""
fig = plt.figure(1, figsize=(12.8, 12.8), dpi=100)
nx.draw_networkx(
graph,
node_size=100,
width=0.5,
with_labels=True,
labels=nx.get_node_attributes(graph, "content"),
font_family="Sarasa Mono SC",
font_size=8,
)
fig.show()

View File

@ -5,25 +5,27 @@ import random
import time import time
import re import re
import json import json
from itertools import combinations
import jieba import jieba
import networkx as nx import networkx as nx
import numpy as np import numpy as np
from itertools import combinations
from typing import List, Tuple, Coroutine, Any, Set
from collections import Counter from collections import Counter
from ...llm_models.utils_model import LLMRequest from rich.traceback import install
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器 from src.chat.memory_system.sample_distribution import MemoryBuildScheduler # 分布生成器
from ..utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp, get_raw_msg_by_timestamp,
build_readable_messages, build_readable_messages,
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,
) # 导入 build_readable_messages ) # 导入 build_readable_messages
from ..utils.utils import translate_timestamp_to_human_readable from src.chat.utils.utils import translate_timestamp_to_human_readable
from rich.traceback import install
from ...config.config import global_config
from src.common.database.database_model import Messages, GraphNodes, GraphEdges # Peewee Models导入
install(extra_lines=3) install(extra_lines=3)
@ -198,8 +200,7 @@ class Hippocampus:
self.parahippocampal_gyrus = ParahippocampalGyrus(self) self.parahippocampal_gyrus = ParahippocampalGyrus(self)
# 从数据库加载记忆图 # 从数据库加载记忆图
self.entorhinal_cortex.sync_memory_from_db() self.entorhinal_cortex.sync_memory_from_db()
# TODO: API-Adapter修改标记 self.model_summary = LLMRequest(model_set=model_config.model_task_config.memory, request_type="memory.builder")
self.model_summary = LLMRequest(global_config.model.memory, request_type="memory.builder")
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表""" """获取记忆图中所有节点的名字列表"""
@ -339,9 +340,7 @@ class Hippocampus:
else: else:
topic_num = 5 # 51+字符: 5个关键词 (其余长文本) topic_num = 5 # 51+字符: 5个关键词 (其余长文本)
topics_response, (reasoning_content, model_name) = await self.model_summary.generate_response_async( topics_response, _ = await self.model_summary.generate_response_async(self.find_topic_llm(text, topic_num))
self.find_topic_llm(text, topic_num)
)
# 提取关键词 # 提取关键词
keywords = re.findall(r"<([^>]+)>", topics_response) keywords = re.findall(r"<([^>]+)>", topics_response)
@ -353,12 +352,11 @@ class Hippocampus:
for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",") for keyword in ",".join(keywords).replace("", ",").replace("", ",").replace(" ", ",").split(",")
if keyword.strip() if keyword.strip()
] ]
if keywords: if keywords:
logger.info(f"提取关键词: {keywords}") logger.info(f"提取关键词: {keywords}")
return keywords return keywords
async def get_memory_from_text( async def get_memory_from_text(
self, self,
@ -1245,7 +1243,7 @@ class ParahippocampalGyrus:
# 2. 使用LLM提取关键主题 # 2. 使用LLM提取关键主题
topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate) topic_num = self.hippocampus.calculate_topic_num(input_text, compress_rate)
topics_response, (reasoning_content, model_name) = await self.hippocampus.model_summary.generate_response_async( topics_response, _ = await self.hippocampus.model_summary.generate_response_async(
self.hippocampus.find_topic_llm(input_text, topic_num) self.hippocampus.find_topic_llm(input_text, topic_num)
) )
@ -1269,7 +1267,7 @@ class ParahippocampalGyrus:
logger.debug(f"过滤后话题: {filtered_topics}") logger.debug(f"过滤后话题: {filtered_topics}")
# 4. 创建所有话题的摘要生成任务 # 4. 创建所有话题的摘要生成任务
tasks = [] tasks: List[Tuple[str, Coroutine[Any, Any, Tuple[str, Tuple[str, str, List | None]]]]] = []
for topic in filtered_topics: for topic in filtered_topics:
# 调用修改后的 topic_what不再需要 time_info # 调用修改后的 topic_what不再需要 time_info
topic_what_prompt = self.hippocampus.topic_what(input_text, topic) topic_what_prompt = self.hippocampus.topic_what(input_text, topic)
@ -1281,7 +1279,7 @@ class ParahippocampalGyrus:
continue continue
# 等待所有任务完成 # 等待所有任务完成
compressed_memory = set() compressed_memory: Set[Tuple[str, str]] = set()
similar_topics_dict = {} similar_topics_dict = {}
for topic, task in tasks: for topic, task in tasks:

View File

@ -3,13 +3,16 @@ import time
import re import re
import json import json
import ast import ast
from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
import traceback import traceback
from src.config.config import global_config from json_repair import repair_json
from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入 from src.common.database.database_model import Memory # Peewee Models导入
from src.config.config import model_config
logger = get_logger(__name__) logger = get_logger(__name__)
@ -35,8 +38,7 @@ class InstantMemory:
self.chat_id = chat_id self.chat_id = chat_id
self.last_view_time = time.time() self.last_view_time = time.time()
self.summary_model = LLMRequest( self.summary_model = LLMRequest(
model=global_config.model.memory, model_set=model_config.model_task_config.memory,
temperature=0.5,
request_type="memory.summary", request_type="memory.summary",
) )
@ -48,14 +50,11 @@ class InstantMemory:
""" """
try: try:
response, _ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt) print(prompt)
print(response) print(response)
if "1" in response: return "1" in response
return True
else:
return False
except Exception as e: except Exception as e:
logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}") logger.error(f"判断是否需要记忆出现错误:{str(e)} {traceback.format_exc()}")
return False return False
@ -71,9 +70,9 @@ class InstantMemory:
}} }}
""" """
try: try:
response, _ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt) # print(prompt)
print(response) # print(response)
if not response: if not response:
return None return None
try: try:
@ -142,7 +141,7 @@ class InstantMemory:
请只输出json格式不要输出其他多余内容 请只输出json格式不要输出其他多余内容
""" """
try: try:
response, _ = await self.summary_model.generate_response_async(prompt) response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt) print(prompt)
print(response) print(response)
if not response: if not response:
@ -177,7 +176,7 @@ class InstantMemory:
for mem in query: for mem in query:
# 对每条记忆 # 对每条记忆
mem_keywords = mem.keywords or [] mem_keywords = mem.keywords or ""
parsed = ast.literal_eval(mem_keywords) parsed = ast.literal_eval(mem_keywords)
if isinstance(parsed, list): if isinstance(parsed, list):
mem_keywords = [str(k).strip() for k in parsed if str(k).strip()] mem_keywords = [str(k).strip() for k in parsed if str(k).strip()]
@ -201,6 +200,7 @@ class InstantMemory:
return None return None
def _parse_time_range(self, time_str): def _parse_time_range(self, time_str):
# sourcery skip: extract-duplicate-method, use-contextlib-suppress
""" """
支持解析如下格式 支持解析如下格式
- 具体日期时间YYYY-MM-DD HH:MM:SS - 具体日期时间YYYY-MM-DD HH:MM:SS
@ -208,8 +208,6 @@ class InstantMemory:
- 相对时间今天昨天前天N天前N个月前 - 相对时间今天昨天前天N天前N个月前
- 空字符串返回(None, None) - 空字符串返回(None, None)
""" """
from datetime import datetime, timedelta
now = datetime.now() now = datetime.now()
if not time_str: if not time_str:
return 0, now return 0, now
@ -239,14 +237,12 @@ class InstantMemory:
start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) end = start + timedelta(days=1)
return start, end return start, end
m = re.match(r"(\d+)天前", time_str) if m := re.match(r"(\d+)天前", time_str):
if m:
days = int(m.group(1)) days = int(m.group(1))
start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=days)).replace(hour=0, minute=0, second=0, microsecond=0)
end = start + timedelta(days=1) end = start + timedelta(days=1)
return start, end return start, end
m = re.match(r"(\d+)个月前", time_str) if m := re.match(r"(\d+)个月前", time_str):
if m:
months = int(m.group(1)) months = int(m.group(1))
# 近似每月30天 # 近似每月30天
start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0) start = (now - timedelta(days=months * 30)).replace(hour=0, minute=0, second=0, microsecond=0)

View File

@ -1,13 +1,15 @@
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from datetime import datetime
from src.chat.memory_system.Hippocampus import hippocampus_manager
from typing import List, Dict
import difflib import difflib
import json import json
from json_repair import repair_json from json_repair import repair_json
from typing import List, Dict
from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.memory_system.Hippocampus import hippocampus_manager
logger = get_logger("memory_activator") logger = get_logger("memory_activator")
@ -61,11 +63,8 @@ def init_prompt():
class MemoryActivator: class MemoryActivator:
def __init__(self): def __init__(self):
# TODO: API-Adapter修改标记
self.key_words_model = LLMRequest( self.key_words_model = LLMRequest(
model=global_config.model.utils_small, model_set=model_config.model_task_config.utils_small,
temperature=0.5,
request_type="memory.activator", request_type="memory.activator",
) )
@ -92,7 +91,9 @@ class MemoryActivator:
# logger.debug(f"prompt: {prompt}") # logger.debug(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.key_words_model.generate_response_async(prompt) response, (reasoning_content, model_name, _) = await self.key_words_model.generate_response_async(
prompt, temperature=0.5
)
keywords = list(get_keywords_from_json(response)) keywords = list(get_keywords_from_json(response))

View File

@ -203,7 +203,7 @@ class MessageRecvS4U(MessageRecv):
self.is_superchat = False self.is_superchat = False
self.gift_info = None self.gift_info = None
self.gift_name = None self.gift_name = None
self.gift_count = None self.gift_count: Optional[str] = None
self.superchat_info = None self.superchat_info = None
self.superchat_price = None self.superchat_price = None
self.superchat_message_text = None self.superchat_message_text = None
@ -444,7 +444,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False, is_emoji: bool = False,
thinking_start_time: float = 0, thinking_start_time: float = 0,
apply_set_reply_logic: bool = False, apply_set_reply_logic: bool = False,
reply_to: str = None, # type: ignore reply_to: Optional[str] = None,
): ):
# 调用父类初始化 # 调用父类初始化
super().__init__( super().__init__(

View File

@ -1,9 +1,10 @@
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from src.plugin_system.base.base_action import BaseAction
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.plugin_system.base.base_action import BaseAction
logger = get_logger("action_manager") logger = get_logger("action_manager")

View File

@ -5,7 +5,7 @@ import time
from typing import List, Any, Dict, TYPE_CHECKING, Tuple from typing import List, Any, Dict, TYPE_CHECKING, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext from src.chat.message_receive.chat_stream import get_chat_manager, ChatMessageContext
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
@ -36,10 +36,7 @@ class ActionModifier:
self.action_manager = action_manager self.action_manager = action_manager
# 用于LLM判定的小模型 # 用于LLM判定的小模型
self.llm_judge = LLMRequest( self.llm_judge = LLMRequest(model_set=model_config.model_task_config.utils_small, request_type="action.judge")
model=global_config.model.utils_small,
request_type="action.judge",
)
# 缓存相关属性 # 缓存相关属性
self._llm_judge_cache = {} # 缓存LLM判定结果 self._llm_judge_cache = {} # 缓存LLM判定结果
@ -438,4 +435,4 @@ class ActionModifier:
return True return True
else: else:
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}") logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
return False return False

View File

@ -7,7 +7,7 @@ from datetime import datetime
from json_repair import repair_json from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
@ -36,8 +36,6 @@ def init_prompt():
{chat_context_description}以下是具体的聊天内容 {chat_context_description}以下是具体的聊天内容
{chat_content_block} {chat_content_block}
{moderation_prompt} {moderation_prompt}
现在请你根据{by_what}选择合适的action和触发action的消息: 现在请你根据{by_what}选择合适的action和触发action的消息:
@ -73,10 +71,7 @@ class ActionPlanner:
self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]" self.log_prefix = f"[{get_chat_manager().get_stream_name(chat_id) or chat_id}]"
self.action_manager = action_manager self.action_manager = action_manager
# LLM规划器配置 # LLM规划器配置
self.planner_llm = LLMRequest( self.planner_llm = LLMRequest(model_set=model_config.model_task_config.planner, request_type="planner") # 用于动作规划
model=global_config.model.planner,
request_type="planner", # 用于动作规划
)
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
@ -140,7 +135,7 @@ class ActionPlanner:
# --- 调用 LLM (普通文本生成) --- # --- 调用 LLM (普通文本生成) ---
llm_content = None llm_content = None
try: try:
llm_content, (reasoning_content, _) = await self.planner_llm.generate_response_async(prompt=prompt) llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}") logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")

View File

@ -8,7 +8,8 @@ from typing import List, Optional, Dict, Any, Tuple
from datetime import datetime from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig
from src.individuality.individuality import get_individuality from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
@ -23,14 +24,13 @@ from src.chat.utils.chat_message_builder import (
replace_user_references_sync, replace_user_references_sync,
) )
from src.chat.express.expression_selector import expression_selector from src.chat.express.expression_selector import expression_selector
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.memory_system.memory_activator import MemoryActivator from src.chat.memory_system.memory_activator import MemoryActivator
from src.chat.memory_system.instant_memory import InstantMemory from src.chat.memory_system.instant_memory import InstantMemory
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
from src.person_info.relationship_fetcher import relationship_fetcher_manager from src.person_info.relationship_fetcher import relationship_fetcher_manager
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from src.tools.tool_executor import ToolExecutor
from src.plugin_system.base.component_types import ActionInfo from src.plugin_system.base.component_types import ActionInfo
from src.plugin_system.apis import llm_api
logger = get_logger("replyer") logger = get_logger("replyer")
@ -40,7 +40,7 @@ def init_prompt():
Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1") Prompt("你正在和{sender_name}聊天,这是你们之前聊的内容:", "chat_target_private1")
Prompt("在群里聊天", "chat_target_group2") Prompt("在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2") Prompt("{sender_name}聊天", "chat_target_private2")
Prompt( Prompt(
""" """
{expression_habits_block} {expression_habits_block}
@ -102,36 +102,57 @@ def init_prompt():
"s4u_style_prompt", "s4u_style_prompt",
) )
Prompt(
"""
你是一个专门获取知识的助手你的名字是{bot_name}现在是{time_now}
群里正在进行的聊天内容
{chat_history}
现在{sender}发送了内容:{target_message},你想要回复ta
请仔细分析聊天内容考虑以下几点
1. 内容中是否包含需要查询信息的问题
2. 是否有明确的知识获取指令
If you need to use the search tool, please directly call the function "lpmm_search_knowledge". If you do not need to use any tool, simply output "No tool needed".
""",
name="lpmm_get_knowledge_prompt",
)
class DefaultReplyer: class DefaultReplyer:
def __init__( def __init__(
self, self,
chat_stream: ChatStream, chat_stream: ChatStream,
model_configs: Optional[List[Dict[str, Any]]] = None, model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "focus.replyer", request_type: str = "focus.replyer",
): ):
self.request_type = request_type self.request_type = request_type
if model_configs: if model_set_with_weight:
self.express_model_configs = model_configs # self.express_model_configs = model_configs
self.model_set: List[Tuple[TaskConfig, float]] = model_set_with_weight
else: else:
# 当未提供配置时,使用默认配置并赋予默认权重 # 当未提供配置时,使用默认配置并赋予默认权重
model_config_1 = global_config.model.replyer_1.copy() # model_config_1 = global_config.model.replyer_1.copy()
model_config_2 = global_config.model.replyer_2.copy() # model_config_2 = global_config.model.replyer_2.copy()
prob_first = global_config.chat.replyer_random_probability prob_first = global_config.chat.replyer_random_probability
model_config_1["weight"] = prob_first # model_config_1["weight"] = prob_first
model_config_2["weight"] = 1.0 - prob_first # model_config_2["weight"] = 1.0 - prob_first
self.express_model_configs = [model_config_1, model_config_2] # self.express_model_configs = [model_config_1, model_config_2]
self.model_set = [
(model_config.model_task_config.replyer_1, prob_first),
(model_config.model_task_config.replyer_2, 1.0 - prob_first),
]
if not self.express_model_configs: # if not self.express_model_configs:
logger.warning("未找到有效的模型配置,回复生成可能会失败。") # logger.warning("未找到有效的模型配置,回复生成可能会失败。")
# 提供一个最终的回退,以防止在空列表上调用 random.choice # # 提供一个最终的回退,以防止在空列表上调用 random.choice
fallback_config = global_config.model.replyer_1.copy() # fallback_config = global_config.model.replyer_1.copy()
fallback_config.setdefault("weight", 1.0) # fallback_config.setdefault("weight", 1.0)
self.express_model_configs = [fallback_config] # self.express_model_configs = [fallback_config]
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id) self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_stream.stream_id)
@ -139,13 +160,16 @@ class DefaultReplyer:
self.heart_fc_sender = HeartFCSender() self.heart_fc_sender = HeartFCSender()
self.memory_activator = MemoryActivator() self.memory_activator = MemoryActivator()
self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id) self.instant_memory = InstantMemory(chat_id=self.chat_stream.stream_id)
from src.plugin_system.core.tool_use import ToolExecutor # 延迟导入ToolExecutor不然会循环依赖
self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3) self.tool_executor = ToolExecutor(chat_id=self.chat_stream.stream_id, enable_cache=True, cache_ttl=3)
def _select_weighted_model_config(self) -> Dict[str, Any]: def _select_weighted_models_config(self) -> Tuple[TaskConfig, float]:
"""使用加权随机选择来挑选一个模型配置""" """使用加权随机选择来挑选一个模型配置"""
configs = self.express_model_configs configs = self.model_set
# 提取权重,如果模型配置中没有'weight'键则默认为1.0 # 提取权重,如果模型配置中没有'weight'键则默认为1.0
weights = [config.get("weight", 1.0) for config in configs] weights = [weight for _, weight in configs]
return random.choices(population=configs, weights=weights, k=1)[0] return random.choices(population=configs, weights=weights, k=1)[0]
@ -155,18 +179,16 @@ class DefaultReplyer:
extra_info: str = "", extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_tool: bool = True, enable_tool: bool = True,
enable_timeout: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]: ) -> Tuple[bool, Optional[str], Optional[str]]:
""" """
回复器 (Replier): 负责生成回复文本的核心逻辑 回复器 (Replier): 负责生成回复文本的核心逻辑
Args: Args:
reply_to: 回复对象格式为 "发送者:消息内容" reply_to: 回复对象格式为 "发送者:消息内容"
extra_info: 额外信息用于补充上下文 extra_info: 额外信息用于补充上下文
available_actions: 可用的动作信息字典 available_actions: 可用的动作信息字典
enable_tool: 是否启用工具调用 enable_tool: 是否启用工具调用
enable_timeout: 是否启用超时处理
Returns: Returns:
Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt) Tuple[bool, Optional[str], Optional[str]]: (是否成功, 生成的回复内容, 使用的prompt)
""" """
@ -177,13 +199,12 @@ class DefaultReplyer:
# 3. 构建 Prompt # 3. 构建 Prompt
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_reply_context( prompt = await self.build_prompt_reply_context(
reply_to = reply_to, reply_to=reply_to,
extra_info=extra_info, extra_info=extra_info,
available_actions=available_actions, available_actions=available_actions,
enable_timeout=enable_timeout,
enable_tool=enable_tool, enable_tool=enable_tool,
) )
if not prompt: if not prompt:
logger.warning("构建prompt失败跳过回复生成") logger.warning("构建prompt失败跳过回复生成")
return False, None, None return False, None, None
@ -194,26 +215,8 @@ class DefaultReplyer:
model_name = "unknown_model" model_name = "unknown_model"
try: try:
with Timer("LLM生成", {}): # 内部计时器,可选保留 content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
# 加权随机选择一个模型配置 logger.debug(f"replyer生成内容: {content}")
selected_model_config = self._select_weighted_model_config()
logger.info(
f"使用模型生成回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest(
model=selected_model_config,
request_type=self.request_type,
)
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
logger.debug(f"replyer生成内容: {content}")
except Exception as llm_e: except Exception as llm_e:
# 精简报错信息 # 精简报错信息
@ -232,22 +235,21 @@ class DefaultReplyer:
raw_reply: str = "", raw_reply: str = "",
reason: str = "", reason: str = "",
reply_to: str = "", reply_to: str = "",
) -> Tuple[bool, Optional[str]]: return_prompt: bool = False,
) -> Tuple[bool, Optional[str], Optional[str]]:
""" """
表达器 (Expressor): 负责重写和优化回复文本 表达器 (Expressor): 负责重写和优化回复文本
Args: Args:
raw_reply: 原始回复内容 raw_reply: 原始回复内容
reason: 回复原因 reason: 回复原因
reply_to: 回复对象格式为 "发送者:消息内容" reply_to: 回复对象格式为 "发送者:消息内容"
relation_info: 关系信息 relation_info: 关系信息
Returns: Returns:
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容) Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
""" """
try: try:
with Timer("构建Prompt", {}): # 内部计时器,可选保留 with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt = await self.build_prompt_rewrite_context( prompt = await self.build_prompt_rewrite_context(
raw_reply=raw_reply, raw_reply=raw_reply,
@ -260,36 +262,23 @@ class DefaultReplyer:
model_name = "unknown_model" model_name = "unknown_model"
if not prompt: if not prompt:
logger.error("Prompt 构建失败,无法生成回复。") logger.error("Prompt 构建失败,无法生成回复。")
return False, None return False, None, None
try: try:
with Timer("LLM生成", {}): # 内部计时器,可选保留 content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
# 加权随机选择一个模型配置 logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
selected_model_config = self._select_weighted_model_config()
logger.info(
f"使用模型重写回复: {selected_model_config.get('name', 'N/A')} (选中概率: {selected_model_config.get('weight', 1.0)})"
)
express_model = LLMRequest(
model=selected_model_config,
request_type=self.request_type,
)
content, (reasoning_content, model_name) = await express_model.generate_response_async(prompt)
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
except Exception as llm_e: except Exception as llm_e:
# 精简报错信息 # 精简报错信息
logger.error(f"LLM 生成失败: {llm_e}") logger.error(f"LLM 生成失败: {llm_e}")
return False, None # LLM 调用失败则无法生成回复 return False, None, prompt if return_prompt else None # LLM 调用失败则无法生成回复
return True, content return True, content, prompt if return_prompt else None
except Exception as e: except Exception as e:
logger.error(f"回复生成意外失败: {e}") logger.error(f"回复生成意外失败: {e}")
traceback.print_exc() traceback.print_exc()
return False, None return False, None, prompt if return_prompt else None
async def build_relation_info(self, reply_to: str = ""): async def build_relation_info(self, reply_to: str = ""):
if not global_config.relationship.enable_relationship: if not global_config.relationship.enable_relationship:
@ -313,11 +302,11 @@ class DefaultReplyer:
async def build_expression_habits(self, chat_history: str, target: str) -> str: async def build_expression_habits(self, chat_history: str, target: str) -> str:
"""构建表达习惯块 """构建表达习惯块
Args: Args:
chat_history: 聊天历史记录 chat_history: 聊天历史记录
target: 目标消息内容 target: 目标消息内容
Returns: Returns:
str: 表达习惯信息字符串 str: 表达习惯信息字符串
""" """
@ -366,17 +355,15 @@ class DefaultReplyer:
if style_habits_str.strip() and grammar_habits_str.strip(): if style_habits_str.strip() and grammar_habits_str.strip():
expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:" expression_habits_title = "你可以参考以下的语言习惯和句法,如果情景合适就使用,不要盲目使用,不要生硬使用,以合理的方式结合到你的回复中:"
expression_habits_block = f"{expression_habits_title}\n{expression_habits_block}" return f"{expression_habits_title}\n{expression_habits_block}"
return expression_habits_block
async def build_memory_block(self, chat_history: str, target: str) -> str: async def build_memory_block(self, chat_history: str, target: str) -> str:
"""构建记忆块 """构建记忆块
Args: Args:
chat_history: 聊天历史记录 chat_history: 聊天历史记录
target: 目标消息内容 target: 目标消息内容
Returns: Returns:
str: 记忆信息字符串 str: 记忆信息字符串
""" """
@ -441,7 +428,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "info") result_type = tool_result.get("type", "tool_result")
tool_info_str += f"- 【{tool_name}{result_type}: {content}\n" tool_info_str += f"- 【{tool_name}{result_type}: {content}\n"
@ -459,10 +446,10 @@ class DefaultReplyer:
def _parse_reply_target(self, target_message: str) -> Tuple[str, str]: def _parse_reply_target(self, target_message: str) -> Tuple[str, str]:
"""解析回复目标消息 """解析回复目标消息
Args: Args:
target_message: 目标消息格式为 "发送者:消息内容" "发送者:消息内容" target_message: 目标消息格式为 "发送者:消息内容" "发送者:消息内容"
Returns: Returns:
Tuple[str, str]: (发送者名称, 消息内容) Tuple[str, str]: (发送者名称, 消息内容)
""" """
@ -481,10 +468,10 @@ class DefaultReplyer:
async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str: async def build_keywords_reaction_prompt(self, target: Optional[str]) -> str:
"""构建关键词反应提示 """构建关键词反应提示
Args: Args:
target: 目标消息内容 target: 目标消息内容
Returns: Returns:
str: 关键词反应提示字符串 str: 关键词反应提示字符串
""" """
@ -523,11 +510,11 @@ class DefaultReplyer:
async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]: async def _time_and_run_task(self, coroutine, name: str) -> Tuple[str, Any, float]:
"""计时并运行异步任务的辅助函数 """计时并运行异步任务的辅助函数
Args: Args:
coroutine: 要执行的协程 coroutine: 要执行的协程
name: 任务名称 name: 任务名称
Returns: Returns:
Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时) Tuple[str, Any, float]: (任务名称, 任务结果, 执行耗时)
""" """
@ -537,7 +524,9 @@ class DefaultReplyer:
duration = end_time - start_time duration = end_time - start_time
return name, result, duration return name, result, duration
def build_s4u_chat_history_prompts(self, message_list_before_now: List[Dict[str, Any]], target_user_id: str) -> Tuple[str, str]: def build_s4u_chat_history_prompts(
self, message_list_before_now: List[Dict[str, Any]], target_user_id: str
) -> Tuple[str, str]:
""" """
构建 s4u 风格的分离对话 prompt 构建 s4u 风格的分离对话 prompt
@ -612,7 +601,7 @@ class DefaultReplyer:
chat_info: str, chat_info: str,
) -> Any: ) -> Any:
"""构建 mai_think 上下文信息 """构建 mai_think 上下文信息
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
memory_block: 记忆块内容 memory_block: 记忆块内容
@ -625,7 +614,7 @@ class DefaultReplyer:
sender: 发送者名称 sender: 发送者名称
target: 目标消息内容 target: 目标消息内容
chat_info: 聊天信息 chat_info: 聊天信息
Returns: Returns:
Any: mai_think 实例 Any: mai_think 实例
""" """
@ -647,19 +636,17 @@ class DefaultReplyer:
reply_to: str, reply_to: str,
extra_info: str = "", extra_info: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None, available_actions: Optional[Dict[str, ActionInfo]] = None,
enable_timeout: bool = False,
enable_tool: bool = True, enable_tool: bool = True,
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
""" """
构建回复器上下文 构建回复器上下文
Args: Args:
reply_data: 回复数据 reply_to: 回复对象格式为 "发送者:消息内容"
replay_data 包含以下字段 extra_info: 额外信息用于补充上下文
structured_info: 结构化信息一般是工具调用获得的信息
reply_to: 回复对象
extra_info/extra_info_block: 额外信息
available_actions: 可用动作 available_actions: 可用动作
enable_timeout: 是否启用超时处理
enable_tool: 是否启用工具调用
Returns: Returns:
str: 构建好的上下文 str: 构建好的上下文
@ -727,7 +714,7 @@ class DefaultReplyer:
self._time_and_run_task( self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info" self.build_tool_info(chat_talking_prompt_short, reply_to, enable_tool=enable_tool), "tool_info"
), ),
self._time_and_run_task(get_prompt_info(target, threshold=0.38), "prompt_info"), self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, reply_to), "prompt_info"),
) )
# 任务名称中英文映射 # 任务名称中英文映射
@ -877,7 +864,7 @@ class DefaultReplyer:
raw_reply: str, raw_reply: str,
reason: str, reason: str,
reply_to: str, reply_to: str,
) -> str: ) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream chat_stream = self.chat_stream
chat_id = chat_stream.stream_id chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info) is_group_chat = bool(chat_stream.group_info)
@ -1011,6 +998,81 @@ class DefaultReplyer:
display_message=display_message, display_message=display_message,
) )
async def llm_generate_content(self, prompt: str):
with Timer("LLM生成", {}): # 内部计时器,可选保留
# 加权随机选择一个模型配置
selected_model_config, weight = self._select_weighted_models_config()
logger.info(f"使用模型集生成回复: {selected_model_config} (选中概率: {weight})")
express_model = LLMRequest(model_set=selected_model_config, request_type=self.request_type)
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\n{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await express_model.generate_response_async(prompt)
logger.debug(f"replyer生成内容: {content}")
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, reply_to: str):
related_info = ""
start_time = time.time()
from src.plugins.built_in.knowledge.lpmm_get_knowledge import SearchKnowledgeFromLPMMTool
if not reply_to:
logger.debug("没有回复对象,跳过获取知识库内容")
return ""
sender, content = self._parse_reply_target(reply_to)
if not content:
logger.debug("回复对象内容为空,跳过获取知识库内容")
return ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
prompt = await global_prompt_manager.format_prompt(
"lpmm_get_knowledge_prompt",
bot_name=bot_name,
time_now=time_now,
chat_history=message,
sender=sender,
target_message=content,
)
_, _, _, _, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
end_time = time.time()
if not result or not result.get("content"):
logger.debug("从LPMM知识库获取知识失败返回空知识...")
return ""
found_knowledge_from_lpmm = result.get("content", "")
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
return f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
def weighted_sample_no_replacement(items, weights, k) -> list: def weighted_sample_no_replacement(items, weights, k) -> list:
""" """
@ -1046,38 +1108,4 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
return selected return selected
async def get_prompt_info(message: str, threshold: float):
related_info = ""
start_time = time.time()
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 从LPMM知识库获取知识
try:
# 检查LPMM知识库是否启用
if qa_manager is None:
logger.debug("LPMM知识库已禁用跳过知识获取")
return ""
found_knowledge_from_lpmm = await qa_manager.get_knowledge(message)
end_time = time.time()
if found_knowledge_from_lpmm is not None:
logger.debug(
f"从LPMM知识库获取知识相关信息{found_knowledge_from_lpmm[:100]}...,信息长度: {len(found_knowledge_from_lpmm)}"
)
related_info += found_knowledge_from_lpmm
logger.debug(f"获取知识库内容耗时: {(end_time - start_time):.3f}")
logger.debug(f"获取知识库内容,相关信息:{related_info[:100]}...,信息长度: {len(related_info)}")
# 格式化知识信息
formatted_prompt_info = f"你有以下这些**知识**\n{related_info}\n请你**记住上面的知识**,之后可能会用到。\n"
return formatted_prompt_info
else:
logger.debug("从LPMM知识库获取知识失败可能是从未导入过知识返回空知识...")
return ""
except Exception as e:
logger.error(f"获取知识库内容时发生异常: {str(e)}")
return ""
init_prompt() init_prompt()

View File

@ -1,6 +1,7 @@
from typing import Dict, Any, Optional, List from typing import Dict, Optional, List, Tuple
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
@ -15,7 +16,7 @@ class ReplyerManager:
self, self,
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None, model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> Optional[DefaultReplyer]: ) -> Optional[DefaultReplyer]:
""" """
@ -49,7 +50,7 @@ class ReplyerManager:
# model_configs 只在此时(初始化时)生效 # model_configs 只在此时(初始化时)生效
replyer = DefaultReplyer( replyer = DefaultReplyer(
chat_stream=target_stream, chat_stream=target_stream,
model_configs=model_configs, # 可以是None此时使用默认模型 model_set_with_weight=model_set_with_weight, # 可以是None此时使用默认模型
request_type=request_type, request_type=request_type,
) )
self._repliers[stream_id] = replyer self._repliers[stream_id] = replyer

View File

@ -1,223 +0,0 @@
import ast
import json
import logging
from typing import Any, Dict, TypeVar, List, Union, Tuple, Optional
# 定义类型变量用于泛型类型提示
T = TypeVar("T")
# 获取logger
logger = logging.getLogger("json_utils")
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
"""
安全地解析JSON字符串出错时返回默认值
现在尝试处理单引号和标准JSON
参数:
json_str: 要解析的JSON字符串
default_value: 解析失败时返回的默认值
返回:
解析后的Python对象或在解析失败时返回default_value
"""
if not json_str or not isinstance(json_str, str):
logger.warning(f"safe_json_loads 接收到非字符串输入: {type(json_str)}, 值: {json_str}")
return default_value
try:
# 尝试标准的 JSON 解析
return json.loads(json_str)
except json.JSONDecodeError:
# 如果标准解析失败,尝试用 ast.literal_eval 解析
try:
# logger.debug(f"标准JSON解析失败尝试用 ast.literal_eval 解析: {json_str[:100]}...")
result = ast.literal_eval(json_str)
if isinstance(result, dict):
return result
logger.warning(f"ast.literal_eval 解析成功但结果不是字典: {type(result)}, 内容: {result}")
return default_value
except (ValueError, SyntaxError, MemoryError, RecursionError) as ast_e:
logger.error(f"使用 ast.literal_eval 解析失败: {ast_e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"使用 ast.literal_eval 解析时发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
except Exception as e:
logger.error(f"JSON解析过程中发生意外错误: {e}, 字符串: {json_str[:100]}...")
return default_value
def extract_tool_call_arguments(
tool_call: Dict[str, Any], default_value: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
参数:
tool_call: 工具调用对象字典
default_value: 解析失败时返回的默认值
返回:
解析后的参数字典或在解析失败时返回default_value
"""
default_result = default_value or {}
if not tool_call or not isinstance(tool_call, dict):
logger.error(f"无效的工具调用对象: {tool_call}")
return default_result
try:
# 提取function参数
function_data = tool_call.get("function", {})
if not function_data or not isinstance(function_data, dict):
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
if arguments_str := function_data.get("arguments", "{}"):
# 解析JSON
return safe_json_loads(arguments_str, default_result)
else:
return default_result
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
"""
安全地将Python对象序列化为JSON字符串
参数:
obj: 要序列化的Python对象
default_value: 序列化失败时返回的默认值
ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符)
pretty: 是否美化输出JSON
返回:
序列化后的JSON字符串或在序列化失败时返回default_value
"""
try:
indent = 2 if pretty else None
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent)
except TypeError as e:
logger.error(f"JSON序列化失败(类型错误): {e}")
return default_value
except Exception as e:
logger.error(f"JSON序列化过程中发生意外错误: {e}")
return default_value
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
"""
标准化LLM响应格式将各种格式如元组转换为统一的列表格式
参数:
response: 原始LLM响应
log_prefix: 日志前缀
返回:
元组 (成功标志, 标准化后的响应列表, 错误消息)
"""
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
# 检查是否为None
if response is None:
return False, [], "LLM响应为None"
# 记录原始类型
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
# 将元组转换为列表
if isinstance(response, tuple):
logger.debug(f"{log_prefix}将元组响应转换为列表")
response = list(response)
# 确保是列表类型
if not isinstance(response, list):
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
# 处理工具调用部分(如果存在)
if len(response) == 3:
content, reasoning, tool_calls = response
# 将工具调用部分转换为列表(如果是元组)
if isinstance(tool_calls, tuple):
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
tool_calls = list(tool_calls)
response[2] = tool_calls
return True, response, ""
def process_llm_tool_calls(
tool_calls: List[Dict[str, Any]], log_prefix: str = ""
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
处理并验证LLM响应中的工具调用列表
参数:
tool_calls: 从LLM响应中直接获取的工具调用列表
log_prefix: 日志前缀
返回:
元组 (成功标志, 验证后的工具调用列表, 错误消息)
"""
# 如果列表为空,表示没有工具调用,这不是错误
if not tool_calls:
return True, [], "工具调用列表为空"
# 验证每个工具调用的格式
valid_tool_calls = []
for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
continue
# 检查基本结构
if tool_call.get("type") != "function":
logger.warning(
f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '未定义')}, 内容: {tool_call}"
)
continue
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
continue
func_details = tool_call["function"]
if "name" not in func_details or not isinstance(func_details.get("name"), str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
continue
# 验证参数 'arguments'
args_value = func_details.get("arguments")
# 1. 检查 arguments 是否存在且是字符串
if args_value is None or not isinstance(args_value, str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'字符串: {func_details}")
continue
# 2. 尝试安全地解析 arguments 字符串
parsed_args = safe_json_loads(args_value, None)
# 3. 检查解析结果是否为字典
if parsed_args is None or not isinstance(parsed_args, dict):
logger.warning(
f"{log_prefix}工具调用[{i}]的'arguments'无法解析为有效的JSON字典, "
f"原始字符串: {args_value[:100]}..., 解析结果类型: {type(parsed_args).__name__}"
)
continue
# 如果检查通过,将原始的 tool_call 加入有效列表
valid_tool_calls.append(tool_call)
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
return False, [], "所有工具调用格式均无效"
return True, valid_tool_calls, ""

View File

@ -11,7 +11,7 @@ from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
@ -109,13 +109,11 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> tuple[bool, float]:
return is_mentioned, reply_probability return is_mentioned, reply_probability
async def get_embedding(text, request_type="embedding"): async def get_embedding(text, request_type="embedding") -> Optional[List[float]]:
"""获取文本的embedding向量""" """获取文本的embedding向量"""
# TODO: API-Adapter修改标记 llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type=request_type)
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
# return llm.get_embedding_sync(text)
try: try:
embedding = await llm.get_embedding(text) embedding, _ = await llm.get_embedding(text)
except Exception as e: except Exception as e:
logger.error(f"获取embedding失败: {str(e)}") logger.error(f"获取embedding失败: {str(e)}")
embedding = None embedding = None

View File

@ -14,7 +14,7 @@ from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions from src.common.database.database_model import Images, ImageDescriptions
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
install(extra_lines=3) install(extra_lines=3)
@ -37,7 +37,7 @@ class ImageManager:
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self.vlm = LLMRequest(model=global_config.model.vlm, temperature=0.4, max_tokens=300, request_type="image") self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
@ -107,6 +107,7 @@ class ImageManager:
# 优先使用EmojiManager查询已注册表情包的描述 # 优先使用EmojiManager查询已注册表情包的描述
try: try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash) cached_emoji_description = await emoji_manager.get_emoji_description_by_hash(image_hash)
if cached_emoji_description: if cached_emoji_description:
@ -116,13 +117,12 @@ class ImageManager:
logger.debug(f"查询EmojiManager时出错: {e}") logger.debug(f"查询EmojiManager时出错: {e}")
# 查询ImageDescriptions表的缓存描述 # 查询ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "emoji") if cached_description := self._get_description_from_db(image_hash, "emoji"):
if cached_description:
logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") logger.info(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
# === 二步走识别流程 === # === 二步走识别流程 ===
# 第一步VLM视觉分析 - 生成详细描述 # 第一步VLM视觉分析 - 生成详细描述
if image_format in ["gif", "GIF"]: if image_format in ["gif", "GIF"]:
image_base64_processed = self.transform_gif(image_base64) image_base64_processed = self.transform_gif(image_base64)
@ -130,10 +130,16 @@ class ImageManager:
logger.warning("GIF转换失败无法获取描述") logger.warning("GIF转换失败无法获取描述")
return "[表情包(GIF处理失败)]" return "[表情包(GIF处理失败)]"
vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" vlm_prompt = "这是一个动态图表情包,每一张图代表了动态图的某一帧,黑色背景代表透明,描述一下表情包表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64_processed, "jpg") detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64_processed, "jpg", temperature=0.4, max_tokens=300
)
else: else:
vlm_prompt = "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析" vlm_prompt = (
detailed_description, _ = await self.vlm.generate_response_for_image(vlm_prompt, image_base64, image_format) "这是一个表情包,请详细描述一下表情包所表达的情感和内容,描述细节,从互联网梗,meme的角度去分析"
)
detailed_description, _ = await self.vlm.generate_response_for_image(
vlm_prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if detailed_description is None: if detailed_description is None:
logger.warning("VLM未能生成表情包详细描述") logger.warning("VLM未能生成表情包详细描述")
@ -150,31 +156,32 @@ class ImageManager:
3. 输出简短精准不要解释 3. 输出简短精准不要解释
4. 如果有多个词用逗号分隔 4. 如果有多个词用逗号分隔
""" """
# 使用较低温度确保输出稳定 # 使用较低温度确保输出稳定
emotion_llm = LLMRequest(model=global_config.model.utils, temperature=0.3, max_tokens=50, request_type="emoji") emotion_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="emoji")
emotion_result, _ = await emotion_llm.generate_response_async(emotion_prompt) emotion_result, _ = await emotion_llm.generate_response_async(
emotion_prompt, temperature=0.3, max_tokens=50
)
if emotion_result is None: if emotion_result is None:
logger.warning("LLM未能生成情感标签使用详细描述的前几个词") logger.warning("LLM未能生成情感标签使用详细描述的前几个词")
# 降级处理:从详细描述中提取关键词 # 降级处理:从详细描述中提取关键词
import jieba import jieba
words = list(jieba.cut(detailed_description)) words = list(jieba.cut(detailed_description))
emotion_result = "".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情") emotion_result = "".join(words[:2]) if len(words) >= 2 else (words[0] if words else "表情")
# 处理情感结果取前1-2个最重要的标签 # 处理情感结果取前1-2个最重要的标签
emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()] emotions = [e.strip() for e in emotion_result.replace("", ",").split(",") if e.strip()]
final_emotion = emotions[0] if emotions else "表情" final_emotion = emotions[0] if emotions else "表情"
# 如果有第二个情感且不重复,也包含进来 # 如果有第二个情感且不重复,也包含进来
if len(emotions) > 1 and emotions[1] != emotions[0]: if len(emotions) > 1 and emotions[1] != emotions[0]:
final_emotion = f"{emotions[0]}{emotions[1]}" final_emotion = f"{emotions[0]}{emotions[1]}"
logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") logger.info(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存,防止并发写入时重复生成 if cached_description := self._get_description_from_db(image_hash, "emoji"):
cached_description = self._get_description_from_db(image_hash, "emoji")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存表情包描述: {cached_description}")
return f"[表情包:{cached_description}]" return f"[表情包:{cached_description}]"
@ -242,9 +249,7 @@ class ImageManager:
logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...") logger.debug(f"[缓存命中] 使用Images表中的图片描述: {existing_image.description[:50]}...")
return f"[图片:{existing_image.description}]" return f"[图片:{existing_image.description}]"
# 查询ImageDescriptions表的缓存描述 if cached_description := self._get_description_from_db(image_hash, "image"):
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...") logger.debug(f"[缓存命中] 使用ImageDescriptions表中的描述: {cached_description[:50]}...")
return f"[图片:{cached_description}]" return f"[图片:{cached_description}]"
@ -252,7 +257,9 @@ class ImageManager:
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
prompt = global_config.custom_prompt.image_prompt prompt = global_config.custom_prompt.image_prompt
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None: if description is None:
logger.warning("AI未能生成图片描述") logger.warning("AI未能生成图片描述")
@ -445,10 +452,7 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64) image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查图片是否已存在 if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
if existing_image:
# 检查是否缺少必要字段,如果缺少则创建新记录 # 检查是否缺少必要字段,如果缺少则创建新记录
if ( if (
not hasattr(existing_image, "image_id") not hasattr(existing_image, "image_id")
@ -524,9 +528,7 @@ class ImageManager:
# 优先检查是否已有其他相同哈希的图片记录包含描述 # 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none( existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
(Images.description.is_null(False)) &
(Images.description != "")
) )
if existing_with_description and existing_with_description.id != image.id: if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
@ -538,8 +540,7 @@ class ImageManager:
return return
# 检查ImageDescriptions表的缓存描述 # 检查ImageDescriptions表的缓存描述
cached_description = self._get_description_from_db(image_hash, "image") if cached_description := self._get_description_from_db(image_hash, "image"):
if cached_description:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description image.description = cached_description
image.vlm_processed = True image.vlm_processed = True
@ -554,15 +555,15 @@ class ImageManager:
# 获取VLM描述 # 获取VLM描述
logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)") logger.info(f"[VLM异步调用] 为图片生成描述 (ID: {image_id}, Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(prompt, image_base64, image_format) description, _ = await self.vlm.generate_response_for_image(
prompt, image_base64, image_format, temperature=0.4, max_tokens=300
)
if description is None: if description is None:
logger.warning("VLM未能生成图片描述") logger.warning("VLM未能生成图片描述")
description = "无法生成描述" description = "无法生成描述"
# 再次检查缓存,防止并发写入时重复生成 if cached_description := self._get_description_from_db(image_hash, "image"):
cached_description = self._get_description_from_db(image_hash, "image")
if cached_description:
logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}") logger.warning(f"虽然生成了描述,但是找到缓存图片描述: {cached_description}")
description = cached_description description = cached_description
@ -606,7 +607,7 @@ def image_path_to_base64(image_path: str) -> str:
raise FileNotFoundError(f"图片文件不存在: {image_path}") raise FileNotFoundError(f"图片文件不存在: {image_path}")
with open(image_path, "rb") as f: with open(image_path, "rb") as f:
image_data = f.read() if image_data := f.read():
if not image_data: return base64.b64encode(image_data).decode("utf-8")
else:
raise IOError(f"读取图片文件失败: {image_path}") raise IOError(f"读取图片文件失败: {image_path}")
return base64.b64encode(image_data).decode("utf-8")

View File

@ -1,35 +1,29 @@
import base64 from src.config.config import global_config, model_config
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from rich.traceback import install from rich.traceback import install
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("chat_voice") logger = get_logger("chat_voice")
async def get_voice_text(voice_base64: str) -> str: async def get_voice_text(voice_base64: str) -> str:
"""获取音频文件描述""" """获取音频文件转录文本"""
if not global_config.voice.enable_asr: if not global_config.voice.enable_asr:
logger.warning("语音识别未启用,无法处理语音消息") logger.warning("语音识别未启用,无法处理语音消息")
return "[语音]" return "[语音]"
try: try:
# 解码base64音频数据 _llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
# 确保base64字符串只包含ASCII字符 text = await _llm.generate_response_for_voice(voice_base64)
if isinstance(voice_base64, str):
voice_base64 = voice_base64.encode("ascii", errors="ignore").decode("ascii")
voice_bytes = base64.b64decode(voice_base64)
_llm = LLMRequest(model=global_config.model.voice, request_type="voice")
text = await _llm.generate_response_for_voice(voice_bytes)
if text is None: if text is None:
logger.warning("未能生成语音文本") logger.warning("未能生成语音文本")
return "[语音(文本生成失败)]" return "[语音(文本生成失败)]"
logger.debug(f"描述是{text}") logger.debug(f"描述是{text}")
return f"[语音:{text}]" return f"[语音:{text}]"
except Exception as e: except Exception as e:
logger.error(f"语音转文字失败: {str(e)}") logger.error(f"语音转文字失败: {str(e)}")
return "[语音]" return "[语音]"

View File

@ -19,13 +19,13 @@ Mxp 模式:梦溪畔独家赞助
下下策是询问一个菜鸟@梦溪畔 下下策是询问一个菜鸟@梦溪畔
""" """
from .willing_manager import BaseWillingManager
from typing import Dict from typing import Dict
import asyncio import asyncio
import time import time
import math import math
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from .willing_manager import BaseWillingManager
class MxpWillingManager(BaseWillingManager): class MxpWillingManager(BaseWillingManager):

View File

@ -281,20 +281,6 @@ class Memory(BaseModel):
table_name = "memory" table_name = "memory"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class Expression(BaseModel): class Expression(BaseModel):
""" """
用于存储表达风格的模型 用于存储表达风格的模型
@ -382,7 +368,6 @@ def create_tables():
ImageDescriptions, ImageDescriptions,
OnlineTime, OnlineTime,
PersonInfo, PersonInfo,
Knowledges,
Expression, Expression,
ThinkingLog, ThinkingLog,
GraphNodes, # 添加图节点表 GraphNodes, # 添加图节点表
@ -408,7 +393,6 @@ def initialize_database():
ImageDescriptions, ImageDescriptions,
OnlineTime, OnlineTime,
PersonInfo, PersonInfo,
Knowledges,
Expression, Expression,
Memory, Memory,
ThinkingLog, ThinkingLog,

View File

@ -334,7 +334,7 @@ MODULE_COLORS = {
"llm_models": "\033[36m", # 青色 "llm_models": "\033[36m", # 青色
"remote": "\033[38;5;242m", # 深灰色,更不显眼 "remote": "\033[38;5;242m", # 深灰色,更不显眼
"planner": "\033[36m", "planner": "\033[36m",
"memory": "\033[34m", "memory": "\033[38;5;117m", # 天蓝色
"hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读 "hfc": "\033[38;5;81m", # 稍微暗一些的青色,保持可读
"action_manager": "\033[38;5;208m", # 橙色不与replyer重复 "action_manager": "\033[38;5;208m", # 橙色不与replyer重复
# 关系系统 # 关系系统
@ -352,7 +352,7 @@ MODULE_COLORS = {
"expressor": "\033[38;5;166m", # 橙色 "expressor": "\033[38;5;166m", # 橙色
# 专注聊天模块 # 专注聊天模块
"replyer": "\033[38;5;166m", # 橙色 "replyer": "\033[38;5;166m", # 橙色
"memory_activator": "\033[34m", # 绿 "memory_activator": "\033[38;5;117m", # 天蓝
# 插件系统 # 插件系统
"plugins": "\033[31m", # 红色 "plugins": "\033[31m", # 红色
"plugin_api": "\033[33m", # 黄色 "plugin_api": "\033[33m", # 黄色
@ -451,7 +451,7 @@ class ModuleColoredConsoleRenderer:
# 日志级别颜色 # 日志级别颜色
self._level_colors = { self._level_colors = {
"debug": "\033[38;5;208m", # 橙色 "debug": "\033[38;5;208m", # 橙色
"info": "\033[34m", # 蓝色 "info": "\033[38;5;117m", # 天蓝色
"success": "\033[32m", # 绿色 "success": "\033[32m", # 绿色
"warning": "\033[33m", # 黄色 "warning": "\033[33m", # 黄色
"error": "\033[31m", # 红色 "error": "\033[31m", # 红色

View File

@ -0,0 +1,142 @@
from dataclasses import dataclass, field
from .config_base import ConfigBase
@dataclass
class APIProvider(ConfigBase):
"""API提供商配置类"""
name: str
"""API提供商名称"""
base_url: str
"""API基础URL"""
api_key: str = field(default_factory=str, repr=False)
"""API密钥列表"""
client_type: str = field(default="openai")
"""客户端类型如openai/google等默认为openai"""
max_retry: int = 2
"""最大重试次数单个模型API调用失败最多重试的次数"""
timeout: int = 10
"""API调用的超时时长超过这个时长本次请求将被视为“请求超时”单位"""
retry_interval: int = 10
"""重试间隔如果API调用失败重试的间隔时间单位"""
def get_api_key(self) -> str:
return self.api_key
def __post_init__(self):
"""确保api_key在repr中不被显示"""
if not self.api_key:
raise ValueError("API密钥不能为空请在配置中设置有效的API密钥。")
if not self.base_url:
raise ValueError("API基础URL不能为空请在配置中设置有效的基础URL。")
if not self.name:
raise ValueError("API提供商名称不能为空请在配置中设置有效的名称。")
@dataclass
class ModelInfo(ConfigBase):
"""单个模型信息配置类"""
model_identifier: str
"""模型标识符用于URL调用"""
name: str
"""模型名称(用于模块调用)"""
api_provider: str
"""API提供商如OpenAI、Azure等"""
price_in: float = field(default=0.0)
"""每M token输入价格"""
price_out: float = field(default=0.0)
"""每M token输出价格"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""
extra_params: dict = field(default_factory=dict)
"""额外参数用于API调用时的额外配置"""
def __post_init__(self):
if not self.model_identifier:
raise ValueError("模型标识符不能为空,请在配置中设置有效的模型标识符。")
if not self.name:
raise ValueError("模型名称不能为空,请在配置中设置有效的模型名称。")
if not self.api_provider:
raise ValueError("API提供商不能为空请在配置中设置有效的API提供商。")
@dataclass
class TaskConfig(ConfigBase):
"""任务配置类"""
model_list: list[str] = field(default_factory=list)
"""任务使用的模型列表"""
max_tokens: int = 1024
"""任务最大输出token数"""
temperature: float = 0.3
"""模型温度"""
@dataclass
class ModelTaskConfig(ConfigBase):
"""模型配置类"""
utils: TaskConfig
"""组件模型配置"""
utils_small: TaskConfig
"""组件小模型配置"""
replyer_1: TaskConfig
"""normal_chat首要回复模型模型配置"""
replyer_2: TaskConfig
"""normal_chat次要回复模型配置"""
memory: TaskConfig
"""记忆模型配置"""
emotion: TaskConfig
"""情绪模型配置"""
vlm: TaskConfig
"""视觉语言模型配置"""
voice: TaskConfig
"""语音识别模型配置"""
tool_use: TaskConfig
"""专注工具使用模型配置"""
planner: TaskConfig
"""规划模型配置"""
embedding: TaskConfig
"""嵌入模型配置"""
lpmm_entity_extract: TaskConfig
"""LPMM实体提取模型配置"""
lpmm_rdf_build: TaskConfig
"""LPMM RDF构建模型配置"""
lpmm_qa: TaskConfig
"""LPMM问答模型配置"""
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
if hasattr(self, task_name):
return getattr(self, task_name)
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")

View File

@ -1,162 +0,0 @@
import shutil
import tomlkit
from tomlkit.items import Table, KeyType
from pathlib import Path
from datetime import datetime
def get_key_comment(toml_table, key):
# 获取key的注释如果有
if hasattr(toml_table, "trivia") and hasattr(toml_table.trivia, "comment"):
return toml_table.trivia.comment
if hasattr(toml_table, "value") and isinstance(toml_table.value, dict):
item = toml_table.value.get(key)
if item is not None and hasattr(item, "trivia"):
return item.trivia.comment
if hasattr(toml_table, "keys"):
for k in toml_table.keys():
if isinstance(k, KeyType) and k.key == key:
return k.trivia.comment
return None
def compare_dicts(new, old, path=None, new_comments=None, old_comments=None, logs=None):
# 递归比较两个dict找出新增和删减项收集注释
if path is None:
path = []
if logs is None:
logs = []
if new_comments is None:
new_comments = {}
if old_comments is None:
old_comments = {}
# 新增项
for key in new:
if key == "version":
continue
if key not in old:
comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], new_comments, old_comments, logs)
# 删减项
for key in old:
if key == "version":
continue
if key not in new:
comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs
def update_config():
print("开始更新配置文件...")
# 获取根目录路径
root_dir = Path(__file__).parent.parent.parent.parent
template_dir = root_dir / "template"
config_dir = root_dir / "config"
old_config_dir = config_dir / "old"
# 创建old目录如果不存在
old_config_dir.mkdir(exist_ok=True)
# 定义文件路径
template_path = template_dir / "bot_config_template.toml"
old_config_path = config_dir / "bot_config.toml"
new_config_path = config_dir / "bot_config.toml"
# 读取旧配置文件
old_config = {}
if old_config_path.exists():
print(f"发现旧配置文件: {old_config_path}")
with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f)
# 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = old_config_dir / f"bot_config_{timestamp}.toml"
# 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path)
print(f"已备份旧配置文件到: {old_backup_path}")
# 复制模板文件到配置目录
print(f"从模板文件创建新配置: {template_path}")
shutil.copy2(template_path, new_config_path)
# 读取新配置文件
with open(new_config_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version:
print(f"检测到版本号相同 (v{old_version}),跳过更新")
# 如果version相同恢复旧配置文件并返回
shutil.move(old_backup_path, old_config_path) # type: ignore
return
else:
print(f"检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}")
# 输出新增和删减项及注释
if old_config:
print("配置项变动如下:")
logs = compare_dicts(new_config, old_config)
if logs:
for log in logs:
print(log)
else:
print("无新增或删减项")
# 递归更新配置
def update_dict(target, source):
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
if isinstance(value, dict) and isinstance(target[key], (dict, Table)):
update_dict(target[key], value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
if not value:
target[key] = tomlkit.array()
else:
# 特殊处理正则表达式数组和包含正则表达式的结构
if key == "ban_msgs_regex":
# 直接使用原始值,不进行额外处理
target[key] = value
elif key == "regex_rules":
# 对于regex_rules需要特殊处理其中的regex字段
target[key] = value
else:
# 检查是否包含正则表达式相关的字典项
contains_regex = False
if value and isinstance(value[0], dict) and "regex" in value[0]:
contains_regex = True
target[key] = value if contains_regex else tomlkit.array(str(value))
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中
print("开始合并新旧配置...")
update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config))
print("配置文件更新完成")
if __name__ == "__main__":
update_config()

View File

@ -1,12 +1,14 @@
import os import os
import tomlkit import tomlkit
import shutil import shutil
import sys
from datetime import datetime from datetime import datetime
from tomlkit import TOMLDocument from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType from tomlkit.items import Table, KeyType
from dataclasses import field, dataclass from dataclasses import field, dataclass
from rich.traceback import install from rich.traceback import install
from typing import List, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
@ -25,7 +27,6 @@ from src.config.official_configs import (
ResponseSplitterConfig, ResponseSplitterConfig,
TelemetryConfig, TelemetryConfig,
ExperimentalConfig, ExperimentalConfig,
ModelConfig,
MessageReceiveConfig, MessageReceiveConfig,
MaimMessageConfig, MaimMessageConfig,
LPMMKnowledgeConfig, LPMMKnowledgeConfig,
@ -36,6 +37,13 @@ from src.config.official_configs import (
CustomPromptConfig, CustomPromptConfig,
) )
from .api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
install(extra_lines=3) install(extra_lines=3)
@ -49,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/ # 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.9.1" MMC_VERSION = "0.10.0-snapshot.4"
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):
@ -79,7 +87,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue continue
if key not in old: if key not in old:
comment = get_key_comment(new, key) comment = get_key_comment(new, key)
logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}") logs.append(f"新增: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)): elif isinstance(new[key], (dict, Table)) and isinstance(old.get(key), (dict, Table)):
compare_dicts(new[key], old[key], path + [str(key)], logs) compare_dicts(new[key], old[key], path + [str(key)], logs)
# 删减项 # 删减项
@ -88,7 +96,7 @@ def compare_dicts(new, old, path=None, logs=None):
continue continue
if key not in new: if key not in new:
comment = get_key_comment(old, key) comment = get_key_comment(old, key)
logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment if comment else ''}") logs.append(f"删减: {'.'.join(path + [str(key)])} 注释: {comment or ''}")
return logs return logs
@ -123,67 +131,110 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
if key in old: if key in old:
if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)): if isinstance(new[key], (dict, Table)) and isinstance(old[key], (dict, Table)):
compare_default_values(new[key], old[key], path + [str(key)], logs, changes) compare_default_values(new[key], old[key], path + [str(key)], logs, changes)
else: elif new[key] != old[key]:
# 只要值发生变化就记录 logs.append(f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}")
if new[key] != old[key]: changes.append((path + [str(key)], old[key], new[key]))
logs.append(
f"默认值变化: {'.'.join(path + [str(key)])} 旧默认值: {old[key]} 新默认值: {new[key]}"
)
changes.append((path + [str(key)], old[key], new[key]))
return logs, changes return logs, changes
def update_config(): def _get_version_from_toml(toml_path) -> Optional[str]:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
def _version_tuple(v):
"""将版本字符串转换为元组以便比较"""
if v is None:
return (0,)
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
_update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
def _update_config_generic(config_name: str, template_name: str):
"""
通用的配置文件更新函数
Args:
config_name: 配置文件名不含扩展名 'bot_config' 'model_config'
template_name: 模板文件名不含扩展名 'bot_config_template' 'model_config_template'
"""
# 获取根目录路径 # 获取根目录路径
old_config_dir = os.path.join(CONFIG_DIR, "old") old_config_dir = os.path.join(CONFIG_DIR, "old")
compare_dir = os.path.join(TEMPLATE_DIR, "compare") compare_dir = os.path.join(TEMPLATE_DIR, "compare")
# 定义文件路径 # 定义文件路径
template_path = os.path.join(TEMPLATE_DIR, "bot_config_template.toml") template_path = os.path.join(TEMPLATE_DIR, f"{template_name}.toml")
old_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") old_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
new_config_path = os.path.join(CONFIG_DIR, "bot_config.toml") new_config_path = os.path.join(CONFIG_DIR, f"{config_name}.toml")
compare_path = os.path.join(compare_dir, "bot_config_template.toml") compare_path = os.path.join(compare_dir, f"{template_name}.toml")
# 创建compare目录如果不存在 # 创建compare目录如果不存在
os.makedirs(compare_dir, exist_ok=True) os.makedirs(compare_dir, exist_ok=True)
# 处理compare下的模板文件 template_version = _get_version_from_toml(template_path)
def get_version_from_toml(toml_path): compare_version = _get_version_from_toml(compare_path)
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
return None
template_version = get_version_from_toml(template_path) # 检查配置文件是否存在
compare_version = get_version_from_toml(compare_path) if not os.path.exists(old_config_path):
logger.info(f"{config_name}.toml配置文件不存在从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新{config_name}配置文件,请填写后重新运行: {old_config_path}")
# 新创建配置文件,退出
sys.exit(0)
def version_tuple(v): compare_config = None
if v is None: new_config = None
return (0,) old_config = None
return tuple(int(x) if x.isdigit() else 0 for x in str(v).replace("v", "").split("-")[0].split("."))
# 先读取 compare 下的模板(如果有),用于默认值变动检测 # 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path): if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f: with open(compare_path, "r", encoding="utf-8") as f:
compare_config = tomlkit.load(f) compare_config = tomlkit.load(f)
else:
compare_config = None
# 读取当前模板 # 读取当前模板
with open(template_path, "r", encoding="utf-8") as f: with open(template_path, "r", encoding="utf-8") as f:
new_config = tomlkit.load(f) new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做) # 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config is not None: if compare_config:
# 读取旧配置 # 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f: with open(old_config_path, "r", encoding="utf-8") as f:
old_config = tomlkit.load(f) old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config) logs, changes = compare_default_values(new_config, compare_config)
if logs: if logs:
logger.info("检测到模板默认值变动如下:") logger.info(f"检测到{config_name}模板默认值变动如下:")
for log in logs: for log in logs:
logger.info(log) logger.info(log)
# 检查旧配置是否等于旧默认值,如果是则更新为新默认值 # 检查旧配置是否等于旧默认值,如果是则更新为新默认值
@ -192,33 +243,20 @@ def update_config():
if old_value == old_default: if old_value == old_default:
set_value_by_path(old_config, path, new_default) set_value_by_path(old_config, path, new_default)
logger.info( logger.info(
f"已自动将配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}" f"已自动将{config_name}配置 {'.'.join(path)} 的值从旧默认值 {old_default} 更新为新默认值 {new_default}"
) )
else: else:
logger.info("未检测到模板默认值变动") logger.info(f"未检测到{config_name}模板默认值变动")
# 保存旧配置的变更(后续合并逻辑会用到 old_config
else:
old_config = None
# 检查 compare 下没有模板,或新模板版本更高,则复制 # 检查 compare 下没有模板,或新模板版本更高,则复制
if not os.path.exists(compare_path): if not os.path.exists(compare_path):
shutil.copy2(template_path, compare_path) shutil.copy2(template_path, compare_path)
logger.info(f"已将模板文件复制到: {compare_path}") logger.info(f"已将{config_name}模板文件复制到: {compare_path}")
elif _version_tuple(template_version) > _version_tuple(compare_version):
shutil.copy2(template_path, compare_path)
logger.info(f"{config_name}模板版本较新已替换compare下的模板: {compare_path}")
else: else:
if version_tuple(template_version) > version_tuple(compare_version): logger.debug(f"compare下的{config_name}模板版本不低于当前模板,无需替换: {compare_path}")
shutil.copy2(template_path, compare_path)
logger.info(f"模板版本较新已替换compare下的模板: {compare_path}")
else:
logger.debug(f"compare下的模板版本不低于当前模板无需替换: {compare_path}")
# 检查配置文件是否存在
if not os.path.exists(old_config_path):
logger.info("配置文件不存在,从模板创建新配置")
os.makedirs(CONFIG_DIR, exist_ok=True) # 创建文件夹
shutil.copy2(template_path, old_config_path) # 复制模板文件
logger.info(f"已创建新配置文件,请填写后重新运行: {old_config_path}")
# 如果是新创建的配置文件,直接返回
quit()
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次 # 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None: if old_config is None:
@ -226,79 +264,60 @@ def update_config():
old_config = tomlkit.load(f) old_config = tomlkit.load(f)
# new_config 已经读取 # new_config 已经读取
# 读取 compare_config 只用于默认值变动检测,后续合并逻辑不再用
# 检查version是否相同 # 检查version是否相同
if old_config and "inner" in old_config and "inner" in new_config: if old_config and "inner" in old_config and "inner" in new_config:
old_version = old_config["inner"].get("version") # type: ignore old_version = old_config["inner"].get("version") # type: ignore
new_version = new_config["inner"].get("version") # type: ignore new_version = new_config["inner"].get("version") # type: ignore
if old_version and new_version and old_version == new_version: if old_version and new_version and old_version == new_version:
logger.info(f"检测到配置文件版本号相同 (v{old_version}),跳过更新") logger.info(f"检测到{config_name}配置文件版本号相同 (v{old_version}),跳过更新")
return return
else: else:
logger.info( logger.info(
f"\n----------------------------------------\n检测到版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------" f"\n----------------------------------------\n检测到{config_name}版本号不同: 旧版本 v{old_version} -> 新版本 v{new_version}\n----------------------------------------"
) )
else: else:
logger.info("已有配置文件未检测到版本号,可能是旧版本。将进行更新") logger.info(f"已有{config_name}配置文件未检测到版本号,可能是旧版本。将进行更新")
# 创建old目录如果不存在 # 创建old目录如果不存在
os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名 os.makedirs(old_config_dir, exist_ok=True) # 生成带时间戳的新文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
old_backup_path = os.path.join(old_config_dir, f"bot_config_{timestamp}.toml") old_backup_path = os.path.join(old_config_dir, f"{config_name}_{timestamp}.toml")
# 移动旧配置文件到old目录 # 移动旧配置文件到old目录
shutil.move(old_config_path, old_backup_path) shutil.move(old_config_path, old_backup_path)
logger.info(f"已备份旧配置文件到: {old_backup_path}") logger.info(f"已备份旧{config_name}配置文件到: {old_backup_path}")
# 复制模板文件到配置目录 # 复制模板文件到配置目录
shutil.copy2(template_path, new_config_path) shutil.copy2(template_path, new_config_path)
logger.info(f"已创建新配置文件: {new_config_path}") logger.info(f"已创建新{config_name}配置文件: {new_config_path}")
# 输出新增和删减项及注释 # 输出新增和删减项及注释
if old_config: if old_config:
logger.info("配置项变动如下:\n----------------------------------------") logger.info(f"{config_name}配置项变动如下:\n----------------------------------------")
logs = compare_dicts(new_config, old_config) if logs := compare_dicts(new_config, old_config):
if logs:
for log in logs: for log in logs:
logger.info(log) logger.info(log)
else: else:
logger.info("无新增或删减项") logger.info("无新增或删减项")
def update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
"""
将source字典的值更新到target字典中如果target中存在相同的键
"""
for key, value in source.items():
# 跳过version字段的更新
if key == "version":
continue
if key in target:
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, (dict, Table)):
update_dict(target_value, value)
else:
try:
# 对数组类型进行特殊处理
if isinstance(value, list):
# 如果是空数组,确保它保持为空数组
target[key] = tomlkit.array(str(value)) if value else tomlkit.array()
else:
# 其他类型使用item方法创建新值
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
# 如果转换失败,直接赋值
target[key] = value
# 将旧配置的值更新到新配置中 # 将旧配置的值更新到新配置中
logger.info("开始合并新旧配置...") logger.info(f"开始合并{config_name}新旧配置...")
update_dict(new_config, old_config) _update_dict(new_config, old_config)
# 保存更新后的配置(保留注释和格式) # 保存更新后的配置(保留注释和格式)
with open(new_config_path, "w", encoding="utf-8") as f: with open(new_config_path, "w", encoding="utf-8") as f:
f.write(tomlkit.dumps(new_config)) f.write(tomlkit.dumps(new_config))
logger.info("配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息") logger.info(f"{config_name}配置文件更新完成,建议检查新配置文件中的内容,以免丢失重要信息")
quit()
def update_config():
"""更新bot_config.toml配置文件"""
_update_config_generic("bot_config", "bot_config_template")
def update_model_config():
"""更新model_config.toml配置文件"""
_update_config_generic("model_config", "model_config_template")
@dataclass @dataclass
@ -323,7 +342,6 @@ class Config(ConfigBase):
response_splitter: ResponseSplitterConfig response_splitter: ResponseSplitterConfig
telemetry: TelemetryConfig telemetry: TelemetryConfig
experimental: ExperimentalConfig experimental: ExperimentalConfig
model: ModelConfig
maim_message: MaimMessageConfig maim_message: MaimMessageConfig
lpmm_knowledge: LPMMKnowledgeConfig lpmm_knowledge: LPMMKnowledgeConfig
tool: ToolConfig tool: ToolConfig
@ -331,11 +349,69 @@ class Config(ConfigBase):
custom_prompt: CustomPromptConfig custom_prompt: CustomPromptConfig
voice: VoiceConfig voice: VoiceConfig
@dataclass
class APIAdapterConfig(ConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo]
"""模型列表"""
model_task_config: ModelTaskConfig
"""模型任务配置"""
api_providers: List[APIProvider] = field(default_factory=list)
"""API提供商列表"""
def __post_init__(self):
if not self.models:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
if not self.api_providers:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in self.api_providers]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
# 检查模型名称是否重复
model_names = [model.name for model in self.models]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
for model in self.models:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
if not model.api_provider or model.api_provider not in self.api_providers_dict:
raise ValueError(f"模型 '{model.name}' 的 api_provider '{model.api_provider}' 不存在")
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]
def load_config(config_path: str) -> Config: def load_config(config_path: str) -> Config:
""" """
加载配置文件 加载配置文件
:param config_path: 配置文件路径 Args:
:return: Config对象 config_path: 配置文件路径
Returns:
Config对象
""" """
# 读取配置文件 # 读取配置文件
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
@ -349,18 +425,32 @@ def load_config(config_path: str) -> Config:
raise e raise e
def get_config_dir() -> str: def api_ada_load_config(config_path: str) -> APIAdapterConfig:
""" """
获取配置目录 加载API适配器配置文件
:return: 配置目录路径 Args:
config_path: 配置文件路径
Returns:
APIAdapterConfig对象
""" """
return CONFIG_DIR # 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建APIAdapterConfig对象
try:
return APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.critical("API适配器配置文件解析失败")
raise e
# 获取配置文件路径 # 获取配置文件路径
logger.info(f"MaiCore当前版本: {MMC_VERSION}") logger.info(f"MaiCore当前版本: {MMC_VERSION}")
update_config() update_config()
update_model_config()
logger.info("正在品鉴配置文件...") logger.info("正在品鉴配置文件...")
global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml")) global_config = load_config(config_path=os.path.join(CONFIG_DIR, "bot_config.toml"))
model_config = api_ada_load_config(config_path=os.path.join(CONFIG_DIR, "model_config.toml"))
logger.info("非常的新鲜,非常的美味!") logger.info("非常的新鲜,非常的美味!")

View File

@ -1,7 +1,7 @@
import re import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Literal, Optional from typing import Literal, Optional
from src.config.config_base import ConfigBase from src.config.config_base import ConfigBase
@ -598,51 +598,3 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024 embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致""" """嵌入向量维度,应该与模型的输出维度一致"""
@dataclass
class ModelConfig(ConfigBase):
"""模型配置类"""
model_max_output_length: int = 800 # 最大回复长度
utils: dict[str, Any] = field(default_factory=lambda: {})
"""组件模型配置"""
utils_small: dict[str, Any] = field(default_factory=lambda: {})
"""组件小模型配置"""
replyer_1: dict[str, Any] = field(default_factory=lambda: {})
"""normal_chat首要回复模型模型配置"""
replyer_2: dict[str, Any] = field(default_factory=lambda: {})
"""normal_chat次要回复模型配置"""
memory: dict[str, Any] = field(default_factory=lambda: {})
"""记忆模型配置"""
emotion: dict[str, Any] = field(default_factory=lambda: {})
"""情绪模型配置"""
vlm: dict[str, Any] = field(default_factory=lambda: {})
"""视觉语言模型配置"""
voice: dict[str, Any] = field(default_factory=lambda: {})
"""语音识别模型配置"""
tool_use: dict[str, Any] = field(default_factory=lambda: {})
"""专注工具使用模型配置"""
planner: dict[str, Any] = field(default_factory=lambda: {})
"""规划模型配置"""
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM实体提取模型配置"""
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM RDF构建模型配置"""
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM问答模型配置"""

View File

@ -4,7 +4,7 @@ import hashlib
import time import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
from rich.traceback import install from rich.traceback import install
@ -23,10 +23,7 @@ class Individuality:
self.meta_info_file_path = "data/personality/meta.json" self.meta_info_file_path = "data/personality/meta.json"
self.personality_data_file_path = "data/personality/personality_data.json" self.personality_data_file_path = "data/personality/personality_data.json"
self.model = LLMRequest( self.model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="individuality.compress")
model=global_config.model.utils,
request_type="individuality.compress",
)
async def initialize(self) -> None: async def initialize(self) -> None:
"""初始化个体特征""" """初始化个体特征"""
@ -35,7 +32,6 @@ class Individuality:
personality_side = global_config.personality.personality_side personality_side = global_config.personality.personality_side
identity = global_config.personality.identity identity = global_config.personality.identity
person_info_manager = get_person_info_manager() person_info_manager = get_person_info_manager()
self.bot_person_id = person_info_manager.get_person_id("system", "bot_id") self.bot_person_id = person_info_manager.get_person_id("system", "bot_id")
self.name = bot_nickname self.name = bot_nickname
@ -85,16 +81,16 @@ class Individuality:
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}" bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
else: else:
bot_nickname = "" bot_nickname = ""
# 从文件获取 short_impression # 从文件获取 short_impression
personality, identity = self._get_personality_from_file() personality, identity = self._get_personality_from_file()
# 确保short_impression是列表格式且有足够的元素 # 确保short_impression是列表格式且有足够的元素
if not personality or not identity: if not personality or not identity:
logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值") logger.warning(f"personality或identity为空: {personality}, {identity}, 使用默认值")
personality = "友好活泼" personality = "友好活泼"
identity = "人类" identity = "人类"
prompt_personality = f"{personality}\n{identity}" prompt_personality = f"{personality}\n{identity}"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
@ -215,7 +211,7 @@ class Individuality:
def _get_personality_from_file(self) -> tuple[str, str]: def _get_personality_from_file(self) -> tuple[str, str]:
"""从文件获取personality数据 """从文件获取personality数据
Returns: Returns:
tuple: (personality, identity) tuple: (personality, identity)
""" """
@ -226,7 +222,7 @@ class Individuality:
def _save_personality_to_file(self, personality: str, identity: str): def _save_personality_to_file(self, personality: str, identity: str):
"""保存personality数据到文件 """保存personality数据到文件
Args: Args:
personality: 压缩后的人格描述 personality: 压缩后的人格描述
identity: 压缩后的身份描述 identity: 压缩后的身份描述
@ -235,7 +231,7 @@ class Individuality:
"personality": personality, "personality": personality,
"identity": identity, "identity": identity,
"bot_nickname": self.name, "bot_nickname": self.name,
"last_updated": int(time.time()) "last_updated": int(time.time()),
} }
self._save_personality_data(personality_data) self._save_personality_data(personality_data)
@ -269,7 +265,7 @@ class Individuality:
2. 尽量简洁不超过30字 2. 尽量简洁不超过30字
3. 直接输出压缩后的内容不要解释""" 3. 直接输出压缩后的内容不要解释"""
response, (_, _) = await self.model.generate_response_async( response, _ = await self.model.generate_response_async(
prompt=prompt, prompt=prompt,
) )
@ -281,7 +277,7 @@ class Individuality:
# 压缩失败时使用原始内容 # 压缩失败时使用原始内容
if personality_side: if personality_side:
personality_parts.append(personality_side) personality_parts.append(personality_side)
if personality_parts: if personality_parts:
personality_result = "".join(personality_parts) personality_result = "".join(personality_parts)
else: else:
@ -308,7 +304,7 @@ class Individuality:
2. 尽量简洁不超过30字 2. 尽量简洁不超过30字
3. 直接输出压缩后的内容不要解释""" 3. 直接输出压缩后的内容不要解释"""
response, (_, _) = await self.model.generate_response_async( response, _ = await self.model.generate_response_async(
prompt=prompt, prompt=prompt,
) )

View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2025 Mai.To.The.Gate
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

View File

@ -0,0 +1,98 @@
from typing import Any
# 常见Error Code Mapping (以OpenAI API为例)
error_code_mapping = {
400: "参数不正确",
401: "API-Key错误认证失败请检查/config/model_list.toml中的配置是否正确",
402: "账号余额不足",
403: "模型拒绝访问,可能需要实名或余额不足",
404: "Not Found",
413: "请求体过大,请尝试压缩图片或减少输入内容",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高",
}
class NetworkConnectionError(Exception):
"""连接异常,常见于网络问题或服务器不可用"""
def __init__(self):
super().__init__()
def __str__(self):
return "连接异常请检查网络连接状态或URL是否正确"
class ReqAbortException(Exception):
"""请求异常退出,常见于请求被中断或取消"""
def __init__(self, message: str | None = None):
super().__init__(message)
self.message = message
def __str__(self):
return self.message or "请求因未知原因异常终止"
class RespNotOkException(Exception):
"""请求响应异常,见于请求未能成功响应(非 '200 OK'"""
def __init__(self, status_code: int, message: str | None = None):
super().__init__(message)
self.status_code = status_code
self.message = message
def __str__(self):
if self.status_code in error_code_mapping:
return error_code_mapping[self.status_code]
elif self.message:
return self.message
else:
return f"未知的异常响应代码:{self.status_code}"
class RespParseException(Exception):
"""响应解析错误,常见于响应格式不正确或解析方法不匹配"""
def __init__(self, ext_info: Any, message: str | None = None):
super().__init__(message)
self.ext_info = ext_info
self.message = message
def __str__(self):
return self.message or "解析响应内容时发生未知错误,请检查是否配置了正确的解析方法"
class PayLoadTooLargeError(Exception):
"""自定义异常类,用于处理请求体过大错误"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return "请求体过大,请尝试压缩图片或减少输入内容。"
class RequestAbortException(Exception):
"""自定义异常类,用于处理请求中断异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message
class PermissionDeniedException(Exception):
"""自定义异常类,用于处理访问拒绝的异常"""
def __init__(self, message: str):
super().__init__(message)
self.message = message
def __str__(self):
return self.message

View File

@ -0,0 +1,8 @@
from src.config.config import model_config
used_client_types = {provider.client_type for provider in model_config.api_providers}
if "openai" in used_client_types:
from . import openai_client # noqa: F401
if "gemini" in used_client_types:
from . import gemini_client # noqa: F401

View File

@ -0,0 +1,172 @@
import asyncio
from dataclasses import dataclass
from abc import ABC, abstractmethod
from typing import Callable, Any, Optional
from src.config.api_ada_configs import ModelInfo, APIProvider
from ..payload_content.message import Message
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolCall
@dataclass
class UsageRecord:
"""
使用记录类
"""
model_name: str
"""模型名称"""
provider_name: str
"""提供商名称"""
prompt_tokens: int
"""提示token数"""
completion_tokens: int
"""完成token数"""
total_tokens: int
"""总token数"""
@dataclass
class APIResponse:
"""
API响应类
"""
content: str | None = None
"""响应内容"""
reasoning_content: str | None = None
"""推理内容"""
tool_calls: list[ToolCall] | None = None
"""工具调用 [(工具名称, 工具参数), ...]"""
embedding: list[float] | None = None
"""嵌入向量"""
usage: UsageRecord | None = None
"""使用情况 (prompt_tokens, completion_tokens, total_tokens)"""
raw_data: Any = None
"""响应原始数据"""
class BaseClient(ABC):
"""
基础客户端
"""
api_provider: APIProvider
def __init__(self, api_provider: APIProvider):
self.api_provider = api_provider
@abstractmethod
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]]
] = None,
async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
:param model_info: 模型信息
:param message_list: 对话体
:param tool_options: 工具选项可选默认为None
:param max_tokens: 最大token数可选默认为1024
:param temperature: 温度可选默认为0.7
:param response_format: 响应格式可选默认为 NotGiven
:param stream_response_handler: 流式响应处理函数可选
:param async_response_parser: 响应解析函数可选
:param interrupt_flag: 中断信号量可选默认为None
:return: (响应文本, 推理文本, 工具调用, 其他数据)
"""
raise NotImplementedError("'get_response' method should be overridden in subclasses")
@abstractmethod
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
raise NotImplementedError("'get_embedding' method should be overridden in subclasses")
@abstractmethod
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
raise NotImplementedError("'get_audio_transcriptions' method should be overridden in subclasses")
@abstractmethod
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
raise NotImplementedError("'get_support_image_formats' method should be overridden in subclasses")
class ClientRegistry:
def __init__(self) -> None:
self.client_registry: dict[str, type[BaseClient]] = {}
def register_client_class(self, client_type: str):
"""
注册API客户端类
Args:
client_class: API客户端类
"""
def decorator(cls: type[BaseClient]) -> type[BaseClient]:
if not issubclass(cls, BaseClient):
raise TypeError(f"{cls.__name__} is not a subclass of BaseClient")
self.client_registry[client_type] = cls
return cls
return decorator
def get_client_class(self, client_type: str) -> type[BaseClient]:
"""
获取注册的API客户端类
Args:
client_type: 客户端类型
Returns:
type[BaseClient]: 注册的API客户端类
"""
if client_type not in self.client_registry:
raise KeyError(f"'{client_type}' 类型的 Client 未注册")
return self.client_registry[client_type]
client_registry = ClientRegistry()

View File

@ -0,0 +1,496 @@
import asyncio
import io
import base64
from typing import Callable, AsyncIterator, Optional, Coroutine, Any, List
from google import genai
from google.genai.types import (
Content,
Part,
FunctionDeclaration,
GenerateContentResponse,
ContentListUnion,
ContentUnion,
ThinkingConfig,
Tool,
GenerateContentConfig,
EmbedContentResponse,
EmbedContentConfig,
)
from google.genai.errors import (
ClientError,
ServerError,
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
)
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat, RespFormatType
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("Gemini客户端")
def _convert_messages(
messages: list[Message],
) -> tuple[ContentListUnion, list[str] | None]:
"""
转换消息格式 - 将消息转换为Gemini API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表(和可能存在的system消息)
"""
def _convert_message_item(message: Message) -> Content:
"""
转换单个消息格式除了system和tool类型的消息
:param message: 消息对象
:return: 转换后的消息字典
"""
# 将openai格式的角色重命名为gemini格式的角色
if message.role == RoleType.Assistant:
role = "model"
elif message.role == RoleType.User:
role = "user"
# 添加Content
if isinstance(message.content, str):
content = [Part.from_text(text=message.content)]
elif isinstance(message.content, list):
content: List[Part] = []
for item in message.content:
if isinstance(item, tuple):
content.append(
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{item[0].lower()}")
)
elif isinstance(item, str):
content.append(Part.from_text(text=item))
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
return Content(role=role, parts=content)
temp_list: list[ContentUnion] = []
system_instructions: list[str] = []
for message in messages:
if message.role == RoleType.System:
if isinstance(message.content, str):
system_instructions.append(message.content)
else:
raise ValueError("你tm怎么往system里面塞图片base64")
elif message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
else:
temp_list.append(_convert_message_item(message))
if system_instructions:
# 如果有system消息就把它加上去
ret: tuple = (temp_list, system_instructions)
else:
# 如果没有system消息就直接返回
ret: tuple = (temp_list, None)
return ret
def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclaration]:
"""
转换工具选项格式 - 将工具选项转换为Gemini API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具对象列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return_dict: dict[str, Any] = {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
if tool_option_param.enum_values:
return_dict["enum"] = tool_option_param.enum_values
return return_dict
def _convert_tool_option_item(tool_option: ToolOption) -> FunctionDeclaration:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的Gemini工具选项对象
"""
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
ret1 = FunctionDeclaration(**ret)
return ret1
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
def _process_delta(
delta: GenerateContentResponse,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, dict[str, Any]]],
):
if not hasattr(delta, "candidates") or not delta.candidates:
raise RespParseException(delta, "响应解析失败缺失candidates字段")
if delta.text:
fc_delta_buffer.write(delta.text)
if delta.function_calls: # 为什么不用hasattr呢是因为这个属性一定有即使是个空的
for call in delta.function_calls:
try:
if not isinstance(call.args, dict): # gemini返回的function call参数就是dict格式的了
raise RespParseException(delta, "响应解析失败,工具调用参数无法解析为字典类型")
if not call.id or not call.name:
raise RespParseException(delta, "响应解析失败工具调用缺失id或name字段")
tool_calls_buffer.append(
(
call.id,
call.name,
call.args or {}, # 如果args是None则转换为一个空字典
)
)
except Exception as e:
raise RespParseException(delta, "响应解析失败,无法解析工具调用参数") from e
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, dict]],
) -> APIResponse:
# sourcery skip: simplify-len-comparison, use-assigned-variable
resp = APIResponse()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if len(_tool_calls_buffer) > 0:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer is not None:
arguments = arguments_buffer
if not isinstance(arguments, dict):
raise RespParseException(
None,
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{arguments_buffer}",
)
else:
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _default_stream_response_handler(
resp_stream: AsyncIterator[GenerateContentResponse],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理Gemini API的流式响应
:param resp_stream: 流式响应对象,是一个神秘的iterator我完全不知道这个玩意能不能跑不过遍历一遍之后它就空了如果跑不了一点的话可以考虑改成别的东西
:return: APIResponse对象
"""
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, dict]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
async for chunk in resp_stream:
# 检查是否有中断量
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
raise ReqAbortException("请求被外部信号中断")
_process_delta(
chunk,
_fc_delta_buffer,
_tool_calls_buffer,
)
if chunk.usage_metadata:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
chunk.usage_metadata.prompt_token_count or 0,
(chunk.usage_metadata.candidates_token_count or 0) + (chunk.usage_metadata.thoughts_token_count or 0),
chunk.usage_metadata.total_token_count or 0,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
def _default_normal_response_parser(
resp: GenerateContentResponse,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将Gemini API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "candidates") or not resp.candidates:
raise RespParseException(resp, "响应解析失败缺失candidates字段")
try:
if resp.candidates[0].content and resp.candidates[0].content.parts:
for part in resp.candidates[0].content.parts:
if not part.text:
continue
if part.thought:
api_response.reasoning_content = (
api_response.reasoning_content + part.text if api_response.reasoning_content else part.text
)
except Exception as e:
logger.warning(f"解析思考内容时发生错误: {e},跳过解析")
if resp.text:
api_response.content = resp.text
if resp.function_calls:
api_response.tool_calls = []
for call in resp.function_calls:
try:
if not isinstance(call.args, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
if not call.name:
raise RespParseException(resp, "响应解析失败工具调用缺失name字段")
api_response.tool_calls.append(ToolCall(call.id or "gemini-tool_call", call.name, call.args or {}))
except Exception as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
if resp.usage_metadata:
_usage_record = (
resp.usage_metadata.prompt_token_count or 0,
(resp.usage_metadata.candidates_token_count or 0) + (resp.usage_metadata.thoughts_token_count or 0),
resp.usage_metadata.total_token_count or 0,
)
else:
_usage_record = None
api_response.raw_data = resp
return api_response, _usage_record
@client_registry.register_client_class("gemini")
class GeminiClient(BaseClient):
client: genai.Client
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client = genai.Client(
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncIterator[GenerateContentResponse], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[GenerateContentResponse], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
Args:
model_info: 模型信息
message_list: 对话体
tool_options: 工具选项可选默认为None
max_tokens: 最大token数可选默认为1024
temperature: 温度可选默认为0.7
response_format: 响应格式默认为text/plain,如果是输入的JSON Schema则必须遵守OpenAPI3.0格式,理论上和openai是一样的暂不支持其它相应格式输入
stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
async_response_parser: 响应解析函数可选默认为default_response_parser
interrupt_flag: 中断信号量可选默认为None
Returns:
APIResponse对象包含响应内容推理内容工具调用等信息
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为Gemini API所需的格式
messages = _convert_messages(message_list)
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
# 将response_format转换为Gemini API所需的格式
generation_config_dict = {
"max_output_tokens": max_tokens,
"temperature": temperature,
"response_modalities": ["TEXT"],
"thinking_config": ThinkingConfig(
include_thoughts=True,
thinking_budget=(
extra_params["thinking_budget"] if extra_params and "thinking_budget" in extra_params else None
),
),
}
if tools:
generation_config_dict["tools"] = Tool(function_declarations=tools)
if messages[1]:
# 如果有system消息则将其添加到配置中
generation_config_dict["system_instructions"] = messages[1]
if response_format and response_format.format_type == RespFormatType.TEXT:
generation_config_dict["response_mime_type"] = "text/plain"
elif response_format and response_format.format_type in (RespFormatType.JSON_OBJ, RespFormatType.JSON_SCHEMA):
generation_config_dict["response_mime_type"] = "application/json"
generation_config_dict["response_schema"] = response_format.to_dict()
generation_config = GenerateContentConfig(**generation_config_dict)
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.aio.models.generate_content_stream(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
req_task = asyncio.create_task(
self.client.aio.models.generate_content(
model=model_info.model_identifier,
contents=messages[0],
config=generation_config,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code, e.message) from None
except (
UnknownFunctionCallArgumentError,
UnsupportedFunctionError,
FunctionInvocationError,
) as e:
raise ValueError(f"工具类型错误:请检查工具选项和参数:{str(e)}") from None
except Exception as e:
raise NetworkConnectionError() from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
try:
raw_response: EmbedContentResponse = await self.client.aio.models.embed_content(
model=model_info.model_identifier,
contents=embedding_input,
config=EmbedContentConfig(task_type="SEMANTIC_SIMILARITY"),
)
except (ClientError, ServerError) as e:
# 重封装ClientError和ServerError为RespNotOkException
raise RespNotOkException(e.code) from None
except Exception as e:
raise NetworkConnectionError() from e
response = APIResponse()
# 解析嵌入响应和使用情况
if hasattr(raw_response, "embeddings") and raw_response.embeddings:
response.embedding = raw_response.embeddings[0].values
else:
raise RespParseException(raw_response, "响应解析失败缺失embeddings字段")
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=len(embedding_input),
completion_tokens=0,
total_tokens=len(embedding_input),
)
return response
def get_audio_transcriptions(
self, model_info: ModelInfo, audio_base64: str, extra_params: dict[str, Any] | None = None
) -> APIResponse:
raise NotImplementedError("尚未实现音频转录功能")
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]

View File

@ -0,0 +1,580 @@
import asyncio
import io
import json
import re
import base64
from collections.abc import Iterable
from typing import Callable, Any, Coroutine, Optional
from json_repair import repair_json
from openai import (
AsyncOpenAI,
APIConnectionError,
APIStatusError,
NOT_GIVEN,
AsyncStream,
)
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessageParam,
ChatCompletionToolParam,
)
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from src.config.api_ada_configs import ModelInfo, APIProvider
from src.common.logger import get_logger
from .base_client import APIResponse, UsageRecord, BaseClient, client_registry
from ..exceptions import (
RespParseException,
NetworkConnectionError,
RespNotOkException,
ReqAbortException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat
from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall
logger = get_logger("OpenAI客户端")
def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessageParam]:
"""
转换消息格式 - 将消息转换为OpenAI API所需的格式
:param messages: 消息列表
:return: 转换后的消息列表
"""
def _convert_message_item(message: Message) -> ChatCompletionMessageParam:
"""
转换单个消息格式
:param message: 消息对象
:return: 转换后的消息字典
"""
# 添加Content
content: str | list[dict[str, Any]]
if isinstance(message.content, str):
content = message.content
elif isinstance(message.content, list):
content = []
for item in message.content:
if isinstance(item, tuple):
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
}
)
elif isinstance(item, str):
content.append({"type": "text", "text": item})
else:
raise RuntimeError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret = {
"role": message.role.value,
"content": content,
}
# 添加工具调用ID
if message.role == RoleType.Tool:
if not message.tool_call_id:
raise ValueError("无法触及的代码请使用MessageBuilder类构建消息对象")
ret["tool_call_id"] = message.tool_call_id
return ret # type: ignore
return [_convert_message_item(message) for message in messages]
def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]]:
"""
转换工具选项格式 - 将工具选项转换为OpenAI API所需的格式
:param tool_options: 工具选项列表
:return: 转换后的工具选项列表
"""
def _convert_tool_param(tool_option_param: ToolParam) -> dict[str, Any]:
"""
转换单个工具参数格式
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
return_dict: dict[str, Any] = {
"type": tool_option_param.param_type.value,
"description": tool_option_param.description,
}
if tool_option_param.enum_values:
return_dict["enum"] = tool_option_param.enum_values
return return_dict
def _convert_tool_option_item(tool_option: ToolOption) -> dict[str, Any]:
"""
转换单个工具项格式
:param tool_option: 工具选项对象
:return: 转换后的工具选项字典
"""
ret: dict[str, Any] = {
"name": tool_option.name,
"description": tool_option.description,
}
if tool_option.params:
ret["parameters"] = {
"type": "object",
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
"required": [param.name for param in tool_option.params if param.required],
}
return ret
return [
{
"type": "function",
"function": _convert_tool_option_item(tool_option),
}
for tool_option in tool_options
]
def _process_delta(
delta: ChoiceDelta,
has_rc_attr_flag: bool,
in_rc_flag: bool,
rc_delta_buffer: io.StringIO,
fc_delta_buffer: io.StringIO,
tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> bool:
# 接收content
if has_rc_attr_flag:
# 有独立的推理内容块则无需考虑content内容的判读
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 如果有推理内容,则将其写入推理内容缓冲区
assert isinstance(delta.reasoning_content, str) # type: ignore
rc_delta_buffer.write(delta.reasoning_content) # type: ignore
elif delta.content:
# 如果有正式内容,则将其写入正式内容缓冲区
fc_delta_buffer.write(delta.content)
elif hasattr(delta, "content") and delta.content is not None:
# 没有独立的推理内容块,但有正式内容
if in_rc_flag:
# 当前在推理内容块中
if delta.content == "</think>":
# 如果当前内容是</think>,则将其视为推理内容的结束标记,退出推理内容块
in_rc_flag = False
else:
# 其他情况视为推理内容,加入推理内容缓冲区
rc_delta_buffer.write(delta.content)
elif delta.content == "<think>" and not fc_delta_buffer.getvalue():
# 如果当前内容是<think>,且正式内容缓冲区为空,说明<think>为输出的首个token
# 则将其视为推理内容的开始标记,进入推理内容块
in_rc_flag = True
else:
# 其他情况视为正式内容,加入正式内容缓冲区
fc_delta_buffer.write(delta.content)
# 接收tool_calls
if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_call_delta = delta.tool_calls[0]
if tool_call_delta.index >= len(tool_calls_buffer):
# 调用索引号大于等于缓冲区长度,说明是新的工具调用
if tool_call_delta.id and tool_call_delta.function and tool_call_delta.function.name:
tool_calls_buffer.append(
(
tool_call_delta.id,
tool_call_delta.function.name,
io.StringIO(),
)
)
else:
logger.warning("工具调用索引号大于等于缓冲区长度但缺少ID或函数信息。")
if tool_call_delta.function and tool_call_delta.function.arguments:
# 如果有工具调用参数,则添加到对应的工具调用的参数串缓冲区中
tool_calls_buffer[tool_call_delta.index][2].write(tool_call_delta.function.arguments)
return in_rc_flag
def _build_stream_api_resp(
_fc_delta_buffer: io.StringIO,
_rc_delta_buffer: io.StringIO,
_tool_calls_buffer: list[tuple[str, str, io.StringIO]],
) -> APIResponse:
resp = APIResponse()
if _rc_delta_buffer.tell() > 0:
# 如果推理内容缓冲区不为空则将其写入APIResponse对象
resp.reasoning_content = _rc_delta_buffer.getvalue()
_rc_delta_buffer.close()
if _fc_delta_buffer.tell() > 0:
# 如果正式内容缓冲区不为空则将其写入APIResponse对象
resp.content = _fc_delta_buffer.getvalue()
_fc_delta_buffer.close()
if _tool_calls_buffer:
# 如果工具调用缓冲区不为空则将其解析为ToolCall对象列表
resp.tool_calls = []
for call_id, function_name, arguments_buffer in _tool_calls_buffer:
if arguments_buffer.tell() > 0:
# 如果参数串缓冲区不为空则解析为JSON对象
raw_arg_data = arguments_buffer.getvalue()
arguments_buffer.close()
try:
arguments = json.loads(repair_json(raw_arg_data))
if not isinstance(arguments, dict):
raise RespParseException(
None,
f"响应解析失败,工具调用参数无法解析为字典类型。工具调用参数原始响应:\n{raw_arg_data}",
)
except json.JSONDecodeError as e:
raise RespParseException(
None,
f"响应解析失败,无法解析工具调用参数。工具调用参数原始响应:{raw_arg_data}",
) from e
else:
arguments_buffer.close()
arguments = None
resp.tool_calls.append(ToolCall(call_id, function_name, arguments))
return resp
async def _default_stream_response_handler(
resp_stream: AsyncStream[ChatCompletionChunk],
interrupt_flag: asyncio.Event | None,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
流式响应处理函数 - 处理OpenAI API的流式响应
:param resp_stream: 流式响应对象
:return: APIResponse对象
"""
_has_rc_attr_flag = False # 标记是否有独立的推理内容块
_in_rc_flag = False # 标记是否在推理内容块中
_rc_delta_buffer = io.StringIO() # 推理内容缓冲区,用于存储接收到的推理内容
_fc_delta_buffer = io.StringIO() # 正式内容缓冲区,用于存储接收到的正式内容
_tool_calls_buffer: list[tuple[str, str, io.StringIO]] = [] # 工具调用缓冲区,用于存储接收到的工具调用
_usage_record = None # 使用情况记录
def _insure_buffer_closed():
# 确保缓冲区被关闭
if _rc_delta_buffer and not _rc_delta_buffer.closed:
_rc_delta_buffer.close()
if _fc_delta_buffer and not _fc_delta_buffer.closed:
_fc_delta_buffer.close()
for _, _, buffer in _tool_calls_buffer:
if buffer and not buffer.closed:
buffer.close()
async for event in resp_stream:
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量被设置则抛出ReqAbortException
_insure_buffer_closed()
raise ReqAbortException("请求被外部信号中断")
delta = event.choices[0].delta # 获取当前块的delta内容
if hasattr(delta, "reasoning_content") and delta.reasoning_content: # type: ignore
# 标记:有独立的推理内容块
_has_rc_attr_flag = True
_in_rc_flag = _process_delta(
delta,
_has_rc_attr_flag,
_in_rc_flag,
_rc_delta_buffer,
_fc_delta_buffer,
_tool_calls_buffer,
)
if event.usage:
# 如果有使用情况则将其存储在APIResponse对象中
_usage_record = (
event.usage.prompt_tokens or 0,
event.usage.completion_tokens or 0,
event.usage.total_tokens or 0,
)
try:
return _build_stream_api_resp(
_fc_delta_buffer,
_rc_delta_buffer,
_tool_calls_buffer,
), _usage_record
except Exception:
# 确保缓冲区被关闭
_insure_buffer_closed()
raise
pattern = re.compile(
r"<think>(?P<think>.*?)</think>(?P<content>.*)|<think>(?P<think_unclosed>.*)|(?P<content_only>.+)",
re.DOTALL,
)
"""用于解析推理内容的正则表达式"""
def _default_normal_response_parser(
resp: ChatCompletion,
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
"""
解析对话补全响应 - 将OpenAI API响应解析为APIResponse对象
:param resp: 响应对象
:return: APIResponse对象
"""
api_response = APIResponse()
if not hasattr(resp, "choices") or len(resp.choices) == 0:
raise RespParseException(resp, "响应解析失败缺失choices字段")
message_part = resp.choices[0].message
if hasattr(message_part, "reasoning_content") and message_part.reasoning_content: # type: ignore
# 有有效的推理字段
api_response.content = message_part.content
api_response.reasoning_content = message_part.reasoning_content # type: ignore
elif message_part.content:
# 提取推理和内容
match = pattern.match(message_part.content)
if not match:
raise RespParseException(resp, "响应解析失败,无法捕获推理内容和输出内容")
if match.group("think") is not None:
result = match.group("think").strip(), match.group("content").strip()
elif match.group("think_unclosed") is not None:
result = match.group("think_unclosed").strip(), None
else:
result = None, match.group("content_only").strip()
api_response.reasoning_content, api_response.content = result
# 提取工具调用
if message_part.tool_calls:
api_response.tool_calls = []
for call in message_part.tool_calls:
try:
arguments = json.loads(repair_json(call.function.arguments))
if not isinstance(arguments, dict):
raise RespParseException(resp, "响应解析失败,工具调用参数无法解析为字典类型")
api_response.tool_calls.append(ToolCall(call.id, call.function.name, arguments))
except json.JSONDecodeError as e:
raise RespParseException(resp, "响应解析失败,无法解析工具调用参数") from e
# 提取Usage信息
if resp.usage:
_usage_record = (
resp.usage.prompt_tokens or 0,
resp.usage.completion_tokens or 0,
resp.usage.total_tokens or 0,
)
else:
_usage_record = None
# 将原始响应存储在原始数据中
api_response.raw_data = resp
return api_response, _usage_record
@client_registry.register_client_class("openai")
class OpenaiClient(BaseClient):
def __init__(self, api_provider: APIProvider):
super().__init__(api_provider)
self.client: AsyncOpenAI = AsyncOpenAI(
base_url=api_provider.base_url,
api_key=api_provider.api_key,
max_retries=0,
)
async def get_response(
self,
model_info: ModelInfo,
message_list: list[Message],
tool_options: list[ToolOption] | None = None,
max_tokens: int = 1024,
temperature: float = 0.7,
response_format: RespFormat | None = None,
stream_response_handler: Optional[
Callable[
[AsyncStream[ChatCompletionChunk], asyncio.Event | None],
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
]
] = None,
async_response_parser: Optional[
Callable[[ChatCompletion], tuple[APIResponse, Optional[tuple[int, int, int]]]]
] = None,
interrupt_flag: asyncio.Event | None = None,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取对话响应
Args:
model_info: 模型信息
message_list: 对话体
tool_options: 工具选项可选默认为None
max_tokens: 最大token数可选默认为1024
temperature: 温度可选默认为0.7
response_format: 响应格式可选默认为 NotGiven
stream_response_handler: 流式响应处理函数可选默认为default_stream_response_handler
async_response_parser: 响应解析函数可选默认为default_response_parser
interrupt_flag: 中断信号量可选默认为None
Returns:
(响应文本, 推理文本, 工具调用, 其他数据)
"""
if stream_response_handler is None:
stream_response_handler = _default_stream_response_handler
if async_response_parser is None:
async_response_parser = _default_normal_response_parser
# 将messages构造为OpenAI API所需的格式
messages: Iterable[ChatCompletionMessageParam] = _convert_messages(message_list)
# 将tool_options转换为OpenAI API所需的格式
tools: Iterable[ChatCompletionToolParam] = _convert_tool_options(tool_options) if tool_options else NOT_GIVEN # type: ignore
try:
if model_info.force_stream_mode:
req_task = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=True,
response_format=NOT_GIVEN,
extra_body=extra_params,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.1) # 等待0.1秒后再次检查任务&中断信号量状态
resp, usage_record = await stream_response_handler(req_task.result(), interrupt_flag)
else:
# 发送请求并获取响应
req_task = asyncio.create_task(
self.client.chat.completions.create(
model=model_info.model_identifier,
messages=messages,
tools=tools,
temperature=temperature,
max_tokens=max_tokens,
stream=False,
response_format=NOT_GIVEN,
extra_body=extra_params,
)
)
while not req_task.done():
if interrupt_flag and interrupt_flag.is_set():
# 如果中断量存在且被设置,则取消任务并抛出异常
req_task.cancel()
raise ReqAbortException("请求被外部信号中断")
await asyncio.sleep(0.5) # 等待0.5秒后再次检查任务&中断信号量状态
resp, usage_record = async_response_parser(req_task.result())
except APIConnectionError as e:
# 重封装APIConnectionError为NetworkConnectionError
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code, e.message) from e
if usage_record:
resp.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=usage_record[0],
completion_tokens=usage_record[1],
total_tokens=usage_record[2],
)
return resp
async def get_embedding(
self,
model_info: ModelInfo,
embedding_input: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取文本嵌入
:param model_info: 模型信息
:param embedding_input: 嵌入输入文本
:return: 嵌入响应
"""
try:
raw_response = await self.client.embeddings.create(
model=model_info.model_identifier,
input=embedding_input,
extra_body=extra_params,
)
except APIConnectionError as e:
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
response = APIResponse()
# 解析嵌入响应
if len(raw_response.data) > 0:
response.embedding = raw_response.data[0].embedding
else:
raise RespParseException(
raw_response,
"响应解析失败,缺失嵌入数据。",
)
# 解析使用情况
if hasattr(raw_response, "usage"):
response.usage = UsageRecord(
model_name=model_info.name,
provider_name=model_info.api_provider,
prompt_tokens=raw_response.usage.prompt_tokens or 0,
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
total_tokens=raw_response.usage.total_tokens or 0,
)
return response
async def get_audio_transcriptions(
self,
model_info: ModelInfo,
audio_base64: str,
extra_params: dict[str, Any] | None = None,
) -> APIResponse:
"""
获取音频转录
:param model_info: 模型信息
:param audio_base64: base64编码的音频数据
:extra_params: 附加的请求参数
:return: 音频转录响应
"""
try:
raw_response = await self.client.audio.transcriptions.create(
model=model_info.model_identifier,
file=("audio.wav", io.BytesIO(base64.b64decode(audio_base64))),
extra_body=extra_params,
)
except APIConnectionError as e:
raise NetworkConnectionError() from e
except APIStatusError as e:
# 重封装APIError为RespNotOkException
raise RespNotOkException(e.status_code) from e
response = APIResponse()
# 解析转录响应
if hasattr(raw_response, "text"):
response.content = raw_response.text
else:
raise RespParseException(
raw_response,
"响应解析失败,缺失转录文本。",
)
return response
def get_support_image_formats(self) -> list[str]:
"""
获取支持的图片格式
:return: 支持的图片格式列表
"""
return ["jpg", "jpeg", "png", "webp", "gif"]

View File

@ -0,0 +1,3 @@
from .tool_option import ToolCall
__all__ = ["ToolCall"]

View File

@ -0,0 +1,107 @@
from enum import Enum
# 设计这系列类的目的是为未来可能的扩展做准备
class RoleType(Enum):
System = "system"
User = "user"
Assistant = "assistant"
Tool = "tool"
SUPPORTED_IMAGE_FORMATS = ["jpg", "jpeg", "png", "webp", "gif"] # openai支持的图片格式
class Message:
def __init__(
self,
role: RoleType,
content: str | list[tuple[str, str] | str],
tool_call_id: str | None = None,
):
"""
初始化消息对象
不应直接修改Message类而应使用MessageBuilder类来构建对象
"""
self.role: RoleType = role
self.content: str | list[tuple[str, str] | str] = content
self.tool_call_id: str | None = tool_call_id
class MessageBuilder:
def __init__(self):
self.__role: RoleType = RoleType.User
self.__content: list[tuple[str, str] | str] = []
self.__tool_call_id: str | None = None
def set_role(self, role: RoleType = RoleType.User) -> "MessageBuilder":
"""
设置角色默认为User
:param role: 角色
:return: MessageBuilder对象
"""
self.__role = role
return self
def add_text_content(self, text: str) -> "MessageBuilder":
"""
添加文本内容
:param text: 文本内容
:return: MessageBuilder对象
"""
self.__content.append(text)
return self
def add_image_content(
self,
image_format: str,
image_base64: str,
support_formats: list[str] = SUPPORTED_IMAGE_FORMATS, # 默认支持格式
) -> "MessageBuilder":
"""
添加图片内容
:param image_format: 图片格式
:param image_base64: 图片的base64编码
:return: MessageBuilder对象
"""
if image_format.lower() not in support_formats:
raise ValueError("不受支持的图片格式")
if not image_base64:
raise ValueError("图片的base64编码不能为空")
self.__content.append((image_format, image_base64))
return self
def add_tool_call(self, tool_call_id: str) -> "MessageBuilder":
"""
添加工具调用指令调用时请确保已设置为Tool角色
:param tool_call_id: 工具调用指令的id
:return: MessageBuilder对象
"""
if self.__role != RoleType.Tool:
raise ValueError("仅当角色为Tool时才能添加工具调用ID")
if not tool_call_id:
raise ValueError("工具调用ID不能为空")
self.__tool_call_id = tool_call_id
return self
def build(self) -> Message:
"""
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0:
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")
return Message(
role=self.__role,
content=(
self.__content[0]
if (len(self.__content) == 1 and isinstance(self.__content[0], str))
else self.__content
),
tool_call_id=self.__tool_call_id,
)

View File

@ -0,0 +1,223 @@
from enum import Enum
from typing import Optional, Any
from pydantic import BaseModel
from typing_extensions import TypedDict, Required
class RespFormatType(Enum):
TEXT = "text" # 文本
JSON_OBJ = "json_object" # JSON
JSON_SCHEMA = "json_schema" # JSON Schema
class JsonSchema(TypedDict, total=False):
name: Required[str]
"""
The name of the response format.
Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length
of 64.
"""
description: Optional[str]
"""
A description of what the response format is for, used by the model to determine
how to respond in the format.
"""
schema: dict[str, object]
"""
The schema for the response format, described as a JSON Schema object. Learn how
to build JSON schemas [here](https://json-schema.org/).
"""
strict: Optional[bool]
"""
Whether to enable strict schema adherence when generating the output. If set to
true, the model will always follow the exact schema defined in the `schema`
field. Only a subset of JSON Schema is supported when `strict` is `true`. To
learn more, read the
[Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs).
"""
def _json_schema_type_check(instance) -> str | None:
if "name" not in instance:
return "schema必须包含'name'字段"
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
return "schema的'name'字段必须是非空字符串"
if "description" in instance and (
not isinstance(instance["description"], str)
or instance["description"].strip() == ""
):
return "schema的'description'字段只能填入非空字符串"
if "schema" not in instance:
return "schema必须包含'schema'字段"
elif not isinstance(instance["schema"], dict):
return "schema的'schema'字段必须是字典详见https://json-schema.org/"
if "strict" in instance and not isinstance(instance["strict"], bool):
return "schema的'strict'字段只能填入布尔值"
return None
def _remove_title(schema: dict[str, Any] | list[Any]) -> dict[str, Any] | list[Any]:
"""
递归移除JSON Schema中的title字段
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "title" in schema:
del schema["title"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
"""
链接JSON Schema中的definitions字段
"""
def link_definitions_recursive(
path: str, sub_schema: list[Any] | dict[str, Any], defs: dict[str, Any]
) -> dict[str, Any]:
"""
递归链接JSON Schema中的definitions字段
:param path: 当前路径
:param sub_schema: 子Schema
:param defs: Schema定义集
:return:
"""
if isinstance(sub_schema, list):
# 如果当前Schema是列表则遍历每个元素
for i in range(len(sub_schema)):
if isinstance(sub_schema[i], dict):
sub_schema[i] = link_definitions_recursive(
f"{path}/{str(i)}", sub_schema[i], defs
)
else:
# 否则为字典
if "$defs" in sub_schema:
# 如果当前Schema有$def字段则将其添加到defs中
key_prefix = f"{path}/$defs/"
for key, value in sub_schema["$defs"].items():
def_key = key_prefix + key
if def_key not in defs:
defs[def_key] = value
del sub_schema["$defs"]
if "$ref" in sub_schema:
# 如果当前Schema有$ref字段则将其替换为defs中的定义
def_key = sub_schema["$ref"]
if def_key in defs:
sub_schema = defs[def_key]
else:
raise ValueError(f"Schema中引用的定义'{def_key}'不存在")
# 遍历键值对
for key, value in sub_schema.items():
if isinstance(value, (dict, list)):
# 如果当前值是字典或列表,则递归调用
sub_schema[key] = link_definitions_recursive(
f"{path}/{key}", value, defs
)
return sub_schema
return link_definitions_recursive("#", schema, {})
def _remove_defs(schema: dict[str, Any]) -> dict[str, Any]:
"""
递归移除JSON Schema中的$defs字段
"""
if isinstance(schema, list):
# 如果当前Schema是列表则对所有dict/list子元素递归调用
for idx, item in enumerate(schema):
if isinstance(item, (dict, list)):
schema[idx] = _remove_title(item)
elif isinstance(schema, dict):
# 是字典移除title字段并对所有dict/list子元素递归调用
if "$defs" in schema:
del schema["$defs"]
for key, value in schema.items():
if isinstance(value, (dict, list)):
schema[key] = _remove_title(value)
return schema
class RespFormat:
"""
响应格式
"""
@staticmethod
def _generate_schema_from_model(schema):
json_schema = {
"name": schema.__name__,
"schema": _remove_defs(
_link_definitions(_remove_title(schema.model_json_schema()))
),
"strict": False,
}
if schema.__doc__:
json_schema["description"] = schema.__doc__
return json_schema
def __init__(
self,
format_type: RespFormatType = RespFormatType.TEXT,
schema: type | JsonSchema | None = None,
):
"""
响应格式
:param format_type: 响应格式类型默认为文本
:param schema: 模板类或JsonSchema仅当format_type为JSON Schema时有效
"""
self.format_type: RespFormatType = format_type
if format_type == RespFormatType.JSON_SCHEMA:
if schema is None:
raise ValueError("当format_type为'JSON_SCHEMA'schema不能为空")
if isinstance(schema, dict):
if check_msg := _json_schema_type_check(schema):
raise ValueError(f"schema格式不正确{check_msg}")
self.schema = schema
elif issubclass(schema, BaseModel):
try:
json_schema = self._generate_schema_from_model(schema)
self.schema = json_schema
except Exception as e:
raise ValueError(
f"自动生成JSON Schema时发生异常请检查模型类{schema.__name__}的定义,详细信息:\n"
f"{schema.__name__}:\n"
) from e
else:
raise ValueError("schema必须是BaseModel的子类或JsonSchema")
else:
self.schema = None
def to_dict(self):
"""
将响应格式转换为字典
:return: 字典
"""
if self.schema:
return {
"format_type": self.format_type.value,
"schema": self.schema,
}
else:
return {
"format_type": self.format_type.value,
}

View File

@ -0,0 +1,163 @@
from enum import Enum
class ToolParamType(Enum):
"""
工具调用参数类型
"""
STRING = "string" # 字符串
INTEGER = "integer" # 整型
FLOAT = "float" # 浮点型
BOOLEAN = "bool" # 布尔型
class ToolParam:
"""
工具调用参数
"""
def __init__(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool,
enum_values: list[str] | None = None,
):
"""
初始化工具调用参数
不应直接修改ToolParam类而应使用ToolOptionBuilder类来构建对象
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填
"""
self.name: str = name
self.param_type: ToolParamType = param_type
self.description: str = description
self.required: bool = required
self.enum_values: list[str] | None = enum_values
class ToolOption:
"""
工具调用项
"""
def __init__(
self,
name: str,
description: str,
params: list[ToolParam] | None = None,
):
"""
初始化工具调用项
不应直接修改ToolOption类而应使用ToolOptionBuilder类来构建对象
:param name: 工具名称
:param description: 工具描述
:param params: 工具参数列表
"""
self.name: str = name
self.description: str = description
self.params: list[ToolParam] | None = params
class ToolOptionBuilder:
"""
工具调用项构建器
"""
def __init__(self):
self.__name: str = ""
self.__description: str = ""
self.__params: list[ToolParam] = []
def set_name(self, name: str) -> "ToolOptionBuilder":
"""
设置工具名称
:param name: 工具名称
:return: ToolBuilder实例
"""
if not name:
raise ValueError("工具名称不能为空")
self.__name = name
return self
def set_description(self, description: str) -> "ToolOptionBuilder":
"""
设置工具描述
:param description: 工具描述
:return: ToolBuilder实例
"""
if not description:
raise ValueError("工具描述不能为空")
self.__description = description
return self
def add_param(
self,
name: str,
param_type: ToolParamType,
description: str,
required: bool = False,
enum_values: list[str] | None = None,
) -> "ToolOptionBuilder":
"""
添加工具参数
:param name: 参数名称
:param param_type: 参数类型
:param description: 参数描述
:param required: 是否必填默认为False
:return: ToolBuilder实例
"""
if not name or not description:
raise ValueError("参数名称/描述不能为空")
self.__params.append(
ToolParam(
name=name,
param_type=param_type,
description=description,
required=required,
enum_values=enum_values,
)
)
return self
def build(self):
"""
构建工具调用项
:return: 工具调用项
"""
if self.__name == "" or self.__description == "":
raise ValueError("工具名称/描述不能为空")
return ToolOption(
name=self.__name,
description=self.__description,
params=None if len(self.__params) == 0 else self.__params,
)
class ToolCall:
"""
来自模型反馈的工具调用
"""
def __init__(
self,
call_id: str,
func_name: str,
args: dict | None = None,
):
"""
初始化工具调用
:param call_id: 工具调用ID
:param func_name: 要调用的函数名称
:param args: 工具调用参数
"""
self.call_id: str = call_id
self.func_name: str = func_name
self.args: dict | None = args

View File

@ -0,0 +1,186 @@
import base64
import io
from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.config.api_ada_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
logger = get_logger("消息压缩工具")
def compress_messages(messages: list[Message], img_target_size: int = 1 * 1024 * 1024) -> list[Message]:
"""
压缩消息列表中的图片
:param messages: 消息列表
:param img_target_size: 图片目标大小默认1MB
:return: 压缩后的消息列表
"""
def reformat_static_image(image_data: bytes) -> bytes:
"""
将静态图片转换为JPEG格式
:param image_data: 图片数据
:return: 转换后的图片数据
"""
try:
image = Image.open(image_data)
if image.format and (image.format.upper() in ["JPEG", "JPG", "PNG", "WEBP"]):
# 静态图像转换为JPEG格式
reformated_image_data = io.BytesIO()
image.save(reformated_image_data, format="JPEG", quality=95, optimize=True)
image_data = reformated_image_data.getvalue()
return image_data
except Exception as e:
logger.error(f"图片转换格式失败: {str(e)}")
return image_data
def rescale_image(image_data: bytes, scale: float) -> tuple[bytes, tuple[int, int] | None, tuple[int, int] | None]:
"""
缩放图片
:param image_data: 图片数据
:param scale: 缩放比例
:return: 缩放后的图片数据
"""
try:
image = Image.open(image_data)
# 原始尺寸
original_size = (image.width, image.height)
# 计算新的尺寸
new_size = (int(original_size[0] * scale), int(original_size[1] * scale))
output_buffer = io.BytesIO()
if getattr(image, "is_animated", False):
# 动态图片,处理所有帧
frames = []
new_size = (new_size[0] // 2, new_size[1] // 2) # 动图,缩放尺寸再打折
for frame_idx in range(getattr(image, "n_frames", 1)):
image.seek(frame_idx)
new_frame = image.copy()
new_frame = new_frame.resize(new_size, Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format="GIF",
save_all=True,
append_images=frames[1:],
optimize=True,
duration=image.info.get("duration", 100),
loop=image.info.get("loop", 0),
)
else:
# 静态图片,直接缩放保存
resized_image = image.resize(new_size, Image.Resampling.LANCZOS)
resized_image.save(output_buffer, format="JPEG", quality=95, optimize=True)
return output_buffer.getvalue(), original_size, new_size
except Exception as e:
logger.error(f"图片缩放失败: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return image_data, None, None
def compress_base64_image(base64_data: str, target_size: int = 1 * 1024 * 1024) -> str:
original_b64_data_size = len(base64_data) # 计算原始数据大小
image_data = base64.b64decode(base64_data)
# 先尝试转换格式为JPEG
image_data = reformat_static_image(image_data)
base64_data = base64.b64encode(image_data).decode("utf-8")
if len(base64_data) <= target_size:
# 如果转换后小于目标大小,直接返回
logger.info(f"成功将图片转为JPEG格式编码后大小: {len(base64_data) / 1024:.1f}KB")
return base64_data
# 如果转换后仍然大于目标大小,进行尺寸压缩
scale = min(1.0, target_size / len(base64_data))
image_data, original_size, new_size = rescale_image(image_data, scale)
base64_data = base64.b64encode(image_data).decode("utf-8")
if original_size and new_size:
logger.info(
f"压缩图片: {original_size[0]}x{original_size[1]} -> {new_size[0]}x{new_size[1]}\n"
f"压缩前大小: {original_b64_data_size / 1024:.1f}KB, 压缩后大小: {len(base64_data) / 1024:.1f}KB"
)
return base64_data
compressed_messages = []
for message in messages:
if isinstance(message.content, list):
# 检查content如有图片则压缩
message_builder = MessageBuilder()
for content_item in message.content:
if isinstance(content_item, tuple):
# 图片,进行压缩
message_builder.add_image_content(
content_item[0],
compress_base64_image(content_item[1], target_size=img_target_size),
)
else:
message_builder.add_text_content(content_item)
compressed_messages.append(message_builder.build())
else:
compressed_messages.append(message)
return compressed_messages
class LLMUsageRecorder:
"""
LLM使用情况记录器
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
def record_usage_to_database(
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str
):
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_info.model_identifier,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
f"总计: {model_usage.total_tokens}"
)
except Exception as e:
logger.error(f"记录token使用情况失败: {str(e)}")
llm_usage_recorder = LLMUsageRecorder()

File diff suppressed because it is too large Load Diff

View File

@ -2,13 +2,15 @@ from src.chat.message_receive.chat_stream import get_chat_manager
import time import time
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import model_config
from src.chat.message_receive.message import MessageRecvS4U from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mais4u.mais4u_chat.internal_manager import internal_manager from src.mais4u.mais4u_chat.internal_manager import internal_manager
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def init_prompt(): def init_prompt():
Prompt( Prompt(
""" """
@ -32,10 +34,8 @@ def init_prompt():
) )
class MaiThinking: class MaiThinking:
def __init__(self,chat_id): def __init__(self, chat_id):
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_stream = get_chat_manager().get_stream(chat_id)
self.platform = self.chat_stream.platform self.platform = self.chat_stream.platform
@ -44,11 +44,11 @@ class MaiThinking:
self.is_group = True self.is_group = True
else: else:
self.is_group = False self.is_group = False
self.s4u_message_processor = S4UMessageProcessor() self.s4u_message_processor = S4UMessageProcessor()
self.mind = "" self.mind = ""
self.memory_block = "" self.memory_block = ""
self.relation_info_block = "" self.relation_info_block = ""
self.time_block = "" self.time_block = ""
@ -59,17 +59,13 @@ class MaiThinking:
self.identity = "" self.identity = ""
self.sender = "" self.sender = ""
self.target = "" self.target = ""
self.thinking_model = LLMRequest( self.thinking_model = LLMRequest(model_set=model_config.model_task_config.replyer_1, request_type="thinking")
model=global_config.model.replyer_1,
request_type="thinking",
)
async def do_think_before_response(self): async def do_think_before_response(self):
pass pass
async def do_think_after_response(self,reponse:str): async def do_think_after_response(self, reponse: str):
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"after_response_think_prompt", "after_response_think_prompt",
mind=self.mind, mind=self.mind,
@ -85,47 +81,44 @@ class MaiThinking:
sender=self.sender, sender=self.sender,
target=self.target, target=self.target,
) )
result, _ = await self.thinking_model.generate_response_async(prompt) result, _ = await self.thinking_model.generate_response_async(prompt)
self.mind = result self.mind = result
logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}") logger.info(f"[{self.chat_id}] 思考前想法:{self.mind}")
# logger.info(f"[{self.chat_id}] 思考前prompt{prompt}") # logger.info(f"[{self.chat_id}] 思考前prompt{prompt}")
logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}") logger.info(f"[{self.chat_id}] 思考后想法:{self.mind}")
msg_recv = await self.build_internal_message_recv(self.mind) msg_recv = await self.build_internal_message_recv(self.mind)
await self.s4u_message_processor.process_message(msg_recv) await self.s4u_message_processor.process_message(msg_recv)
internal_manager.set_internal_state(self.mind) internal_manager.set_internal_state(self.mind)
async def do_think_when_receive_message(self): async def do_think_when_receive_message(self):
pass pass
async def build_internal_message_recv(self,message_text:str): async def build_internal_message_recv(self, message_text: str):
msg_id = f"internal_{time.time()}" msg_id = f"internal_{time.time()}"
message_dict = { message_dict = {
"message_info": { "message_info": {
"message_id": msg_id, "message_id": msg_id,
"time": time.time(), "time": time.time(),
"user_info": { "user_info": {
"user_id": "internal", # 内部用户ID "user_id": "internal", # 内部用户ID
"user_nickname": "内心", # 内部昵称 "user_nickname": "内心", # 内部昵称
"platform": self.platform, # 平台标记为 internal "platform": self.platform, # 平台标记为 internal
# 其他 user_info 字段按需补充 # 其他 user_info 字段按需补充
}, },
"platform": self.platform, # 平台 "platform": self.platform, # 平台
# 其他 message_info 字段按需补充 # 其他 message_info 字段按需补充
}, },
"message_segment": { "message_segment": {
"type": "text", # 消息类型 "type": "text", # 消息类型
"data": message_text, # 消息内容 "data": message_text, # 消息内容
# 其他 segment 字段按需补充 # 其他 segment 字段按需补充
}, },
"raw_message": message_text, # 原始消息内容 "raw_message": message_text, # 原始消息内容
"processed_plain_text": message_text, # 处理后的纯文本 "processed_plain_text": message_text, # 处理后的纯文本
# 下面这些字段可选,根据 MessageRecv 需要 # 下面这些字段可选,根据 MessageRecv 需要
"is_emoji": False, "is_emoji": False,
"has_emoji": False, "has_emoji": False,
@ -139,45 +132,36 @@ class MaiThinking:
"priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级 "priority_info": {"message_priority": 10.0}, # 内部消息可设高优先级
"interest_value": 1.0, "interest_value": 1.0,
} }
if self.is_group: if self.is_group:
message_dict["message_info"]["group_info"] = { message_dict["message_info"]["group_info"] = {
"platform": self.platform, "platform": self.platform,
"group_id": self.chat_stream.group_info.group_id, "group_id": self.chat_stream.group_info.group_id,
"group_name": self.chat_stream.group_info.group_name, "group_name": self.chat_stream.group_info.group_name,
} }
msg_recv = MessageRecvS4U(message_dict) msg_recv = MessageRecvS4U(message_dict)
msg_recv.chat_info = self.chat_info msg_recv.chat_info = self.chat_info
msg_recv.chat_stream = self.chat_stream msg_recv.chat_stream = self.chat_stream
msg_recv.is_internal = True msg_recv.is_internal = True
return msg_recv return msg_recv
class MaiThinkingManager: class MaiThinkingManager:
def __init__(self): def __init__(self):
self.mai_think_list = [] self.mai_think_list = []
def get_mai_think(self,chat_id): def get_mai_think(self, chat_id):
for mai_think in self.mai_think_list: for mai_think in self.mai_think_list:
if mai_think.chat_id == chat_id: if mai_think.chat_id == chat_id:
return mai_think return mai_think
mai_think = MaiThinking(chat_id) mai_think = MaiThinking(chat_id)
self.mai_think_list.append(mai_think) self.mai_think_list.append(mai_think)
return mai_think return mai_think
mai_thinking_manager = MaiThinkingManager() mai_thinking_manager = MaiThinkingManager()
init_prompt() init_prompt()

View File

@ -1,14 +1,16 @@
import json import json
import time import time
from json_repair import repair_json
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
from json_repair import repair_json
from src.mais4u.s4u_config import s4u_config from src.mais4u.s4u_config import s4u_config
logger = get_logger("action") logger = get_logger("action")
@ -32,7 +34,7 @@ BODY_CODE = {
"帅气的姿势": "010_0190", "帅气的姿势": "010_0190",
"另一个帅气的姿势": "010_0191", "另一个帅气的姿势": "010_0191",
"手掌朝前可爱": "010_0210", "手掌朝前可爱": "010_0210",
"平静,双手后放":"平静,双手后放", "平静,双手后放": "平静,双手后放",
"思考": "思考", "思考": "思考",
"优雅,左手放在腰上": "优雅,左手放在腰上", "优雅,左手放在腰上": "优雅,左手放在腰上",
"一般": "一般", "一般": "一般",
@ -94,19 +96,15 @@ class ChatAction:
self.body_action_cooldown: dict[str, int] = {} self.body_action_cooldown: dict[str, int] = {}
print(s4u_config.models.motion) print(s4u_config.models.motion)
print(global_config.model.emotion) print(model_config.model_task_config.emotion)
self.action_model = LLMRequest(
model=global_config.model.emotion,
temperature=0.7,
request_type="motion",
)
self.last_change_time = 0 self.action_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
self.last_change_time: float = 0
async def send_action_update(self): async def send_action_update(self):
"""发送动作更新到前端""" """发送动作更新到前端"""
body_code = BODY_CODE.get(self.body_action, "") body_code = BODY_CODE.get(self.body_action, "")
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="body_action", message_type="body_action",
@ -115,13 +113,11 @@ class ChatAction:
storage_message=False, storage_message=False,
show_log=True, show_log=True,
) )
async def update_action_by_message(self, message: MessageRecv): async def update_action_by_message(self, message: MessageRecv):
self.regression_count = 0 self.regression_count = 0
message_time = message.message_info.time message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_change_time, timestamp_start=self.last_change_time,
@ -147,13 +143,13 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try: try:
# 冷却池处理:过滤掉冷却中的动作 # 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown() self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
all_actions = "\n".join(available_actions) all_actions = "\n".join(available_actions)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"change_action_prompt", "change_action_prompt",
chat_talking_prompt=chat_talking_prompt, chat_talking_prompt=chat_talking_prompt,
@ -163,19 +159,18 @@ class ChatAction:
) )
logger.info(f"prompt: {prompt}") logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}") logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}") logger.info(f"reasoning_content: {reasoning_content}")
action_data = json.loads(repair_json(response)) if action_data := json.loads(repair_json(response)):
if action_data:
# 记录原动作,切换后进入冷却 # 记录原动作,切换后进入冷却
prev_body_action = self.body_action prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action) new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action: if new_body_action != prev_body_action and prev_body_action:
if prev_body_action: self.body_action_cooldown[prev_body_action] = 3
self.body_action_cooldown[prev_body_action] = 3
self.body_action = new_body_action self.body_action = new_body_action
self.head_action = action_data.get("head_action", self.head_action) self.head_action = action_data.get("head_action", self.head_action)
# 发送动作更新 # 发送动作更新
@ -213,7 +208,6 @@ class ChatAction:
prompt_personality = global_config.personality.personality_core prompt_personality = global_config.personality.personality_core
indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}" indentify_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
try: try:
# 冷却池处理:过滤掉冷却中的动作 # 冷却池处理:过滤掉冷却中的动作
self._update_body_action_cooldown() self._update_body_action_cooldown()
available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown] available_actions = [k for k in BODY_CODE.keys() if k not in self.body_action_cooldown]
@ -228,17 +222,17 @@ class ChatAction:
) )
logger.info(f"prompt: {prompt}") logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await self.action_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.action_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"response: {response}") logger.info(f"response: {response}")
logger.info(f"reasoning_content: {reasoning_content}") logger.info(f"reasoning_content: {reasoning_content}")
action_data = json.loads(repair_json(response)) if action_data := json.loads(repair_json(response)):
if action_data:
prev_body_action = self.body_action prev_body_action = self.body_action
new_body_action = action_data.get("body_action", self.body_action) new_body_action = action_data.get("body_action", self.body_action)
if new_body_action != prev_body_action: if new_body_action != prev_body_action and prev_body_action:
if prev_body_action: self.body_action_cooldown[prev_body_action] = 6
self.body_action_cooldown[prev_body_action] = 6
self.body_action = new_body_action self.body_action = new_body_action
# 发送动作更新 # 发送动作更新
await self.send_action_update() await self.send_action_update()
@ -306,9 +300,6 @@ class ActionManager:
return new_action_state return new_action_state
init_prompt() init_prompt()
action_manager = ActionManager() action_manager = ActionManager()

View File

@ -137,7 +137,7 @@ class MessageSenderContainer:
await self.storage.store_message(bot_message, self.chat_stream) await self.storage.store_message(bot_message, self.chat_stream)
except Exception as e: except Exception as e:
logger.error(f"[{self.chat_stream.get_stream_name()}] 消息发送或存储时出现错误: {e}", exc_info=True) logger.error(f"[消息流: {self.chat_stream.stream_id}] 消息发送或存储时出现错误: {e}", exc_info=True)
finally: finally:
# CRUCIAL: Always call task_done() for any item that was successfully retrieved. # CRUCIAL: Always call task_done() for any item that was successfully retrieved.

View File

@ -6,7 +6,7 @@ from src.chat.message_receive.message import MessageRecv
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
@ -114,18 +114,12 @@ class ChatMood:
self.regression_count: int = 0 self.regression_count: int = 0
self.mood_model = LLMRequest( self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood_text")
model=global_config.model.emotion,
temperature=0.7,
request_type="mood_text",
)
self.mood_model_numerical = LLMRequest( self.mood_model_numerical = LLMRequest(
model=global_config.model.emotion, model_set=model_config.model_task_config.emotion, request_type="mood_numerical"
temperature=0.4,
request_type="mood_numerical",
) )
self.last_change_time = 0 self.last_change_time: float = 0
# 发送初始情绪状态到ws端 # 发送初始情绪状态到ws端
asyncio.create_task(self.send_emotion_update(self.mood_values)) asyncio.create_task(self.send_emotion_update(self.mood_values))
@ -164,7 +158,7 @@ class ChatMood:
async def update_mood_by_message(self, message: MessageRecv): async def update_mood_by_message(self, message: MessageRecv):
self.regression_count = 0 self.regression_count = 0
message_time = message.message_info.time message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_change_time, timestamp_start=self.last_change_time,
@ -199,7 +193,9 @@ class ChatMood:
mood_state=self.mood_state, mood_state=self.mood_state,
) )
logger.debug(f"text mood prompt: {prompt}") logger.debug(f"text mood prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text mood response: {response}") logger.info(f"text mood response: {response}")
logger.debug(f"text mood reasoning_content: {reasoning_content}") logger.debug(f"text mood reasoning_content: {reasoning_content}")
return response return response
@ -216,8 +212,8 @@ class ChatMood:
fear=self.mood_values["fear"], fear=self.mood_values["fear"],
) )
logger.debug(f"numerical mood prompt: {prompt}") logger.debug(f"numerical mood prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt prompt=prompt, temperature=0.4
) )
logger.info(f"numerical mood response: {response}") logger.info(f"numerical mood response: {response}")
logger.debug(f"numerical mood reasoning_content: {reasoning_content}") logger.debug(f"numerical mood reasoning_content: {reasoning_content}")
@ -276,7 +272,9 @@ class ChatMood:
mood_state=self.mood_state, mood_state=self.mood_state,
) )
logger.debug(f"text regress prompt: {prompt}") logger.debug(f"text regress prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
logger.info(f"text regress response: {response}") logger.info(f"text regress response: {response}")
logger.debug(f"text regress reasoning_content: {reasoning_content}") logger.debug(f"text regress reasoning_content: {reasoning_content}")
return response return response
@ -293,8 +291,9 @@ class ChatMood:
fear=self.mood_values["fear"], fear=self.mood_values["fear"],
) )
logger.debug(f"numerical regress prompt: {prompt}") logger.debug(f"numerical regress prompt: {prompt}")
response, (reasoning_content, model_name) = await self.mood_model_numerical.generate_response_async( response, (reasoning_content, _, _) = await self.mood_model_numerical.generate_response_async(
prompt=prompt prompt=prompt,
temperature=0.4,
) )
logger.info(f"numerical regress response: {response}") logger.info(f"numerical regress response: {response}")
logger.debug(f"numerical regress reasoning_content: {reasoning_content}") logger.debug(f"numerical regress reasoning_content: {reasoning_content}")
@ -447,6 +446,7 @@ class MoodManager:
# 发送初始情绪状态到ws端 # 发送初始情绪状态到ws端
asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values)) asyncio.create_task(new_mood.send_emotion_update(new_mood.mood_values))
if ENABLE_S4U: if ENABLE_S4U:
init_prompt() init_prompt()
mood_manager = MoodManager() mood_manager = MoodManager()

View File

@ -150,19 +150,18 @@ class PromptBuilder:
relation_prompt = "" relation_prompt = ""
if global_config.relationship.enable_relationship and who_chat_in_group: if global_config.relationship.enable_relationship and who_chat_in_group:
relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id) relationship_fetcher = relationship_fetcher_manager.get_fetcher(chat_stream.stream_id)
# 将 (platform, user_id, nickname) 转换为 person_id # 将 (platform, user_id, nickname) 转换为 person_id
person_ids = [] person_ids = []
for person in who_chat_in_group: for person in who_chat_in_group:
person_id = PersonInfoManager.get_person_id(person[0], person[1]) person_id = PersonInfoManager.get_person_id(person[0], person[1])
person_ids.append(person_id) person_ids.append(person_id)
# 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为 # 使用 RelationshipFetcher 的 build_relation_info 方法,设置 points_num=3 保持与原来相同的行为
relation_info_list = await asyncio.gather( relation_info_list = await asyncio.gather(
*[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids] *[relationship_fetcher.build_relation_info(person_id, points_num=3) for person_id in person_ids]
) )
relation_info = "".join(relation_info_list) if relation_info := "".join(relation_info_list):
if relation_info:
relation_prompt = await global_prompt_manager.format_prompt( relation_prompt = await global_prompt_manager.format_prompt(
"relation_prompt", relation_info=relation_info "relation_prompt", relation_info=relation_info
) )
@ -186,9 +185,9 @@ class PromptBuilder:
timestamp=time.time(), timestamp=time.time(),
limit=300, limit=300,
) )
talk_type = message.message_info.platform + ":" + str(message.chat_stream.user_info.user_id)
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
core_dialogue_list = [] core_dialogue_list = []
background_dialogue_list = [] background_dialogue_list = []
@ -258,19 +257,19 @@ class PromptBuilder:
all_msg_seg_list.append(msg_seg_str) all_msg_seg_list.append(msg_seg_str)
for msg in all_msg_seg_list: for msg in all_msg_seg_list:
core_msg_str += msg core_msg_str += msg
all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat( all_dialogue_prompt = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_stream.stream_id, chat_id=chat_stream.stream_id,
timestamp=time.time(), timestamp=time.time(),
limit=20, limit=20,
) )
all_dialogue_prompt_str = build_readable_messages( all_dialogue_prompt_str = build_readable_messages(
all_dialogue_prompt, all_dialogue_prompt,
timestamp_mode="normal_no_YMD", timestamp_mode="normal_no_YMD",
show_pic=False, show_pic=False,
) )
return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str return core_msg_str, background_dialogue_prompt,all_dialogue_prompt_str

View File

@ -1,7 +1,7 @@
import os import os
from typing import AsyncGenerator from typing import AsyncGenerator
from src.mais4u.openai_client import AsyncOpenAIClient from src.mais4u.openai_client import AsyncOpenAIClient
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecvS4U from src.chat.message_receive.message import MessageRecvS4U
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
from src.common.logger import get_logger from src.common.logger import get_logger
@ -14,24 +14,27 @@ logger = get_logger("s4u_stream_generator")
class S4UStreamGenerator: class S4UStreamGenerator:
def __init__(self): def __init__(self):
replyer_1_config = global_config.model.replyer_1 replyer_1_config = model_config.model_task_config.replyer_1
provider = replyer_1_config.get("provider") model_to_use = replyer_1_config.model_list[0]
if not provider: model_info = model_config.get_model_info(model_to_use)
logger.error("`replyer_1` 在配置文件中缺少 `provider` 字段") if not model_info:
raise ValueError("`replyer_1` 在配置文件中缺少 `provider` 字段") logger.error(f"模型 {model_to_use} 在配置中未找到")
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
provider_name = model_info.api_provider
provider_info = model_config.get_provider(provider_name)
if not provider_info:
logger.error("`replyer_1` 找不到对应的Provider")
raise ValueError("`replyer_1` 找不到对应的Provider")
api_key = os.environ.get(f"{provider.upper()}_KEY") api_key = provider_info.api_key
base_url = os.environ.get(f"{provider.upper()}_BASE_URL") base_url = provider_info.base_url
if not api_key: if not api_key:
logger.error(f"环境变量 {provider.upper()}_KEY 未设置") logger.error(f"{provider_name}没有配置API KEY")
raise ValueError(f"环境变量 {provider.upper()}_KEY 未设置") raise ValueError(f"{provider_name}没有配置API KEY")
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url) self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
self.model_1_name = replyer_1_config.get("name") self.model_1_name = model_to_use
if not self.model_1_name:
logger.error("`replyer_1` 在配置文件中缺少 `model_name` 字段")
raise ValueError("`replyer_1` 在配置文件中缺少 `model_name` 字段")
self.replyer_1_config = replyer_1_config self.replyer_1_config = replyer_1_config
self.current_model_name = "unknown model" self.current_model_name = "unknown model"
@ -44,10 +47,10 @@ class S4UStreamGenerator:
r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符 r'[^.。!?\n\r]+(?:[.。!?\n\r](?![\'"])|$))', # 匹配直到句子结束符
re.UNICODE | re.DOTALL, re.UNICODE | re.DOTALL,
) )
self.chat_stream =None self.chat_stream = None
async def build_last_internal_message(self,message:MessageRecvS4U,previous_reply_context:str = ""): async def build_last_internal_message(self, message: MessageRecvS4U, previous_reply_context: str = ""):
# person_id = PersonInfoManager.get_person_id( # person_id = PersonInfoManager.get_person_id(
# message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id # message.chat_stream.user_info.platform, message.chat_stream.user_info.user_id
# ) # )
@ -71,14 +74,10 @@ class S4UStreamGenerator:
[这是用户发来的新消息, 你需要结合上下文对此进行回复]: [这是用户发来的新消息, 你需要结合上下文对此进行回复]:
{message.processed_plain_text} {message.processed_plain_text}
""" """
return True,message_txt return True, message_txt
else: else:
message_txt = message.processed_plain_text message_txt = message.processed_plain_text
return False,message_txt return False, message_txt
async def generate_response( async def generate_response(
self, message: MessageRecvS4U, previous_reply_context: str = "" self, message: MessageRecvS4U, previous_reply_context: str = ""
@ -88,7 +87,7 @@ class S4UStreamGenerator:
self.partial_response = "" self.partial_response = ""
message_txt = message.processed_plain_text message_txt = message.processed_plain_text
if not message.is_internal: if not message.is_internal:
interupted,message_txt_added = await self.build_last_internal_message(message,previous_reply_context) interupted, message_txt_added = await self.build_last_internal_message(message, previous_reply_context)
if interupted: if interupted:
message_txt = message_txt_added message_txt = message_txt_added
@ -105,7 +104,6 @@ class S4UStreamGenerator:
current_client = self.client_1 current_client = self.client_1
self.current_model_name = self.model_1_name self.current_model_name = self.model_1_name
extra_kwargs = {} extra_kwargs = {}
if self.replyer_1_config.get("enable_thinking") is not None: if self.replyer_1_config.get("enable_thinking") is not None:
extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking") extra_kwargs["enable_thinking"] = self.replyer_1_config.get("enable_thinking")

View File

@ -214,51 +214,49 @@ class SuperChatManager:
def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str: def build_superchat_display_string(self, chat_id: str, max_count: int = 10) -> str:
"""构建SuperChat显示字符串""" """构建SuperChat显示字符串"""
superchats = self.get_superchats_by_chat(chat_id) superchats = self.get_superchats_by_chat(chat_id)
if not superchats: if not superchats:
return "" return ""
# 限制显示数量 # 限制显示数量
display_superchats = superchats[:max_count] display_superchats = superchats[:max_count]
lines = [] lines = ["📢 当前有效超级弹幕:"]
lines.append("📢 当前有效超级弹幕:")
for i, sc in enumerate(display_superchats, 1): for i, sc in enumerate(display_superchats, 1):
remaining_minutes = int(sc.remaining_time() / 60) remaining_minutes = int(sc.remaining_time() / 60)
remaining_seconds = int(sc.remaining_time() % 60) remaining_seconds = int(sc.remaining_time() % 60)
time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}" time_display = f"{remaining_minutes}{remaining_seconds}" if remaining_minutes > 0 else f"{remaining_seconds}"
line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}" line = f"{i}. 【{sc.price}元】{sc.user_nickname}: {sc.message_text}"
if len(line) > 100: # 限制单行长度 if len(line) > 100: # 限制单行长度
line = line[:97] + "..." line = f"{line[:97]}..."
line += f" (剩余{time_display})" line += f" (剩余{time_display})"
lines.append(line) lines.append(line)
if len(superchats) > max_count: if len(superchats) > max_count:
lines.append(f"... 还有{len(superchats) - max_count}条SuperChat") lines.append(f"... 还有{len(superchats) - max_count}条SuperChat")
return "\n".join(lines) return "\n".join(lines)
def build_superchat_summary_string(self, chat_id: str) -> str: def build_superchat_summary_string(self, chat_id: str) -> str:
"""构建SuperChat摘要字符串""" """构建SuperChat摘要字符串"""
superchats = self.get_superchats_by_chat(chat_id) superchats = self.get_superchats_by_chat(chat_id)
if not superchats: if not superchats:
return "当前没有有效的超级弹幕" return "当前没有有效的超级弹幕"
lines = [] lines = []
for sc in superchats: for sc in superchats:
single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}" single_sc_str = f"{sc.user_nickname} - {sc.price}元 - {sc.message_text}"
if len(single_sc_str) > 100: if len(single_sc_str) > 100:
single_sc_str = single_sc_str[:97] + "..." single_sc_str = f"{single_sc_str[:97]}..."
single_sc_str += f" (剩余{int(sc.remaining_time())}秒)" single_sc_str += f" (剩余{int(sc.remaining_time())}秒)"
lines.append(single_sc_str) lines.append(single_sc_str)
total_amount = sum(sc.price for sc in superchats) total_amount = sum(sc.price for sc in superchats)
count = len(superchats) count = len(superchats)
highest_amount = max(sc.price for sc in superchats) highest_amount = max(sc.price for sc in superchats)
final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}" final_str = f"当前有{count}条超级弹幕,总金额{total_amount}元,最高单笔{highest_amount}"
if lines: if lines:
final_str += "\n" + "\n".join(lines) final_str += "\n" + "\n".join(lines)
@ -287,7 +285,7 @@ class SuperChatManager:
"lowest_amount": min(amounts) "lowest_amount": min(amounts)
} }
async def shutdown(self): async def shutdown(self): # sourcery skip: use-contextlib-suppress
"""关闭管理器,清理资源""" """关闭管理器,清理资源"""
if self._cleanup_task and not self._cleanup_task.done(): if self._cleanup_task and not self._cleanup_task.done():
self._cleanup_task.cancel() self._cleanup_task.cancel()
@ -300,6 +298,7 @@ class SuperChatManager:
# sourcery skip: assign-if-exp
if ENABLE_S4U: if ENABLE_S4U:
super_chat_manager = SuperChatManager() super_chat_manager = SuperChatManager()
else: else:

View File

@ -1,19 +1,14 @@
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import model_config
from src.plugin_system.apis import send_api from src.plugin_system.apis import send_api
logger = get_logger(__name__) logger = get_logger(__name__)
head_actions_list = [ head_actions_list = ["不做额外动作", "点头一次", "点头两次", "摇头", "歪脑袋", "低头望向一边"]
"不做额外动作",
"点头一次",
"点头两次",
"摇头",
"歪脑袋",
"低头望向一边"
]
async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat_id: str = ""):
async def yes_or_no_head(text: str, emotion: str = "", chat_history: str = "", chat_id: str = ""):
prompt = f""" prompt = f"""
{chat_history} {chat_history}
以上是对方的发言 以上是对方的发言
@ -30,22 +25,14 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
低头望向一边 低头望向一边
请从上面的动作中选择一个并输出请只输出你选择的动作就好不要输出其他内容""" 请从上面的动作中选择一个并输出请只输出你选择的动作就好不要输出其他内容"""
model = LLMRequest( model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="motion")
model=global_config.model.emotion,
temperature=0.7,
request_type="motion",
)
try: try:
# logger.info(f"prompt: {prompt}") # logger.info(f"prompt: {prompt}")
response, (reasoning_content, model_name) = await model.generate_response_async(prompt=prompt) response, _ = await model.generate_response_async(prompt=prompt, temperature=0.7)
logger.info(f"response: {response}") logger.info(f"response: {response}")
if response in head_actions_list: head_action = response if response in head_actions_list else "不做额外动作"
head_action = response
else:
head_action = "不做额外动作"
await send_api.custom_to_stream( await send_api.custom_to_stream(
message_type="head_action", message_type="head_action",
content=head_action, content=head_action,
@ -53,11 +40,7 @@ async def yes_or_no_head(text: str,emotion: str = "",chat_history: str = "",chat
storage_message=False, storage_message=False,
show_log=True, show_log=True,
) )
except Exception as e: except Exception as e:
logger.error(f"yes_or_no_head error: {e}") logger.error(f"yes_or_no_head error: {e}")
return "不做额外动作" return "不做额外动作"

View File

@ -3,13 +3,14 @@ import random
import time import time
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.manager.async_task_manager import AsyncTask, async_task_manager from src.manager.async_task_manager import AsyncTask, async_task_manager
from src.chat.message_receive.chat_stream import get_chat_manager
logger = get_logger("mood") logger = get_logger("mood")
@ -49,7 +50,7 @@ class ChatMood:
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
self.chat_stream = chat_manager.get_stream(self.chat_id) self.chat_stream = chat_manager.get_stream(self.chat_id)
if not self.chat_stream: if not self.chat_stream:
raise ValueError(f"Chat stream for chat_id {chat_id} not found") raise ValueError(f"Chat stream for chat_id {chat_id} not found")
@ -59,11 +60,7 @@ class ChatMood:
self.regression_count: int = 0 self.regression_count: int = 0
self.mood_model = LLMRequest( self.mood_model = LLMRequest(model_set=model_config.model_task_config.emotion, request_type="mood")
model=global_config.model.emotion,
temperature=0.7,
request_type="mood",
)
self.last_change_time: float = 0 self.last_change_time: float = 0
@ -83,12 +80,16 @@ class ChatMood:
logger.debug( logger.debug(
f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}" f"base_probability: {base_probability}, time_multiplier: {time_multiplier}, interest_multiplier: {interest_multiplier}"
) )
update_probability = global_config.mood.mood_update_threshold * min(1.0, base_probability * time_multiplier * interest_multiplier) update_probability = global_config.mood.mood_update_threshold * min(
1.0, base_probability * time_multiplier * interest_multiplier
)
if random.random() > update_probability: if random.random() > update_probability:
return return
logger.debug(f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}") logger.debug(
f"{self.log_prefix} 更新情绪状态,感兴趣度: {interested_rate:.2f}, 更新概率: {update_probability:.2f}"
)
message_time: float = message.message_info.time # type: ignore message_time: float = message.message_info.time # type: ignore
message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive( message_list_before_now = get_raw_msg_by_timestamp_with_chat_inclusive(
@ -124,7 +125,9 @@ class ChatMood:
mood_state=self.mood_state, mood_state=self.mood_state,
) )
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} prompt: {prompt}")
logger.info(f"{self.log_prefix} response: {response}") logger.info(f"{self.log_prefix} response: {response}")
@ -171,7 +174,9 @@ class ChatMood:
mood_state=self.mood_state, mood_state=self.mood_state,
) )
response, (reasoning_content, model_name) = await self.mood_model.generate_response_async(prompt=prompt) response, (reasoning_content, _, _) = await self.mood_model.generate_response_async(
prompt=prompt, temperature=0.7
)
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} prompt: {prompt}") logger.info(f"{self.log_prefix} prompt: {prompt}")

View File

@ -11,7 +11,7 @@ from src.common.logger import get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import PersonInfo from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
""" """
@ -54,11 +54,7 @@ person_info_default = {
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
self.person_name_list = {} self.person_name_list = {}
# TODO: API-Adapter修改标记 self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
self.qv_name_llm = LLMRequest(
model=global_config.model.utils,
request_type="relation.qv_name",
)
try: try:
db.connect(reuse_if_open=True) db.connect(reuse_if_open=True)
# 设置连接池参数 # 设置连接池参数
@ -199,7 +195,7 @@ class PersonInfoManager:
if existing: if existing:
logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建") logger.debug(f"用户 {p_data['person_id']} 已存在,跳过创建")
return True return True
# 尝试创建 # 尝试创建
PersonInfo.create(**p_data) PersonInfo.create(**p_data)
return True return True
@ -376,7 +372,7 @@ class PersonInfoManager:
"nickname": "昵称", "nickname": "昵称",
"reason": "理由" "reason": "理由"
}""" }"""
response, (reasoning_content, model_name) = await self.qv_name_llm.generate_response_async(qv_name_prompt) response, _ = await self.qv_name_llm.generate_response_async(qv_name_prompt)
# logger.info(f"取名提示词:{qv_name_prompt}\n取名回复{response}") # logger.info(f"取名提示词:{qv_name_prompt}\n取名回复{response}")
result = self._extract_json_from_text(response) result = self._extract_json_from_text(response)
@ -592,7 +588,7 @@ class PersonInfoManager:
record = PersonInfo.get_or_none(PersonInfo.person_id == p_id) record = PersonInfo.get_or_none(PersonInfo.person_id == p_id)
if record: if record:
return record, False # 记录存在,未创建 return record, False # 记录存在,未创建
# 记录不存在,尝试创建 # 记录不存在,尝试创建
try: try:
PersonInfo.create(**init_data) PersonInfo.create(**init_data)
@ -622,7 +618,7 @@ class PersonInfoManager:
"points": [], "points": [],
"forgotten_points": [], "forgotten_points": [],
} }
# 序列化JSON字段 # 序列化JSON字段
for key in JSON_SERIALIZED_FIELDS: for key in JSON_SERIALIZED_FIELDS:
if key in initial_data: if key in initial_data:
@ -630,12 +626,12 @@ class PersonInfoManager:
initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False) initial_data[key] = json.dumps(initial_data[key], ensure_ascii=False)
elif initial_data[key] is None: elif initial_data[key] is None:
initial_data[key] = json.dumps([], ensure_ascii=False) initial_data[key] = json.dumps([], ensure_ascii=False)
model_fields = PersonInfo._meta.fields.keys() # type: ignore model_fields = PersonInfo._meta.fields.keys() # type: ignore
filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields} filtered_initial_data = {k: v for k, v in initial_data.items() if v is not None and k in model_fields}
record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data) record, was_created = await asyncio.to_thread(_db_get_or_create_sync, person_id, filtered_initial_data)
if was_created: if was_created:
logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。") logger.info(f"用户 {platform}:{user_id} (person_id: {person_id}) 不存在,将创建新记录 (Peewee)。")
logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}") logger.info(f"已为 {person_id} 创建新记录,初始数据 (filtered for model): {filtered_initial_data}")

View File

@ -7,7 +7,7 @@ from typing import List, Dict, Any
from json_repair import repair_json from json_repair import repair_json
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
@ -73,14 +73,12 @@ class RelationshipFetcher:
# LLM模型配置 # LLM模型配置
self.llm_model = LLMRequest( self.llm_model = LLMRequest(
model=global_config.model.utils_small, model_set=model_config.model_task_config.utils_small, request_type="relation.fetcher"
request_type="relation.fetcher",
) )
# 小模型用于即时信息提取 # 小模型用于即时信息提取
self.instant_llm_model = LLMRequest( self.instant_llm_model = LLMRequest(
model=global_config.model.utils_small, model_set=model_config.model_task_config.utils_small, request_type="relation.fetch"
request_type="relation.fetch",
) )
name = get_chat_manager().get_stream_name(self.chat_id) name = get_chat_manager().get_stream_name(self.chat_id)
@ -96,7 +94,7 @@ class RelationshipFetcher:
if not self.info_fetched_cache[person_id]: if not self.info_fetched_cache[person_id]:
del self.info_fetched_cache[person_id] del self.info_fetched_cache[person_id]
async def build_relation_info(self, person_id, points_num = 3): async def build_relation_info(self, person_id, points_num=3):
# 清理过期的信息缓存 # 清理过期的信息缓存
self._cleanup_expired_cache() self._cleanup_expired_cache()
@ -361,7 +359,6 @@ class RelationshipFetcher:
logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}") logger.error(f"{self.log_prefix} 执行信息提取时出错: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str): async def _save_info_to_cache(self, person_id: str, info_type: str, info_content: str):
# sourcery skip: use-next # sourcery skip: use-next
"""将提取到的信息保存到 person_info 的 info_list 字段中 """将提取到的信息保存到 person_info 的 info_list 字段中

View File

@ -3,7 +3,7 @@ from .person_info import PersonInfoManager, get_person_info_manager
import time import time
import random import random
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages from src.chat.utils.chat_message_builder import build_readable_messages
import json import json
from json_repair import repair_json from json_repair import repair_json
@ -20,9 +20,8 @@ logger = get_logger("relation")
class RelationshipManager: class RelationshipManager:
def __init__(self): def __init__(self):
self.relationship_llm = LLMRequest( self.relationship_llm = LLMRequest(
model=global_config.model.utils, model_set=model_config.model_task_config.utils, request_type="relationship"
request_type="relationship", # 用于动作规划 ) # 用于动作规划
)
@staticmethod @staticmethod
async def is_known_some_one(platform, user_id): async def is_known_some_one(platform, user_id):
@ -181,18 +180,14 @@ class RelationshipManager:
try: try:
points = repair_json(points) points = repair_json(points)
points_data = json.loads(points) points_data = json.loads(points)
# 只处理正确的格式,错误格式直接跳过 # 只处理正确的格式,错误格式直接跳过
if points_data == "none" or not points_data: if points_data == "none" or not points_data:
points_list = [] points_list = []
elif isinstance(points_data, str) and points_data.lower() == "none": elif isinstance(points_data, str) and points_data.lower() == "none":
points_list = [] points_list = []
elif isinstance(points_data, list): elif isinstance(points_data, list):
# 正确格式:数组格式 [{"point": "...", "weight": 10}, ...] points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
if not points_data: # 空数组
points_list = []
else:
points_list = [(item["point"], float(item["weight"]), current_time) for item in points_data]
else: else:
# 错误格式,直接跳过不解析 # 错误格式,直接跳过不解析
logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(points_data)}, 内容: {points_data}") logger.warning(f"LLM返回了错误的JSON格式跳过解析: {type(points_data)}, 内容: {points_data}")

View File

@ -9,6 +9,7 @@ from .base import (
BasePlugin, BasePlugin,
BaseAction, BaseAction,
BaseCommand, BaseCommand,
BaseTool,
ConfigField, ConfigField,
ComponentType, ComponentType,
ActionActivationType, ActionActivationType,
@ -17,11 +18,13 @@ from .base import (
ActionInfo, ActionInfo,
CommandInfo, CommandInfo,
PluginInfo, PluginInfo,
ToolInfo,
PythonDependency, PythonDependency,
BaseEventHandler, BaseEventHandler,
EventHandlerInfo, EventHandlerInfo,
EventType, EventType,
MaiMessages, MaiMessages,
ToolParamType,
) )
# 导入工具模块 # 导入工具模块
@ -34,6 +37,7 @@ from .utils import (
from .apis import ( from .apis import (
chat_api, chat_api,
tool_api,
component_manage_api, component_manage_api,
config_api, config_api,
database_api, database_api,
@ -44,17 +48,17 @@ from .apis import (
person_api, person_api,
plugin_manage_api, plugin_manage_api,
send_api, send_api,
utils_api,
register_plugin, register_plugin,
get_logger, get_logger,
) )
__version__ = "1.0.0" __version__ = "2.0.0"
__all__ = [ __all__ = [
# API 模块 # API 模块
"chat_api", "chat_api",
"tool_api",
"component_manage_api", "component_manage_api",
"config_api", "config_api",
"database_api", "database_api",
@ -65,13 +69,13 @@ __all__ = [
"person_api", "person_api",
"plugin_manage_api", "plugin_manage_api",
"send_api", "send_api",
"utils_api",
"register_plugin", "register_plugin",
"get_logger", "get_logger",
# 基础类 # 基础类
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseTool",
"BaseEventHandler", "BaseEventHandler",
# 类型定义 # 类型定义
"ComponentType", "ComponentType",
@ -81,9 +85,11 @@ __all__ = [
"ActionInfo", "ActionInfo",
"CommandInfo", "CommandInfo",
"PluginInfo", "PluginInfo",
"ToolInfo",
"PythonDependency", "PythonDependency",
"EventHandlerInfo", "EventHandlerInfo",
"EventType", "EventType",
"ToolParamType",
# 消息 # 消息
"MaiMessages", "MaiMessages",
# 装饰器 # 装饰器

View File

@ -17,7 +17,7 @@ from src.plugin_system.apis import (
person_api, person_api,
plugin_manage_api, plugin_manage_api,
send_api, send_api,
utils_api, tool_api,
) )
from .logging_api import get_logger from .logging_api import get_logger
from .plugin_register_api import register_plugin from .plugin_register_api import register_plugin
@ -35,7 +35,7 @@ __all__ = [
"person_api", "person_api",
"plugin_manage_api", "plugin_manage_api",
"send_api", "send_api",
"utils_api",
"get_logger", "get_logger",
"register_plugin", "register_plugin",
"tool_api",
] ]

View File

@ -5,6 +5,7 @@ from src.plugin_system.base.component_types import (
EventHandlerInfo, EventHandlerInfo,
PluginInfo, PluginInfo,
ComponentType, ComponentType,
ToolInfo,
) )
@ -119,6 +120,21 @@ def get_registered_command_info(command_name: str) -> Optional[CommandInfo]:
return component_registry.get_registered_command_info(command_name) return component_registry.get_registered_command_info(command_name)
def get_registered_tool_info(tool_name: str) -> Optional[ToolInfo]:
"""
获取指定 Tool 的注册信息
Args:
tool_name (str): Tool 名称
Returns:
ToolInfo: Tool 信息对象如果 Tool 不存在则返回 None
"""
from src.plugin_system.core.component_registry import component_registry
return component_registry.get_registered_tool_info(tool_name)
# === EventHandler 特定查询方法 === # === EventHandler 特定查询方法 ===
def get_registered_event_handler_info( def get_registered_event_handler_info(
event_handler_name: str, event_handler_name: str,
@ -191,6 +207,8 @@ def locally_enable_component(component_name: str, component_type: ComponentType,
return global_announcement_manager.enable_specific_chat_action(stream_id, component_name) return global_announcement_manager.enable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND: case ComponentType.COMMAND:
return global_announcement_manager.enable_specific_chat_command(stream_id, component_name) return global_announcement_manager.enable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.enable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER: case ComponentType.EVENT_HANDLER:
return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name) return global_announcement_manager.enable_specific_chat_event_handler(stream_id, component_name)
case _: case _:
@ -216,11 +234,14 @@ def locally_disable_component(component_name: str, component_type: ComponentType
return global_announcement_manager.disable_specific_chat_action(stream_id, component_name) return global_announcement_manager.disable_specific_chat_action(stream_id, component_name)
case ComponentType.COMMAND: case ComponentType.COMMAND:
return global_announcement_manager.disable_specific_chat_command(stream_id, component_name) return global_announcement_manager.disable_specific_chat_command(stream_id, component_name)
case ComponentType.TOOL:
return global_announcement_manager.disable_specific_chat_tool(stream_id, component_name)
case ComponentType.EVENT_HANDLER: case ComponentType.EVENT_HANDLER:
return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name) return global_announcement_manager.disable_specific_chat_event_handler(stream_id, component_name)
case _: case _:
raise ValueError(f"未知 component type: {component_type}") raise ValueError(f"未知 component type: {component_type}")
def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]: def get_locally_disabled_components(stream_id: str, component_type: ComponentType) -> list[str]:
""" """
获取指定消息流中禁用的组件列表 获取指定消息流中禁用的组件列表
@ -239,7 +260,9 @@ def get_locally_disabled_components(stream_id: str, component_type: ComponentTyp
return global_announcement_manager.get_disabled_chat_actions(stream_id) return global_announcement_manager.get_disabled_chat_actions(stream_id)
case ComponentType.COMMAND: case ComponentType.COMMAND:
return global_announcement_manager.get_disabled_chat_commands(stream_id) return global_announcement_manager.get_disabled_chat_commands(stream_id)
case ComponentType.TOOL:
return global_announcement_manager.get_disabled_chat_tools(stream_id)
case ComponentType.EVENT_HANDLER: case ComponentType.EVENT_HANDLER:
return global_announcement_manager.get_disabled_chat_event_handlers(stream_id) return global_announcement_manager.get_disabled_chat_event_handlers(stream_id)
case _: case _:
raise ValueError(f"未知 component type: {component_type}") raise ValueError(f"未知 component type: {component_type}")

View File

@ -152,10 +152,7 @@ async def db_query(
except DoesNotExist: except DoesNotExist:
# 记录不存在 # 记录不存在
if query_type == "get" and single_result: return None if query_type == "get" and single_result else []
return None
return []
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc() traceback.print_exc()
@ -170,7 +167,8 @@ async def db_query(
async def db_save( async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Union[Dict[str, Any], None]: ) -> Optional[Dict[str, Any]]:
# sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新) """保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新 如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
@ -203,10 +201,9 @@ async def db_save(
try: try:
# 如果提供了key_field和key_value尝试更新现有记录 # 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None: if key_field and key_value is not None:
# 查找现有记录 if existing_records := list(
existing_records = list(model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)) model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
):
if existing_records:
# 更新现有记录 # 更新现有记录
existing_record = existing_records[0] existing_record = existing_records[0]
for field, value in data.items(): for field, value in data.items():
@ -244,8 +241,8 @@ async def db_get(
Args: Args:
model_class: Peewee模型类 model_class: Peewee模型类
filters: 过滤条件字段名和值的字典 filters: 过滤条件字段名和值的字典
order_by: 排序字段前缀'-'表示降序例如'-time'表示按时间字段即time字段降序
limit: 结果数量限制 limit: 结果数量限制
order_by: 排序字段前缀'-'表示降序例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表 single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns: Returns:
@ -310,7 +307,7 @@ async def store_action_info(
thinking_id: str = "", thinking_id: str = "",
action_data: Optional[dict] = None, action_data: Optional[dict] = None,
action_name: str = "", action_name: str = "",
) -> Union[Dict[str, Any], None]: ) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库 """存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建 将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建

View File

@ -65,14 +65,14 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
return None return None
async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str, str]]]: async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
"""随机获取指定数量的表情包 """随机获取指定数量的表情包
Args: Args:
count: 要获取的表情包数量默认为1 count: 要获取的表情包数量默认为1
Returns: Returns:
Optional[List[Tuple[str, str, str]]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表如果失败则为None List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表失败则返回空列表
Raises: Raises:
TypeError: 如果count不是整数类型 TypeError: 如果count不是整数类型
@ -94,13 +94,13 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
if not all_emojis: if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包") logger.warning("[EmojiAPI] 没有可用的表情包")
return None return []
# 过滤有效表情包 # 过滤有效表情包
valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted] valid_emojis = [emoji for emoji in all_emojis if not emoji.is_deleted]
if not valid_emojis: if not valid_emojis:
logger.warning("[EmojiAPI] 没有有效的表情包") logger.warning("[EmojiAPI] 没有有效的表情包")
return None return []
if len(valid_emojis) < count: if len(valid_emojis) < count:
logger.warning( logger.warning(
@ -127,14 +127,14 @@ async def get_random(count: Optional[int] = 1) -> Optional[List[Tuple[str, str,
if not results and count > 0: if not results and count > 0:
logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理") logger.warning("[EmojiAPI] 随机获取表情包失败,没有一个可以成功处理")
return None return []
logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包") logger.info(f"[EmojiAPI] 成功获取 {len(results)} 个随机表情包")
return results return results
except Exception as e: except Exception as e:
logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}") logger.error(f"[EmojiAPI] 获取随机表情包失败: {e}")
return None return []
async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
@ -162,10 +162,11 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
# 筛选匹配情感的表情包 # 筛选匹配情感的表情包
matching_emojis = [] matching_emojis = []
for emoji_obj in all_emojis: matching_emojis.extend(
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]: emoji_obj
matching_emojis.append(emoji_obj) for emoji_obj in all_emojis
if not emoji_obj.is_deleted and emotion.lower() in [e.lower() for e in emoji_obj.emotion]
)
if not matching_emojis: if not matching_emojis:
logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包") logger.warning(f"[EmojiAPI] 未找到匹配情感 '{emotion}' 的表情包")
return None return None
@ -256,10 +257,11 @@ def get_descriptions() -> List[str]:
emoji_manager = get_emoji_manager() emoji_manager = get_emoji_manager()
descriptions = [] descriptions = []
for emoji_obj in emoji_manager.emoji_objects: descriptions.extend(
if not emoji_obj.is_deleted and emoji_obj.description: emoji_obj.description
descriptions.append(emoji_obj.description) for emoji_obj in emoji_manager.emoji_objects
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions return descriptions
except Exception as e: except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}") logger.error(f"[EmojiAPI] 获取表情包描述失败: {e}")

View File

@ -12,6 +12,7 @@ import traceback
from typing import Tuple, Any, Dict, List, Optional from typing import Tuple, Any, Dict, List, Optional
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
@ -31,7 +32,7 @@ logger = get_logger("generator_api")
def get_replyer( def get_replyer(
chat_stream: Optional[ChatStream] = None, chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
model_configs: Optional[List[Dict[str, Any]]] = None, model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> Optional[DefaultReplyer]: ) -> Optional[DefaultReplyer]:
"""获取回复器对象 """获取回复器对象
@ -42,7 +43,7 @@ def get_replyer(
Args: Args:
chat_stream: 聊天流对象优先 chat_stream: 聊天流对象优先
chat_id: 聊天ID实际上就是stream_id chat_id: 聊天ID实际上就是stream_id
model_configs: 模型配置列表 model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型 request_type: 请求类型
Returns: Returns:
@ -58,7 +59,7 @@ def get_replyer(
return replyer_manager.get_replyer( return replyer_manager.get_replyer(
chat_stream=chat_stream, chat_stream=chat_stream,
chat_id=chat_id, chat_id=chat_id,
model_configs=model_configs, model_set_with_weight=model_set_with_weight,
request_type=request_type, request_type=request_type,
) )
except Exception as e: except Exception as e:
@ -83,31 +84,36 @@ async def generate_reply(
enable_splitter: bool = True, enable_splitter: bool = True,
enable_chinese_typo: bool = True, enable_chinese_typo: bool = True,
return_prompt: bool = False, return_prompt: bool = False,
model_configs: Optional[List[Dict[str, Any]]] = None, model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
request_type: str = "", request_type: str = "generator_api",
enable_timeout: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]: ) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""生成回复 """生成回复
Args: Args:
chat_stream: 聊天流对象优先 chat_stream: 聊天流对象优先
chat_id: 聊天ID备用 chat_id: 聊天ID备用
action_data: 动作数据 action_data: 动作数据向下兼容包含reply_to和extra_info
reply_to: 回复对象格式为 "发送者:消息内容"
extra_info: 额外信息用于补充上下文
available_actions: 可用动作
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器 enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器 enable_chinese_typo: 是否启用错字生成器
return_prompt: 是否返回提示词 return_prompt: 是否返回提示词
model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
request_type: 请求类型可选记录LLM使用
Returns: Returns:
Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词) Tuple[bool, List[Tuple[str, Any]], Optional[str]]: (是否成功, 回复集合, 提示词)
""" """
try: try:
# 获取回复器 # 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs, request_type=request_type) replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight, request_type=request_type)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None return False, [], None
logger.debug("[GeneratorAPI] 开始生成回复") logger.debug("[GeneratorAPI] 开始生成回复")
if not reply_to and action_data: if not reply_to and action_data:
reply_to = action_data.get("reply_to", "") reply_to = action_data.get("reply_to", "")
if not extra_info and action_data: if not extra_info and action_data:
@ -118,7 +124,6 @@ async def generate_reply(
reply_to=reply_to, reply_to=reply_to,
extra_info=extra_info, extra_info=extra_info,
available_actions=available_actions, available_actions=available_actions,
enable_timeout=enable_timeout,
enable_tool=enable_tool, enable_tool=enable_tool,
) )
reply_set = [] reply_set = []
@ -150,33 +155,35 @@ async def rewrite_reply(
chat_id: Optional[str] = None, chat_id: Optional[str] = None,
enable_splitter: bool = True, enable_splitter: bool = True,
enable_chinese_typo: bool = True, enable_chinese_typo: bool = True,
model_configs: Optional[List[Dict[str, Any]]] = None, model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
raw_reply: str = "", raw_reply: str = "",
reason: str = "", reason: str = "",
reply_to: str = "", reply_to: str = "",
) -> Tuple[bool, List[Tuple[str, Any]]]: return_prompt: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
"""重写回复 """重写回复
Args: Args:
chat_stream: 聊天流对象优先 chat_stream: 聊天流对象优先
reply_data: 回复数据字典备用当其他参数缺失时从此获取 reply_data: 回复数据字典向下兼容备用当其他参数缺失时从此获取
chat_id: 聊天ID备用 chat_id: 聊天ID备用
enable_splitter: 是否启用消息分割器 enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器 enable_chinese_typo: 是否启用错字生成器
model_configs: 模型配置列表 model_set_with_weight: 模型配置列表每个元素为 (TaskConfig, weight) 元组
raw_reply: 原始回复内容 raw_reply: 原始回复内容
reason: 回复原因 reason: 回复原因
reply_to: 回复对象 reply_to: 回复对象
return_prompt: 是否返回提示词
Returns: Returns:
Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合) Tuple[bool, List[Tuple[str, Any]]]: (是否成功, 回复集合)
""" """
try: try:
# 获取回复器 # 获取回复器
replyer = get_replyer(chat_stream, chat_id, model_configs=model_configs) replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer: if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器") logger.error("[GeneratorAPI] 无法获取回复器")
return False, [] return False, [], None
logger.info("[GeneratorAPI] 开始重写回复") logger.info("[GeneratorAPI] 开始重写回复")
@ -187,10 +194,11 @@ async def rewrite_reply(
reply_to = reply_to or reply_data.get("reply_to", "") reply_to = reply_to or reply_data.get("reply_to", "")
# 调用回复器重写回复 # 调用回复器重写回复
success, content = await replyer.rewrite_reply_with_context( success, content, prompt = await replyer.rewrite_reply_with_context(
raw_reply=raw_reply, raw_reply=raw_reply,
reason=reason, reason=reason,
reply_to=reply_to, reply_to=reply_to,
return_prompt=return_prompt,
) )
reply_set = [] reply_set = []
if content: if content:
@ -201,14 +209,14 @@ async def rewrite_reply(
else: else:
logger.warning("[GeneratorAPI] 重写回复失败") logger.warning("[GeneratorAPI] 重写回复失败")
return success, reply_set return success, reply_set, prompt if return_prompt else None
except ValueError as ve: except ValueError as ve:
raise ve raise ve
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}") logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
return False, [] return False, [], None
async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]: async def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
@ -234,3 +242,27 @@ async def process_human_text(content: str, enable_splitter: bool, enable_chinese
except Exception as e: except Exception as e:
logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}") logger.error(f"[GeneratorAPI] 处理人形文本时出错: {e}")
return [] return []
async def generate_response_custom(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
model_set_with_weight: Optional[List[Tuple[TaskConfig, float]]] = None,
prompt: str = "",
) -> Optional[str]:
replyer = get_replyer(chat_stream, chat_id, model_set_with_weight=model_set_with_weight)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return None
try:
logger.debug("[GeneratorAPI] 开始生成自定义回复")
response, _, _, _ = await replyer.llm_generate_content(prompt)
if response:
logger.debug("[GeneratorAPI] 自定义回复生成成功")
return response
else:
logger.warning("[GeneratorAPI] 自定义回复生成失败")
return None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成自定义回复时出错: {e}")
return None

View File

@ -7,10 +7,12 @@
success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config) success, response, reasoning, model_name = await llm_api.generate_with_model(prompt, model_config)
""" """
from typing import Tuple, Dict, Any from typing import Tuple, Dict, List, Any, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from src.llm_models.payload_content.tool_option import ToolCall
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config from src.config.config import global_config, model_config
from src.config.api_ada_configs import TaskConfig
logger = get_logger("llm_api") logger = get_logger("llm_api")
@ -19,9 +21,7 @@ logger = get_logger("llm_api")
# ============================================================================= # =============================================================================
def get_available_models() -> Dict[str, TaskConfig]:
def get_available_models() -> Dict[str, Any]:
"""获取所有可用的模型配置 """获取所有可用的模型配置
Returns: Returns:
@ -33,14 +33,14 @@ def get_available_models() -> Dict[str, Any]:
return {} return {}
# 自动获取所有属性并转换为字典形式 # 自动获取所有属性并转换为字典形式
rets = {} models = model_config.model_task_config
models = global_config.model
attrs = dir(models) attrs = dir(models)
rets: Dict[str, TaskConfig] = {}
for attr in attrs: for attr in attrs:
if not attr.startswith("__"): if not attr.startswith("__"):
try: try:
value = getattr(models, attr) value = getattr(models, attr)
if not callable(value): # 排除方法 if not callable(value) and isinstance(value, TaskConfig):
rets[attr] = value rets[attr] = value
except Exception as e: except Exception as e:
logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}") logger.debug(f"[LLMAPI] 获取属性 {attr} 失败: {e}")
@ -53,7 +53,11 @@ def get_available_models() -> Dict[str, Any]:
async def generate_with_model( async def generate_with_model(
prompt: str, model_config: Dict[str, Any], request_type: str = "plugin.generate", **kwargs prompt: str,
model_config: TaskConfig,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str]: ) -> Tuple[bool, str, str, str]:
"""使用指定模型生成内容 """使用指定模型生成内容
@ -61,22 +65,62 @@ async def generate_with_model(
prompt: 提示词 prompt: 提示词
model_config: 模型配置 get_available_models 获取的模型配置 model_config: 模型配置 get_available_models 获取的模型配置
request_type: 请求类型标识 request_type: 请求类型标识
**kwargs: 其他模型特定参数如temperaturemax_tokens等
Returns: Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称) Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
""" """
try: try:
model_name = model_config.get("name") model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型 {model_name} 生成内容") logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}") logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model=model_config, request_type=request_type, **kwargs) llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning, model_name) = await llm_request.generate_response_async(prompt) response, (reasoning_content, model_name, _) = await llm_request.generate_response_async(prompt, temperature=temperature, max_tokens=max_tokens)
return True, response, reasoning, model_name return True, response, reasoning_content, model_name
except Exception as e: except Exception as e:
error_msg = f"生成内容时出错: {str(e)}" error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}") logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "" return False, error_msg, "", ""
async def generate_with_model_with_tools(
prompt: str,
model_config: TaskConfig,
tool_options: List[Dict[str, Any]] | None = None,
request_type: str = "plugin.generate",
temperature: Optional[float] = None,
max_tokens: Optional[int] = None,
) -> Tuple[bool, str, str, str, List[ToolCall] | None]:
"""使用指定模型和工具生成内容
Args:
prompt: 提示词
model_config: 模型配置 get_available_models 获取的模型配置
tool_options: 工具选项列表
request_type: 请求类型标识
temperature: 温度参数
max_tokens: 最大token数
Returns:
Tuple[bool, str, str, str]: (是否成功, 生成的内容, 推理过程, 模型名称)
"""
try:
model_name_list = model_config.model_list
logger.info(f"[LLMAPI] 使用模型集合 {model_name_list} 生成内容")
logger.debug(f"[LLMAPI] 完整提示词: {prompt}")
llm_request = LLMRequest(model_set=model_config, request_type=request_type)
response, (reasoning_content, model_name, tool_call) = await llm_request.generate_response_async(
prompt,
tools=tool_options,
temperature=temperature,
max_tokens=max_tokens
)
return True, response, reasoning_content, model_name, tool_call
except Exception as e:
error_msg = f"生成内容时出错: {str(e)}"
logger.error(f"[LLMAPI] {error_msg}")
return False, error_msg, "", "", None

View File

@ -207,7 +207,7 @@ def get_random_chat_messages(
def get_messages_by_time_for_users( def get_messages_by_time_for_users(
start_time: float, end_time: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" start_time: float, end_time: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取指定用户在所有聊天中指定时间范围内的消息 获取指定用户在所有聊天中指定时间范围内的消息
@ -287,7 +287,7 @@ def get_messages_before_time_in_chat(
return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit) return get_raw_msg_before_timestamp_with_chat(chat_id, timestamp, limit)
def get_messages_before_time_for_users(timestamp: float, person_ids: list, limit: int = 0) -> List[Dict[str, Any]]: def get_messages_before_time_for_users(timestamp: float, person_ids: List[str], limit: int = 0) -> List[Dict[str, Any]]:
""" """
获取指定用户在指定时间戳之前的消息 获取指定用户在指定时间戳之前的消息
@ -372,7 +372,7 @@ def count_new_messages(chat_id: str, start_time: float = 0.0, end_time: Optional
return num_new_messages_since(chat_id, start_time, end_time) return num_new_messages_since(chat_id, start_time, end_time)
def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: list) -> int: def count_new_messages_for_users(chat_id: str, start_time: float, end_time: float, person_ids: List[str]) -> int:
""" """
计算指定聊天中指定用户从开始时间到结束时间的新消息数量 计算指定聊天中指定用户从开始时间到结束时间的新消息数量

View File

@ -1,10 +1,12 @@
from typing import Tuple, List from typing import Tuple, List
def list_loaded_plugins() -> List[str]: def list_loaded_plugins() -> List[str]:
""" """
列出所有当前加载的插件 列出所有当前加载的插件
Returns: Returns:
list: 当前加载的插件名称列表 List[str]: 当前加载的插件名称列表
""" """
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.plugin_manager import plugin_manager
@ -16,17 +18,38 @@ def list_registered_plugins() -> List[str]:
列出所有已注册的插件 列出所有已注册的插件
Returns: Returns:
list: 已注册的插件名称列表 List[str]: 已注册的插件名称列表
""" """
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.list_registered_plugins() return plugin_manager.list_registered_plugins()
def get_plugin_path(plugin_name: str) -> str:
"""
获取指定插件的路径
Args:
plugin_name (str): 插件名称
Returns:
str: 插件目录的绝对路径
Raises:
ValueError: 如果插件不存在
"""
from src.plugin_system.core.plugin_manager import plugin_manager
if plugin_path := plugin_manager.get_plugin_path(plugin_name):
return plugin_path
else:
raise ValueError(f"插件 '{plugin_name}' 不存在。")
async def remove_plugin(plugin_name: str) -> bool: async def remove_plugin(plugin_name: str) -> bool:
""" """
卸载指定的插件 卸载指定的插件
**此函数是异步的确保在异步环境中调用** **此函数是异步的确保在异步环境中调用**
Args: Args:
@ -43,7 +66,7 @@ async def remove_plugin(plugin_name: str) -> bool:
async def reload_plugin(plugin_name: str) -> bool: async def reload_plugin(plugin_name: str) -> bool:
""" """
重新加载指定的插件 重新加载指定的插件
**此函数是异步的确保在异步环境中调用** **此函数是异步的确保在异步环境中调用**
Args: Args:
@ -71,6 +94,7 @@ def load_plugin(plugin_name: str) -> Tuple[bool, int]:
return plugin_manager.load_registered_plugin_classes(plugin_name) return plugin_manager.load_registered_plugin_classes(plugin_name)
def add_plugin_directory(plugin_directory: str) -> bool: def add_plugin_directory(plugin_directory: str) -> bool:
""" """
添加插件目录 添加插件目录
@ -84,6 +108,7 @@ def add_plugin_directory(plugin_directory: str) -> bool:
return plugin_manager.add_plugin_directory(plugin_directory) return plugin_manager.add_plugin_directory(plugin_directory)
def rescan_plugin_directory() -> Tuple[int, int]: def rescan_plugin_directory() -> Tuple[int, int]:
""" """
重新扫描插件目录加载新插件 重新扫描插件目录加载新插件
@ -92,4 +117,4 @@ def rescan_plugin_directory() -> Tuple[int, int]:
""" """
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.plugin_manager import plugin_manager
return plugin_manager.rescan_plugin_directory() return plugin_manager.rescan_plugin_directory()

View File

@ -49,7 +49,7 @@ async def _send_to_target(
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
reply_to: str = "", reply_to: str = "",
reply_to_platform_id: str = "", reply_to_platform_id: Optional[str] = None,
storage_message: bool = True, storage_message: bool = True,
show_log: bool = True, show_log: bool = True,
) -> bool: ) -> bool:
@ -60,8 +60,11 @@ async def _send_to_target(
content: 消息内容 content: 消息内容
stream_id: 目标流ID stream_id: 目标流ID
display_message: 显示消息 display_message: 显示消息
typing: 是否显示正在输入 typing: 是否模拟打字等待
reply_to: 回复消息的格式"发送者:消息内容" reply_to: 回复消息格式为"发送者:消息内容"
reply_to_platform_id: 回复消息格式为"平台:用户ID"如果不提供则自动查找插件开发者禁用
storage_message: 是否存储消息到数据库
show_log: 发送是否显示日志
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
@ -97,6 +100,10 @@ async def _send_to_target(
anchor_message = None anchor_message = None
if reply_to: if reply_to:
anchor_message = await _find_reply_message(target_stream, reply_to) anchor_message = await _find_reply_message(target_stream, reply_to)
if anchor_message and anchor_message.message_info.user_info and not reply_to_platform_id:
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
)
# 构建发送消息对象 # 构建发送消息对象
bot_message = MessageSending( bot_message = MessageSending(
@ -262,12 +269,22 @@ async def text_to_stream(
stream_id: 聊天流ID stream_id: 聊天流ID
typing: 是否显示正在输入 typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容" reply_to: 回复消息格式为"发送者:消息内容"
reply_to_platform_id: 回复消息格式为"平台:用户ID"如果不提供则自动查找插件开发者禁用
storage_message: 是否存储消息到数据库 storage_message: 是否存储消息到数据库
Returns: Returns:
bool: 是否发送成功 bool: 是否发送成功
""" """
return await _send_to_target("text", text, stream_id, "", typing, reply_to, reply_to_platform_id, storage_message) return await _send_to_target(
"text",
text,
stream_id,
"",
typing,
reply_to,
reply_to_platform_id=reply_to_platform_id,
storage_message=storage_message,
)
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool: async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True) -> bool:
@ -318,7 +335,7 @@ async def command_to_stream(
async def custom_to_stream( async def custom_to_stream(
message_type: str, message_type: str,
content: str, content: str | dict,
stream_id: str, stream_id: str,
display_message: str = "", display_message: str = "",
typing: bool = False, typing: bool = False,
@ -350,249 +367,3 @@ async def custom_to_stream(
storage_message=storage_message, storage_message=storage_message,
show_log=show_log, show_log=show_log,
) )
async def text_to_group(
text: str,
group_id: str,
platform: str = "qq",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向群聊发送文本消息
Args:
text: 要发送的文本内容
group_id: 群聊ID
platform: 平台默认为"qq"
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def text_to_user(
text: str,
user_id: str,
platform: str = "qq",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向用户发送私聊文本消息
Args:
text: 要发送的文本内容
user_id: 用户ID
platform: 平台默认为"qq"
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("text", text, stream_id, "", typing, reply_to, storage_message=storage_message)
async def emoji_to_group(emoji_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送表情包
Args:
emoji_base64: 表情包的base64编码
group_id: 群聊ID
platform: 平台默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def emoji_to_user(emoji_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送表情包
Args:
emoji_base64: 表情包的base64编码
user_id: 用户ID
platform: 平台默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_group(image_base64: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送图片
Args:
image_base64: 图片的base64编码
group_id: 群聊ID
platform: 平台默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message)
async def image_to_user(image_base64: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送图片
Args:
image_base64: 图片的base64编码
user_id: 用户ID
platform: 平台默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("image", image_base64, stream_id, "", typing=False)
async def command_to_group(command: str, group_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向群聊发送命令
Args:
command: 命令
group_id: 群聊ID
platform: 平台默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
async def command_to_user(command: str, user_id: str, platform: str = "qq", storage_message: bool = True) -> bool:
"""向用户发送命令
Args:
command: 命令
user_id: 用户ID
platform: 平台默认为"qq"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target("command", command, stream_id, "", typing=False, storage_message=storage_message)
# =============================================================================
# 通用发送函数 - 支持任意消息类型
# =============================================================================
async def custom_to_group(
message_type: str,
content: str,
group_id: str,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向群聊发送自定义类型消息
Args:
message_type: 消息类型"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
group_id: 群聊ID
platform: 平台默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, group_id, True)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_to_user(
message_type: str,
content: str,
user_id: str,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""向用户发送自定义类型消息
Args:
message_type: 消息类型"text""image""emoji""video""file"
content: 消息内容通常是base64编码或文本
user_id: 用户ID
platform: 平台默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
"""
stream_id = get_chat_manager().get_stream_id(platform, user_id, False)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)
async def custom_message(
message_type: str,
content: str,
target_id: str,
is_group: bool = True,
platform: str = "qq",
display_message: str = "",
typing: bool = False,
reply_to: str = "",
storage_message: bool = True,
) -> bool:
"""发送自定义消息的通用接口
Args:
message_type: 消息类型"text""image""emoji""video""file""audio"
content: 消息内容
target_id: 目标ID群ID或用户ID
is_group: 是否为群聊True为群聊False为私聊
platform: 平台默认为"qq"
display_message: 显示消息
typing: 是否显示正在输入
reply_to: 回复消息格式为"发送者:消息内容"
Returns:
bool: 是否发送成功
示例:
# 发送视频到群聊
await send_api.custom_message("video", video_base64, "123456", True)
# 发送文件到用户
await send_api.custom_message("file", file_base64, "987654", False)
# 发送音频到群聊并回复特定消息
await send_api.custom_message("audio", audio_base64, "123456", True, reply_to="张三:你好")
"""
stream_id = get_chat_manager().get_stream_id(platform, target_id, is_group)
return await _send_to_target(
message_type, content, stream_id, display_message, typing, reply_to, storage_message=storage_message
)

View File

@ -0,0 +1,27 @@
from typing import Optional, Type
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
from src.common.logger import get_logger
logger = get_logger("tool_api")
def get_tool_instance(tool_name: str) -> Optional[BaseTool]:
"""获取公开工具实例"""
from src.plugin_system.core import component_registry
tool_class: Type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
return tool_class() if tool_class else None
def get_llm_available_tool_definitions():
"""获取LLM可用的工具定义列表
Returns:
List[Tuple[str, Dict[str, Any]]]: 工具定义列表[("tool_name", 定义)]
"""
from src.plugin_system.core import component_registry
llm_available_tools = component_registry.get_llm_available_tools()
return [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]

View File

@ -1,168 +0,0 @@
"""工具类API模块
提供了各种辅助功能
使用方式
from src.plugin_system.apis import utils_api
plugin_path = utils_api.get_plugin_path()
data = utils_api.read_json_file("data.json")
timestamp = utils_api.get_timestamp()
"""
import os
import json
import time
import inspect
import datetime
import uuid
from typing import Any, Optional
from src.common.logger import get_logger
logger = get_logger("utils_api")
# =============================================================================
# 文件操作API函数
# =============================================================================
def get_plugin_path(caller_frame=None) -> str:
"""获取调用者插件的路径
Args:
caller_frame: 调用者的栈帧默认为None自动获取
Returns:
str: 插件目录的绝对路径
"""
try:
if caller_frame is None:
caller_frame = inspect.currentframe().f_back # type: ignore
plugin_module_path = inspect.getfile(caller_frame) # type: ignore
plugin_dir = os.path.dirname(plugin_module_path)
return plugin_dir
except Exception as e:
logger.error(f"[UtilsAPI] 获取插件路径失败: {e}")
return ""
def read_json_file(file_path: str, default: Any = None) -> Any:
"""读取JSON文件
Args:
file_path: 文件路径可以是相对于插件目录的路径
default: 如果文件不存在或读取失败时返回的默认值
Returns:
Any: JSON数据或默认值
"""
try:
# 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back # type: ignore
plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path)
if not os.path.exists(file_path):
logger.warning(f"[UtilsAPI] 文件不存在: {file_path}")
return default
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"[UtilsAPI] 读取JSON文件出错: {e}")
return default
def write_json_file(file_path: str, data: Any, indent: int = 2) -> bool:
"""写入JSON文件
Args:
file_path: 文件路径可以是相对于插件目录的路径
data: 要写入的数据
indent: JSON缩进
Returns:
bool: 是否写入成功
"""
try:
# 如果是相对路径,则相对于调用者的插件目录
if not os.path.isabs(file_path):
caller_frame = inspect.currentframe().f_back # type: ignore
plugin_dir = get_plugin_path(caller_frame)
file_path = os.path.join(plugin_dir, file_path)
# 确保目录存在
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=indent)
return True
except Exception as e:
logger.error(f"[UtilsAPI] 写入JSON文件出错: {e}")
return False
# =============================================================================
# 时间相关API函数
# =============================================================================
def get_timestamp() -> int:
"""获取当前时间戳
Returns:
int: 当前时间戳
"""
return int(time.time())
def format_time(timestamp: Optional[int | float] = None, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""格式化时间
Args:
timestamp: 时间戳如果为None则使用当前时间
format_str: 时间格式字符串
Returns:
str: 格式化后的时间字符串
"""
try:
if timestamp is None:
timestamp = time.time()
return datetime.datetime.fromtimestamp(timestamp).strftime(format_str)
except Exception as e:
logger.error(f"[UtilsAPI] 格式化时间失败: {e}")
return ""
def parse_time(time_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> int:
"""解析时间字符串为时间戳
Args:
time_str: 时间字符串
format_str: 时间格式字符串
Returns:
int: 时间戳
"""
try:
dt = datetime.datetime.strptime(time_str, format_str)
return int(dt.timestamp())
except Exception as e:
logger.error(f"[UtilsAPI] 解析时间失败: {e}")
return 0
# =============================================================================
# 其他工具函数
# =============================================================================
def generate_unique_id() -> str:
"""生成唯一ID
Returns:
str: 唯一ID
"""
return str(uuid.uuid4())

View File

@ -6,6 +6,7 @@
from .base_plugin import BasePlugin from .base_plugin import BasePlugin
from .base_action import BaseAction from .base_action import BaseAction
from .base_tool import BaseTool
from .base_command import BaseCommand from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler from .base_events_handler import BaseEventHandler
from .component_types import ( from .component_types import (
@ -15,11 +16,13 @@ from .component_types import (
ComponentInfo, ComponentInfo,
ActionInfo, ActionInfo,
CommandInfo, CommandInfo,
ToolInfo,
PluginInfo, PluginInfo,
PythonDependency, PythonDependency,
EventHandlerInfo, EventHandlerInfo,
EventType, EventType,
MaiMessages, MaiMessages,
ToolParamType,
) )
from .config_types import ConfigField from .config_types import ConfigField
@ -27,12 +30,14 @@ __all__ = [
"BasePlugin", "BasePlugin",
"BaseAction", "BaseAction",
"BaseCommand", "BaseCommand",
"BaseTool",
"ComponentType", "ComponentType",
"ActionActivationType", "ActionActivationType",
"ChatMode", "ChatMode",
"ComponentInfo", "ComponentInfo",
"ActionInfo", "ActionInfo",
"CommandInfo", "CommandInfo",
"ToolInfo",
"PluginInfo", "PluginInfo",
"PythonDependency", "PythonDependency",
"ConfigField", "ConfigField",
@ -40,4 +45,5 @@ __all__ = [
"EventType", "EventType",
"BaseEventHandler", "BaseEventHandler",
"MaiMessages", "MaiMessages",
"ToolParamType",
] ]

Some files were not shown because too many files have changed in this diff Show More