Merge branch 'dev' into feat-lpmm知识库加强

pull/1386/head
Dawn ARC 2025-12-18 18:58:10 +08:00 committed by GitHub
commit 20c9cbad3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
166 changed files with 22730 additions and 3447 deletions

View File

@ -2,8 +2,7 @@ name: Docker Build and Push (Dev)
on:
schedule:
- cron: '0 0 * * *'
# push:
- cron: '0 0 * * *' # every day at midnight UTC
# branches:
# - dev
workflow_dispatch: # 允许手动触发工作流
@ -28,11 +27,11 @@ jobs:
fetch-depth: 0
# Clone required dependencies
- name: Clone maim_message
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
# - name: Clone maim_message
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
- name: Clone lpmm
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@ -82,11 +81,11 @@ jobs:
fetch-depth: 0
# Clone required dependencies
- name: Clone maim_message
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
# - name: Clone maim_message
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
- name: Clone lpmm
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@ -147,8 +146,8 @@ jobs:
with:
images: ${{ secrets.DOCKERHUB_USERNAME }}/maibot
tags: |
type=ref,event=branch
type=sha,prefix=${{ github.ref_name }}-,enable=${{ github.ref_type == 'branch' }}
type=raw,value=dev
type=schedule,pattern=dev-{{date 'YYMMDD'}}
- name: Create and Push Manifest
run: |

View File

@ -31,11 +31,11 @@ jobs:
fetch-depth: 0
# Clone required dependencies
- name: Clone maim_message
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
# - name: Clone maim_message
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
- name: Clone lpmm
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@ -84,11 +84,11 @@ jobs:
fetch-depth: 0
# Clone required dependencies
- name: Clone maim_message
run: git clone https://github.com/MaiM-with-u/maim_message maim_message
# - name: Clone maim_message
# run: git clone https://github.com/MaiM-with-u/maim_message maim_message
- name: Clone lpmm
run: git clone https://github.com/MaiM-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
run: git clone https://github.com/Mai-with-u/MaiMBot-LPMM.git MaiMBot-LPMM
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

7
.gitignore vendored
View File

@ -35,9 +35,6 @@ message_queue_content.bat
message_queue_window.bat
message_queue_window.txt
queue_update.txt
memory_graph.gml
/src/tools/tool_can_use/auto_create_tool.py
/src/tools/tool_can_use/execute_python_code_tool.py
.env
.env.*
.cursor
@ -48,9 +45,6 @@ config/lpmm_config.toml
config/lpmm_config.toml.bak
template/compare/bot_config_template.toml
template/compare/model_config_template.toml
(测试版)麦麦生成人格.bat
(临时版)麦麦开始学习.bat
src/plugins/utils/statistic.py
CLAUDE.md
MaiBot-Dashboard/
cloudflare-workers/
@ -327,6 +321,7 @@ run_pet.bat
!/plugins/emoji_manage_plugin
!/plugins/take_picture_plugin
!/plugins/deep_think
!/plugins/MaiBot_MCPBridgePlugin
!/plugins/ChatFrequency/
!/plugins/__init__.py

View File

@ -71,7 +71,6 @@
1. **GitHub Issues**: 对于公开的违规行为可以在相关issue中直接指出
2. **私下联系**: 可以通过GitHub私信联系项目维护者
3. **邮件联系**: [如果有项目邮箱地址,请在此提供]
所有报告都将得到及时和公正的处理。我们承诺保护报告者的隐私和安全。

27
bot.py
View File

@ -34,13 +34,17 @@ else:
print(f"自动创建 .env 失败: {e}")
raise
initialize_logging()
# 检查是否是 Worker 进程,只在 Worker 进程中输出详细的初始化信息
# Runner 进程只需要基本的日志功能,不需要详细的初始化日志
is_worker = os.environ.get("MAIBOT_WORKER_PROCESS") == "1"
initialize_logging(verbose=is_worker)
install(extra_lines=3)
logger = get_logger("main")
# 定义重启退出码
RESTART_EXIT_CODE = 42
def run_runner_process():
"""
Runner 进程逻辑作为守护进程运行负责启动和监控 Worker 进程
@ -55,25 +59,25 @@ def run_runner_process():
while True:
logger.info(f"正在启动 {script_file}...")
# 启动子进程 (Worker)
# 使用 sys.executable 确保使用相同的 Python 解释器
cmd = [python_executable, script_file] + sys.argv[1:]
process = subprocess.Popen(cmd, env=env)
try:
# 等待子进程结束
return_code = process.wait()
if return_code == RESTART_EXIT_CODE:
logger.info("检测到重启请求 (退出码 42),正在重启...")
time.sleep(1) # 稍作等待
time.sleep(1) # 稍作等待
continue
else:
logger.info(f"程序已退出 (退出码 {return_code})")
sys.exit(return_code)
except KeyboardInterrupt:
# 向子进程发送终止信号
if process.poll() is None:
@ -87,6 +91,7 @@ def run_runner_process():
process.kill()
sys.exit(0)
# 检查是否是 Worker 进程
# 如果没有设置 MAIBOT_WORKER_PROCESS 环境变量,说明是直接运行的脚本,
# 此时应该作为 Runner 运行。
@ -99,8 +104,10 @@ if os.environ.get("MAIBOT_WORKER_PROCESS") != "1":
# 以下是 Worker 进程的逻辑
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
# from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
# initialize_logging()
# 注意Runner 进程已经在第 37 行初始化了日志系统,但 Worker 进程是独立进程,需要重新初始化
# 由于 Runner 和 Worker 是不同进程,它们有独立的内存空间,所以都会初始化一次
# 这是正常的,但为了避免重复的初始化日志,我们在 initialize_logging() 中添加了防重复机制
# 不过由于是不同进程,每个进程仍会初始化一次,这是预期的行为
from src.main import MainSystem # noqa
from src.manager.async_task_manager import async_task_manager # noqa
@ -143,7 +150,7 @@ def print_opensource_notice():
"",
f"{Fore.WHITE} 官方仓库: {Fore.BLUE}https://github.com/MaiM-with-u/MaiBot {Style.RESET_ALL}",
f"{Fore.WHITE} 官方文档: {Fore.BLUE}https://docs.mai-mai.org {Style.RESET_ALL}",
f"{Fore.WHITE} 官方群聊: {Fore.BLUE}766798517{Style.RESET_ALL}",
f"{Fore.WHITE} 官方群聊: {Fore.BLUE}1006149251{Style.RESET_ALL}",
f"{Fore.CYAN}{'' * 70}{Style.RESET_ALL}",
f"{Fore.RED} ⚠ 将本软件作为「商品」倒卖、隐瞒开源性质均违反协议!{Style.RESET_ALL}",
f"{Fore.CYAN}{'' * 70}{Style.RESET_ALL}",

View File

@ -1,5 +1,21 @@
# Changelog
## [0.12.0] - 2025-12-16
### 🌟 重大更新
- 添加思考力度机制,动态控制回复时间和长度
- planner和replyer现在开启联动更好的回复逻辑
- 新的私聊系统吸收了pfc的优秀机制
- 增加麦麦做梦功能
- mcp插件作为内置插件加入默认不启用
- 添加全局记忆配置项,现在可以选择让记忆为全局的
### 细节功能更改
- 移除频率自动调整
- 移除情绪功能
- 优化记忆差许多呢超时设置
- 部分配置为0的bug
- 黑话和表达不再提取包含名称的内容
## [0.11.6] - 2025-12-2
### 🌟 重大更新
- 大幅提高记忆检索能力略微提高token消耗

10
dummy 100644
View File

@ -0,0 +1,10 @@
{
"cells": [],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -0,0 +1,569 @@
# MCP 桥接插件 - 开发文档
本文档面向 AI 助手或开发者进行插件开发/维护。
## 前置知识
本插件基于 MaiBot 插件系统开发,需要了解:
- MaiBot 插件框架:`BasePlugin`, `BaseTool`, `BaseCommand`, `BaseEventHandler`
- 配置系统:`ConfigField`, `config_schema`
- 组件注册:`component_registry.register_component()`
详见项目根目录 `.kiro/steering/plugin-dev.md`
---
## 版本历史
| 版本 | 主要功能 |
|------|----------|
| v1.5.4 | 易用性优化:新增 MCP 服务器获取快捷入口 |
| v1.5.3 | 配置优化:新增智能心跳 WebUI 配置项 |
| v1.5.2 | 性能优化:智能心跳间隔,根据服务器稳定性动态调整 |
| v1.5.1 | 易用性优化:新增「快速添加服务器」表单式配置 |
| v1.5.0 | 性能优化:服务器并行连接,大幅减少启动时间 |
| v1.4.4 | 修复首次生成默认配置文件时多行字符串导致 TOML 解析失败 |
| v1.4.3 | 修复 WebUI 保存配置后多行字符串格式错误导致配置文件无法读取 |
| v1.4.2 | HTTP 鉴权头支持headers 字段) |
| v1.4.0 | 工具禁用、调用追踪、缓存、权限控制、WebUI 易用性改进 |
| v1.3.0 | 结果后处理LLM 摘要提炼) |
| v1.2.0 | Resources/Prompts 支持(实验性) |
| v1.1.x | 心跳检测、自动重连、调用统计、`/mcp` 命令 |
| v1.0.0 | 基础 MCP 桥接 |
---
## 项目结构
```
MCPBridgePlugin/
├── plugin.py # 主插件逻辑1800+ 行)
├── mcp_client.py # MCP 客户端封装800+ 行)
├── _manifest.json # 插件清单
├── config.example.toml # 配置示例
├── requirements.txt # 依赖mcp>=1.0.0
├── README.md # 用户文档
└── DEVELOPMENT.md # 开发文档(本文件)
```
---
## 核心模块详解
### 1. mcp_client.py - MCP 客户端
负责与 MCP 服务器通信,可独立于 MaiBot 运行测试。
#### 数据类
```python
class TransportType(Enum):
STDIO = "stdio" # 本地进程
SSE = "sse" # Server-Sent Events
HTTP = "http" # HTTP
STREAMABLE_HTTP = "streamable_http" # HTTP Streamable推荐
@dataclass
class MCPServerConfig:
name: str # 服务器唯一标识
enabled: bool = True
transport: TransportType = TransportType.STDIO
command: str = "" # stdio: 启动命令
args: List[str] = field(default_factory=list) # stdio: 参数
env: Dict[str, str] = field(default_factory=dict) # stdio: 环境变量
url: str = "" # http/sse: 服务器 URL
@dataclass
class MCPToolInfo:
name: str # 工具原始名称
description: str
input_schema: Dict[str, Any] # JSON Schema
server_name: str
@dataclass
class MCPCallResult:
success: bool
content: str = ""
error: Optional[str] = None
duration_ms: float = 0.0
@dataclass
class MCPResourceInfo:
uri: str
name: str
description: str = ""
mime_type: Optional[str] = None
server_name: str = ""
@dataclass
class MCPPromptInfo:
name: str
description: str = ""
arguments: List[Dict[str, Any]] = field(default_factory=list)
server_name: str = ""
```
#### MCPClientSession
管理单个 MCP 服务器连接。
```python
class MCPClientSession:
def __init__(self, config: MCPServerConfig): ...
async def connect(self) -> bool:
"""连接服务器,返回是否成功"""
async def disconnect(self) -> None:
"""断开连接"""
async def call_tool(self, tool_name: str, arguments: Dict) -> MCPCallResult:
"""调用工具"""
async def check_health(self) -> bool:
"""健康检查(用于心跳)"""
async def fetch_resources(self) -> bool:
"""获取资源列表"""
async def read_resource(self, uri: str) -> MCPCallResult:
"""读取资源"""
async def fetch_prompts(self) -> bool:
"""获取提示模板列表"""
async def get_prompt(self, name: str, arguments: Optional[Dict]) -> MCPCallResult:
"""获取提示模板"""
@property
def tools(self) -> List[MCPToolInfo]: ...
@property
def resources(self) -> List[MCPResourceInfo]: ...
@property
def prompts(self) -> List[MCPPromptInfo]: ...
@property
def is_connected(self) -> bool: ...
```
#### MCPClientManager
全局单例,管理多服务器。
```python
class MCPClientManager:
def configure(self, settings: Dict) -> None:
"""配置超时、重试等参数"""
async def add_server(self, config: MCPServerConfig) -> bool:
"""添加并连接服务器"""
async def remove_server(self, server_name: str) -> bool:
"""移除服务器"""
async def reconnect_server(self, server_name: str) -> bool:
"""重连服务器"""
async def call_tool(self, tool_key: str, arguments: Dict) -> MCPCallResult:
"""调用工具tool_key 格式: mcp_{server}_{tool}"""
async def start_heartbeat(self) -> None:
"""启动心跳检测"""
async def shutdown(self) -> None:
"""关闭所有连接"""
def get_status(self) -> Dict[str, Any]:
"""获取状态"""
def get_all_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
def set_status_change_callback(self, callback: Callable) -> None:
"""设置状态变化回调"""
@property
def all_tools(self) -> Dict[str, Tuple[MCPToolInfo, MCPClientSession]]: ...
@property
def all_resources(self) -> Dict[str, Tuple[MCPResourceInfo, MCPClientSession]]: ...
@property
def all_prompts(self) -> Dict[str, Tuple[MCPPromptInfo, MCPClientSession]]: ...
@property
def disconnected_servers(self) -> List[str]: ...
# 全局单例
mcp_manager = MCPClientManager()
```
---
### 2. plugin.py - MaiBot 插件
#### v1.4.0 新增模块
```python
# ============ 调用追踪 ============
@dataclass
class ToolCallRecord:
call_id: str # UUID
timestamp: float
tool_name: str
server_name: str
chat_id: str = ""
user_id: str = ""
user_query: str = ""
arguments: Dict = field(default_factory=dict)
raw_result: str = ""
processed_result: str = ""
duration_ms: float = 0.0
success: bool = True
error: str = ""
post_processed: bool = False
cache_hit: bool = False
class ToolCallTracer:
def configure(self, enabled: bool, max_records: int, log_enabled: bool, log_path: Path): ...
def record(self, record: ToolCallRecord) -> None: ...
def get_recent(self, n: int = 10) -> List[ToolCallRecord]: ...
def get_by_tool(self, tool_name: str) -> List[ToolCallRecord]: ...
def clear(self) -> None: ...
tool_call_tracer = ToolCallTracer()
# ============ 调用缓存 ============
@dataclass
class CacheEntry:
tool_name: str
args_hash: str # MD5(tool_name + sorted_json_args)
result: str
created_at: float
expires_at: float
hit_count: int = 0
class ToolCallCache:
def configure(self, enabled: bool, ttl: int, max_entries: int, exclude_tools: str): ...
def get(self, tool_name: str, args: Dict) -> Optional[str]: ...
def set(self, tool_name: str, args: Dict, result: str) -> None: ...
def clear(self) -> None: ...
def get_stats(self) -> Dict[str, Any]: ...
tool_call_cache = ToolCallCache()
# ============ 权限控制 ============
class PermissionChecker:
def configure(self, enabled: bool, default_mode: str, rules_json: str,
quick_deny_groups: str = "", quick_allow_users: str = ""): ...
def check(self, tool_name: str, chat_id: str, user_id: str, is_group: bool) -> bool: ...
def get_rules_for_tool(self, tool_name: str) -> List[Dict]: ...
permission_checker = PermissionChecker()
```
#### 工具代理
```python
class MCPToolProxy(BaseTool):
"""所有 MCP 工具的基类"""
# 类属性(动态子类覆盖)
name: str = ""
description: str = ""
parameters: List[Tuple] = []
available_for_llm: bool = True
# MCP 属性
_mcp_tool_key: str = ""
_mcp_original_name: str = ""
_mcp_server_name: str = ""
async def execute(self, function_args: Dict) -> Dict[str, Any]:
"""执行流程:
1. 权限检查 → 拒绝则返回错误
2. 缓存检查 → 命中则返回缓存
3. 调用 MCP 服务器
4. 存入缓存
5. 后处理(可选)
6. 记录追踪
7. 返回结果
"""
def create_mcp_tool_class(tool_key: str, tool_info: MCPToolInfo,
tool_prefix: str, disabled: bool = False) -> Type[MCPToolProxy]:
"""动态创建工具类"""
```
#### 内置工具
```python
class MCPStatusTool(BaseTool):
"""mcp_status - 查询状态/工具/资源/模板/统计/追踪/缓存"""
name = "mcp_status"
parameters = [
("query_type", STRING, "查询类型", False,
["status", "tools", "resources", "prompts", "stats", "trace", "cache", "all"]),
("server_name", STRING, "服务器名称", False, None),
]
class MCPReadResourceTool(BaseTool):
"""mcp_read_resource - 读取资源"""
name = "mcp_read_resource"
class MCPGetPromptTool(BaseTool):
"""mcp_get_prompt - 获取提示模板"""
name = "mcp_get_prompt"
```
#### 命令
```python
class MCPStatusCommand(BaseCommand):
"""处理 /mcp 命令"""
command_pattern = r"^[/]mcp(?:\s+(?P<subcommand>status|tools|stats|reconnect|trace|cache|perm))?(?:\s+(?P<arg>\S+))?$"
# 子命令处理
async def _handle_reconnect(self, server_name): ...
async def _handle_trace(self, arg): ...
async def _handle_cache(self, arg): ...
async def _handle_perm(self, arg): ...
```
#### 事件处理器
```python
class MCPStartupHandler(BaseEventHandler):
"""ON_START - 连接服务器、注册工具"""
event_type = EventType.ON_START
class MCPStopHandler(BaseEventHandler):
"""ON_STOP - 关闭连接"""
event_type = EventType.ON_STOP
```
#### 主插件类
```python
@register_plugin
class MCPBridgePlugin(BasePlugin):
plugin_name = "mcp_bridge_plugin"
python_dependencies = ["mcp"]
config_section_descriptions = {
"guide": "📖 快速入门",
"servers": "🔌 服务器配置",
"status": "📊 运行状态",
"plugin": "插件开关",
"settings": "⚙️ 高级设置",
"tools": "🔧 工具管理",
"permissions": "🔐 权限控制",
}
config_schema = {
"guide": { "quick_start": ConfigField(...) },
"plugin": { "enabled": ConfigField(...) },
"settings": {
# 基础tool_prefix, connect_timeout, call_timeout, auto_connect, retry_*
# 心跳heartbeat_enabled, heartbeat_interval, auto_reconnect, max_reconnect_attempts
# 高级enable_resources, enable_prompts
# 后处理post_process_*
# 追踪trace_*
# 缓存cache_*
},
"tools": { "tool_list", "disabled_tools" },
"permissions": { "perm_enabled", "perm_default_mode", "quick_deny_groups", "quick_allow_users", "perm_rules" },
"servers": { "list" },
"status": { "connection_status" },
}
def __init__(self):
# 配置 mcp_manager, tool_call_tracer, tool_call_cache, permission_checker
async def _async_connect_servers(self):
# 解析配置 → 连接服务器 → 注册工具(检查禁用列表)
def _update_status_display(self):
# 更新 WebUI 状态显示
def _update_tool_list_display(self):
# 更新工具清单显示
```
---
## 数据流
```
MaiBot 启动
MCPBridgePlugin.__init__()
├─ mcp_manager.configure(settings)
├─ tool_call_tracer.configure(...)
├─ tool_call_cache.configure(...)
└─ permission_checker.configure(...)
ON_START 事件 → MCPStartupHandler.execute()
_async_connect_servers()
├─ 解析 servers.list JSON
├─ 遍历服务器配置
│ ├─ mcp_manager.add_server(config)
│ ├─ 获取工具列表
│ ├─ 检查 disabled_tools
│ └─ component_registry.register_component(tool_info, tool_class)
├─ _update_status_display()
└─ _update_tool_list_display()
mcp_manager.start_heartbeat()
LLM 调用工具 → MCPToolProxy.execute(function_args)
├─ 1. permission_checker.check() → 拒绝则返回错误
├─ 2. tool_call_cache.get() → 命中则跳到步骤 5
├─ 3. mcp_manager.call_tool()
├─ 4. tool_call_cache.set()
├─ 5. _post_process_result() (如果启用且超过阈值)
├─ 6. tool_call_tracer.record()
└─ 7. 返回 {"name": ..., "content": ...}
ON_STOP 事件 → MCPStopHandler.execute()
mcp_manager.shutdown()
mcp_tool_registry.clear()
```
---
## 配置项速查
### settings高级设置
| 配置项 | 类型 | 默认值 | 说明 |
|--------|------|--------|------|
| tool_prefix | str | "mcp" | 工具名前缀 |
| connect_timeout | float | 30.0 | 连接超时(秒) |
| call_timeout | float | 60.0 | 调用超时(秒) |
| auto_connect | bool | true | 自动连接 |
| retry_attempts | int | 3 | 重试次数 |
| retry_interval | float | 5.0 | 重试间隔 |
| heartbeat_enabled | bool | true | 心跳检测 |
| heartbeat_interval | float | 60.0 | 心跳间隔 |
| auto_reconnect | bool | true | 自动重连 |
| max_reconnect_attempts | int | 3 | 最大重连次数 |
| enable_resources | bool | false | Resources 支持 |
| enable_prompts | bool | false | Prompts 支持 |
| post_process_enabled | bool | false | 结果后处理 |
| post_process_threshold | int | 500 | 后处理阈值 |
| trace_enabled | bool | true | 调用追踪 |
| trace_max_records | int | 100 | 追踪记录上限 |
| cache_enabled | bool | false | 调用缓存 |
| cache_ttl | int | 300 | 缓存 TTL |
| cache_max_entries | int | 200 | 最大缓存条目 |
### permissions权限控制
| 配置项 | 说明 |
|--------|------|
| perm_enabled | 启用权限控制 |
| perm_default_mode | allow_all / deny_all |
| quick_deny_groups | 禁用群列表(每行一个群号) |
| quick_allow_users | 管理员白名单(每行一个 QQ 号) |
| perm_rules | 高级规则 JSON |
---
## 扩展开发示例
### 添加新命令子命令
```python
# 1. 修改 command_pattern
command_pattern = r"^[/]mcp(?:\s+(?P<subcommand>status|...|newcmd))?..."
# 2. 在 execute() 添加分支
if subcommand == "newcmd":
return await self._handle_newcmd(arg)
# 3. 实现处理方法
async def _handle_newcmd(self, arg: str = None):
# 处理逻辑
await self.send_text("结果")
return (True, None, True)
```
### 添加新配置项
```python
# 1. config_schema 添加
"settings": {
"new_option": ConfigField(
type=bool,
default=False,
description="新选项说明",
label="🆕 新选项",
order=50,
),
}
# 2. 在 __init__ 或相应方法中读取
new_option = settings.get("new_option", False)
```
### 添加新的全局模块
```python
# 1. 定义数据类和管理类
@dataclass
class NewRecord:
...
class NewManager:
def configure(self, ...): ...
def do_something(self, ...): ...
new_manager = NewManager()
# 2. 在 MCPBridgePlugin.__init__ 中配置
new_manager.configure(...)
# 3. 在 MCPToolProxy.execute() 中使用
result = new_manager.do_something(...)
```
---
## 调试
```python
# 导入
from plugins.MCPBridgePlugin.mcp_client import mcp_manager
from plugins.MCPBridgePlugin.plugin import tool_call_tracer, tool_call_cache, permission_checker
# 检查状态
mcp_manager.get_status()
mcp_manager.get_all_stats()
# 追踪记录
tool_call_tracer.get_recent(10)
# 缓存状态
tool_call_cache.get_stats()
# 手动调用
result = await mcp_manager.call_tool("mcp_server_tool", {"arg": "value"})
```
---
## 依赖
- MaiBot >= 0.11.6
- Python >= 3.10
- mcp >= 1.0.0
## 许可证
AGPL-3.0

View File

@ -0,0 +1,220 @@
# MCP 桥接插件
将 [MCP (Model Context Protocol)](https://modelcontextprotocol.io/) 服务器的工具桥接到 MaiBot使麦麦能够调用外部 MCP 工具。
<img width="3012" height="1794" alt="image" src="https://github.com/user-attachments/assets/ece56404-301a-4abf-b16d-87bd430fc977" />
## 🚀 快速开始
### 1. 安装
```bash
# 克隆到 MaiBot 插件目录
cd /path/to/MaiBot/plugins
git clone https://github.com/CharTyr/MaiBot_MCPBridgePlugin.git MCPBridgePlugin
# 安装依赖
pip install mcp
# 复制配置文件
cd MCPBridgePlugin
cp config.example.toml config.toml
```
### 2. 添加服务器
编辑 `config.toml`,在 `[servers]``list` 中添加服务器:
**免费服务器:**
```json
{"name": "time", "enabled": true, "transport": "streamable_http", "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time"}
```
**带鉴权的服务器v1.4.2**
```json
{"name": "my-server", "enabled": true, "transport": "streamable_http", "url": "https://mcp.xxx.com/mcp", "headers": {"Authorization": "Bearer 你的密钥"}}
```
**本地服务器(需要 uvx**
```json
{"name": "fetch", "enabled": true, "transport": "stdio", "command": "uvx", "args": ["mcp-server-fetch"]}
```
### 3. 启动
重启 MaiBot或发送 `/mcp reconnect`
---
## 📚 去哪找 MCP 服务器?
| 平台 | 说明 |
|------|------|
| [mcp.modelscope.cn](https://mcp.modelscope.cn/) | 魔搭 ModelScope免费推荐 |
| [smithery.ai](https://smithery.ai/) | MCP 服务器注册中心 |
| [github.com/modelcontextprotocol/servers](https://github.com/modelcontextprotocol/servers) | 官方服务器列表 |
---
## 💡 常用命令
| 命令 | 说明 |
|------|------|
| `/mcp` | 查看连接状态 |
| `/mcp tools` | 查看可用工具 |
| `/mcp reconnect` | 重连服务器 |
| `/mcp trace` | 查看调用记录 |
| `/mcp cache` | 查看缓存状态 |
| `/mcp perm` | 查看权限配置 |
| `/mcp import <json>` | 🆕 导入 Claude Desktop 配置 |
| `/mcp export [claude]` | 🆕 导出配置 |
| `/mcp search <关键词>` | 🆕 搜索工具 |
---
## ✨ 功能特性
### 核心功能
- 🔌 多服务器同时连接
- 📡 支持 stdio / SSE / HTTP / Streamable HTTP
- 🔄 自动重试、心跳检测、断线重连
- 🖥️ WebUI 完整配置支持
### v1.7.0 新增
- ⚡ **断路器模式** - 故障服务器快速失败,避免拖慢整体响应
- 🔄 **状态实时刷新** - WebUI 自动更新连接状态(可配置间隔)
- 🔍 **工具搜索** - `/mcp search <关键词>` 快速查找工具
### v1.6.0 新增
- 📥 **配置导入** - 从 Claude Desktop 格式一键导入
- 📤 **配置导出** - 导出为 Claude Desktop / Kiro / MaiBot 格式
### v1.4.0 新增
- 🚫 **工具禁用** - WebUI 直接禁用不想用的工具
- 🔍 **调用追踪** - 记录每次调用详情,便于调试
- 🗄️ **调用缓存** - 相同请求自动缓存
- 🔐 **权限控制** - 按群/用户限制工具使用
### 高级功能
- 📦 Resources 支持(实验性)
- 📝 Prompts 支持(实验性)
- 🔄 结果后处理LLM 摘要提炼)
---
## ⚙️ 配置说明
### 服务器配置
```json
[
{
"name": "服务器名",
"enabled": true,
"transport": "streamable_http",
"url": "https://..."
}
]
```
| 字段 | 说明 |
|------|------|
| `name` | 服务器名称(唯一) |
| `enabled` | 是否启用 |
| `transport` | `stdio` / `sse` / `http` / `streamable_http` |
| `url` | 远程服务器地址 |
| `headers` | 🆕 鉴权头(如 `{"Authorization": "Bearer xxx"}` |
| `command` / `args` | 本地服务器启动命令 |
### 权限控制v1.4.0
**快捷配置(推荐):**
```toml
[permissions]
perm_enabled = true
quick_deny_groups = "123456789" # 禁用的群号
quick_allow_users = "111111111" # 管理员白名单
```
**高级规则:**
```json
[{"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}]
```
### 工具禁用
```toml
[tools]
disabled_tools = '''
mcp_filesystem_delete_file
mcp_filesystem_write_file
'''
```
### 调用缓存
```toml
[settings]
cache_enabled = true
cache_ttl = 300
cache_exclude_tools = "mcp_*_time_*"
```
---
## ❓ 常见问题
**Q: 工具没有注册?**
- 检查 `enabled = true`
- 检查 MaiBot 日志错误信息
- 确认 `pip install mcp`
**Q: JSON 格式报错?**
- 多行 JSON 用 `'''` 三引号包裹
- 使用英文双引号 `"`
**Q: 如何手动重连?**
- `/mcp reconnect``/mcp reconnect 服务器名`
---
## 📥 配置导入导出v1.6.0
### 从 Claude Desktop 导入
如果你已有 Claude Desktop 的 MCP 配置,可以直接导入:
```
/mcp import {"mcpServers":{"time":{"command":"uvx","args":["mcp-server-time"]},"fetch":{"command":"uvx","args":["mcp-server-fetch"]}}}
```
支持的格式:
- Claude Desktop 格式(`mcpServers` 对象)
- Kiro MCP 格式
- MaiBot 格式(数组)
### 导出配置
```
/mcp export # 导出为 Claude Desktop 格式(默认)
/mcp export claude # 导出为 Claude Desktop 格式
/mcp export kiro # 导出为 Kiro MCP 格式
/mcp export maibot # 导出为 MaiBot 格式
```
### 注意事项
- 导入时会自动跳过同名服务器
- 导入后需要发送 `/mcp reconnect` 使配置生效
- 支持 stdio、sse、http、streamable_http 全部传输类型
---
## 📋 依赖
- MaiBot >= 0.11.6
- Python >= 3.10
- mcp >= 1.0.0
## 📄 许可证
AGPL-3.0

View File

@ -0,0 +1,44 @@
"""
MCP 桥接插件
MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot
v1.1.0 新增功能:
- 心跳检测和自动重连
- 调用统计次数成功率耗时
- 更好的错误处理
v1.2.0 新增功能:
- Resources 支持资源读取
- Prompts 支持提示模板
"""
from .plugin import MCPBridgePlugin, mcp_tool_registry, MCPStartupHandler, MCPStopHandler
from .mcp_client import (
mcp_manager,
MCPClientManager,
MCPServerConfig,
TransportType,
MCPCallResult,
MCPToolInfo,
MCPResourceInfo,
MCPPromptInfo,
ToolCallStats,
ServerStats,
)
__all__ = [
"MCPBridgePlugin",
"mcp_tool_registry",
"mcp_manager",
"MCPClientManager",
"MCPServerConfig",
"TransportType",
"MCPCallResult",
"MCPToolInfo",
"MCPResourceInfo",
"MCPPromptInfo",
"ToolCallStats",
"ServerStats",
"MCPStartupHandler",
"MCPStopHandler",
]

View File

@ -0,0 +1,60 @@
{
"manifest_version": 1,
"name": "MCP桥接插件",
"version": "1.7.0",
"description": "将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot使麦麦能够调用外部 MCP 工具",
"author": {
"name": "CharTyr",
"url": "https://github.com/CharTyr"
},
"license": "AGPL-3.0",
"host_application": {
"min_version": "0.11.6"
},
"homepage_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"repository_url": "https://github.com/CharTyr/MaiBot_MCPBridgePlugin",
"keywords": [
"mcp",
"bridge",
"tool",
"integration",
"resources",
"prompts",
"post-process",
"cache",
"trace",
"permissions",
"import",
"export",
"claude-desktop"
],
"categories": [
"工具扩展",
"外部集成"
],
"default_locale": "zh-CN",
"plugin_info": {
"is_built_in": false,
"components": [],
"features": [
"支持多个 MCP 服务器",
"自动发现并注册 MCP 工具",
"支持 stdio、SSE、HTTP、Streamable HTTP 四种传输方式",
"工具参数自动转换",
"心跳检测与自动重连",
"调用统计(次数、成功率、耗时)",
"WebUI 配置支持",
"Resources 支持(实验性)",
"Prompts 支持(实验性)",
"结果后处理LLM 摘要提炼)",
"工具禁用管理",
"调用链路追踪",
"工具调用缓存LRU",
"工具权限控制(群/用户级别)",
"配置导入导出Claude Desktop / Kiro 格式)",
"断路器模式(故障快速失败)",
"状态实时刷新"
]
},
"id": "MaiBot Community.MCPBridgePlugin"
}

View File

@ -0,0 +1,334 @@
# MCP桥接插件 v1.7.0 - 配置文件示例
# 将 MCP (Model Context Protocol) 服务器的工具桥接到 MaiBot
#
# 使用方法:复制此文件为 config.toml然后根据需要修改配置
#
# ============================================================
# 🎯 快速开始(三步)
# ============================================================
# 1. 在下方 [servers] 添加 MCP 服务器配置
# 2. 将 enabled 改为 true 启用服务器
# 3. 重启 MaiBot 或发送 /mcp reconnect
#
# ============================================================
# 📚 去哪找 MCP 服务器?
# ============================================================
#
# 【远程服务(推荐新手)】
# - ModelScope: https://mcp.modelscope.cn/ (免费,推荐)
# - Smithery: https://smithery.ai/
# - Glama: https://glama.ai/mcp/servers
#
# 【本地服务(需要 npx 或 uvx
# - 官方列表: https://github.com/modelcontextprotocol/servers
#
# ============================================================
# ============================================================
# 插件基本信息
# ============================================================
[plugin]
name = "mcp_bridge_plugin"
version = "1.7.0"
config_version = "1.7.0"
enabled = false # 默认禁用,在 WebUI 中启用
# ============================================================
# 🆕 v1.5.4 快速入门只读WebUI 显示)
# ============================================================
[guide]
# 🚀 快速入门 - 三步开始使用
quick_start = "1. 从下方链接获取 MCP 服务器 2. 在「快速添加」填写信息 3. 保存后发送 /mcp reconnect"
# 🌐 获取 MCP 服务器 - 复制链接到浏览器打开,获取免费 MCP 服务器
# 魔搭 ModelScope 国内免费推荐,复制服务器 URL 到「快速添加」即可
mcp_sources = "https://modelscope.cn/mcp (魔搭·推荐) | https://smithery.ai | https://glama.ai | https://mcp.so"
# 📝 配置示例 - 复制到服务器列表可直接使用(免费时间服务器)
example_config = '{"name": "time", "enabled": true, "transport": "streamable_http", "url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time"}'
# ============================================================
# 全局设置
# ============================================================
[settings]
# 🏷️ 工具前缀 - 用于区分 MCP 工具和原生工具
tool_prefix = "mcp"
# ⏱️ 连接超时(秒)
connect_timeout = 30.0
# ⏱️ 调用超时(秒)
call_timeout = 60.0
# 🔄 自动连接 - 启动时自动连接所有已启用的服务器
auto_connect = true
# 🔁 重试次数 - 连接失败时的重试次数
retry_attempts = 3
# ⏳ 重试间隔(秒)
retry_interval = 5.0
# 💓 心跳检测 - 定期检测服务器连接状态
heartbeat_enabled = true
# 💓 心跳间隔(秒)- 建议 30-120 秒
heartbeat_interval = 60.0
# 🧠 智能心跳 - 根据服务器稳定性自动调整心跳间隔v1.5.2
# 稳定服务器逐渐增加间隔,断开的服务器缩短间隔
heartbeat_adaptive = true
# 📈 最大间隔倍数 - 稳定服务器心跳间隔最高可达 基准间隔 × 此值v1.5.3
heartbeat_max_multiplier = 3.0
# 🔄 自动重连 - 检测到断开时自动尝试重连
auto_reconnect = true
# 🔄 最大重连次数 - 连续重连失败后暂停重连
max_reconnect_attempts = 3
# ============================================================
# v1.7.0 状态实时刷新
# ============================================================
# 📊 启用状态实时刷新 - 定期更新 WebUI 状态显示
status_refresh_enabled = true
# 📊 状态刷新间隔(秒)- 值越小刷新越频繁,但会增加少量磁盘写入
status_refresh_interval = 10.0
# ============================================================
# v1.2.0 高级功能(实验性)
# ============================================================
# 📦 启用 Resources - 允许读取 MCP 服务器提供的资源
enable_resources = false
# 📝 启用 Prompts - 允许使用 MCP 服务器提供的提示模板
enable_prompts = false
# ============================================================
# v1.3.0 结果后处理功能
# ============================================================
# 当 MCP 工具返回的内容过长时,使用 LLM 对结果进行摘要提炼
# 🔄 启用结果后处理
post_process_enabled = false
# 📏 后处理阈值(字符数)- 结果长度超过此值才触发后处理
post_process_threshold = 500
# <20> 后处理输e出限制 - LLM 摘要输出的最大 token 数
post_process_max_tokens = 500
# 🤖 后处理模型(可选)- 留空则使用 utils 模型组
post_process_model = ""
# <20> 后处理提示词模板-
post_process_prompt = '''{query}
{result}
'''
# ============================================================
# 🆕 v1.4.0 调用链路追踪
# ============================================================
# 记录工具调用详情,便于调试和分析
# 🔍 启用调用追踪
trace_enabled = true
# 📊 追踪记录上限 - 内存中保留的最大记录数
trace_max_records = 100
# 📝 追踪日志文件 - 是否将追踪记录写入日志文件
# 启用后记录写入 plugins/MaiBot_MCPBridgePlugin/logs/trace.jsonl
trace_log_enabled = false
# ============================================================
# 🆕 v1.4.0 工具调用缓存
# ============================================================
# 缓存相同参数的调用结果,减少重复请求
# 🗄️ 启用调用缓存
cache_enabled = false
# ⏱️ 缓存有效期(秒)
cache_ttl = 300
# 📦 最大缓存条目 - 超出后 LRU 淘汰
cache_max_entries = 200
# <20> 缓存排除列表 - 即不缓存的工具(每行一个,支持通配符 *
# 时间类、随机类工具建议排除
cache_exclude_tools = '''
mcp_*_time_*
mcp_*_random_*
'''
# ============================================================
# 🆕 v1.4.0 工具管理
# ============================================================
[tools]
# 📋 工具清单(只读)- 启动后自动生成
tool_list = "(启动后自动生成)"
# 🚫 禁用工具列表 - 要禁用的工具名(每行一个)
# 从上方工具清单复制工具名,禁用后该工具不会被 LLM 调用
# 示例:
# disabled_tools = '''
# mcp_filesystem_delete_file
# mcp_filesystem_write_file
# '''
disabled_tools = ""
# ============================================================
# 🆕 v1.5.1 快速添加服务器
# ============================================================
# 表单式配置,无需手写 JSON
[quick_add]
# 📛 服务器名称 - 服务器唯一名称(英文,如 time-server
server_name = ""
# 📡 传输类型 - 远程服务器选 streamable_http/http/sse本地选 stdio
server_type = "streamable_http"
# 🌐 服务器 URL - 远程服务器必填streamable_http/http/sse 类型)
server_url = ""
# ⌨️ 启动命令 - stdio 类型必填(如 uvx、npx、python
server_command = ""
# 📝 命令参数 - stdio 类型使用,每行一个参数
server_args = ""
# 🔑 鉴权头(可选)- JSON 格式,如 {"Authorization": "Bearer xxx"}
server_headers = ""
# ============================================================
# 🆕 v1.6.0 配置导入导出
# ============================================================
# 支持从 Claude Desktop / Kiro / MaiBot 格式导入导出
[import_export]
# 📥 导入配置 - 粘贴 Claude Desktop 或其他格式的 MCP 配置 JSON
# 粘贴配置后点击保存2秒内自动导入。查看下方「导入结果」确认状态
import_config = ""
# 📋 导入结果(只读)- 显示导入操作的结果
import_result = ""
# 📤 导出格式 - claude: Claude Desktop 格式 | kiro: Kiro MCP 格式 | maibot: 本插件格式
export_format = "claude"
# 📤 导出结果(只读,可复制)- 点击保存后生成,可复制到 Claude Desktop 或其他支持 MCP 的应用
export_result = "(点击保存后生成)"
# ============================================================
# 🆕 v1.4.0 权限控制
# ============================================================
[permissions]
# 🔐 启用权限控制 - 按群/用户限制工具使用
perm_enabled = false
# 📋 默认模式
# allow_all: 未配置规则的工具默认允许
# deny_all: 未配置规则的工具默认禁止
perm_default_mode = "allow_all"
# ────────────────────────────────────────────────────────────
# 🚀 快捷配置(推荐新手使用)
# ────────────────────────────────────────────────────────────
# 🚫 禁用群列表 - 这些群无法使用任何 MCP 工具(每行一个群号)
# 示例:
# quick_deny_groups = '''
# 123456789
# 987654321
# '''
quick_deny_groups = ""
# ✅ 管理员白名单 - 这些用户始终可以使用所有工具每行一个QQ号
# 示例:
# quick_allow_users = '''
# 111111111
# '''
quick_allow_users = ""
# ────────────────────────────────────────────────────────────
# 📜 高级权限规则(可选,针对特定工具配置)
# ────────────────────────────────────────────────────────────
# 格式: qq:ID:group/private/user工具名支持通配符 *
# 示例:
# perm_rules = '''
# [
# {"tool": "mcp_*_delete_*", "denied": ["qq:123456:group"]}
# ]
# '''
perm_rules = "[]"
# ============================================================
# 🔌 MCP 服务器配置
# ============================================================
#
# ⚠️ 重要JSON 格式说明
# ────────────────────────────────────────────────────────────
# 服务器列表必须是 JSON 数组格式!
#
# ❌ 错误写法:
# { "name": "server1", ... },
# { "name": "server2", ... }
#
# ✅ 正确写法:
# [
# { "name": "server1", ... },
# { "name": "server2", ... }
# ]
#
# ────────────────────────────────────────────────────────────
# 每个服务器的配置字段:
# name - 服务器名称(唯一标识)
# enabled - 是否启用 (true/false)
# transport - 传输方式: "stdio" / "sse" / "http" / "streamable_http"
# url - 服务器地址sse/http/streamable_http 模式必填)
# headers - 🆕 鉴权头(可选,如 {"Authorization": "Bearer xxx"}
# command - 启动命令stdio 模式,如 "npx" 或 "uvx"
# args - 命令参数数组stdio 模式)
# env - 环境变量对象stdio 模式,可选)
# post_process - 服务器级别后处理配置(可选)
#
# ============================================================
[servers]
list = '''
[
{
"name": "time-mcp-server",
"enabled": false,
"transport": "streamable_http",
"url": "https://mcp.api-inference.modelscope.cn/server/mcp-server-time"
},
{
"name": "my-auth-server",
"enabled": false,
"transport": "streamable_http",
"url": "https://mcp.api-inference.modelscope.net/xxxxxx/mcp",
"headers": {
"Authorization": "Bearer ms-xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
}
},
{
"name": "fetch-local",
"enabled": false,
"transport": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"]
}
]
'''
# ============================================================
# 状态显示(只读)
# ============================================================
[status]
connection_status = "未初始化"

View File

@ -0,0 +1,436 @@
"""
MCP 配置格式转换模块 v1.0.0
支持的格式:
- Claude Desktop (claude_desktop_config.json)
- Kiro MCP (mcp.json)
- MaiBot MCP Bridge Plugin (本插件格式)
转换规则:
- stdio: command + args + env
- sse/http/streamable_http: url + headers
"""
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
@dataclass
class ConversionResult:
"""转换结果"""
success: bool
servers: List[Dict[str, Any]] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
warnings: List[str] = field(default_factory=list)
skipped: List[str] = field(default_factory=list)
class ConfigConverter:
"""MCP 配置格式转换器"""
# transport 类型映射 (外部格式 -> 内部格式)
TRANSPORT_MAP_IN = {
"sse": "sse",
"http": "http",
"streamable-http": "streamable_http",
"streamable_http": "streamable_http",
"streamable-http": "streamable_http",
"stdio": "stdio",
}
# 支持的 transport 字段名(有些格式用 type 而不是 transport
TRANSPORT_FIELD_NAMES = ["transport", "type"]
# transport 类型映射 (内部格式 -> Claude 格式)
TRANSPORT_MAP_OUT = {
"sse": "sse",
"http": "http",
"streamable_http": "streamable-http",
"stdio": "stdio",
}
@classmethod
def detect_format(cls, config: Dict[str, Any]) -> Optional[str]:
"""检测配置格式类型
Returns:
"claude": Claude Desktop 格式 (mcpServers 对象)
"kiro": Kiro MCP 格式 (mcpServers 对象 Claude 相同)
"maibot": MaiBot 插件格式 (数组)
None: 无法识别
"""
if isinstance(config, list):
# 数组格式,检查是否是 MaiBot 格式
if len(config) == 0:
return "maibot"
if isinstance(config[0], dict) and "name" in config[0]:
return "maibot"
return None
if isinstance(config, dict):
# 对象格式
if "mcpServers" in config:
return "claude" # Claude 和 Kiro 格式相同
# 可能是单个服务器配置
if "name" in config:
return "maibot_single"
return None
return None
@classmethod
def parse_json_safe(cls, json_str: str) -> Tuple[Optional[Any], Optional[str]]:
"""安全解析 JSON 字符串
Returns:
(解析结果, 错误信息)
"""
if not json_str or not json_str.strip():
return None, "输入为空"
json_str = json_str.strip()
try:
return json.loads(json_str), None
except json.JSONDecodeError as e:
# 尝试提供更友好的错误信息
line = e.lineno
col = e.colno
return None, f"JSON 解析失败 (行 {line}, 列 {col}): {e.msg}"
@classmethod
def validate_server_config(cls, name: str, config: Dict[str, Any]) -> Tuple[bool, Optional[str], List[str]]:
"""验证单个服务器配置
Args:
name: 服务器名称
config: 服务器配置字典
Returns:
(是否有效, 错误信息, 警告列表)
"""
warnings = []
if not isinstance(config, dict):
return False, f"服务器 '{name}' 配置必须是对象", []
has_command = "command" in config
has_url = "url" in config
# 必须有 command 或 url 之一
if not has_command and not has_url:
return False, f"服务器 '{name}' 缺少 'command''url' 字段", []
# 同时有 command 和 url 时给出警告
if has_command and has_url:
warnings.append(f"'{name}': 同时存在 command 和 url将优先使用 stdio 模式")
# 验证 url 格式
if has_url and not has_command:
url = config.get("url", "")
if not isinstance(url, str):
return False, f"服务器 '{name}' 的 url 必须是字符串", []
if not url.startswith(("http://", "https://")):
warnings.append(f"'{name}': url 不是标准 HTTP(S) 地址")
# 验证 command 格式
if has_command:
command = config.get("command", "")
if not isinstance(command, str):
return False, f"服务器 '{name}' 的 command 必须是字符串", []
if not command.strip():
return False, f"服务器 '{name}' 的 command 不能为空", []
# 验证 args 格式
if "args" in config:
args = config.get("args")
if not isinstance(args, list):
return False, f"服务器 '{name}' 的 args 必须是数组", []
for i, arg in enumerate(args):
if not isinstance(arg, str):
warnings.append(f"'{name}': args[{i}] 不是字符串,将自动转换")
# 验证 env 格式
if "env" in config:
env = config.get("env")
if not isinstance(env, dict):
return False, f"服务器 '{name}' 的 env 必须是对象", []
# 验证 headers 格式
if "headers" in config:
headers = config.get("headers")
if not isinstance(headers, dict):
return False, f"服务器 '{name}' 的 headers 必须是对象", []
# 验证 transport/type 格式
transport_value = None
for field_name in cls.TRANSPORT_FIELD_NAMES:
if field_name in config:
transport_value = config.get(field_name, "").lower()
break
if transport_value and transport_value not in cls.TRANSPORT_MAP_IN:
warnings.append(f"'{name}': 未知的 transport 类型 '{transport_value}',将自动推断")
return True, None, warnings
@classmethod
def convert_claude_server(cls, name: str, config: Dict[str, Any]) -> Dict[str, Any]:
"""将单个 Claude 格式服务器配置转换为 MaiBot 格式
Args:
name: 服务器名称
config: Claude 格式的服务器配置
Returns:
MaiBot 格式的服务器配置
"""
result = {
"name": name,
"enabled": True,
}
has_command = "command" in config
if has_command:
# stdio 模式
result["transport"] = "stdio"
result["command"] = config.get("command", "")
# 处理 args
args = config.get("args", [])
if args:
# 确保所有 args 都是字符串
result["args"] = [str(arg) for arg in args]
# 处理 env
env = config.get("env", {})
if env and isinstance(env, dict):
result["env"] = env
else:
# 远程模式 (sse/http/streamable_http)
# 支持 transport 或 type 字段
transport_raw = None
for field_name in cls.TRANSPORT_FIELD_NAMES:
if field_name in config:
transport_raw = config.get(field_name, "").lower()
break
if not transport_raw:
transport_raw = "sse"
result["transport"] = cls.TRANSPORT_MAP_IN.get(transport_raw, "sse")
result["url"] = config.get("url", "")
# 处理 headers
headers = config.get("headers", {})
if headers and isinstance(headers, dict):
result["headers"] = headers
return result
@classmethod
def convert_maibot_server(cls, config: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""将单个 MaiBot 格式服务器配置转换为 Claude 格式
Args:
config: MaiBot 格式的服务器配置
Returns:
(服务器名称, Claude 格式的服务器配置)
"""
name = config.get("name", "unnamed")
result = {}
transport = config.get("transport", "stdio").lower()
if transport == "stdio":
# stdio 模式
result["command"] = config.get("command", "")
args = config.get("args", [])
if args:
result["args"] = args
env = config.get("env", {})
if env:
result["env"] = env
else:
# 远程模式
result["url"] = config.get("url", "")
# 转换 transport 名称
claude_transport = cls.TRANSPORT_MAP_OUT.get(transport, "sse")
if claude_transport != "sse": # sse 是默认值,可以省略
result["transport"] = claude_transport
headers = config.get("headers", {})
if headers:
result["headers"] = headers
return name, result
@classmethod
def from_claude_format(cls, config: Dict[str, Any], existing_names: Optional[set] = None) -> ConversionResult:
"""从 Claude Desktop 格式转换为 MaiBot 格式
Args:
config: Claude Desktop 配置 (包含 mcpServers 字段)
existing_names: 已存在的服务器名称集合用于跳过重复
Returns:
ConversionResult
"""
result = ConversionResult(success=True)
existing_names = existing_names or set()
# 检查格式
if not isinstance(config, dict):
result.success = False
result.errors.append("配置必须是 JSON 对象")
return result
mcp_servers = config.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
result.success = False
result.errors.append("mcpServers 必须是对象")
return result
if not mcp_servers:
result.warnings.append("mcpServers 为空,没有服务器可导入")
return result
# 转换每个服务器
for name, srv_config in mcp_servers.items():
# 检查名称是否已存在
if name in existing_names:
result.skipped.append(f"'{name}' (已存在)")
continue
# 验证配置
valid, error, warnings = cls.validate_server_config(name, srv_config)
result.warnings.extend(warnings)
if not valid:
result.errors.append(error)
continue
# 转换配置
try:
converted = cls.convert_claude_server(name, srv_config)
result.servers.append(converted)
except Exception as e:
result.errors.append(f"转换服务器 '{name}' 失败: {str(e)}")
# 如果有错误但也有成功的,仍然标记为成功(部分成功)
if result.errors and not result.servers:
result.success = False
return result
@classmethod
def to_claude_format(cls, servers: List[Dict[str, Any]]) -> Dict[str, Any]:
"""将 MaiBot 格式转换为 Claude Desktop 格式
Args:
servers: MaiBot 格式的服务器列表
Returns:
Claude Desktop 格式的配置
"""
mcp_servers = {}
for srv in servers:
if not isinstance(srv, dict):
continue
name, config = cls.convert_maibot_server(srv)
mcp_servers[name] = config
return {"mcpServers": mcp_servers}
@classmethod
def import_from_string(cls, json_str: str, existing_names: Optional[set] = None) -> ConversionResult:
"""从 JSON 字符串导入配置
自动检测格式并转换为 MaiBot 格式
Args:
json_str: JSON 字符串
existing_names: 已存在的服务器名称集合
Returns:
ConversionResult
"""
result = ConversionResult(success=True)
existing_names = existing_names or set()
# 解析 JSON
parsed, error = cls.parse_json_safe(json_str)
if error:
result.success = False
result.errors.append(error)
return result
# 检测格式
fmt = cls.detect_format(parsed)
if fmt is None:
result.success = False
result.errors.append("无法识别的配置格式")
return result
if fmt == "maibot":
# 已经是 MaiBot 格式,直接验证并返回
for srv in parsed:
if not isinstance(srv, dict):
result.warnings.append("跳过非对象元素")
continue
name = srv.get("name", "")
if not name:
result.warnings.append("跳过缺少 name 的服务器")
continue
if name in existing_names:
result.skipped.append(f"'{name}' (已存在)")
continue
result.servers.append(srv)
elif fmt == "maibot_single":
# 单个 MaiBot 格式服务器
name = parsed.get("name", "")
if name in existing_names:
result.skipped.append(f"'{name}' (已存在)")
else:
result.servers.append(parsed)
elif fmt in ("claude", "kiro"):
# Claude/Kiro 格式
return cls.from_claude_format(parsed, existing_names)
return result
@classmethod
def export_to_string(cls, servers: List[Dict[str, Any]], format_type: str = "claude", pretty: bool = True) -> str:
"""导出配置为 JSON 字符串
Args:
servers: MaiBot 格式的服务器列表
format_type: 导出格式 ("claude", "kiro", "maibot")
pretty: 是否格式化输出
Returns:
JSON 字符串
"""
indent = 2 if pretty else None
if format_type in ("claude", "kiro"):
config = cls.to_claude_format(servers)
else:
config = servers
return json.dumps(config, ensure_ascii=False, indent=indent)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,2 @@
# MCP 桥接插件依赖
mcp>=1.0.0

View File

@ -0,0 +1,278 @@
#!/usr/bin/env python3
"""
MCP 客户端测试脚本
测试 mcp_client.py 的基本功能
"""
import asyncio
import sys
import os
# 确保当前目录在 path 中
sys.path.insert(0, os.path.dirname(__file__))
from mcp_client import (
MCPClientManager,
MCPServerConfig,
TransportType,
ToolCallStats,
ServerStats,
)
async def test_stats():
"""测试统计类"""
print("\n=== 测试统计类 ===")
# 测试 ToolCallStats
stats = ToolCallStats(tool_key="test_tool")
stats.record_call(True, 100.0)
stats.record_call(True, 200.0)
stats.record_call(False, 50.0, "timeout")
assert stats.total_calls == 3
assert stats.success_calls == 2
assert stats.failed_calls == 1
assert stats.success_rate == (2 / 3) * 100
assert stats.avg_duration_ms == 150.0
assert stats.last_error == "timeout"
print(f"✅ ToolCallStats: {stats.to_dict()}")
# 测试 ServerStats
server_stats = ServerStats(server_name="test_server")
server_stats.record_connect()
server_stats.record_heartbeat()
server_stats.record_disconnect()
server_stats.record_failure()
server_stats.record_failure()
assert server_stats.connect_count == 1
assert server_stats.disconnect_count == 1
assert server_stats.consecutive_failures == 2
print(f"✅ ServerStats: {server_stats.to_dict()}")
return True
async def test_manager_basic():
"""测试管理器基本功能"""
print("\n=== 测试管理器基本功能 ===")
# 创建新的管理器实例(绕过单例)
manager = MCPClientManager.__new__(MCPClientManager)
manager._initialized = False
manager.__init__()
# 配置
manager.configure(
{
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 1,
"retry_interval": 1.0,
"heartbeat_enabled": False,
}
)
# 测试状态
status = manager.get_status()
assert status["total_servers"] == 0
assert status["connected_servers"] == 0
print(f"✅ 初始状态: {status}")
# 测试添加禁用的服务器
config = MCPServerConfig(
name="disabled_server", enabled=False, transport=TransportType.HTTP, url="https://example.com/mcp"
)
result = await manager.add_server(config)
assert result == True
assert "disabled_server" in manager._clients
assert manager._clients["disabled_server"].is_connected == False
print("✅ 添加禁用服务器成功")
# 测试重复添加
result = await manager.add_server(config)
assert result == False
print("✅ 重复添加被拒绝")
# 测试移除
result = await manager.remove_server("disabled_server")
assert result == True
assert "disabled_server" not in manager._clients
print("✅ 移除服务器成功")
# 清理
await manager.shutdown()
print("✅ 管理器关闭成功")
return True
async def test_http_connection():
"""测试 HTTP 连接(使用真实的 MCP 服务器)"""
print("\n=== 测试 HTTP 连接 ===")
# 创建新的管理器实例
manager = MCPClientManager.__new__(MCPClientManager)
manager._initialized = False
manager.__init__()
manager.configure(
{
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 2,
"retry_interval": 2.0,
"heartbeat_enabled": False,
}
)
# 使用 HowToCook MCP 服务器测试
config = MCPServerConfig(
name="howtocook",
enabled=True,
transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
)
print(f"正在连接 {config.url} ...")
result = await manager.add_server(config)
if result:
print("✅ 连接成功!")
# 检查工具
tools = manager.all_tools
print(f"✅ 发现 {len(tools)} 个工具:")
for tool_key in tools:
print(f" - {tool_key}")
# 测试心跳
client = manager._clients["howtocook"]
healthy = await client.check_health()
print(f"✅ 心跳检测: {'健康' if healthy else '异常'}")
# 测试工具调用
if "mcp_howtocook_whatToEat" in tools:
print("\n正在调用 whatToEat 工具...")
call_result = await manager.call_tool("mcp_howtocook_whatToEat", {})
if call_result.success:
print(f"✅ 工具调用成功 (耗时: {call_result.duration_ms:.0f}ms)")
print(
f" 结果: {call_result.content[:200]}..."
if len(str(call_result.content)) > 200
else f" 结果: {call_result.content}"
)
else:
print(f"❌ 工具调用失败: {call_result.error}")
# 查看统计
stats = manager.get_all_stats()
print("\n📊 统计信息:")
print(f" 全局调用: {stats['global']['total_tool_calls']}")
print(f" 成功: {stats['global']['successful_calls']}")
print(f" 失败: {stats['global']['failed_calls']}")
else:
print("❌ 连接失败")
# 清理
await manager.shutdown()
return result
async def test_heartbeat():
"""测试心跳检测功能"""
print("\n=== 测试心跳检测 ===")
# 创建新的管理器实例
manager = MCPClientManager.__new__(MCPClientManager)
manager._initialized = False
manager.__init__()
manager.configure(
{
"tool_prefix": "mcp",
"call_timeout": 30.0,
"retry_attempts": 1,
"retry_interval": 1.0,
"heartbeat_enabled": True,
"heartbeat_interval": 5.0, # 5秒间隔用于测试
"auto_reconnect": True,
"max_reconnect_attempts": 2,
}
)
# 添加一个测试服务器
config = MCPServerConfig(
name="heartbeat_test",
enabled=True,
transport=TransportType.HTTP,
url="https://mcp.api-inference.modelscope.net/c9b55951d4ed47/mcp",
)
print("正在连接服务器...")
result = await manager.add_server(config)
if result:
print("✅ 服务器连接成功")
# 启动心跳检测
await manager.start_heartbeat()
print("✅ 心跳检测已启动")
# 等待一个心跳周期
print("等待心跳检测...")
await asyncio.sleep(2)
# 检查状态
status = manager.get_status()
print(f"✅ 心跳运行状态: {status['heartbeat_running']}")
# 停止心跳
await manager.stop_heartbeat()
print("✅ 心跳检测已停止")
else:
print("❌ 服务器连接失败,跳过心跳测试")
await manager.shutdown()
return True
async def main():
"""运行所有测试"""
print("=" * 50)
print("MCP 客户端测试")
print("=" * 50)
try:
# 基础测试
await test_stats()
await test_manager_basic()
# 网络测试
print("\n是否进行网络连接测试? (需要网络) [y/N]: ", end="")
# 自动进行网络测试
await test_http_connection()
# 心跳测试
await test_heartbeat()
print("\n" + "=" * 50)
print("✅ 所有测试通过!")
print("=" * 50)
except Exception as e:
print(f"\n❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
return True
if __name__ == "__main__":
asyncio.run(main())

View File

@ -38,6 +38,10 @@ dependencies = [
]
[tool.uv]
index-url = "https://pypi.tuna.tsinghua.edu.cn/simple"
[tool.ruff]
include = ["*.py"]

View File

@ -0,0 +1,567 @@
"""
模拟 Expression 合并过程
用法:
python scripts/expression_merge_simulation.py
或指定 chat_id:
python scripts/expression_merge_simulation.py --chat-id <chat_id>
或指定相似度阈值:
python scripts/expression_merge_simulation.py --similarity-threshold 0.8
"""
import sys
import os
import json
import argparse
import asyncio
import random
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.bw_learner.learner_utils import calculate_style_similarity # noqa: E402
from src.llm_models.utils_model import LLMRequest # noqa: E402
from src.config.config import model_config # noqa: E402
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def parse_content_list(stored_list: Optional[str]) -> List[str]:
"""解析 content_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def parse_style_list(stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def find_exact_style_match(
expressions: List[Expression],
target_style: str,
chat_id: str,
exclude_ids: set
) -> Optional[Expression]:
"""
查找具有完全匹配 style Expression 记录
检查 style 字段和 style_list 中的每一项
"""
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
if expr.style == target_style:
return expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
if target_style in style_list:
return expr
return None
def find_similar_style_expression(
expressions: List[Expression],
target_style: str,
chat_id: str,
similarity_threshold: float,
exclude_ids: set
) -> Optional[Tuple[Expression, float]]:
"""
查找具有相似 style Expression 记录
检查 style 字段和 style_list 中的每一项
Returns:
(Expression, similarity) None
"""
best_match = None
best_similarity = 0.0
for expr in expressions:
if expr.chat_id != chat_id or expr.id in exclude_ids:
continue
# 检查 style 字段
similarity = calculate_style_similarity(target_style, expr.style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
# 检查 style_list 中的每一项
style_list = parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(target_style, existing_style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
return (best_match, best_similarity)
return None
async def compose_situation_text(content_list: List[str], summary_model: LLMRequest) -> str:
"""组合 situation 文本,尝试使用 LLM 总结"""
sanitized = [c.strip() for c in content_list if c.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 situation 失败: {e}")
# 如果总结失败,返回用 "/" 连接的字符串
return "/".join(sanitized)
async def compose_style_text(style_list: List[str], summary_model: LLMRequest) -> str:
"""组合 style 文本,尝试使用 LLM 总结"""
sanitized = [s.strip() for s in style_list if s.strip()]
if not sanitized:
return ""
if len(sanitized) == 1:
return sanitized[0]
# 尝试使用 LLM 总结
prompt = (
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in sanitized[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await summary_model.generate_response_async(prompt, temperature=0.2)
print(f"Prompt:{prompt} Summary:{summary}")
summary = summary.strip()
if summary:
return summary
except Exception as e:
print(f" ⚠️ LLM 总结 style 失败: {e}")
# 如果总结失败,返回第一个
return sanitized[0]
async def simulate_merge(
expressions: List[Expression],
similarity_threshold: float = 0.75,
use_llm: bool = False,
max_samples: int = 10,
) -> Dict:
"""
模拟合并过程
Args:
expressions: Expression 列表从数据库读出的原始记录
similarity_threshold: style 相似度阈值
use_llm: 是否使用 LLM 进行实际总结
max_samples: 最多随机抽取的 Expression 数量 0 None 时表示不限制
Returns:
包含合并统计信息的字典
"""
# 如果样本太多,随机抽取一部分进行模拟,避免运行时间过长
if max_samples and len(expressions) > max_samples:
expressions = random.sample(expressions, max_samples)
# 按 chat_id 分组
expressions_by_chat = defaultdict(list)
for expr in expressions:
expressions_by_chat[expr.chat_id].append(expr)
# 初始化 LLM 模型(如果需要)
summary_model = None
if use_llm:
try:
summary_model = LLMRequest(
model_set=model_config.model_task_config.utils_small,
request_type="expression.summary"
)
print("✅ LLM 模型已初始化,将进行实际总结")
except Exception as e:
print(f"⚠️ LLM 模型初始化失败: {e},将跳过 LLM 总结")
use_llm = False
merge_stats = {
"total_expressions": len(expressions),
"total_chats": len(expressions_by_chat),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merge_details": [],
"chat_stats": {},
"use_llm": use_llm
}
# 为每个 chat_id 模拟合并
for chat_id, chat_expressions in expressions_by_chat.items():
chat_name = get_chat_name(chat_id)
chat_stat = {
"chat_id": chat_id,
"chat_name": chat_name,
"total": len(chat_expressions),
"exact_matches": 0,
"similar_matches": 0,
"new_records": 0,
"merges": []
}
processed_ids = set()
for expr in chat_expressions:
if expr.id in processed_ids:
continue
target_style = expr.style
target_situation = expr.situation
# 第一层:检查完全匹配
exact_match = find_exact_style_match(
chat_expressions,
target_style,
chat_id,
{expr.id}
)
if exact_match:
# 完全匹配(不使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(exact_match.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(exact_match.style_list)
if exact_match.style and exact_match.style not in target_style_list:
target_style_list.append(exact_match.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
merge_info = {
"type": "exact",
"source_id": expr.id,
"target_id": exact_match.id,
"source_style": target_style,
"target_style": exact_match.style,
"source_situation": target_situation,
"target_situation": exact_match.situation,
"similarity": 1.0,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": exact_match.situation, # 完全匹配时保持原 situation
"merged_style": exact_match.style # 完全匹配时保持原 style
}
chat_stat["exact_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["exact_matches"] += 1
processed_ids.add(expr.id)
continue
# 第二层:检查相似匹配
similar_match = find_similar_style_expression(
chat_expressions,
target_style,
chat_id,
similarity_threshold,
{expr.id}
)
if similar_match:
match_expr, similarity = similar_match
# 相似匹配(使用 LLM 总结)
# 模拟合并后的 content_list 和 style_list
target_content_list = parse_content_list(match_expr.content_list)
target_content_list.append(target_situation)
target_style_list = parse_style_list(match_expr.style_list)
if match_expr.style and match_expr.style not in target_style_list:
target_style_list.append(match_expr.style)
if target_style not in target_style_list:
target_style_list.append(target_style)
# 使用 LLM 总结(如果启用)
merged_situation = match_expr.situation
merged_style = match_expr.style or target_style
if use_llm and summary_model:
try:
merged_situation = await compose_situation_text(target_content_list, summary_model)
merged_style = await compose_style_text(target_style_list, summary_model)
except Exception as e:
print(f" ⚠️ 处理记录 {expr.id} 时 LLM 总结失败: {e}")
# 如果总结失败,使用 fallback
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
else:
# 不使用 LLM 时,使用简单拼接
merged_situation = "/".join([c.strip() for c in target_content_list if c.strip()]) or match_expr.situation
merged_style = target_style_list[0] if target_style_list else (match_expr.style or target_style)
merge_info = {
"type": "similar",
"source_id": expr.id,
"target_id": match_expr.id,
"source_style": target_style,
"target_style": match_expr.style,
"source_situation": target_situation,
"target_situation": match_expr.situation,
"similarity": similarity,
"merged_content_list": target_content_list,
"merged_style_list": target_style_list,
"merged_situation": merged_situation,
"merged_style": merged_style,
"llm_used": use_llm and summary_model is not None
}
chat_stat["similar_matches"] += 1
chat_stat["merges"].append(merge_info)
merge_stats["similar_matches"] += 1
processed_ids.add(expr.id)
continue
# 没有匹配,作为新记录
chat_stat["new_records"] += 1
merge_stats["new_records"] += 1
processed_ids.add(expr.id)
merge_stats["chat_stats"][chat_id] = chat_stat
merge_stats["merge_details"].extend(chat_stat["merges"])
return merge_stats
def print_merge_results(stats: Dict, show_details: bool = True, max_details: int = 50):
"""打印合并结果"""
print("\n" + "=" * 80)
print("Expression 合并模拟结果")
print("=" * 80)
print("\n📊 总体统计:")
print(f" 总 Expression 数: {stats['total_expressions']}")
print(f" 总聊天数: {stats['total_chats']}")
print(f" 完全匹配合并: {stats['exact_matches']}")
print(f" 相似匹配合并: {stats['similar_matches']}")
print(f" 新记录(无匹配): {stats['new_records']}")
if stats.get('use_llm'):
print(" LLM 总结: 已启用")
else:
print(" LLM 总结: 未启用(仅模拟)")
total_merges = stats['exact_matches'] + stats['similar_matches']
if stats['total_expressions'] > 0:
merge_ratio = (total_merges / stats['total_expressions']) * 100
print(f" 合并比例: {merge_ratio:.1f}%")
# 按聊天分组显示
print("\n📋 按聊天分组统计:")
for chat_id, chat_stat in stats['chat_stats'].items():
print(f"\n {chat_stat['chat_name']} ({chat_id[:8]}...):")
print(f" 总数: {chat_stat['total']}")
print(f" 完全匹配: {chat_stat['exact_matches']}")
print(f" 相似匹配: {chat_stat['similar_matches']}")
print(f" 新记录: {chat_stat['new_records']}")
# 显示合并详情
if show_details and stats['merge_details']:
print(f"\n📝 合并详情 (显示前 {min(max_details, len(stats['merge_details']))} 条):")
print()
for idx, merge in enumerate(stats['merge_details'][:max_details], 1):
merge_type = "完全匹配" if merge['type'] == 'exact' else f"相似匹配 (相似度: {merge['similarity']:.3f})"
print(f" {idx}. {merge_type}")
print(f" 源记录 ID: {merge['source_id']}")
print(f" 目标记录 ID: {merge['target_id']}")
print(f" 源 Style: {merge['source_style'][:50]}")
print(f" 目标 Style: {merge['target_style'][:50]}")
print(f" 源 Situation: {merge['source_situation'][:50]}")
print(f" 目标 Situation: {merge['target_situation'][:50]}")
# 显示合并后的结果
if 'merged_situation' in merge:
print(f" → 合并后 Situation: {merge['merged_situation'][:50]}")
if 'merged_style' in merge:
print(f" → 合并后 Style: {merge['merged_style'][:50]}")
if merge.get('llm_used'):
print(" → LLM 总结: 已使用")
elif merge['type'] == 'similar':
print(" → LLM 总结: 未使用(模拟模式)")
# 显示合并后的列表
if 'merged_content_list' in merge and len(merge['merged_content_list']) > 1:
print(f" → Content List ({len(merge['merged_content_list'])} 项): {', '.join(merge['merged_content_list'][:3])}")
if len(merge['merged_content_list']) > 3:
print(f" ... 还有 {len(merge['merged_content_list']) - 3}")
if 'merged_style_list' in merge and len(merge['merged_style_list']) > 1:
print(f" → Style List ({len(merge['merged_style_list'])} 项): {', '.join(merge['merged_style_list'][:3])}")
if len(merge['merged_style_list']) > 3:
print(f" ... 还有 {len(merge['merged_style_list']) - 3}")
print()
if len(stats['merge_details']) > max_details:
print(f" ... 还有 {len(stats['merge_details']) - max_details} 条合并记录未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="模拟 Expression 合并过程")
parser.add_argument(
"--chat-id",
type=str,
default=None,
help="指定要分析的 chat_id不指定则分析所有"
)
parser.add_argument(
"--similarity-threshold",
type=float,
default=0.75,
help="相似度阈值 (0-1, 默认: 0.75)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-details",
type=int,
default=50,
help="最多显示的合并详情数 (默认: 50)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
parser.add_argument(
"--use-llm",
action="store_true",
help="启用 LLM 进行实际总结(默认: 仅模拟,不调用 LLM"
)
parser.add_argument(
"--max-samples",
type=int,
default=10,
help="最多随机抽取的 Expression 数量 (默认: 10设置为 0 表示不限制)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.similarity_threshold <= 1:
print("错误: similarity-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_merge_simulation_{timestamp}.txt")
# 查询 Expression 记录
print("正在从数据库加载Expression数据...")
try:
if args.chat_id:
expressions = list(Expression.select().where(Expression.chat_id == args.chat_id))
print(f"✅ 成功加载 {len(expressions)} 条Expression记录 (chat_id: {args.chat_id})")
else:
expressions = list(Expression.select())
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
# 执行合并模拟
print(f"\n正在模拟合并过程(相似度阈值: {args.similarity_threshold},最大样本数: {args.max_samples}...")
if args.use_llm:
print("⚠️ 已启用 LLM 总结,将进行实际的 API 调用")
else:
print(" 未启用 LLM 总结,仅进行模拟(使用 --use-llm 启用实际 LLM 调用)")
stats = asyncio.run(
simulate_merge(
expressions,
similarity_threshold=args.similarity_threshold,
use_llm=args.use_llm,
max_samples=args.max_samples,
)
)
# 输出结果
original_stdout = sys.stdout
try:
with open(output_file, "w", encoding="utf-8") as f:
sys.stdout = f
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
sys.stdout = original_stdout
# 同时在控制台输出
print_merge_results(stats, show_details=not args.no_details, max_details=args.max_details)
except Exception as e:
sys.stdout = original_stdout
print(f"❌ 写入文件失败: {e}")
return
print(f"\n✅ 模拟结果已保存到: {output_file}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,564 @@
"""
分析expression库中situation和style的相似度
用法:
python scripts/expression_similarity_analysis.py
或指定阈值:
python scripts/expression_similarity_analysis.py --situation-threshold 0.8 --style-threshold 0.7
"""
import sys
import os
import argparse
from typing import List, Tuple
from collections import defaultdict
from difflib import SequenceMatcher
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# Import after setting up path (required for project imports)
from src.common.database.database_model import Expression, ChatStreams # noqa: E402
from src.config.config import global_config # noqa: E402
import hashlib # noqa: E402
class TeeOutput:
"""同时输出到控制台和文件的类"""
def __init__(self, file_path: str):
self.file = open(file_path, "w", encoding="utf-8")
self.console = sys.stdout
def write(self, text: str):
"""写入文本到控制台和文件"""
self.console.write(text)
self.file.write(text)
self.file.flush() # 立即刷新到文件
def flush(self):
"""刷新输出"""
self.console.flush()
self.file.flush()
def close(self):
"""关闭文件"""
if self.file:
self.file.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
return False
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""
解析'platform:id:type'为chat_id与ExpressionSelector中的逻辑一致
"""
try:
parts = stream_config_str.split(":")
if len(parts) != 3:
return None
platform = parts[0]
id_str = parts[1]
stream_type = parts[2]
is_group = stream_type == "group"
if is_group:
components = [platform, str(id_str)]
else:
components = [platform, str(id_str), "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
except Exception:
return None
def build_chat_id_groups() -> dict[str, set[str]]:
"""
根据expression_groups配置构建chat_id到相关chat_id集合的映射
Returns:
dict: {chat_id: set of related chat_ids (including itself)}
"""
groups = global_config.expression.expression_groups
chat_id_groups: dict[str, set[str]] = {}
# 检查是否存在全局共享组(包含"*"的组)
global_group_exists = any("*" in group for group in groups)
if global_group_exists:
# 如果存在全局共享组收集所有配置中的chat_id
all_chat_ids = set()
for group in groups:
for stream_config_str in group:
if stream_config_str == "*":
continue
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
all_chat_ids.add(chat_id_candidate)
# 所有chat_id都互相相关
for chat_id in all_chat_ids:
chat_id_groups[chat_id] = all_chat_ids.copy()
else:
# 处理普通组
for group in groups:
group_chat_ids = set()
for stream_config_str in group:
if chat_id_candidate := _parse_stream_config_to_chat_id(stream_config_str):
group_chat_ids.add(chat_id_candidate)
# 组内的所有chat_id都互相相关
for chat_id in group_chat_ids:
if chat_id not in chat_id_groups:
chat_id_groups[chat_id] = set()
chat_id_groups[chat_id].update(group_chat_ids)
# 确保每个chat_id至少包含自身
for chat_id in chat_id_groups:
chat_id_groups[chat_id].add(chat_id)
return chat_id_groups
def are_chat_ids_related(chat_id1: str, chat_id2: str, chat_id_groups: dict[str, set[str]]) -> bool:
"""
判断两个chat_id是否相关相同或同组
Args:
chat_id1: 第一个chat_id
chat_id2: 第二个chat_id
chat_id_groups: chat_id到相关chat_id集合的映射
Returns:
bool: 如果两个chat_id相同或同组返回True
"""
if chat_id1 == chat_id2:
return True
# 如果chat_id1在映射中检查chat_id2是否在其相关集合中
if chat_id1 in chat_id_groups:
return chat_id2 in chat_id_groups[chat_id1]
# 如果chat_id1不在映射中说明它不在任何组中只与自己相关
return False
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id[:8]}...)"
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
else:
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def text_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
使用SequenceMatcher计算相似度返回0-1之间的值
在计算前会移除"使用""句式"这两个词
"""
if not text1 or not text2:
return 0.0
# 移除"使用"和"句式"这两个词
def remove_ignored_words(text: str) -> str:
"""移除需要忽略的词"""
text = text.replace("使用", "")
text = text.replace("句式", "")
return text.strip()
cleaned_text1 = remove_ignored_words(text1)
cleaned_text2 = remove_ignored_words(text2)
# 如果清理后文本为空返回0
if not cleaned_text1 or not cleaned_text2:
return 0.0
return SequenceMatcher(None, cleaned_text1, cleaned_text2).ratio()
def find_similar_pairs(
expressions: List[Expression],
field_name: str,
threshold: float,
max_pairs: int = None
) -> List[Tuple[int, int, float, str, str]]:
"""
找出相似的expression对
Args:
expressions: Expression对象列表
field_name: 要比较的字段名 ('situation' 'style')
threshold: 相似度阈值 (0-1)
max_pairs: 最多返回的对数None表示返回所有
Returns:
List of (index1, index2, similarity, text1, text2) tuples
"""
similar_pairs = []
n = len(expressions)
print(f"正在分析 {field_name} 字段的相似度...")
print(f"总共需要比较 {n * (n - 1) // 2} 对...")
for i in range(n):
if (i + 1) % 100 == 0:
print(f" 已处理 {i + 1}/{n} 个项目...")
expr1 = expressions[i]
text1 = getattr(expr1, field_name, "")
for j in range(i + 1, n):
expr2 = expressions[j]
text2 = getattr(expr2, field_name, "")
similarity = text_similarity(text1, text2)
if similarity >= threshold:
similar_pairs.append((i, j, similarity, text1, text2))
# 按相似度降序排序
similar_pairs.sort(key=lambda x: x[2], reverse=True)
if max_pairs:
similar_pairs = similar_pairs[:max_pairs]
return similar_pairs
def group_similar_items(
expressions: List[Expression],
field_name: str,
threshold: float,
chat_id_groups: dict[str, set[str]]
) -> List[List[int]]:
"""
将相似的expression分组仅比较相同chat_id或同组的项目
Args:
expressions: Expression对象列表
field_name: 要比较的字段名 ('situation' 'style')
threshold: 相似度阈值 (0-1)
chat_id_groups: chat_id到相关chat_id集合的映射
Returns:
List of groups, each group is a list of indices
"""
n = len(expressions)
# 使用并查集的思想来分组
parent = list(range(n))
def find(x):
if parent[x] != x:
parent[x] = find(parent[x])
return parent[x]
def union(x, y):
px, py = find(x), find(y)
if px != py:
parent[px] = py
print(f"正在对 {field_name} 字段进行分组仅比较相同chat_id或同组的项目...")
# 统计需要比较的对数
total_pairs = 0
for i in range(n):
for j in range(i + 1, n):
if are_chat_ids_related(expressions[i].chat_id, expressions[j].chat_id, chat_id_groups):
total_pairs += 1
print(f"总共需要比较 {total_pairs}已过滤不同chat_id且不同组的项目...")
compared_pairs = 0
for i in range(n):
if (i + 1) % 100 == 0:
print(f" 已处理 {i + 1}/{n} 个项目...")
expr1 = expressions[i]
text1 = getattr(expr1, field_name, "")
for j in range(i + 1, n):
expr2 = expressions[j]
# 只比较相同chat_id或同组的项目
if not are_chat_ids_related(expr1.chat_id, expr2.chat_id, chat_id_groups):
continue
compared_pairs += 1
text2 = getattr(expr2, field_name, "")
similarity = text_similarity(text1, text2)
if similarity >= threshold:
union(i, j)
# 收集分组
groups = defaultdict(list)
for i in range(n):
root = find(i)
groups[root].append(i)
# 只返回包含多个项目的组
result = [group for group in groups.values() if len(group) > 1]
result.sort(key=len, reverse=True)
return result
def print_similarity_analysis(
expressions: List[Expression],
field_name: str,
threshold: float,
chat_id_groups: dict[str, set[str]],
show_details: bool = True,
max_groups: int = 20
):
"""打印相似度分析结果"""
print("\n" + "=" * 80)
print(f"{field_name.upper()} 相似度分析 (阈值: {threshold})")
print("=" * 80)
# 分组分析
groups = group_similar_items(expressions, field_name, threshold, chat_id_groups)
total_items = len(expressions)
similar_items_count = sum(len(group) for group in groups)
unique_groups = len(groups)
print("\n📊 统计信息:")
print(f" 总项目数: {total_items}")
print(f" 相似项目数: {similar_items_count} ({similar_items_count / total_items * 100:.1f}%)")
print(f" 相似组数: {unique_groups}")
print(f" 平均每组项目数: {similar_items_count / unique_groups:.1f}" if unique_groups > 0 else " 平均每组项目数: 0")
if not groups:
print(f"\n未找到相似度 >= {threshold} 的项目组")
return
print(f"\n📋 相似组详情 (显示前 {min(max_groups, len(groups))} 组):")
print()
for group_idx, group in enumerate(groups[:max_groups], 1):
print(f"{group_idx} (共 {len(group)} 个项目):")
if show_details:
# 显示组内所有项目的详细信息
for idx in group:
expr = expressions[idx]
text = getattr(expr, field_name, "")
chat_name = get_chat_name(expr.chat_id)
# 截断过长的文本
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" [{expr.id}] {display_text}")
print(f" 聊天: {chat_name}, Count: {expr.count}")
# 计算组内平均相似度
if len(group) > 1:
similarities = []
above_threshold_pairs = [] # 存储满足阈值的相似对
above_threshold_count = 0
for i in range(len(group)):
for j in range(i + 1, len(group)):
text1 = getattr(expressions[group[i]], field_name, "")
text2 = getattr(expressions[group[j]], field_name, "")
sim = text_similarity(text1, text2)
similarities.append(sim)
if sim >= threshold:
above_threshold_count += 1
# 存储满足阈值的对的信息
expr1 = expressions[group[i]]
expr2 = expressions[group[j]]
display_text1 = text1[:40] + "..." if len(text1) > 40 else text1
display_text2 = text2[:40] + "..." if len(text2) > 40 else text2
above_threshold_pairs.append((
expr1.id, display_text1,
expr2.id, display_text2,
sim
))
if similarities:
avg_sim = sum(similarities) / len(similarities)
min_sim = min(similarities)
max_sim = max(similarities)
above_threshold_ratio = above_threshold_count / len(similarities) * 100
print(f" 平均相似度: {avg_sim:.3f} (范围: {min_sim:.3f} - {max_sim:.3f})")
print(f" 满足阈值({threshold})的比例: {above_threshold_ratio:.1f}% ({above_threshold_count}/{len(similarities)})")
# 显示满足阈值的相似对(这些是直接连接,导致它们被分到一组)
if above_threshold_pairs:
print(" ⚠️ 直接相似的对 (这些对导致它们被分到一组):")
# 按相似度降序排序
above_threshold_pairs.sort(key=lambda x: x[4], reverse=True)
for idx1, text1, idx2, text2, sim in above_threshold_pairs[:10]: # 最多显示10对
print(f" [{idx1}] ↔ [{idx2}]: {sim:.3f}")
print(f" \"{text1}\"\"{text2}\"")
if len(above_threshold_pairs) > 10:
print(f" ... 还有 {len(above_threshold_pairs) - 10} 对满足阈值")
else:
print(f" ⚠️ 警告: 组内没有任何对满足阈值({threshold:.2f}),可能是通过传递性连接")
else:
# 只显示组内第一个项目作为示例
expr = expressions[group[0]]
text = getattr(expr, field_name, "")
display_text = text[:60] + "..." if len(text) > 60 else text
print(f" 示例: {display_text}")
print(f" ... 还有 {len(group) - 1} 个相似项目")
print()
if len(groups) > max_groups:
print(f"... 还有 {len(groups) - max_groups} 组未显示")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="分析expression库中situation和style的相似度")
parser.add_argument(
"--situation-threshold",
type=float,
default=0.7,
help="situation相似度阈值 (0-1, 默认: 0.7)"
)
parser.add_argument(
"--style-threshold",
type=float,
default=0.7,
help="style相似度阈值 (0-1, 默认: 0.7)"
)
parser.add_argument(
"--no-details",
action="store_true",
help="不显示详细信息,只显示统计"
)
parser.add_argument(
"--max-groups",
type=int,
default=20,
help="最多显示的组数 (默认: 20)"
)
parser.add_argument(
"--output",
type=str,
default=None,
help="输出文件路径 (默认: 自动生成带时间戳的文件)"
)
args = parser.parse_args()
# 验证阈值
if not 0 <= args.situation_threshold <= 1:
print("错误: situation-threshold 必须在 0-1 之间")
return
if not 0 <= args.style_threshold <= 1:
print("错误: style-threshold 必须在 0-1 之间")
return
# 确定输出文件路径
if args.output:
output_file = args.output
else:
# 自动生成带时间戳的输出文件
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = os.path.join(project_root, "data", "temp")
os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"expression_similarity_analysis_{timestamp}.txt")
# 使用TeeOutput同时输出到控制台和文件
with TeeOutput(output_file) as tee:
# 临时替换sys.stdout
original_stdout = sys.stdout
sys.stdout = tee
try:
print("=" * 80)
print("Expression 相似度分析工具")
print("=" * 80)
print(f"输出文件: {output_file}")
print()
_run_analysis(args)
finally:
# 恢复原始stdout
sys.stdout = original_stdout
print(f"\n✅ 分析结果已保存到: {output_file}")
def _run_analysis(args):
"""执行分析的主逻辑"""
# 查询所有Expression记录
print("正在从数据库加载Expression数据...")
try:
expressions = list(Expression.select())
except Exception as e:
print(f"❌ 加载数据失败: {e}")
return
if not expressions:
print("❌ 数据库中没有找到Expression记录")
return
print(f"✅ 成功加载 {len(expressions)} 条Expression记录")
print()
# 构建chat_id分组映射
print("正在构建chat_id分组映射根据expression_groups配置...")
try:
chat_id_groups = build_chat_id_groups()
print(f"✅ 成功构建 {len(chat_id_groups)} 个chat_id的分组映射")
if chat_id_groups:
# 统计分组信息
total_related = sum(len(related) for related in chat_id_groups.values())
avg_related = total_related / len(chat_id_groups)
print(f" 平均每个chat_id与 {avg_related:.1f} 个chat_id相关包括自身")
print()
except Exception as e:
print(f"⚠️ 构建chat_id分组映射失败: {e}")
print(" 将使用默认行为只比较相同chat_id的项目")
chat_id_groups = {}
# 分析situation相似度
print_similarity_analysis(
expressions,
"situation",
args.situation_threshold,
chat_id_groups,
show_details=not args.no_details,
max_groups=args.max_groups
)
# 分析style相似度
print_similarity_analysis(
expressions,
"style",
args.style_threshold,
chat_id_groups,
show_details=not args.no_details,
max_groups=args.max_groups
)
print("\n" + "=" * 80)
print("分析完成!")
print("=" * 80)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,303 @@
"""
统计和展示 replyer 动作选择记录
用法:
python scripts/replyer_action_stats.py
"""
import json
import os
import sys
from collections import Counter, defaultdict
from datetime import datetime
from typing import Dict, List, Any
from pathlib import Path
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
try:
from src.common.database.database_model import ChatStreams
from src.chat.message_receive.chat_stream import get_chat_manager
except ImportError:
ChatStreams = None
get_chat_manager = None
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
if ChatStreams:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
if chat_stream.group_name:
return f"{chat_stream.group_name}"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊"
if get_chat_manager:
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(chat_id)
if stream_name:
return stream_name
return f"未知聊天 ({chat_id[:8]}...)"
except Exception:
return f"查询失败 ({chat_id[:8]}...)"
def load_records(temp_dir: str = "data/temp") -> List[Dict[str, Any]]:
"""加载所有 replyer 动作记录"""
records = []
temp_path = Path(temp_dir)
if not temp_path.exists():
print(f"目录不存在: {temp_dir}")
return records
# 查找所有 replyer_action_*.json 文件
pattern = "replyer_action_*.json"
for file_path in temp_path.glob(pattern):
try:
with open(file_path, "r", encoding="utf-8") as f:
data = json.load(f)
records.append(data)
except Exception as e:
print(f"读取文件失败 {file_path}: {e}")
# 按时间戳排序
records.sort(key=lambda x: x.get("timestamp", ""))
return records
def format_timestamp(ts: str) -> str:
"""格式化时间戳"""
try:
dt = datetime.fromisoformat(ts)
return dt.strftime("%Y-%m-%d %H:%M:%S")
except Exception:
return ts
def calculate_time_distribution(records: List[Dict[str, Any]]) -> Dict[str, int]:
"""计算时间分布"""
now = datetime.now()
distribution = {
"今天": 0,
"昨天": 0,
"3天内": 0,
"7天内": 0,
"30天内": 0,
"更早": 0,
}
for record in records:
try:
ts = record.get("timestamp", "")
if not ts:
continue
dt = datetime.fromisoformat(ts)
diff = (now - dt).days
if diff == 0:
distribution["今天"] += 1
elif diff == 1:
distribution["昨天"] += 1
elif diff < 3:
distribution["3天内"] += 1
elif diff < 7:
distribution["7天内"] += 1
elif diff < 30:
distribution["30天内"] += 1
else:
distribution["更早"] += 1
except Exception:
pass
return distribution
def print_statistics(records: List[Dict[str, Any]]):
"""打印统计信息"""
if not records:
print("没有找到任何记录")
return
print("=" * 80)
print("Replyer 动作选择记录统计")
print("=" * 80)
print()
# 总记录数
total_count = len(records)
print(f"📊 总记录数: {total_count}")
print()
# 时间范围
timestamps = [r.get("timestamp", "") for r in records if r.get("timestamp")]
if timestamps:
first_time = format_timestamp(min(timestamps))
last_time = format_timestamp(max(timestamps))
print(f"📅 时间范围: {first_time} ~ {last_time}")
print()
# 按 think_level 统计
think_levels = [r.get("think_level", 0) for r in records]
think_level_counter = Counter(think_levels)
print("🧠 思考深度分布:")
for level in sorted(think_level_counter.keys()):
count = think_level_counter[level]
percentage = (count / total_count) * 100
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f" Level {level} ({level_name}): {count} 次 ({percentage:.1f}%)")
print()
# 按 chat_id 统计(总体)
chat_counter = Counter([r.get("chat_id", "未知") for r in records])
print(f"💬 聊天分布 (共 {len(chat_counter)} 个聊天):")
# 只显示前10个
for chat_id, count in chat_counter.most_common(10):
chat_name = get_chat_name(chat_id)
percentage = (count / total_count) * 100
print(f" {chat_name}: {count} 次 ({percentage:.1f}%)")
if len(chat_counter) > 10:
print(f" ... 还有 {len(chat_counter) - 10} 个聊天")
print()
# 每个 chat_id 的详细统计
print("=" * 80)
print("每个聊天的详细统计")
print("=" * 80)
print()
# 按 chat_id 分组记录
records_by_chat = defaultdict(list)
for record in records:
chat_id = record.get("chat_id", "未知")
records_by_chat[chat_id].append(record)
# 按记录数排序
sorted_chats = sorted(records_by_chat.items(), key=lambda x: len(x[1]), reverse=True)
for chat_id, chat_records in sorted_chats:
chat_name = get_chat_name(chat_id)
chat_count = len(chat_records)
chat_percentage = (chat_count / total_count) * 100
print(f"📱 {chat_name} ({chat_id[:8]}...)")
print(f" 总记录数: {chat_count} ({chat_percentage:.1f}%)")
# 该聊天的 think_level 分布
chat_think_levels = [r.get("think_level", 0) for r in chat_records]
chat_think_counter = Counter(chat_think_levels)
print(" 思考深度分布:")
for level in sorted(chat_think_counter.keys()):
level_count = chat_think_counter[level]
level_percentage = (level_count / chat_count) * 100
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f" Level {level} ({level_name}): {level_count} 次 ({level_percentage:.1f}%)")
# 该聊天的时间范围
chat_timestamps = [r.get("timestamp", "") for r in chat_records if r.get("timestamp")]
if chat_timestamps:
first_time = format_timestamp(min(chat_timestamps))
last_time = format_timestamp(max(chat_timestamps))
print(f" 时间范围: {first_time} ~ {last_time}")
# 该聊天的时间分布
chat_time_dist = calculate_time_distribution(chat_records)
print(" 时间分布:")
for period, count in chat_time_dist.items():
if count > 0:
period_percentage = (count / chat_count) * 100
print(f" {period}: {count} 次 ({period_percentage:.1f}%)")
# 显示该聊天最近的一条理由示例
if chat_records:
latest_record = chat_records[-1]
reason = latest_record.get("reason", "无理由")
if len(reason) > 120:
reason = reason[:120] + "..."
timestamp = format_timestamp(latest_record.get("timestamp", ""))
think_level = latest_record.get("think_level", 0)
print(f" 最新记录 [{timestamp}] (Level {think_level}): {reason}")
print()
# 时间分布
time_dist = calculate_time_distribution(records)
print("⏰ 时间分布:")
for period, count in time_dist.items():
if count > 0:
percentage = (count / total_count) * 100
print(f" {period}: {count} 次 ({percentage:.1f}%)")
print()
# 显示一些示例理由
print("📝 示例理由 (最近5条):")
recent_records = records[-5:]
for i, record in enumerate(recent_records, 1):
reason = record.get("reason", "无理由")
think_level = record.get("think_level", 0)
timestamp = format_timestamp(record.get("timestamp", ""))
chat_id = record.get("chat_id", "未知")
chat_name = get_chat_name(chat_id)
# 截断过长的理由
if len(reason) > 100:
reason = reason[:100] + "..."
print(f" {i}. [{timestamp}] {chat_name} (Level {think_level})")
print(f" {reason}")
print()
# 按 think_level 分组显示理由示例
print("=" * 80)
print("按思考深度分类的示例理由")
print("=" * 80)
print()
for level in [0, 1, 2]:
level_records = [r for r in records if r.get("think_level") == level]
if not level_records:
continue
level_name = {0: "不需要思考", 1: "简单思考", 2: "深度思考"}.get(level, f"未知({level})")
print(f"Level {level} ({level_name}) - 共 {len(level_records)} 条:")
# 显示3个示例选择最近的
examples = level_records[-3:] if len(level_records) >= 3 else level_records
for i, record in enumerate(examples, 1):
reason = record.get("reason", "无理由")
if len(reason) > 150:
reason = reason[:150] + "..."
timestamp = format_timestamp(record.get("timestamp", ""))
chat_id = record.get("chat_id", "未知")
chat_name = get_chat_name(chat_id)
print(f" {i}. [{timestamp}] {chat_name}")
print(f" {reason}")
print()
# 统计信息汇总
print("=" * 80)
print("统计汇总")
print("=" * 80)
print(f"总记录数: {total_count}")
print(f"涉及聊天数: {len(chat_counter)}")
if chat_counter:
avg_count = total_count / len(chat_counter)
print(f"平均每个聊天记录数: {avg_count:.1f}")
else:
print("平均每个聊天记录数: N/A")
print()
def main():
"""主函数"""
records = load_records()
print_statistics(records)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,775 @@
import time
import json
import os
import re
import asyncio
from typing import List, Optional, Tuple, Any, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import (
build_anonymous_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.bw_learner.learner_utils import (
filter_message_content,
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
calculate_style_similarity,
)
from src.bw_learner.jargon_miner import miner_manager
from json_repair import repair_json
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def init_prompt() -> None:
learn_style_prompt = """{chat_str}
你的名字是{bot_name},现在请你完成两个提取任务
任务1请从上面这段群聊中用户的语言风格和说话方式
1. 只考虑文字不要考虑表情包和图片
2. 不要总结SELF的发言
3. 不要涉及具体的人名也不要涉及具体名词
4. 思考有没有特殊的梗一并总结成语言风格
5. 例子仅供参考请严格根据群聊内容总结!!!
注意总结成如下格式的规律总结的内容要详细但具有概括性
例如"AAAAA"可以"BBBBB", AAAAA代表某个场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字
表达方式在3-5个左右不要超过10个
任务2请从上面这段聊天内容中提取"可能是黑话"的候选项黑话/俚语/网络缩写/口头禅
- 必须为对话中真实出现过的短词或短语
- 必须是你无法理解含义的词语没有明确含义的词语请不要选择有明确含义或者含义清晰的词语
- 排除人名@表情包/图片中的内容纯标点常规功能词如的啊等
- 每个词条长度建议 2-8 个字符不强制尽量短小
- 请你提取出可能的黑话最多30个黑话请尽量提取所有
黑话必须为以下几种类型
- 由字母构成的汉语拼音首字母的简写词例如nbyydsxswl
- 英文词语的缩写用英文字母概括一个词汇或含义例如CPUGPUAPI
- 中文词语的缩写用几个汉字概括一个词汇或含义例如社死内卷
输出要求
将表达方式语言风格和黑话以 JSON 数组输出每个元素为一个对象结构如下注意字段名
注意请不要输出重复内容请对表达方式和黑话进行去重
[
{{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}},
{{"situation": "CCCC", "style": "DDDD", "source_id": "7"}}
{{"situation": "对某件事表示十分惊叹", "style": "使用 我嘞个xxxx", "source_id": "[消息编号]"}},
{{"situation": "表示讽刺的赞同,不讲道理", "style": "对对对", "source_id": "[消息编号]"}},
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}},
{{"content": "词条", "source_id": "12"}},
{{"content": "词条2", "source_id": "5"}}
]
其中
表达方式条目
- situation表示在什么情境下的简短概括不超过20个字
- style表示对应的语言风格或常用表达不超过20个字
- source_id该表达方式对应的来源行编号即上方聊天记录中方括号里的数字例如 [3]请只输出数字本身不要包含方括号
黑话jargon条目
- content:表示黑话的内容
- source_id该黑话对应的来源行编号即上方聊天记录中方括号里的数字例如 [3]请只输出数字本身不要包含方括号
现在请你输出 JSON
"""
Prompt(learn_style_prompt, "learn_style_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
)
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock()
async def learn_and_store(
self,
messages: List[Any],
) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
Args:
messages: 外部传入的消息列表必需
num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
"""
if not messages:
return None
random_msg = messages
# 学习用(开启行编号,便于溯源)
random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
prompt: str = await global_prompt_manager.format_prompt(
"learn_style_prompt",
bot_name=global_config.bot.nickname,
chat_str=random_msg_str,
)
# print(f"random_msg_str:{random_msg_str}")
# logger.info(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
# 解析 LLM 返回的表达方式列表和黑话列表(包含来源行编号)
expressions: List[Tuple[str, str, str]]
jargon_entries: List[Tuple[str, str]] # (content, source_id)
expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
# 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
expressions = []
# 检查黑话数量如果超过30个则放弃本次黑话学习
if len(jargon_entries) > 30:
logger.info(f"黑话提取数量超过30个实际{len(jargon_entries)}个),放弃本次黑话学习")
jargon_entries = []
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
if jargon_entries:
await self._process_jargon_entries(jargon_entries, random_msg)
# 如果没有表达方式,直接返回
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return []
logger.info(f"学习的prompt: {prompt}")
logger.info(f"学习的expressions: {expressions}")
logger.info(f"学习的jargon_entries: {jargon_entries}")
logger.info(f"学习的response: {response}")
# 直接根据 source_id 在 random_msg 中溯源,获取 context
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
for situation, style, source_id in expressions:
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
# 无效的来源行编号,跳过
continue
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
if line_index < 0 or line_index >= len(random_msg):
# 超出范围,跳过
continue
# 当前行的原始内容
current_msg = random_msg[line_index]
# 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg):
continue
context = filter_message_content(current_msg.processed_plain_text or "")
if not context:
continue
# 过滤掉包含 SELF 的内容(不学习)
if "SELF" in (situation or "") or "SELF" in (style or "") or "SELF" in context:
logger.info(
f"跳过包含 SELF 的表达方式: situation={situation}, style={style}, source_id={source_id}"
)
continue
filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
return learnt_expressions
def parse_expression_response(self, response: str) -> Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
"""
解析 LLM 返回的表达风格总结和黑话 JSON提取两个列表
期望的 JSON 结构
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}, // 表达方式
{"content": "词条", "source_id": "12"}, // 黑话
...
]
Returns:
Tuple[List[Tuple[str, str, str]], List[Tuple[str, str]]]:
第一个列表是表达方式 (situation, style, source_id)
第二个列表是黑话 (content, source_id)
"""
if not response:
return [], []
raw = response.strip()
# 尝试提取 ```json 代码块
json_block_pattern = r"```json\s*(.*?)\s*```"
match = re.search(json_block_pattern, raw, re.DOTALL)
if match:
raw = match.group(1).strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = None
expressions: List[Tuple[str, str, str]] = [] # (situation, style, source_id)
jargon_entries: List[Tuple[str, str]] = [] # (content, source_id)
try:
# 优先尝试直接解析
if raw.startswith("[") and raw.endswith("]"):
parsed = json.loads(raw)
else:
repaired = repair_json(raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as parse_error:
# 如果解析失败,尝试修复中文引号问题
# 使用状态机方法,在 JSON 字符串值内部将中文引号替换为转义的英文引号
try:
def fix_chinese_quotes_in_json(text):
"""使用状态机修复 JSON 字符串值中的中文引号"""
result = []
i = 0
in_string = False
escape_next = False
while i < len(text):
char = text[i]
if escape_next:
# 当前字符是转义字符后的字符,直接添加
result.append(char)
escape_next = False
i += 1
continue
if char == "\\":
# 转义字符
result.append(char)
escape_next = True
i += 1
continue
if char == '"' and not escape_next:
# 遇到英文引号,切换字符串状态
in_string = not in_string
result.append(char)
i += 1
continue
if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 U+201C
result.append('\\"')
elif char == '"': # 中文右引号 U+201D
result.append('\\"')
else:
result.append(char)
else:
# 不在字符串内,直接添加
result.append(char)
i += 1
return "".join(result)
fixed_raw = fix_chinese_quotes_in_json(raw)
# 再次尝试解析
if fixed_raw.startswith("[") and fixed_raw.endswith("]"):
parsed = json.loads(fixed_raw)
else:
repaired = repair_json(fixed_raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception as fix_error:
logger.error(f"解析表达风格 JSON 失败,初始错误: {type(parse_error).__name__}: {str(parse_error)}")
logger.error(f"修复中文引号后仍失败,错误: {type(fix_error).__name__}: {str(fix_error)}")
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
logger.error(f"处理后的 JSON 字符串前500字符{raw[:500]}")
return []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return []
for item in parsed_list:
if not isinstance(item, dict):
continue
# 检查是否是表达方式条目(有 situation 和 style
situation = str(item.get("situation", "")).strip()
style = str(item.get("style", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if situation and style and source_id:
# 表达方式条目
expressions.append((situation, style, source_id))
elif item.get("content"):
# 黑话条目(有 content 字段)
content = str(item.get("content", "")).strip()
source_id = str(item.get("source_id", "")).strip()
if content and source_id:
jargon_entries.append((content, source_id))
return expressions, jargon_entries
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str, str]] = []
removed_count = 0
for situation, style, source_id in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style, source_id))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
current_time: float,
) -> None:
# 第一层:检查是否有完全一致的 style检查 style 字段和 style_list
expr_obj = await self._find_exact_style_match(style)
if expr_obj:
# 找到完全匹配的 style合并到现有记录不使用 LLM 总结)
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
style=style,
context=context,
current_time=current_time,
use_llm_summary=False,
)
return
# 第二层:检查是否有相似的 style相似度 >= 0.75,检查 style 字段和 style_list
similar_expr_obj = await self._find_similar_style_expression(style, similarity_threshold=0.75)
if similar_expr_obj:
# 找到相似的 style合并到现有记录使用 LLM 总结)
await self._update_existing_expression(
expr_obj=similar_expr_obj,
situation=situation,
style=style,
context=context,
current_time=current_time,
use_llm_summary=True,
)
return
# 没有找到匹配的记录,创建新记录
await self._create_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
async def _create_expression_record(
self,
situation: str,
style: str,
context: str,
current_time: float,
) -> None:
content_list = [situation]
# 创建新记录时,直接使用原始的 situation不进行总结
formatted_situation = situation
Expression.create(
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
style_list=None, # 新记录初始时 style_list 为空
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
style: str,
context: str,
current_time: float,
use_llm_summary: bool = True,
) -> None:
"""
更新现有 Expression 记录style 完全匹配或相似的情况
将新的 situation 添加到 content_list将新的 style 添加到 style_list如果不同
Args:
use_llm_summary: 是否使用 LLM 进行总结完全匹配时为 False相似匹配时为 True
"""
# 更新 content_list添加新的 situation
content_list = self._parse_content_list(expr_obj.content_list)
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
# 更新 style_list如果 style 不同,添加到 style_list
style_list = self._parse_style_list(expr_obj.style_list)
# 将原有的 style 也加入 style_list如果还没有的话
if expr_obj.style and expr_obj.style not in style_list:
style_list.append(expr_obj.style)
# 如果新的 style 不在 style_list 中,添加它
if style not in style_list:
style_list.append(style)
expr_obj.style_list = json.dumps(style_list, ensure_ascii=False)
# 更新其他字段
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.last_active_time = current_time
expr_obj.context = context
if use_llm_summary:
# 相似匹配时,使用 LLM 重新组合 situation 和 style
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
new_style = await self._compose_style_text(
style_list=style_list,
count=expr_obj.count,
fallback=expr_obj.style or style,
)
expr_obj.style = new_style
else:
# 完全匹配时,不进行 LLM 总结,保持原有的 situation 和 style 不变
# 只更新 content_list 和 style_list
pass
expr_obj.save()
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
def _parse_style_list(self, stored_list: Optional[str]) -> List[str]:
"""解析 style_list JSON 字符串为列表,逻辑与 _parse_content_list 相同"""
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _find_exact_style_match(self, style: str) -> Optional[Expression]:
"""
查找具有完全匹配 style Expression 记录
检查 style 字段和 style_list 中的每一项
Args:
style: 要查找的 style
Returns:
找到的 Expression 对象如果没有找到则返回 None
"""
# 查询同一 chat_id 的所有记录
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
for expr in all_expressions:
# 检查 style 字段
if expr.style == style:
return expr
# 检查 style_list 中的每一项
style_list = self._parse_style_list(expr.style_list)
if style in style_list:
return expr
return None
async def _find_similar_style_expression(self, style: str, similarity_threshold: float = 0.75) -> Optional[Expression]:
"""
查找具有相似 style Expression 记录
检查 style 字段和 style_list 中的每一项
Args:
style: 要查找的 style
similarity_threshold: 相似度阈值默认 0.75
Returns:
找到的最相似的 Expression 对象如果没有找到则返回 None
"""
# 查询同一 chat_id 的所有记录
all_expressions = Expression.select().where(Expression.chat_id == self.chat_id)
best_match = None
best_similarity = 0.0
for expr in all_expressions:
# 检查 style 字段
similarity = calculate_style_similarity(style, expr.style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
# 检查 style_list 中的每一项
style_list = self._parse_style_list(expr.style_list)
for existing_style in style_list:
similarity = calculate_style_similarity(style, existing_style)
if similarity >= similarity_threshold and similarity > best_similarity:
best_similarity = similarity
best_match = expr
if best_match:
logger.debug(f"找到相似的 style: 相似度={best_similarity:.3f}, 现有='{best_match.style}', 新='{style}'")
return best_match
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _compose_style_text(self, style_list: List[str], count: int, fallback: str = "") -> str:
"""
组合 style 文本如果 style_list 有多个元素则尝试总结
"""
sanitized = [s.strip() for s in style_list if s.strip()]
if len(sanitized) > 1:
# 只有当有多个 style 时才尝试总结
summary = await self._summarize_styles(sanitized)
if summary:
return summary
# 如果只有一个或总结失败,返回第一个或 fallback
return sanitized[0] if sanitized else fallback
async def _summarize_styles(self, styles: List[str]) -> Optional[str]:
"""总结多个 style生成一个概括性的 style 描述"""
if not styles or len(styles) <= 1:
return None
prompt = (
"请阅读以下多个语言风格/表达方式,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in styles[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
logger.error(f"概括表达风格失败: {e}")
return None
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
"""
处理从 expression learner 提取的黑话条目路由到 jargon_miner
Args:
jargon_entries: 黑话条目列表每个元素是 (content, source_id)
messages: 消息列表用于构建上下文
"""
if not jargon_entries or not messages:
return
# 获取 jargon_miner 实例
jargon_miner = miner_manager.get_miner(self.chat_id)
# 构建黑话条目格式,与 jargon_miner.run_once 中的格式一致
entries: List[Dict[str, List[str]]] = []
for content, source_id in jargon_entries:
content = content.strip()
if not content:
continue
# 过滤掉包含 SELF 的黑话,不学习
if "SELF" in content:
logger.info(f"跳过包含 SELF 的黑话: {content}")
continue
# 检查是否包含机器人名称
if contains_bot_self_name(content):
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
continue
# 解析 source_id
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
logger.warning(f"黑话条目 source_id 无效: content={content}, source_id={source_id_str}")
continue
# build_anonymous_messages 的编号从 1 开始
line_index = int(source_id_str) - 1
if line_index < 0 or line_index >= len(messages):
logger.warning(f"黑话条目 source_id 超出范围: content={content}, source_id={source_id_str}")
continue
# 检查是否是机器人自己的消息
target_msg = messages[line_index]
if is_bot_message(target_msg):
logger.info(f"跳过引用机器人自身消息的黑话: content={content}, source_id={source_id_str}")
continue
# 构建上下文段落
context_paragraph = build_context_paragraph(messages, line_index)
if not context_paragraph:
logger.warning(f"黑话条目上下文为空: content={content}, source_id={source_id_str}")
continue
entries.append({"content": content, "raw_content": [context_paragraph]})
if not entries:
return
# 调用 jargon_miner 处理这些条目
await jargon_miner.process_extracted_entries(entries)
init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
expression_learner_manager = ExpressionLearnerManager()

View File

@ -82,9 +82,7 @@ class ExpressionReflector:
# 获取未检查的表达
try:
logger.info("[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = (
Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
)
expressions = Expression.select().where((~Expression.checked) & (~Expression.rejected)).limit(50)
expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
@ -147,7 +145,7 @@ expression_reflector_manager = ExpressionReflectorManager()
async def _check_tracker_exists(operator_config: str) -> bool:
"""检查指定 Operator 是否已有活跃的 Tracker"""
from src.express.reflect_tracker import reflect_tracker_manager
from src.bw_learner.reflect_tracker import reflect_tracker_manager
chat_manager = get_chat_manager()
chat_stream = None
@ -242,7 +240,7 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
stream_id = chat_stream.stream_id
# 注册 Tracker
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
from src.bw_learner.reflect_tracker import ReflectTracker, reflect_tracker_manager
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
reflect_tracker_manager.add_tracker(stream_id, tracker)

View File

@ -10,7 +10,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.express_utils import weighted_sample
from src.bw_learner.learner_utils import weighted_sample
logger = get_logger("expression_selector")
@ -111,6 +111,85 @@ class ExpressionSelector:
return group_chat_ids
return [chat_id]
def _select_expressions_simple(self, chat_id: str, max_num: int) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
简单模式只选择 count > 1 的项目要求至少有10个才进行选择随机选5个不进行LLM选择
Args:
chat_id: 聊天流ID
max_num: 最大选择数量此参数在此模式下不使用固定选择5个
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 支持多chat_id合并抽选
related_chat_ids = self.get_related_chat_ids(chat_id)
# 查询所有相关chat_id的表达方式排除 rejected=1 的,且只选择 count > 1 的
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected) & (Expression.count > 1)
)
style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
}
for expr in style_query
]
# 要求至少有一定数量的 count > 1 的表达方式才进行“完整简单模式”选择
min_required = 8
if len(style_exprs) < min_required:
# 高 count 样本不足:如果还有候选,就降级为随机选 3 个;如果一个都没有,则直接返回空
if not style_exprs:
logger.info(
f"聊天流 {chat_id} 没有满足 count > 1 且未被拒绝的表达方式,简单模式不进行选择"
)
# 完全没有高 count 样本时退化为全量随机抽样不进入LLM流程
fallback_num = min(3, max_num) if max_num > 0 else 3
fallback_selected = self._random_expressions(chat_id, fallback_num)
if fallback_selected:
self.update_expressions_last_active_time(fallback_selected)
selected_ids = [expr["id"] for expr in fallback_selected]
logger.info(
f"聊天流 {chat_id} 使用简单模式降级随机抽选 {len(fallback_selected)} 个表达(无 count>1 样本)"
)
return fallback_selected, selected_ids
return [], []
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_required} 个(实际 {len(style_exprs)} 个),"
f"简单模式降级为随机选择 3 个"
)
select_count = min(3, len(style_exprs))
else:
# 高 count 数量达标时,固定选择 5 个
select_count = 5
import random
selected_style = random.sample(style_exprs, select_count)
# 更新last_active_time
if selected_style:
self.update_expressions_last_active_time(selected_style)
selected_ids = [expr["id"] for expr in selected_style]
logger.debug(
f"think_level=0: 从 {len(style_exprs)} 个 count>1 的表达方式中随机选择了 {len(selected_style)}"
)
return selected_style, selected_ids
except Exception as e:
logger.error(f"简单模式选择表达方式失败: {e}")
return [], []
def _random_expressions(self, chat_id: str, total_num: int) -> List[Dict[str, Any]]:
"""
随机选择表达方式
@ -127,9 +206,7 @@ class ExpressionSelector:
related_chat_ids = self.get_related_chat_ids(chat_id)
# 优化一次性查询所有相关chat_id的表达方式排除 rejected=1 的表达
style_query = Expression.select().where(
(Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected)
)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
style_exprs = [
{
@ -164,6 +241,7 @@ class ExpressionSelector:
max_num: int = 10,
target_message: Optional[str] = None,
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
选择适合的表达方式使用classic模式随机选择+LLM选择
@ -174,6 +252,7 @@ class ExpressionSelector:
max_num: 最大选择数量
target_message: 目标消息内容
reply_reason: planner给出的回复理由
think_level: 思考级别0/1
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
@ -184,8 +263,10 @@ class ExpressionSelector:
return [], []
# 使用classic模式随机选择+LLM选择
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式")
return await self._select_expressions_classic(chat_id, chat_info, max_num, target_message, reply_reason)
logger.debug(f"使用classic模式为聊天流 {chat_id} 选择表达方式think_level={think_level}")
return await self._select_expressions_classic(
chat_id, chat_info, max_num, target_message, reply_reason, think_level
)
async def _select_expressions_classic(
self,
@ -194,6 +275,7 @@ class ExpressionSelector:
max_num: int = 10,
target_message: Optional[str] = None,
reply_reason: Optional[str] = None,
think_level: int = 1,
) -> Tuple[List[Dict[str, Any]], List[int]]:
"""
classic模式随机选择+LLM选择
@ -204,24 +286,91 @@ class ExpressionSelector:
max_num: 最大选择数量
target_message: 目标消息内容
reply_reason: planner给出的回复理由
think_level: 思考级别0/1
Returns:
Tuple[List[Dict[str, Any]], List[int]]: 选中的表达方式列表和ID列表
"""
try:
# 1. 使用随机抽样选择表达方式
style_exprs = self._random_expressions(chat_id, 20)
# think_level == 0: 只选择 count > 1 的项目随机选10个不进行LLM选择
if think_level == 0:
return self._select_expressions_simple(chat_id, max_num)
if len(style_exprs) < 10:
logger.info(f"聊天流 {chat_id} 表达方式正在积累中")
# think_level == 1: 先选高count再从所有表达方式中随机抽样
# 1. 获取所有表达方式并分离 count > 1 和 count <= 1 的
related_chat_ids = self.get_related_chat_ids(chat_id)
style_query = Expression.select().where((Expression.chat_id.in_(related_chat_ids)) & (~Expression.rejected))
all_style_exprs = [
{
"id": expr.id,
"situation": expr.situation,
"style": expr.style,
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
"checked": expr.checked if getattr(expr, "checked", None) is not None else False,
}
for expr in style_query
]
# 分离 count > 1 和 count <= 1 的表达方式
high_count_exprs = [expr for expr in all_style_exprs if (expr.get("count", 1) or 1) > 1]
# 根据 think_level 设置要求(仅支持 0/10 已在上方返回)
min_high_count = 10
min_total_count = 10
select_high_count = 5
select_random_count = 5
# 检查数量要求
# 对于高 count 表达:如果数量不足,不再直接停止,而是仅跳过“高 count 优先选择”
if len(high_count_exprs) < min_high_count:
logger.info(
f"聊天流 {chat_id} count > 1 的表达方式不足 {min_high_count} 个(实际 {len(high_count_exprs)} 个),"
f"将跳过高 count 优先选择,仅从全部表达中随机抽样"
)
high_count_valid = False
else:
high_count_valid = True
# 总量不足仍然直接返回,避免样本过少导致选择质量过低
if len(all_style_exprs) < min_total_count:
logger.info(
f"聊天流 {chat_id} 总表达方式不足 {min_total_count} 个(实际 {len(all_style_exprs)} 个),不进行选择"
)
return [], []
# 先选取高count的表达方式如果数量达标
if high_count_valid:
selected_high = weighted_sample(high_count_exprs, min(len(high_count_exprs), select_high_count))
else:
selected_high = []
# 然后从所有表达方式中随机抽样(使用加权抽样)
remaining_num = select_random_count
selected_random = weighted_sample(all_style_exprs, min(len(all_style_exprs), remaining_num))
# 合并候选池(去重,避免重复)
candidate_exprs = selected_high.copy()
candidate_ids = {expr["id"] for expr in candidate_exprs}
for expr in selected_random:
if expr["id"] not in candidate_ids:
candidate_exprs.append(expr)
candidate_ids.add(expr["id"])
# 打乱顺序避免高count的都在前面
import random
random.shuffle(candidate_exprs)
# 2. 构建所有表达方式的索引和情境列表
all_expressions: List[Dict[str, Any]] = []
all_situations: List[str] = []
# 添加style表达方式
for expr in style_exprs:
for expr in candidate_exprs:
expr = expr.copy()
all_expressions.append(expr)
all_situations.append(f"{len(all_expressions)}.当 {expr['situation']} 时,使用 {expr['style']}")
@ -233,7 +382,7 @@ class ExpressionSelector:
all_situations_str = "\n".join(all_situations)
if target_message:
target_message_str = f",现在你想要对这条消息进行回复:“{target_message}"
target_message_str = f',现在你想要对这条消息进行回复:"{target_message}"'
target_message_extra_block = "4.考虑你要回复的目标消息"
else:
target_message_str = ""
@ -262,7 +411,8 @@ class ExpressionSelector:
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# print(prompt)
print(prompt)
print(content)
if not content:
logger.warning("LLM返回空结果")

View File

@ -7,8 +7,13 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.jargon.jargon_miner import search_jargon
from src.jargon.jargon_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
from src.bw_learner.jargon_miner import search_jargon
from src.bw_learner.learner_utils import (
is_bot_message,
contains_bot_self_name,
parse_chat_id_list,
chat_id_list_contains,
)
logger = get_logger("jargon")
@ -82,7 +87,7 @@ class JargonExplainer:
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
# 开启all_global只查询is_global=True的记录
query = query.where(Jargon.is_global)
else:
@ -107,7 +112,7 @@ class JargonExplainer:
continue
# 检查chat_id如果all_global=False
if not global_config.jargon.all_global:
if not global_config.expression.all_global_jargon:
if jargon.is_global:
# 全局黑话,包含
pass
@ -181,7 +186,7 @@ class JargonExplainer:
content = entry["content"]
# 根据是否开启全局黑话,决定查询方式
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
# 开启全局黑话查询所有is_global=True的记录
results = search_jargon(
keyword=content,
@ -265,7 +270,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
return []
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
query = query.where(Jargon.is_global)
query = query.order_by(Jargon.count.desc())
@ -277,7 +282,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
if not content:
continue
if not global_config.jargon.all_global and not jargon.is_global:
if not global_config.expression.all_global_jargon and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id):
continue
@ -357,4 +362,4 @@ async def retrieve_concepts_with_jargon(concepts: List[str], chat_id: str) -> st
if results:
return "【概念检索结果】\n" + "\n".join(results) + "\n"
return ""
return ""

View File

@ -1,8 +1,8 @@
import time
import json
import asyncio
import random
from collections import OrderedDict
from typing import List, Dict, Optional, Any
from typing import List, Dict, Optional, Any, Callable
from json_repair import repair_json
from peewee import fn
@ -13,10 +13,9 @@ from src.config.config import model_config, global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_by_timestamp_with_chat_inclusive,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.jargon.jargon_utils import (
from src.bw_learner.learner_utils import (
is_bot_message,
build_context_paragraph,
contains_bot_self_name,
@ -29,6 +28,29 @@ from src.jargon.jargon_utils import (
logger = get_logger("jargon")
def _is_single_char_jargon(content: str) -> bool:
"""
判断是否是单字黑话单个汉字英文或数字
Args:
content: 词条内容
Returns:
bool: 如果是单字黑话返回True否则返回False
"""
if not content or len(content) != 1:
return False
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
"\u4e00" <= char <= "\u9fff" # 汉字
or "a" <= char <= "z" # 小写字母
or "A" <= char <= "Z" # 大写字母
or "0" <= char <= "9" # 数字
)
def _init_prompt() -> None:
prompt_str = """
**聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID**
@ -36,11 +58,9 @@ def _init_prompt() -> None:
请从上面这段聊天内容中提取"可能是黑话"的候选项黑话/俚语/网络缩写/口头禅
- 必须为对话中真实出现过的短词或短语
- 必须是你无法理解含义的词语没有明确含义的词语
- 请不要选择有明确含义或者含义清晰的词语
- 必须是你无法理解含义的词语没有明确含义的词语请不要选择有明确含义或者含义清晰的词语
- 排除人名@表情包/图片中的内容纯标点常规功能词如的啊等
- 每个词条长度建议 2-8 个字符不强制尽量短小
- 合并重复项去重
黑话必须为以下几种类型
- 由字母构成的汉语拼音首字母的简写词例如nbyydsxswl
@ -48,7 +68,7 @@ def _init_prompt() -> None:
- 中文词语的缩写用几个汉字概括一个词汇或含义例如社死内卷
JSON 数组输出元素为对象严格按以下结构
请你提取出可能的黑话最多10
请你提取出可能的黑话最多30个黑话请尽量提取所有
[
{{"content": "词条", "msg_id": "m12"}}, // msg_id 必须与上方聊天中展示的ID完全一致
{{"content": "词条2", "msg_id": "m15"}}
@ -67,12 +87,14 @@ def _init_inference_prompts() -> None:
{content}
**词条出现的上下文其中的{bot_name}的发言内容是你自己的发言**
{raw_content_list}
{previous_meaning_section}
请根据上下文推断"{content}"这个词条的含义
- 如果这是一个黑话俚语或网络用语请推断其含义
- 如果含义明确常规词汇也请说明
- {bot_name} 的发言内容可能包含错误请不要参考其发言内容
- 如果上下文信息不足无法推断含义请设置 no_info true
{previous_meaning_instruction}
JSON 格式输出
{{
@ -166,23 +188,24 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
class JargonMiner:
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.last_learning_time: float = time.time()
# 频率控制,可按需调整
self.min_messages_for_learning: int = 10
self.min_learning_interval: float = 20
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="jargon.extract",
)
self.llm_inference = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="jargon.inference",
)
# 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(self.chat_id)
self.stream_name = stream_name if stream_name else self.chat_id
self.cache_limit = 100
self.cache_limit = 50
self.cache: OrderedDict[str, None] = OrderedDict()
# 黑话提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
@ -195,6 +218,10 @@ class JargonMiner:
if not key:
return
# 单字黑话(单个汉字、英文或数字)不记录到缓存
if _is_single_char_jargon(key):
return
if key in self.cache:
self.cache.move_to_end(key)
else:
@ -267,16 +294,44 @@ class JargonMiner:
logger.warning(f"jargon {content} 没有raw_content跳过推断")
return
# 获取当前count和上一次的meaning
current_count = jargon_obj.count or 0
previous_meaning = jargon_obj.meaning or ""
# 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(
f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目"
)
# 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = ""
previous_meaning_instruction = ""
if current_count in [24, 60, 100] and previous_meaning:
previous_meaning_section = f"""
**上一次推断的含义仅供参考**
{previous_meaning}
"""
previous_meaning_instruction = (
"- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
)
prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt",
content=content,
bot_name=global_config.bot.nickname,
raw_content_list=raw_content_text,
previous_meaning_section=previous_meaning_section,
previous_meaning_instruction=previous_meaning_instruction,
)
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3)
if not response1:
logger.warning(f"jargon {content} 推断1失败无响应")
return
@ -313,7 +368,7 @@ class JargonMiner:
content=content,
)
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
response2, _ = await self.llm_inference.generate_response_async(prompt2, temperature=0.3)
if not response2:
logger.warning(f"jargon {content} 推断2失败无响应")
return
@ -360,7 +415,7 @@ class JargonMiner:
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
response3, _ = await self.llm_inference.generate_response_async(prompt3, temperature=0.3)
if not response3:
logger.warning(f"jargon {content} 比较失败:无响应")
return
@ -425,45 +480,21 @@ class JargonMiner:
traceback.print_exc()
def should_trigger(self) -> bool:
# 冷却时间检查
if time.time() - self.last_learning_time < self.min_learning_interval:
return False
async def run_once(
self,
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
运行一次黑话提取
# 拉取最近消息数量是否足够
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning)
async def run_once(self) -> None:
Args:
messages: 外部传入的消息列表必需
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
# 使用异步锁防止并发执行
async with self._extraction_lock:
try:
# 在锁内检查,避免并发触发
if not self.should_trigger():
return
chat_stream = get_chat_manager().get_stream(self.chat_id)
if not chat_stream:
return
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_learning_time
extraction_end_time = time.time()
# 立即更新学习时间,防止并发触发
self.last_learning_time = extraction_end_time
# 拉取学习窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time,
limit=20,
)
if not messages:
return
@ -538,6 +569,11 @@ class JargonMiner:
if contains_bot_self_name(content):
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
# 检查是否包含人物名称
if person_name_filter and person_name_filter(content):
logger.info(f"解析阶段跳过包含人物名称的词条: {content}")
continue
msg_id_str = str(msg_id_value or "").strip()
if not msg_id_str:
@ -603,7 +639,7 @@ class JargonMiner:
# 查找匹配的记录
matched_obj = None
for obj in query:
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
# 开启all_global所有content匹配的记录都可以
matched_obj = obj
break
@ -626,7 +662,9 @@ class JargonMiner:
if obj.raw_content:
try:
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
json.loads(obj.raw_content)
if isinstance(obj.raw_content, str)
else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
@ -643,7 +681,7 @@ class JargonMiner:
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
# 开启all_global时确保记录标记为is_global=True
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
obj.is_global = True
# 关闭all_global时保持原有is_global不变不修改
@ -659,7 +697,7 @@ class JargonMiner:
updated += 1
else:
# 没找到匹配记录,创建新记录
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
# 开启all_global新记录默认为is_global=True
is_global_new = True
else:
@ -699,6 +737,158 @@ class JargonMiner:
logger.error(f"JargonMiner 运行失败: {e}")
# 即使失败也保持时间戳更新,避免频繁重试
async def process_extracted_entries(
self,
entries: List[Dict[str, List[str]]],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
处理已提取的黑话条目 expression_learner 路由过来的
Args:
entries: 黑话条目列表每个元素格式为 {"content": "...", "raw_content": [...]}
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
if not entries:
return
try:
# 去重并合并raw_content按 content 聚合)
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
for entry in entries:
content_key = entry["content"]
# 检查是否包含人物名称
# logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}")
# logger.info(f"person_name_filter: {person_name_filter}")
if person_name_filter and person_name_filter(content_key):
logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}")
continue
raw_list = entry.get("raw_content", []) or []
if content_key in merged_entries:
merged_entries[content_key]["raw_content"].extend(raw_list)
else:
merged_entries[content_key] = {
"content": content_key,
"raw_content": list(raw_list),
}
uniq_entries = []
for merged_entry in merged_entries.values():
raw_content_list = merged_entry["raw_content"]
if raw_content_list:
merged_entry["raw_content"] = list(dict.fromkeys(raw_content_list))
uniq_entries.append(merged_entry)
saved = 0
updated = 0
for entry in uniq_entries:
content = entry["content"]
raw_content_list = entry["raw_content"] # 已经是列表
try:
# 查询所有content匹配的记录
query = Jargon.select().where(Jargon.content == content)
# 查找匹配的记录
matched_obj = None
for obj in query:
if global_config.expression.all_global_jargon:
# 开启all_global所有content匹配的记录都可以
matched_obj = obj
break
else:
# 关闭all_global需要检查chat_id列表是否包含目标chat_id
chat_id_list = parse_chat_id_list(obj.chat_id)
if chat_id_list_contains(chat_id_list, self.chat_id):
matched_obj = obj
break
if matched_obj:
obj = matched_obj
try:
obj.count = (obj.count or 0) + 1
except Exception:
obj.count = 1
# 合并raw_content列表读取现有列表追加新值去重
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
existing_raw_content = [obj.raw_content] if obj.raw_content else []
# 合并并去重
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
# 更新chat_id列表增加当前chat_id的计数
chat_id_list = parse_chat_id_list(obj.chat_id)
updated_chat_id_list = update_chat_id_list(chat_id_list, self.chat_id, increment=1)
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
# 开启all_global时确保记录标记为is_global=True
if global_config.expression.all_global_jargon:
obj.is_global = True
# 关闭all_global时保持原有is_global不变不修改
obj.save()
# 检查是否需要推断(达到阈值且超过上次判定值)
if _should_infer_meaning(obj):
# 异步触发推断,不阻塞主流程
# 重新加载对象以确保数据最新
jargon_id = obj.id
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
updated += 1
else:
# 没找到匹配记录,创建新记录
if global_config.expression.all_global_jargon:
# 开启all_global新记录默认为is_global=True
is_global_new = True
else:
# 关闭all_global新记录is_global=False
is_global_new = False
# 使用新格式创建chat_id列表[[chat_id, count]]
chat_id_list = [[self.chat_id, 1]]
chat_id_json = json.dumps(chat_id_list, ensure_ascii=False)
Jargon.create(
content=content,
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
chat_id=chat_id_json,
is_global=is_global_new,
count=1,
)
saved += 1
except Exception as e:
logger.error(f"保存jargon失败: chat_id={self.chat_id}, content={content}, err={e}")
continue
finally:
self._add_to_cache(content)
# 固定输出提取的jargon结果格式化为可读形式只要有提取结果就输出
if uniq_entries:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
# 输出格式化的结果使用logger.info会自动应用jargon模块的颜色
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
if saved or updated:
logger.debug(f"jargon写入: 新增 {saved} 条,更新 {updated}chat_id={self.chat_id}")
except Exception as e:
logger.error(f"处理已提取的黑话条目失败: {e}")
class JargonMinerManager:
def __init__(self) -> None:
@ -713,11 +903,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager()
async def extract_and_store_jargon(chat_id: str) -> None:
miner = miner_manager.get_miner(chat_id)
await miner.run_once()
def search_jargon(
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:
@ -765,7 +950,7 @@ def search_jargon(
query = query.where(search_condition)
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
if global_config.expression.all_global_jargon:
# 开启all_global所有记录都是全局的查询所有is_global=True的记录无视chat_id
query = query.where(Jargon.is_global)
# 注意对于all_global=False的情况chat_id过滤在Python层面进行以便兼容新旧格式
@ -782,7 +967,7 @@ def search_jargon(
results = []
for jargon in query:
# 如果提供了chat_id且all_global=False需要检查chat_id列表是否包含目标chat_id
if chat_id and not global_config.jargon.all_global:
if chat_id and not global_config.expression.all_global_jargon:
chat_id_list = parse_chat_id_list(jargon.chat_id)
# 如果记录是is_global=True或者chat_id列表包含目标chat_id则包含
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):

View File

@ -0,0 +1,380 @@
import re
import difflib
import random
import json
from datetime import datetime
from typing import Optional, List, Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.chat.utils.utils import parse_platform_accounts
logger = get_logger("learner_utils")
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID
content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def calculate_style_similarity(style1: str, style2: str) -> float:
"""
计算两个 style 的相似度返回0-1之间的值
在计算前会移除"使用""句式"这两个词参考 expression_similarity_analysis.py
Args:
style1: 第一个 style
style2: 第二个 style
Returns:
float: 相似度值范围0-1
"""
if not style1 or not style2:
return 0.0
# 移除"使用"和"句式"这两个词
def remove_ignored_words(text: str) -> str:
"""移除需要忽略的词"""
text = text.replace("使用", "")
text = text.replace("句式", "")
return text.strip()
cleaned_style1 = remove_ignored_words(style1)
cleaned_style2 = remove_ignored_words(style2)
# 如果清理后文本为空返回0
if not cleaned_style1 or not cleaned_style2:
return 0.0
return difflib.SequenceMatcher(None, cleaned_style1, cleaned_style2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间
count越高权重越高但最多为基础权重的5倍
如果表达已checked权重会再乘以3倍
"""
if not population:
return []
counts = []
checked_flags = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
# 获取checked状态
checked = item.get("checked", False)
checked_flags.append(bool(checked))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
base_weights = [1.0 for _ in counts]
else:
base_weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
base_weights.append(1.0 + normalized * 4.0) # 1~5
# 如果checked权重乘以3
weights = []
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
if checked:
weights.append(base_weight * 3.0)
else:
weights.append(base_weight)
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
"""
解析chat_id字段兼容旧格式字符串和新格式JSON列表
Args:
chat_id_value: 可能是字符串旧格式或JSON字符串新格式
Returns:
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
"""
if not chat_id_value:
return []
# 如果是字符串尝试解析为JSON
if isinstance(chat_id_value, str):
# 尝试解析JSON
try:
parsed = json.loads(chat_id_value)
if isinstance(parsed, list):
# 新格式:已经是列表
return parsed
elif isinstance(parsed, str):
# 解析后还是字符串,说明是旧格式
return [[parsed, 1]]
else:
# 其他类型,当作旧格式处理
return [[str(chat_id_value), 1]]
except (json.JSONDecodeError, TypeError):
# 解析失败,当作旧格式(纯字符串)
return [[str(chat_id_value), 1]]
elif isinstance(chat_id_value, list):
# 已经是列表格式
return chat_id_value
else:
# 其他类型,转换为旧格式
return [[str(chat_id_value), 1]]
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
"""
更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目
Args:
chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要更新或添加的chat_id
increment: 增加的计数默认为1
Returns:
List[List[Any]]: 更新后的chat_id列表
"""
item = _find_chat_id_item(chat_id_list, target_chat_id)
if item is not None:
# 找到匹配的chat_id增加计数
if len(item) >= 2:
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
else:
item.append(increment)
else:
# 未找到,添加新条目
chat_id_list.append([target_chat_id, increment])
return chat_id_list
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
"""
在chat_id列表中查找匹配的项辅助函数
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
如果找到则返回匹配的项否则返回None
"""
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return item
return None
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
"""
检查chat_id列表中是否包含指定的chat_id
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
bool: 如果包含则返回True
"""
return _find_chat_id_item(chat_id_list, target_chat_id) is not None
def contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates)
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
"""
构建包含中心消息上下文的段落前3条+后3条使用标准的 readable builder 输出
"""
if not messages or center_index < 0 or center_index >= len(messages):
return None
context_start = max(0, center_index - 3)
context_end = min(len(messages), center_index + 1 + 3)
context_messages = messages[context_start:context_end]
if not context_messages:
return None
try:
paragraph = build_readable_messages(
messages=context_messages,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=False,
show_actions=False,
show_pic=True,
message_id_list=None,
remove_emoji_stickers=False,
pic_single=True,
)
except Exception as e:
logger.warning(f"构建上下文段落失败: {e}")
return None
paragraph = paragraph.strip()
return paragraph or None
def is_bot_message(msg: Any) -> bool:
"""判断消息是否来自机器人自身"""
if msg is None:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
platform = (
str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
.strip()
.lower()
)
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
if not platform or not user_id:
return False
platform_accounts = {}
try:
platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
except Exception:
platform_accounts = {}
bot_accounts: Dict[str, str] = {}
qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
if qq_account:
bot_accounts["qq"] = qq_account
telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
if telegram_account:
bot_accounts["telegram"] = telegram_account
for plat, account in platform_accounts.items():
if account and plat not in bot_accounts:
bot_accounts[plat] = account
bot_account = bot_accounts.get(platform)
return bool(bot_account and user_id == bot_account)

View File

@ -0,0 +1,212 @@
import time
import asyncio
from typing import List, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.jargon_miner import miner_manager
logger = get_logger("bw_learner")
class MessageRecorder:
"""
统一的消息记录器负责管理时间窗口和消息提取并将消息分发给 expression_learner jargon_miner
"""
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time()
# 提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
# 获取 expression 和 jargon 的配置参数
self._init_parameters()
# 获取 expression_learner 和 jargon_miner 实例
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
self.jargon_miner = miner_manager.get_miner(chat_id)
def _init_parameters(self) -> None:
"""初始化提取参数"""
# 获取 expression 配置
_, self.enable_expression_learning, self.enable_jargon_learning = (
global_config.expression.get_expression_config_for_chat(self.chat_id)
)
self.min_messages_for_extraction = 30
self.min_extraction_interval = 60
logger.debug(
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
f"min_messages={self.min_messages_for_extraction}, "
f"min_interval={self.min_extraction_interval}"
)
def should_trigger_extraction(self) -> bool:
"""
检查是否应该触发消息提取
Returns:
bool: 是否应该触发提取
"""
# 检查时间间隔
time_diff = time.time() - self.last_extraction_time
if time_diff < self.min_extraction_interval:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_extraction_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
return False
return True
async def extract_and_distribute(self) -> None:
"""
提取消息并分发给 expression_learner jargon_miner
"""
# 使用异步锁防止并发执行
async with self._extraction_lock:
# 在锁内检查,避免并发触发
if not self.should_trigger_extraction():
return
# 检查 chat_stream 是否存在
if not self.chat_stream:
return
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_extraction_time
extraction_end_time = time.time()
# 立即更新提取时间,防止并发触发
self.last_extraction_time = extraction_end_time
try:
# logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
# 拉取提取窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time,
)
if not messages:
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
return
# 按时间排序,确保顺序一致
messages = sorted(messages, key=lambda msg: msg.time or 0)
logger.info(
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
)
# 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用)
if self.enable_expression_learning:
asyncio.create_task(
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
)
# 触发 jargon 提取(如果启用),传递消息
# if self.enable_jargon_learning:
# asyncio.create_task(
# self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
# )
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 expression 学习使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息给 ExpressionLearner必需参数
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback
traceback.print_exc()
async def _trigger_jargon_extraction(
self, timestamp_start: float, timestamp_end: float, messages: List[Any]
) -> None:
"""
触发 jargon 提取使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息给 JargonMiner避免它重复获取
await self.jargon_miner.run_once(messages=messages)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
class MessageRecorderManager:
"""MessageRecorder 管理器"""
def __init__(self) -> None:
self._recorders: dict[str, MessageRecorder] = {}
def get_recorder(self, chat_id: str) -> MessageRecorder:
"""获取或创建指定 chat_id 的 MessageRecorder"""
if chat_id not in self._recorders:
self._recorders[chat_id] = MessageRecorder(chat_id)
return self._recorders[chat_id]
# 全局管理器实例
recorder_manager = MessageRecorderManager()
async def extract_and_distribute_messages(chat_id: str) -> None:
"""
统一的消息提取和分发入口函数
Args:
chat_id: 聊天流ID
"""
recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute()

View File

@ -0,0 +1,491 @@
import time
from typing import Tuple, Optional # 增加了 Optional
from src.common.logger_manager import get_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
logger = get_logger("pfc_action_planner")
# --- 定义 Prompt 模板 ---
# Prompt(1): 首次回复或非连续回复时的决策 Prompt
PROMPT_INITIAL_REPLY = """{persona_text}。现在你在参与一场QQ私聊请根据以下【所有信息】审慎且灵活的决策下一步行动可以回复可以倾听可以调取知识甚至可以屏蔽对方
当前对话目标
{goals_str}
{knowledge_info_str}
最近行动历史概要
{action_history_summary}
上一次行动的详细情况和结果
{last_action_context}
时间和超时提示
{time_since_last_bot_message_info}{timeout_context}
最近的对话记录(包括你已成功发送的消息 新收到的消息)
{chat_history_text}
------
可选行动类型以及解释
fetch_knowledge: 需要调取知识或记忆当需要专业知识或特定信息时选择对方若提到你不太认识的人名或实体也可以尝试选择
listening: 倾听对方发言当你认为对方话才说到一半发言明显未结束时选择
direct_reply: 直接回复对方
rethink_goal: 思考一个对话目标当你觉得目前对话需要目标或当前目标不再适用或话题卡住时选择注意私聊的环境是灵活的有可能需要经常选择
end_conversation: 结束对话对方长时间没回复或者当你觉得对话告一段落时可以选择
block_and_ignore: 更加极端的结束对话方式直接结束对话并在一段时间内无视对方所有发言屏蔽当对话让你感到十分不适或你遭到各类骚扰时选择
请以JSON格式输出你的决策
{{
"action": "选择的行动类型 (必须是上面列表中的一个)",
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的)"
}}
注意请严格按照JSON格式输出不要包含任何其他内容"""
# Prompt(2): 上一次成功回复后,决定继续发言时的决策 Prompt
PROMPT_FOLLOW_UP = """{persona_text}。现在你在参与一场QQ私聊刚刚你已经回复了对方请根据以下【所有信息】审慎且灵活的决策下一步行动可以继续发送新消息可以等待可以倾听可以调取知识甚至可以屏蔽对方
当前对话目标
{goals_str}
{knowledge_info_str}
最近行动历史概要
{action_history_summary}
上一次行动的详细情况和结果
{last_action_context}
时间和超时提示
{time_since_last_bot_message_info}{timeout_context}
最近的对话记录(包括你已成功发送的消息 新收到的消息)
{chat_history_text}
------
可选行动类型以及解释
fetch_knowledge: 需要调取知识当需要专业知识或特定信息时选择对方若提到你不太认识的人名或实体也可以尝试选择
wait: 暂时不说话留给对方交互空间等待对方回复尤其是在你刚发言后或上次发言因重复发言过多被拒时或不确定做什么时这是不错的选择
listening: 倾听对方发言虽然你刚发过言但如果对方立刻回复且明显话没说完可以选择这个
send_new_message: 发送一条新消息继续对话允许适当的追问补充深入话题或开启相关新话题**但是避免在因重复被拒后立即使用也不要在对方没有回复的情况下过多的消息轰炸或重复发言**
rethink_goal: 思考一个对话目标当你觉得目前对话需要目标或当前目标不再适用或话题卡住时选择注意私聊的环境是灵活的有可能需要经常选择
end_conversation: 结束对话对方长时间没回复或者当你觉得对话告一段落时可以选择
block_and_ignore: 更加极端的结束对话方式直接结束对话并在一段时间内无视对方所有发言屏蔽当对话让你感到十分不适或你遭到各类骚扰时选择
请以JSON格式输出你的决策
{{
"action": "选择的行动类型 (必须是上面列表中的一个)",
"reason": "选择该行动的详细原因 (必须有解释你是如何根据“上一次行动结果”、“对话记录”和自身设定人设做出合理判断的。请说明你为什么选择继续发言而不是等待,以及打算发送什么类型的新消息连续发言,必须记录已经发言了几次)"
}}
注意请严格按照JSON格式输出不要包含任何其他内容"""
# 新增Prompt(3): 决定是否在结束对话前发送告别语
PROMPT_END_DECISION = """{persona_text}。刚刚你决定结束一场 QQ 私聊。
你们之前的聊天记录
{chat_history_text}
你觉得你们的对话已经完整结束了吗有时候在对话自然结束后再说点什么可能会有点奇怪但有时也可能需要一条简短的消息来圆满结束
如果觉得确实有必要再发一条简短自然符合你人设的告别消息比如 "好,下次再聊~" "嗯,先这样吧"就输出 "yes"
如果觉得当前状态下直接结束对话更好没有必要再发消息就输出 "no"
请以 JSON 格式输出你的选择
{{
"say_bye": "yes/no",
"reason": "选择 yes 或 no 的原因和内心想法 (简要说明)"
}}
注意请严格按照 JSON 格式输出不要包含任何其他内容"""
# ActionPlanner 类定义,顶格
class ActionPlanner:
"""行动规划器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model=global_config.llm_PFC_action_planner,
temperature=global_config.llm_PFC_action_planner["temp"],
max_tokens=1500,
request_type="action_planning",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
async def plan(
self,
observation_info: ObservationInfo,
conversation_info: ConversationInfo,
last_successful_reply_action: Optional[str],
) -> Tuple[str, str]:
"""规划下一步行动
Args:
observation_info: 决策信息
conversation_info: 对话信息
last_successful_reply_action: 上一次成功的回复动作类型 ('direct_reply' 'send_new_message' None)
Returns:
Tuple[str, str]: (行动类型, 行动原因)
"""
# --- 获取 Bot 上次发言时间信息 ---
# (这部分逻辑不变)
time_since_last_bot_message_info = ""
try:
bot_id = str(global_config.BOT_QQ)
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
for i in range(len(observation_info.chat_history) - 1, -1, -1):
msg = observation_info.chat_history[i]
if not isinstance(msg, dict):
continue
sender_info = msg.get("user_info", {})
sender_id = str(sender_info.get("user_id")) if isinstance(sender_info, dict) else None
msg_time = msg.get("time")
if sender_id == bot_id and msg_time:
time_diff = time.time() - msg_time
if time_diff < 60.0:
time_since_last_bot_message_info = (
f"提示:你上一条成功发送的消息是在 {time_diff:.1f} 秒前。\n"
)
break
else:
logger.debug(
f"[私聊][{self.private_name}]Observation info chat history is empty or not available for bot time check."
)
except AttributeError:
logger.warning(
f"[私聊][{self.private_name}]ObservationInfo object might not have chat_history attribute yet for bot time check."
)
except Exception as e:
logger.warning(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
# --- 获取超时提示信息 ---
# (这部分逻辑不变)
timeout_context = ""
try:
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
last_goal_dict = conversation_info.goal_list[-1]
if isinstance(last_goal_dict, dict) and "goal" in last_goal_dict:
last_goal_text = last_goal_dict["goal"]
if isinstance(last_goal_text, str) and "分钟,思考接下来要做什么" in last_goal_text:
try:
timeout_minutes_text = last_goal_text.split("")[0].replace("你等待了", "")
timeout_context = f"重要提示:对方已经长时间({timeout_minutes_text})没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
except Exception:
timeout_context = "重要提示:对方已经长时间没有回复你的消息了(这可能代表对方繁忙/不想回复/没注意到你的消息等情况,或在对方看来本次聊天已告一段落),请基于此情况规划下一步。\n"
else:
logger.debug(
f"[私聊][{self.private_name}]Conversation info goal_list is empty or not available for timeout check."
)
except AttributeError:
logger.warning(
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet for timeout check."
)
except Exception as e:
logger.warning(f"[私聊][{self.private_name}]检查超时目标时出错: {e}")
# --- 构建通用 Prompt 参数 ---
logger.debug(
f"[私聊][{self.private_name}]开始规划行动:当前目标: {getattr(conversation_info, 'goal_list', '不可用')}"
)
# 构建对话目标 (goals_str)
goals_str = ""
try:
if hasattr(conversation_info, "goal_list") and conversation_info.goal_list:
for goal_reason in conversation_info.goal_list:
if isinstance(goal_reason, dict):
goal = goal_reason.get("goal", "目标内容缺失")
reasoning = goal_reason.get("reasoning", "没有明确原因")
else:
goal = str(goal_reason)
reasoning = "没有明确原因"
goal = str(goal) if goal is not None else "目标内容缺失"
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
if not goals_str:
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
else:
goals_str = "- 目前没有明确对话目标,请考虑设定一个。\n"
except AttributeError:
logger.warning(
f"[私聊][{self.private_name}]ConversationInfo object might not have goal_list attribute yet."
)
goals_str = "- 获取对话目标时出错。\n"
except Exception as e:
logger.error(f"[私聊][{self.private_name}]构建对话目标字符串时出错: {e}")
goals_str = "- 构建对话目标时出错。\n"
# --- 知识信息字符串构建开始 ---
knowledge_info_str = "【已获取的相关知识和记忆】\n"
try:
# 检查 conversation_info 是否有 knowledge_list 并且不为空
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
# 最多只显示最近的 5 条知识,防止 Prompt 过长
recent_knowledge = conversation_info.knowledge_list[-5:]
for i, knowledge_item in enumerate(recent_knowledge):
if isinstance(knowledge_item, dict):
query = knowledge_item.get("query", "未知查询")
knowledge = knowledge_item.get("knowledge", "无知识内容")
source = knowledge_item.get("source", "未知来源")
# 只取知识内容的前 2000 个字,避免太长
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
knowledge_info_str += (
f"{i + 1}. 关于 '{query}' 的知识 (来源: {source}):\n {knowledge_snippet}\n"
)
else:
# 处理列表里不是字典的异常情况
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
if not recent_knowledge: # 如果 knowledge_list 存在但为空
knowledge_info_str += "- 暂无相关知识和记忆。\n"
else:
# 如果 conversation_info 没有 knowledge_list 属性,或者列表为空
knowledge_info_str += "- 暂无相关知识记忆。\n"
except AttributeError:
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
knowledge_info_str += "- 获取知识列表时出错。\n"
except Exception as e:
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
knowledge_info_str += "- 处理知识列表时出错。\n"
# --- 知识信息字符串构建结束 ---
# 获取聊天历史记录 (chat_history_text)
try:
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
chat_history_text = observation_info.chat_history_str
if not chat_history_text:
chat_history_text = "还没有聊天记录。\n"
else:
chat_history_text = "还没有聊天记录。\n"
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
new_messages_list = observation_info.unprocessed_messages
new_messages_str = await build_readable_messages(
new_messages_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += (
f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
)
else:
logger.warning(
f"[私聊][{self.private_name}]ObservationInfo has new_messages_count > 0 but unprocessed_messages is empty or missing."
)
except AttributeError:
logger.warning(
f"[私聊][{self.private_name}]ObservationInfo object might be missing expected attributes for chat history."
)
chat_history_text = "获取聊天记录时出错。\n"
except Exception as e:
logger.error(f"[私聊][{self.private_name}]处理聊天记录时发生未知错误: {e}")
chat_history_text = "处理聊天记录时出错。\n"
# 构建 Persona 文本 (persona_text)
persona_text = f"你的名字是{self.name}{self.personality_info}"
# 构建行动历史和上一次行动结果 (action_history_summary, last_action_context)
# (这部分逻辑不变)
action_history_summary = "你最近执行的行动历史:\n"
last_action_context = "关于你【上一次尝试】的行动:\n"
action_history_list = []
try:
if hasattr(conversation_info, "done_action") and conversation_info.done_action:
action_history_list = conversation_info.done_action[-5:]
else:
logger.debug(f"[私聊][{self.private_name}]Conversation info done_action is empty or not available.")
except AttributeError:
logger.warning(
f"[私聊][{self.private_name}]ConversationInfo object might not have done_action attribute yet."
)
except Exception as e:
logger.error(f"[私聊][{self.private_name}]访问行动历史时出错: {e}")
if not action_history_list:
action_history_summary += "- 还没有执行过行动。\n"
last_action_context += "- 这是你规划的第一个行动。\n"
else:
for i, action_data in enumerate(action_history_list):
action_type = "未知"
plan_reason = "未知"
status = "未知"
final_reason = ""
action_time = ""
if isinstance(action_data, dict):
action_type = action_data.get("action", "未知")
plan_reason = action_data.get("plan_reason", "未知规划原因")
status = action_data.get("status", "未知")
final_reason = action_data.get("final_reason", "")
action_time = action_data.get("time", "")
elif isinstance(action_data, tuple):
# 假设旧格式兼容
if len(action_data) > 0:
action_type = action_data[0]
if len(action_data) > 1:
plan_reason = action_data[1] # 可能是规划原因或最终原因
if len(action_data) > 2:
status = action_data[2]
if status == "recall" and len(action_data) > 3:
final_reason = action_data[3]
elif status == "done" and action_type in ["direct_reply", "send_new_message"]:
plan_reason = "成功发送" # 简化显示
reason_text = f", 失败/取消原因: {final_reason}" if final_reason else ""
summary_line = f"- 时间:{action_time}, 尝试行动:'{action_type}', 状态:{status}{reason_text}"
action_history_summary += summary_line + "\n"
if i == len(action_history_list) - 1:
last_action_context += f"- 上次【规划】的行动是: '{action_type}'\n"
last_action_context += f"- 当时规划的【原因】是: {plan_reason}\n"
if status == "done":
last_action_context += "- 该行动已【成功执行】。\n"
# 记录这次成功的行动类型,供下次决策
# self.last_successful_action_type = action_type # 不在这里记录,由 conversation 控制
elif status == "recall":
last_action_context += "- 但该行动最终【未能执行/被取消】。\n"
if final_reason:
last_action_context += f"- 【重要】失败/取消的具体原因是: “{final_reason}\n"
else:
last_action_context += "- 【重要】失败/取消原因未明确记录。\n"
# self.last_successful_action_type = None # 行动失败,清除记录
else:
last_action_context += f"- 该行动当前状态: {status}\n"
# self.last_successful_action_type = None # 非完成状态,清除记录
# --- 选择 Prompt ---
if last_successful_reply_action in ["direct_reply", "send_new_message"]:
prompt_template = PROMPT_FOLLOW_UP
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_FOLLOW_UP (追问决策)")
else:
prompt_template = PROMPT_INITIAL_REPLY
logger.debug(f"[私聊][{self.private_name}]使用 PROMPT_INITIAL_REPLY (首次/非连续回复决策)")
# --- 格式化最终的 Prompt ---
prompt = prompt_template.format(
persona_text=persona_text,
goals_str=goals_str if goals_str.strip() else "- 目前没有明确对话目标,请考虑设定一个。",
action_history_summary=action_history_summary,
last_action_context=last_action_context,
time_since_last_bot_message_info=time_since_last_bot_message_info,
timeout_context=timeout_context,
chat_history_text=chat_history_text if chat_history_text.strip() else "还没有聊天记录。",
knowledge_info_str=knowledge_info_str,
)
logger.debug(f"[私聊][{self.private_name}]发送到LLM的最终提示词:\n------\n{prompt}\n------")
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"[私聊][{self.private_name}]LLM (行动规划) 原始返回内容: {content}")
# --- 初始行动规划解析 ---
success, initial_result = get_items_from_json(
content,
self.private_name,
"action",
"reason",
default_values={"action": "wait", "reason": "LLM返回格式错误或未提供原因默认等待"},
)
initial_action = initial_result.get("action", "wait")
initial_reason = initial_result.get("reason", "LLM未提供原因默认等待")
# 检查是否需要进行结束对话决策 ---
if initial_action == "end_conversation":
logger.info(f"[私聊][{self.private_name}]初步规划结束对话,进入告别决策...")
# 使用新的 PROMPT_END_DECISION
end_decision_prompt = PROMPT_END_DECISION.format(
persona_text=persona_text, # 复用之前的 persona_text
chat_history_text=chat_history_text, # 复用之前的 chat_history_text
)
logger.debug(
f"[私聊][{self.private_name}]发送到LLM的结束决策提示词:\n------\n{end_decision_prompt}\n------"
)
try:
end_content, _ = await self.llm.generate_response_async(end_decision_prompt) # 再次调用LLM
logger.debug(f"[私聊][{self.private_name}]LLM (结束决策) 原始返回内容: {end_content}")
# 解析结束决策的JSON
end_success, end_result = get_items_from_json(
end_content,
self.private_name,
"say_bye",
"reason",
default_values={"say_bye": "no", "reason": "结束决策LLM返回格式错误默认不告别"},
required_types={"say_bye": str, "reason": str}, # 明确类型
)
say_bye_decision = end_result.get("say_bye", "no").lower() # 转小写方便比较
end_decision_reason = end_result.get("reason", "未提供原因")
if end_success and say_bye_decision == "yes":
# 决定要告别,返回新的 'say_goodbye' 动作
logger.info(
f"[私聊][{self.private_name}]结束决策: yes, 准备生成告别语. 原因: {end_decision_reason}"
)
# 注意:这里的 reason 可以考虑拼接初始原因和结束决策原因,或者只用结束决策原因
final_action = "say_goodbye"
final_reason = f"决定发送告别语。决策原因: {end_decision_reason} (原结束理由: {initial_reason})"
return final_action, final_reason
else:
# 决定不告别 (包括解析失败或明确说no)
logger.info(
f"[私聊][{self.private_name}]结束决策: no, 直接结束对话. 原因: {end_decision_reason}"
)
# 返回原始的 'end_conversation' 动作
final_action = "end_conversation"
final_reason = initial_reason # 保持原始的结束理由
return final_action, final_reason
except Exception as end_e:
logger.error(f"[私聊][{self.private_name}]调用结束决策LLM或处理结果时出错: {str(end_e)}")
# 出错时,默认执行原始的结束对话
logger.warning(f"[私聊][{self.private_name}]结束决策出错,将按原计划执行 end_conversation")
return "end_conversation", initial_reason # 返回原始动作和原因
else:
action = initial_action
reason = initial_reason
# 验证action类型 (保持不变)
valid_actions = [
"direct_reply",
"send_new_message",
"fetch_knowledge",
"wait",
"listening",
"rethink_goal",
"end_conversation", # 仍然需要验证,因为可能从上面决策后返回
"block_and_ignore",
"say_goodbye", # 也要验证这个新动作
]
if action not in valid_actions:
logger.warning(f"[私聊][{self.private_name}]LLM返回了未知的行动类型: '{action}',强制改为 wait")
reason = f"(原始行动'{action}'无效已强制改为wait) {reason}"
action = "wait"
logger.info(f"[私聊][{self.private_name}]规划的行动: {action}")
logger.info(f"[私聊][{self.private_name}]行动原因: {reason}")
return action, reason
except Exception as e:
# 外层异常处理保持不变
logger.error(f"[私聊][{self.private_name}]规划行动时调用 LLM 或处理结果出错: {str(e)}")
return "wait", f"行动规划处理中发生错误,暂时等待: {str(e)}"

View File

@ -0,0 +1,379 @@
import time
import asyncio
import traceback
from typing import Optional, Dict, Any, List
from src.common.logger import get_module_logger
from maim_message import UserInfo
from ...config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("chat_observer")
class ChatObserver:
"""聊天状态观察器"""
# 类级别的实例管理
_instances: Dict[str, "ChatObserver"] = {}
@classmethod
def get_instance(cls, stream_id: str, private_name: str) -> "ChatObserver":
"""获取或创建观察器实例
Args:
stream_id: 聊天流ID
private_name: 私聊名称
Returns:
ChatObserver: 观察器实例
"""
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id, private_name)
return cls._instances[stream_id]
def __init__(self, stream_id: str, private_name: str):
"""初始化观察器
Args:
stream_id: 聊天流ID
"""
self.last_check_time = None
self.last_bot_speak_time = None
self.last_user_speak_time = None
if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
self.stream_id = stream_id
self.private_name = private_name
self.message_storage = MongoDBMessageStorage()
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID
self.last_message_time: float = time.time()
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 运行状态
self._running: bool = False
self._task: Optional[asyncio.Task] = None
self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件
# 通知管理器
self.notification_manager = NotificationManager()
# 冷场检查配置
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
self.last_cold_chat_check: float = time.time()
self.is_cold_chat_state: bool = False
self.update_event = asyncio.Event()
self.update_interval = 2 # 更新间隔(秒)
self.message_cache = []
self.update_running = False
async def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息
Returns:
bool: 是否有新消息
"""
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
if new_message_exists:
logger.debug(f"[私聊][{self.private_name}]发现新消息")
self.last_check_time = time.time()
return new_message_exists
async def _add_message_to_history(self, message: Dict[str, Any]):
"""添加消息到历史记录并发送通知
Args:
message: 消息数据
"""
try:
# 发送新消息通知
notification = create_new_message_notification(
sender="chat_observer", target="observation_info", message=message
)
# print(self.notification_manager)
await self.notification_manager.send_notification(notification)
except Exception as e:
logger.error(f"[私聊][{self.private_name}]添加消息到历史记录时出错: {e}")
print(traceback.format_exc())
# 检查并更新冷场状态
await self._check_cold_chat()
async def _check_cold_chat(self):
"""检查是否处于冷场状态并发送通知"""
current_time = time.time()
# 每10秒检查一次冷场状态
if current_time - self.last_cold_chat_check < 10:
return
self.last_cold_chat_check = current_time
# 判断是否冷场
is_cold = (
True
if self.last_message_time is None
else (current_time - self.last_message_time) > self.cold_chat_threshold
)
# 如果冷场状态发生变化,发送通知
if is_cold != self.is_cold_chat_state:
self.is_cold_chat_state = is_cold
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
await self.notification_manager.send_notification(notification)
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
Args:
time_point: 时间戳
Returns:
bool: 是否有新消息
"""
if self.last_message_time is None:
logger.debug(f"[私聊][{self.private_name}]没有最后消息时间,返回 False")
return False
has_new = self.last_message_time > time_point
logger.debug(
f"[私聊][{self.private_name}]判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}"
)
return has_new
def get_message_history(
self,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
limit: Optional[int] = None,
user_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""获取消息历史
Args:
start_time: 开始时间戳
end_time: 结束时间戳
limit: 限制返回消息数量
user_id: 指定用户ID
Returns:
List[Dict[str, Any]]: 消息列表
"""
filtered_messages = self.message_history
if start_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
if end_time is not None:
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
if user_id is not None:
filtered_messages = [
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
]
if limit is not None:
filtered_messages = filtered_messages[-limit:]
return filtered_messages
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息
Returns:
List[Dict[str, Any]]: 新消息列表
"""
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time)
if new_messages:
self.last_message_read = new_messages[-1]
self.last_message_time = new_messages[-1]["time"]
# print(f"获取数据库中找到的新消息: {new_messages}")
return new_messages
async def _fetch_new_messages_before(self, time_point: float) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
time_point: 时间戳
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
logger.debug(f"[私聊][{self.private_name}]获取指定时间点111之前的消息: {new_messages}")
return new_messages
"""主要观察循环"""
async def _update_loop(self):
"""更新循环"""
# try:
# start_time = time.time()
# messages = await self._fetch_new_messages_before(start_time)
# for message in messages:
# await self._add_message_to_history(message)
# logger.debug(f"[私聊][{self.private_name}]缓冲消息: {messages}")
# except Exception as e:
# logger.error(f"[私聊][{self.private_name}]缓冲消息出错: {e}")
while self._running:
try:
# 等待事件或超时1秒
try:
# print("等待事件")
await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError:
# print("超时")
pass # 超时后也执行一次检查
self._update_event.clear() # 重置触发事件
self._update_complete.clear() # 重置完成事件
# 获取新消息
new_messages = await self._fetch_new_messages()
if new_messages:
# 处理新消息
for message in new_messages:
await self._add_message_to_history(message)
# 设置完成事件
self._update_complete.set()
except Exception as e:
logger.error(f"[私聊][{self.private_name}]更新循环出错: {e}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
self._update_complete.set() # 即使出错也要设置完成事件
def trigger_update(self):
"""触发一次立即更新"""
self._update_event.set()
async def wait_for_update(self, timeout: float = 5.0) -> bool:
"""等待更新完成
Args:
timeout: 超时时间
Returns:
bool: 是否成功完成更新False表示超时
"""
try:
await asyncio.wait_for(self._update_complete.wait(), timeout=timeout)
return True
except asyncio.TimeoutError:
logger.warning(f"[私聊][{self.private_name}]等待更新完成超时({timeout}秒)")
return False
def start(self):
"""启动观察器"""
if self._running:
return
self._running = True
self._task = asyncio.create_task(self._update_loop())
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} started")
def stop(self):
"""停止观察器"""
self._running = False
self._update_event.set() # 设置事件以解除等待
self._update_complete.set() # 设置完成事件以解除等待
if self._task:
self._task.cancel()
logger.debug(f"[私聊][{self.private_name}]ChatObserver for {self.stream_id} stopped")
async def process_chat_history(self, messages: list):
"""处理聊天历史
Args:
messages: 消息列表
"""
self.update_check_time()
for msg in messages:
try:
user_info = UserInfo.from_dict(msg.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
self.update_bot_speak_time(msg["time"])
else:
self.update_user_speak_time(msg["time"])
except Exception as e:
logger.warning(f"[私聊][{self.private_name}]处理消息时间时出错: {e}")
continue
def update_check_time(self):
"""更新查看时间"""
self.last_check_time = time.time()
def update_bot_speak_time(self, speak_time: Optional[float] = None):
"""更新机器人说话时间"""
self.last_bot_speak_time = speak_time or time.time()
def update_user_speak_time(self, speak_time: Optional[float] = None):
"""更新用户说话时间"""
self.last_user_speak_time = speak_time or time.time()
def get_time_info(self) -> str:
"""获取时间信息文本"""
current_time = time.time()
time_info = ""
if self.last_bot_speak_time:
bot_speak_ago = current_time - self.last_bot_speak_time
time_info += f"\n距离你上次发言已经过去了{int(bot_speak_ago)}"
if self.last_user_speak_time:
user_speak_ago = current_time - self.last_user_speak_time
time_info += f"\n距离对方上次发言已经过去了{int(user_speak_ago)}"
return time_info
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史
Args:
limit: 获取的最大消息数量默认50
Returns:
List[Dict[str, Any]]: 缓存的消息历史列表
"""
return self.message_cache[-limit:]
def get_last_message(self) -> Optional[Dict[str, Any]]:
"""获取最后一条消息
Returns:
Optional[Dict[str, Any]]: 最后一条消息如果没有则返回None
"""
if not self.message_cache:
return None
return self.message_cache[-1]
def __str__(self):
return f"ChatObserver for {self.stream_id}"

View File

@ -0,0 +1,290 @@
from enum import Enum, auto
from typing import Optional, Dict, Any, List, Set
from dataclasses import dataclass
from datetime import datetime
from abc import ABC, abstractmethod
class ChatState(Enum):
"""聊天状态枚举"""
NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息
COLD_CHAT = auto() # 冷场状态
ACTIVE_CHAT = auto() # 活跃状态
BOT_SPEAKING = auto() # 机器人正在说话
USER_SPEAKING = auto() # 用户正在说话
SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态
class NotificationType(Enum):
"""通知类型枚举"""
NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知
ACTIVE_CHAT = auto() # 活跃通知
BOT_SPEAKING = auto() # 机器人说话通知
USER_SPEAKING = auto() # 用户说话通知
MESSAGE_DELETED = auto() # 消息删除通知
USER_JOINED = auto() # 用户加入通知
USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知
@dataclass
class ChatStateInfo:
"""聊天状态信息"""
state: ChatState
last_message_time: Optional[float] = None
last_message_content: Optional[str] = None
last_speaker: Optional[str] = None
message_count: int = 0
cold_duration: float = 0.0 # 冷场持续时间(秒)
active_duration: float = 0.0 # 活跃持续时间(秒)
@dataclass
class Notification:
"""通知基类"""
type: NotificationType
timestamp: float
sender: str # 发送者标识
target: str # 接收者标识
data: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data}
@dataclass
class StateNotification(Notification):
"""持续状态通知"""
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
base_dict = super().to_dict()
base_dict["is_active"] = self.is_active
return base_dict
class NotificationHandler(ABC):
"""通知处理器接口"""
@abstractmethod
async def handle_notification(self, notification: Notification):
"""处理通知"""
pass
class NotificationManager:
"""通知管理器"""
def __init__(self):
# 按接收者和通知类型存储处理器
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
self._active_states: Set[NotificationType] = set()
self._notification_history: List[Notification] = []
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注册通知处理器
Args:
target: 接收者标识例如"pfc"
notification_type: 要处理的通知类型
handler: 处理器实例
"""
if target not in self._handlers:
self._handlers[target] = {}
if notification_type not in self._handlers[target]:
self._handlers[target][notification_type] = []
# print(self._handlers[target][notification_type])
self._handlers[target][notification_type].append(handler)
# print(self._handlers[target][notification_type])
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注销通知处理器
Args:
target: 接收者标识
notification_type: 通知类型
handler: 要注销的处理器实例
"""
if target in self._handlers and notification_type in self._handlers[target]:
handlers = self._handlers[target][notification_type]
if handler in handlers:
handlers.remove(handler)
# 如果该类型的处理器列表为空,删除该类型
if not handlers:
del self._handlers[target][notification_type]
# 如果该目标没有任何处理器,删除该目标
if not self._handlers[target]:
del self._handlers[target]
async def send_notification(self, notification: Notification):
"""发送通知"""
self._notification_history.append(notification)
# 如果是状态通知,更新活跃状态
if isinstance(notification, StateNotification):
if notification.is_active:
self._active_states.add(notification.type)
else:
self._active_states.discard(notification.type)
# 调用目标接收者的处理器
target = notification.target
if target in self._handlers:
handlers = self._handlers[target].get(notification.type, [])
# print(handlers)
for handler in handlers:
# print(f"调用处理器: {handler}")
await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]:
"""获取当前活跃的状态"""
return self._active_states.copy()
def is_state_active(self, state_type: NotificationType) -> bool:
"""检查特定状态是否活跃"""
return state_type in self._active_states
def get_notification_history(
self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None
) -> List[Notification]:
"""获取通知历史
Args:
sender: 过滤特定发送者的通知
target: 过滤特定接收者的通知
limit: 限制返回数量
"""
history = self._notification_history
if sender:
history = [n for n in history if n.sender == sender]
if target:
history = [n for n in history if n.target == target]
if limit is not None:
history = history[-limit:]
return history
def __str__(self):
str = ""
for target, handlers in self._handlers.items():
for notification_type, handler_list in handlers.items():
str += f"NotificationManager for {target} {notification_type} {handler_list}"
return str
# 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
"""创建新消息通知"""
return Notification(
type=NotificationType.NEW_MESSAGE,
timestamp=datetime.now().timestamp(),
sender=sender,
target=target,
data={
"message_id": message.get("message_id"),
"processed_plain_text": message.get("processed_plain_text"),
"detailed_plain_text": message.get("detailed_plain_text"),
"user_info": message.get("user_info"),
"time": message.get("time"),
},
)
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
"""创建冷场状态通知"""
return StateNotification(
type=NotificationType.COLD_CHAT,
timestamp=datetime.now().timestamp(),
sender=sender,
target=target,
data={"is_cold": is_cold},
is_active=is_cold,
)
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
"""创建活跃状态通知"""
return StateNotification(
type=NotificationType.ACTIVE_CHAT,
timestamp=datetime.now().timestamp(),
sender=sender,
target=target,
data={"is_active": is_active},
is_active=is_active,
)
class ChatStateManager:
"""聊天状态管理器"""
def __init__(self):
self.current_state = ChatState.NORMAL
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
self.state_history: list[ChatStateInfo] = []
def update_state(self, new_state: ChatState, **kwargs):
"""更新聊天状态
Args:
new_state: 新的状态
**kwargs: 其他状态信息
"""
self.current_state = new_state
self.state_info.state = new_state
# 更新其他状态信息
for key, value in kwargs.items():
if hasattr(self.state_info, key):
setattr(self.state_info, key, value)
# 记录状态历史
self.state_history.append(self.state_info)
def get_current_state_info(self) -> ChatStateInfo:
"""获取当前状态信息"""
return self.state_info
def get_state_history(self) -> list[ChatStateInfo]:
"""获取状态历史"""
return self.state_history
def is_cold_chat(self, threshold: float = 60.0) -> bool:
"""判断是否处于冷场状态
Args:
threshold: 冷场阈值
Returns:
bool: 是否冷场
"""
if not self.state_info.last_message_time:
return True
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) > threshold
def is_active_chat(self, threshold: float = 5.0) -> bool:
"""判断是否处于活跃状态
Args:
threshold: 活跃阈值
Returns:
bool: 是否活跃
"""
if not self.state_info.last_message_time:
return False
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold

View File

@ -0,0 +1,701 @@
import time
import asyncio
import datetime
# from .message_storage import MongoDBMessageStorage
from src.plugins.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
# from ...config.config import global_config
from typing import Dict, Any, Optional
from ..chat.message import Message
from .pfc_types import ConversationState
from .pfc import ChatObserver, GoalAnalyzer
from .message_sender import DirectMessageSender
from src.common.logger_manager import get_logger
from .action_planner import ActionPlanner
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo # 确保导入 ConversationInfo
from .reply_generator import ReplyGenerator
from ..chat.chat_stream import ChatStream
from maim_message import UserInfo
from src.plugins.chat.chat_stream import chat_manager
from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter
import traceback
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("pfc")
class Conversation:
"""对话类,负责管理单个对话的状态和行为"""
def __init__(self, stream_id: str, private_name: str):
"""初始化对话实例
Args:
stream_id: 聊天流ID
"""
self.stream_id = stream_id
self.private_name = private_name
self.state = ConversationState.INIT
self.should_continue = False
self.ignore_until_timestamp: Optional[float] = None
# 回复相关
self.generated_reply = ""
async def _initialize(self):
"""初始化实例,注册所有组件"""
try:
self.action_planner = ActionPlanner(self.stream_id, self.private_name)
self.goal_analyzer = GoalAnalyzer(self.stream_id, self.private_name)
self.reply_generator = ReplyGenerator(self.stream_id, self.private_name)
self.knowledge_fetcher = KnowledgeFetcher(self.private_name)
self.waiter = Waiter(self.stream_id, self.private_name)
self.direct_sender = DirectMessageSender(self.private_name)
# 获取聊天流信息
self.chat_stream = chat_manager.get_stream(self.stream_id)
self.stop_action_planner = False
except Exception as e:
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册运行组件失败: {e}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
raise
try:
# 决策所需要的信息,包括自身自信和观察信息两部分
# 注册观察器和观测信息
self.chat_observer = ChatObserver.get_instance(self.stream_id, self.private_name)
self.chat_observer.start()
self.observation_info = ObservationInfo(self.private_name)
self.observation_info.bind_to_chat_observer(self.chat_observer)
# print(self.chat_observer.get_cached_messages(limit=)
self.conversation_info = ConversationInfo()
except Exception as e:
logger.error(f"[私聊][{self.private_name}]初始化对话实例:注册信息组件失败: {e}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
raise
try:
logger.info(f"[私聊][{self.private_name}]为 {self.stream_id} 加载初始聊天记录...")
initial_messages = get_raw_msg_before_timestamp_with_chat( #
chat_id=self.stream_id,
timestamp=time.time(),
limit=30, # 加载最近30条作为初始上下文可以调整
)
chat_talking_prompt = await build_readable_messages(
initial_messages,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
if initial_messages:
# 将加载的消息填充到 ObservationInfo 的 chat_history
self.observation_info.chat_history = initial_messages
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
self.observation_info.chat_history_count = len(initial_messages)
# 更新 ObservationInfo 中的时间戳等信息
last_msg = initial_messages[-1]
self.observation_info.last_message_time = last_msg.get("time")
last_user_info = UserInfo.from_dict(last_msg.get("user_info", {}))
self.observation_info.last_message_sender = last_user_info.user_id
self.observation_info.last_message_content = last_msg.get("processed_plain_text", "")
logger.info(
f"[私聊][{self.private_name}]成功加载 {len(initial_messages)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
)
# 让 ChatObserver 从加载的最后一条消息之后开始同步
self.chat_observer.last_message_time = self.observation_info.last_message_time
self.chat_observer.last_message_read = last_msg # 更新 observer 的最后读取记录
else:
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
except Exception as load_err:
logger.error(f"[私聊][{self.private_name}]加载初始聊天记录时出错: {load_err}")
# 出错也要继续,只是没有历史记录而已
# 组件准备完成,启动该论对话
self.should_continue = True
asyncio.create_task(self.start())
async def start(self):
"""开始对话流程"""
try:
logger.info(f"[私聊][{self.private_name}]对话系统启动中...")
asyncio.create_task(self._plan_and_action_loop())
except Exception as e:
logger.error(f"[私聊][{self.private_name}]启动对话系统失败: {e}")
raise
async def _plan_and_action_loop(self):
"""思考步PFC核心循环模块"""
while self.should_continue:
# 忽略逻辑
if self.ignore_until_timestamp and time.time() < self.ignore_until_timestamp:
await asyncio.sleep(30)
continue
elif self.ignore_until_timestamp and time.time() >= self.ignore_until_timestamp:
logger.info(f"[私聊][{self.private_name}]忽略时间已到 {self.stream_id},准备结束对话。")
self.ignore_until_timestamp = None
self.should_continue = False
continue
try:
# --- 在规划前记录当前新消息数量 ---
initial_new_message_count = 0
if hasattr(self.observation_info, "new_messages_count"):
initial_new_message_count = self.observation_info.new_messages_count + 1 # 算上麦麦自己发的那一条
else:
logger.warning(
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' before planning."
)
# --- 调用 Action Planner ---
# 传递 self.conversation_info.last_successful_reply_action
action, reason = await self.action_planner.plan(
self.observation_info, self.conversation_info, self.conversation_info.last_successful_reply_action
)
# --- 规划后检查是否有 *更多* 新消息到达 ---
current_new_message_count = 0
if hasattr(self.observation_info, "new_messages_count"):
current_new_message_count = self.observation_info.new_messages_count
else:
logger.warning(
f"[私聊][{self.private_name}]ObservationInfo missing 'new_messages_count' after planning."
)
if current_new_message_count > initial_new_message_count + 2:
logger.info(
f"[私聊][{self.private_name}]规划期间发现新增消息 ({initial_new_message_count} -> {current_new_message_count}),跳过本次行动,重新规划"
)
# 如果规划期间有新消息,也应该重置上次回复状态,因为现在要响应新消息了
self.conversation_info.last_successful_reply_action = None
await asyncio.sleep(0.1)
continue
# 包含 send_new_message
if initial_new_message_count > 0 and action in ["direct_reply", "send_new_message"]:
if hasattr(self.observation_info, "clear_unprocessed_messages"):
logger.debug(
f"[私聊][{self.private_name}]准备执行 {action},清理 {initial_new_message_count} 条规划时已知的新消息。"
)
await self.observation_info.clear_unprocessed_messages()
if hasattr(self.observation_info, "new_messages_count"):
self.observation_info.new_messages_count = 0
else:
logger.error(
f"[私聊][{self.private_name}]无法清理未处理消息: ObservationInfo 缺少 clear_unprocessed_messages 方法!"
)
await self._handle_action(action, reason, self.observation_info, self.conversation_info)
# 检查是否需要结束对话 (逻辑不变)
goal_ended = False
if hasattr(self.conversation_info, "goal_list") and self.conversation_info.goal_list:
for goal_item in self.conversation_info.goal_list:
if isinstance(goal_item, dict):
current_goal = goal_item.get("goal")
if current_goal == "结束对话":
goal_ended = True
break
if goal_ended:
self.should_continue = False
logger.info(f"[私聊][{self.private_name}]检测到'结束对话'目标,停止循环。")
except Exception as loop_err:
logger.error(f"[私聊][{self.private_name}]PFC主循环出错: {loop_err}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
await asyncio.sleep(1)
if self.should_continue:
await asyncio.sleep(0.1)
logger.info(f"[私聊][{self.private_name}]PFC 循环结束 for stream_id: {self.stream_id}")
def _check_new_messages_after_planning(self):
"""检查在规划后是否有新消息"""
# 检查 ObservationInfo 是否已初始化并且有 new_messages_count 属性
if not hasattr(self, "observation_info") or not hasattr(self.observation_info, "new_messages_count"):
logger.warning(
f"[私聊][{self.private_name}]ObservationInfo 未初始化或缺少 'new_messages_count' 属性,无法检查新消息。"
)
return False # 或者根据需要抛出错误
if self.observation_info.new_messages_count > 2:
logger.info(
f"[私聊][{self.private_name}]生成/执行动作期间收到 {self.observation_info.new_messages_count} 条新消息,取消当前动作并重新规划"
)
# 如果有新消息,也应该重置上次回复状态
if hasattr(self, "conversation_info"): # 确保 conversation_info 已初始化
self.conversation_info.last_successful_reply_action = None
else:
logger.warning(
f"[私聊][{self.private_name}]ConversationInfo 未初始化,无法重置 last_successful_reply_action。"
)
return True
return False
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象"""
try:
# 尝试从 msg_dict 直接获取 chat_stream如果失败则从全局 chat_manager 获取
chat_info = msg_dict.get("chat_info")
if chat_info and isinstance(chat_info, dict):
chat_stream = ChatStream.from_dict(chat_info)
elif self.chat_stream: # 使用实例变量中的 chat_stream
chat_stream = self.chat_stream
else: # Fallback: 尝试从 manager 获取 (可能需要 stream_id)
chat_stream = chat_manager.get_stream(self.stream_id)
if not chat_stream:
raise ValueError(f"无法确定 ChatStream for stream_id {self.stream_id}")
user_info = UserInfo.from_dict(msg_dict.get("user_info", {}))
return Message(
message_id=msg_dict.get("message_id", f"gen_{time.time()}"), # 提供默认 ID
chat_stream=chat_stream, # 使用确定的 chat_stream
time=msg_dict.get("time", time.time()), # 提供默认时间
user_info=user_info,
processed_plain_text=msg_dict.get("processed_plain_text", ""),
detailed_plain_text=msg_dict.get("detailed_plain_text", ""),
)
except Exception as e:
logger.warning(f"[私聊][{self.private_name}]转换消息时出错: {e}")
# 可以选择返回 None 或重新抛出异常,这里选择重新抛出以指示问题
raise ValueError(f"无法将字典转换为 Message 对象: {e}") from e
async def _handle_action(
self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo
):
"""处理规划的行动"""
logger.debug(f"[私聊][{self.private_name}]执行行动: {action}, 原因: {reason}")
# 记录action历史 (逻辑不变)
current_action_record = {
"action": action,
"plan_reason": reason,
"status": "start",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
"final_reason": None,
}
# 确保 done_action 列表存在
if not hasattr(conversation_info, "done_action"):
conversation_info.done_action = []
conversation_info.done_action.append(current_action_record)
action_index = len(conversation_info.done_action) - 1
action_successful = False # 用于标记动作是否成功完成
# --- 根据不同的 action 执行 ---
# send_new_message 失败后执行 wait
if action == "send_new_message":
max_reply_attempts = 3
reply_attempt_count = 0
is_suitable = False
need_replan = False
check_reason = "未进行尝试"
final_reply_to_send = ""
while reply_attempt_count < max_reply_attempts and not is_suitable:
reply_attempt_count += 1
logger.info(
f"[私聊][{self.private_name}]尝试生成追问回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
)
self.state = ConversationState.GENERATING
# 1. 生成回复 (调用 generate 时传入 action_type)
self.generated_reply = await self.reply_generator.generate(
observation_info, conversation_info, action_type="send_new_message"
)
logger.info(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的追问回复: {self.generated_reply}"
)
# 2. 检查回复 (逻辑不变)
self.state = ConversationState.CHECKING
try:
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
reply=self.generated_reply,
goal=current_goal_str,
chat_history=observation_info.chat_history,
chat_history_str=observation_info.chat_history_str,
retry_count=reply_attempt_count - 1,
)
logger.info(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
)
if is_suitable:
final_reply_to_send = self.generated_reply
break
elif need_replan:
logger.warning(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次追问检查建议重新规划,停止尝试。原因: {check_reason}"
)
break
except Exception as check_err:
logger.error(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (追问) 时出错: {check_err}"
)
check_reason = f"{reply_attempt_count} 次检查过程出错: {check_err}"
break
# 循环结束,处理最终结果
if is_suitable:
# 检查是否有新消息
if self._check_new_messages_after_planning():
logger.info(f"[私聊][{self.private_name}]生成追问回复期间收到新消息,取消发送,重新规划行动")
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"有新消息,取消发送追问: {final_reply_to_send}"}
)
return # 直接返回,重新规划
# 发送合适的回复
self.generated_reply = final_reply_to_send
# --- 在这里调用 _send_reply ---
await self._send_reply() # <--- 调用恢复后的函数
# 更新状态: 标记上次成功是 send_new_message
self.conversation_info.last_successful_reply_action = "send_new_message"
action_successful = True # 标记动作成功
elif need_replan:
# 打回动作决策
logger.warning(
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,追问回复决定打回动作决策。打回原因: {check_reason}"
)
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后打回: {check_reason}"}
)
else:
# 追问失败
logger.warning(
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的追问回复。最终原因: {check_reason}"
)
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"追问尝试{reply_attempt_count}次后失败: {check_reason}"}
)
# 重置状态: 追问失败,下次用初始 prompt
self.conversation_info.last_successful_reply_action = None
# 执行 Wait 操作
logger.info(f"[私聊][{self.private_name}]由于无法生成合适追问回复,执行 'wait' 操作...")
self.state = ConversationState.WAITING
await self.waiter.wait(self.conversation_info)
wait_action_record = {
"action": "wait",
"plan_reason": "因 send_new_message 多次尝试失败而执行的后备等待",
"status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
"final_reason": None,
}
conversation_info.done_action.append(wait_action_record)
elif action == "direct_reply":
max_reply_attempts = 3
reply_attempt_count = 0
is_suitable = False
need_replan = False
check_reason = "未进行尝试"
final_reply_to_send = ""
while reply_attempt_count < max_reply_attempts and not is_suitable:
reply_attempt_count += 1
logger.info(
f"[私聊][{self.private_name}]尝试生成首次回复 (第 {reply_attempt_count}/{max_reply_attempts} 次)..."
)
self.state = ConversationState.GENERATING
# 1. 生成回复
self.generated_reply = await self.reply_generator.generate(
observation_info, conversation_info, action_type="direct_reply"
)
logger.info(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次生成的首次回复: {self.generated_reply}"
)
# 2. 检查回复
self.state = ConversationState.CHECKING
try:
current_goal_str = conversation_info.goal_list[0]["goal"] if conversation_info.goal_list else ""
is_suitable, check_reason, need_replan = await self.reply_generator.check_reply(
reply=self.generated_reply,
goal=current_goal_str,
chat_history=observation_info.chat_history,
chat_history_str=observation_info.chat_history_str,
retry_count=reply_attempt_count - 1,
)
logger.info(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查结果: 合适={is_suitable}, 原因='{check_reason}', 需重新规划={need_replan}"
)
if is_suitable:
final_reply_to_send = self.generated_reply
break
elif need_replan:
logger.warning(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次首次回复检查建议重新规划,停止尝试。原因: {check_reason}"
)
break
except Exception as check_err:
logger.error(
f"[私聊][{self.private_name}]第 {reply_attempt_count} 次调用 ReplyChecker (首次回复) 时出错: {check_err}"
)
check_reason = f"{reply_attempt_count} 次检查过程出错: {check_err}"
break
# 循环结束,处理最终结果
if is_suitable:
# 检查是否有新消息
if self._check_new_messages_after_planning():
logger.info(f"[私聊][{self.private_name}]生成首次回复期间收到新消息,取消发送,重新规划行动")
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"有新消息,取消发送首次回复: {final_reply_to_send}"}
)
return # 直接返回,重新规划
# 发送合适的回复
self.generated_reply = final_reply_to_send
# --- 在这里调用 _send_reply ---
await self._send_reply() # <--- 调用恢复后的函数
# 更新状态: 标记上次成功是 direct_reply
self.conversation_info.last_successful_reply_action = "direct_reply"
action_successful = True # 标记动作成功
elif need_replan:
# 打回动作决策
logger.warning(
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,首次回复决定打回动作决策。打回原因: {check_reason}"
)
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后打回: {check_reason}"}
)
else:
# 首次回复失败
logger.warning(
f"[私聊][{self.private_name}]经过 {reply_attempt_count} 次尝试,未能生成合适的首次回复。最终原因: {check_reason}"
)
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"首次回复尝试{reply_attempt_count}次后失败: {check_reason}"}
)
# 重置状态: 首次回复失败,下次还是用初始 prompt
self.conversation_info.last_successful_reply_action = None
# 执行 Wait 操作 (保持原有逻辑)
logger.info(f"[私聊][{self.private_name}]由于无法生成合适首次回复,执行 'wait' 操作...")
self.state = ConversationState.WAITING
await self.waiter.wait(self.conversation_info)
wait_action_record = {
"action": "wait",
"plan_reason": "因 direct_reply 多次尝试失败而执行的后备等待",
"status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
"final_reason": None,
}
conversation_info.done_action.append(wait_action_record)
elif action == "fetch_knowledge":
self.state = ConversationState.FETCHING
knowledge_query = reason
try:
# 检查 knowledge_fetcher 是否存在
if not hasattr(self, "knowledge_fetcher"):
logger.error(f"[私聊][{self.private_name}]KnowledgeFetcher 未初始化,无法获取知识。")
raise AttributeError("KnowledgeFetcher not initialized")
knowledge, source = await self.knowledge_fetcher.fetch(knowledge_query, observation_info.chat_history)
logger.info(f"[私聊][{self.private_name}]获取到知识: {knowledge[:100]}..., 来源: {source}")
if knowledge:
# 确保 knowledge_list 存在
if not hasattr(conversation_info, "knowledge_list"):
conversation_info.knowledge_list = []
conversation_info.knowledge_list.append(
{"query": knowledge_query, "knowledge": knowledge, "source": source}
)
action_successful = True
except Exception as fetch_err:
logger.error(f"[私聊][{self.private_name}]获取知识时出错: {str(fetch_err)}")
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"获取知识失败: {str(fetch_err)}"}
)
self.conversation_info.last_successful_reply_action = None # 重置状态
elif action == "rethink_goal":
self.state = ConversationState.RETHINKING
try:
# 检查 goal_analyzer 是否存在
if not hasattr(self, "goal_analyzer"):
logger.error(f"[私聊][{self.private_name}]GoalAnalyzer 未初始化,无法重新思考目标。")
raise AttributeError("GoalAnalyzer not initialized")
await self.goal_analyzer.analyze_goal(conversation_info, observation_info)
action_successful = True
except Exception as rethink_err:
logger.error(f"[私聊][{self.private_name}]重新思考目标时出错: {rethink_err}")
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"重新思考目标失败: {rethink_err}"}
)
self.conversation_info.last_successful_reply_action = None # 重置状态
elif action == "listening":
self.state = ConversationState.LISTENING
logger.info(f"[私聊][{self.private_name}]倾听对方发言...")
try:
# 检查 waiter 是否存在
if not hasattr(self, "waiter"):
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法倾听。")
raise AttributeError("Waiter not initialized")
await self.waiter.wait_listening(conversation_info)
action_successful = True # Listening 完成就算成功
except Exception as listen_err:
logger.error(f"[私聊][{self.private_name}]倾听时出错: {listen_err}")
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"倾听失败: {listen_err}"}
)
self.conversation_info.last_successful_reply_action = None # 重置状态
elif action == "say_goodbye":
self.state = ConversationState.GENERATING # 也可以定义一个新的状态,如 ENDING
logger.info(f"[私聊][{self.private_name}]执行行动: 生成并发送告别语...")
try:
# 1. 生成告别语 (使用 'say_goodbye' action_type)
self.generated_reply = await self.reply_generator.generate(
observation_info, conversation_info, action_type="say_goodbye"
)
logger.info(f"[私聊][{self.private_name}]生成的告别语: {self.generated_reply}")
# 2. 直接发送告别语 (不经过检查)
if self.generated_reply: # 确保生成了内容
await self._send_reply() # 调用发送方法
# 发送成功后,标记动作成功
action_successful = True
logger.info(f"[私聊][{self.private_name}]告别语已发送。")
else:
logger.warning(f"[私聊][{self.private_name}]未能生成告别语内容,无法发送。")
action_successful = False # 标记动作失败
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": "未能生成告别语内容"}
)
# 3. 无论是否发送成功,都准备结束对话
self.should_continue = False
logger.info(f"[私聊][{self.private_name}]发送告别语流程结束,即将停止对话实例。")
except Exception as goodbye_err:
logger.error(f"[私聊][{self.private_name}]生成或发送告别语时出错: {goodbye_err}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
# 即使出错,也结束对话
self.should_continue = False
action_successful = False # 标记动作失败
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"生成或发送告别语时出错: {goodbye_err}"}
)
elif action == "end_conversation":
# 这个分支现在只会在 action_planner 最终决定不告别时被调用
self.should_continue = False
logger.info(f"[私聊][{self.private_name}]收到最终结束指令,停止对话...")
action_successful = True # 标记这个指令本身是成功的
elif action == "block_and_ignore":
logger.info(f"[私聊][{self.private_name}]不想再理你了...")
ignore_duration_seconds = 10 * 60
self.ignore_until_timestamp = time.time() + ignore_duration_seconds
logger.info(
f"[私聊][{self.private_name}]将忽略此对话直到: {datetime.datetime.fromtimestamp(self.ignore_until_timestamp)}"
)
self.state = ConversationState.IGNORED
action_successful = True # 标记动作成功
else: # 对应 'wait' 动作
self.state = ConversationState.WAITING
logger.info(f"[私聊][{self.private_name}]等待更多信息...")
try:
# 检查 waiter 是否存在
if not hasattr(self, "waiter"):
logger.error(f"[私聊][{self.private_name}]Waiter 未初始化,无法等待。")
raise AttributeError("Waiter not initialized")
_timeout_occurred = await self.waiter.wait(self.conversation_info)
action_successful = True # Wait 完成就算成功
except Exception as wait_err:
logger.error(f"[私聊][{self.private_name}]等待时出错: {wait_err}")
conversation_info.done_action[action_index].update(
{"status": "recall", "final_reason": f"等待失败: {wait_err}"}
)
self.conversation_info.last_successful_reply_action = None # 重置状态
# --- 更新 Action History 状态 ---
# 只有当动作本身成功时,才更新状态为 done
if action_successful:
conversation_info.done_action[action_index].update(
{
"status": "done",
"time": datetime.datetime.now().strftime("%H:%M:%S"),
}
)
# 重置状态: 对于非回复类动作的成功,清除上次回复状态
if action not in ["direct_reply", "send_new_message"]:
self.conversation_info.last_successful_reply_action = None
logger.debug(f"[私聊][{self.private_name}]动作 {action} 成功完成,重置 last_successful_reply_action")
# 如果动作是 recall 状态,在各自的处理逻辑中已经更新了 done_action
async def _send_reply(self):
"""发送回复"""
if not self.generated_reply:
logger.warning(f"[私聊][{self.private_name}]没有生成回复内容,无法发送。")
return
try:
_current_time = time.time()
reply_content = self.generated_reply
# 发送消息 (确保 direct_sender 和 chat_stream 有效)
if not hasattr(self, "direct_sender") or not self.direct_sender:
logger.error(f"[私聊][{self.private_name}]DirectMessageSender 未初始化,无法发送回复。")
return
if not self.chat_stream:
logger.error(f"[私聊][{self.private_name}]ChatStream 未初始化,无法发送回复。")
return
await self.direct_sender.send_message(chat_stream=self.chat_stream, content=reply_content)
# 发送成功后,手动触发 observer 更新可能导致重复处理自己发送的消息
# 更好的做法是依赖 observer 的自动轮询或数据库触发器(如果支持)
# 暂时注释掉,观察是否影响 ObservationInfo 的更新
# self.chat_observer.trigger_update()
# if not await self.chat_observer.wait_for_update():
# logger.warning(f"[私聊][{self.private_name}]等待 ChatObserver 更新完成超时")
self.state = ConversationState.ANALYZING # 更新状态
except Exception as e:
logger.error(f"[私聊][{self.private_name}]发送消息或更新状态时失败: {str(e)}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}")
self.state = ConversationState.ANALYZING
async def _send_timeout_message(self):
"""发送超时结束消息"""
try:
messages = self.chat_observer.get_cached_messages(limit=1)
if not messages:
return
latest_message = self._convert_to_message(messages[0])
await self.direct_sender.send_message(
chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message
)
except Exception as e:
logger.error(f"[私聊][{self.private_name}]发送超时消息失败: {str(e)}")

View File

@ -0,0 +1,10 @@
from typing import Optional
class ConversationInfo:
def __init__(self):
self.done_action = []
self.goal_list = []
self.knowledge_list = []
self.memory_list = []
self.last_successful_reply_action: Optional[str] = None

View File

@ -0,0 +1,81 @@
import time
from typing import Optional
from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream
from ..chat.message import Message
from maim_message import UserInfo, Seg
from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager
from ..storage.storage import MessageStorage
from ...config.config import global_config
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("message_sender")
class DirectMessageSender:
"""直接消息发送器"""
def __init__(self, private_name: str):
self.private_name = private_name
self.storage = MessageStorage()
async def send_message(
self,
chat_stream: ChatStream,
content: str,
reply_to_message: Optional[Message] = None,
) -> None:
"""发送消息到聊天流
Args:
chat_stream: 聊天流
content: 消息内容
reply_to_message: 要回复的消息可选
"""
try:
# 创建消息内容
segments = Seg(type="seglist", data=[Seg(type="text", data=content)])
# 获取麦麦的信息
bot_user_info = UserInfo(
user_id=global_config.BOT_QQ,
user_nickname=global_config.BOT_NICKNAME,
platform=chat_stream.platform,
)
# 用当前时间作为message_id和之前那套sender一样
message_id = f"dm{round(time.time(), 2)}"
# 构建消息对象
message = MessageSending(
message_id=message_id,
chat_stream=chat_stream,
bot_user_info=bot_user_info,
sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
message_segment=segments,
reply=reply_to_message,
is_head=True,
is_emoji=False,
thinking_start_time=time.time(),
)
# 处理消息
await message.process()
# 不知道有什么用先留下来了和之前那套sender一样
_message_json = message.to_dict()
# 发送消息
message_set = MessageSet(chat_stream, message_id)
message_set.add_message(message)
await message_manager.add_message(message_set)
await self.storage.store_message(message, chat_stream)
logger.info(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")
raise

View File

@ -0,0 +1,119 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from src.common.database import db
class MessageStorage(ABC):
"""消息存储接口"""
@abstractmethod
async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]:
"""获取指定消息ID之后的所有消息
Args:
chat_id: 聊天ID
message: 消息
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
chat_id: 聊天ID
time_point: 时间戳
limit: 最大消息数量
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
"""检查是否有新消息
Args:
chat_id: 聊天ID
after_time: 时间戳
Returns:
bool: 是否有新消息
"""
pass
class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
# print(f"storage_check_message: {message_time}")
return list(db.messages.find(query).sort("time", 1))
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
# 将消息按时间正序排列
messages.reverse()
return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
return db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试"""
# def __init__(self):
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = self.messages[chat_id]
# if not message_id:
# return messages
# # 找到message_id的索引
# try:
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
# return messages[index + 1:]
# except StopIteration:
# return []
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = [
# m for m in self.messages[chat_id]
# if m["time"] < time_point
# ]
# return messages[-limit:]
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
# if chat_id not in self.messages:
# return False
# return any(m["time"] > after_time for m in self.messages[chat_id])
# # 测试辅助方法
# def add_message(self, chat_id: str, message: Dict[str, Any]):
# """添加测试消息"""
# if chat_id not in self.messages:
# self.messages[chat_id] = []
# self.messages[chat_id].append(message)
# self.messages[chat_id].sort(key=lambda m: m["time"])

View File

@ -0,0 +1,389 @@
from typing import List, Optional, Dict, Any, Set
from maim_message import UserInfo
import time
from src.common.logger import get_module_logger
from .chat_observer import ChatObserver
from .chat_states import NotificationHandler, NotificationType, Notification
from src.plugins.utils.chat_message_builder import build_readable_messages
import traceback # 导入 traceback 用于调试
logger = get_module_logger("observation_info")
class ObservationInfoHandler(NotificationHandler):
"""ObservationInfo的通知处理器"""
def __init__(self, observation_info: "ObservationInfo", private_name: str):
"""初始化处理器
Args:
observation_info: 要更新的ObservationInfo实例
private_name: 私聊对象的名称用于日志记录
"""
self.observation_info = observation_info
# 将 private_name 存储在 handler 实例中
self.private_name = private_name
async def handle_notification(self, notification: Notification): # 添加类型提示
# 获取通知类型和数据
notification_type = notification.type
data = notification.data
try: # 添加错误处理块
if notification_type == NotificationType.NEW_MESSAGE:
# 处理新消息通知
# logger.debug(f"[私聊][{self.private_name}]收到新消息通知data: {data}") # 可以在需要时取消注释
message_id = data.get("message_id")
processed_plain_text = data.get("processed_plain_text")
detailed_plain_text = data.get("detailed_plain_text")
user_info_dict = data.get("user_info") # 先获取字典
time_value = data.get("time")
# 确保 user_info 是字典类型再创建 UserInfo 对象
user_info = None
if isinstance(user_info_dict, dict):
try:
user_info = UserInfo.from_dict(user_info_dict)
except Exception as e:
logger.error(
f"[私聊][{self.private_name}]从字典创建 UserInfo 时出错: {e}, 字典内容: {user_info_dict}"
)
# 可以选择在这里返回或记录错误,避免后续代码出错
return
elif user_info_dict is not None:
logger.warning(
f"[私聊][{self.private_name}]收到的 user_info 不是预期的字典类型: {type(user_info_dict)}"
)
# 根据需要处理非字典情况,这里暂时返回
return
message = {
"message_id": message_id,
"processed_plain_text": processed_plain_text,
"detailed_plain_text": detailed_plain_text,
"user_info": user_info_dict, # 存储原始字典或 UserInfo 对象,取决于你的 update_from_message 如何处理
"time": time_value,
}
# 传递 UserInfo 对象(如果成功创建)或原始字典
await self.observation_info.update_from_message(message, user_info) # 修改:传递 user_info 对象
elif notification_type == NotificationType.COLD_CHAT:
# 处理冷场通知
is_cold = data.get("is_cold", False)
await self.observation_info.update_cold_chat_status(is_cold, time.time()) # 修改:改为 await 调用
elif notification_type == NotificationType.ACTIVE_CHAT:
# 处理活跃通知 (通常由 COLD_CHAT 的反向状态处理)
is_active = data.get("is_active", False)
self.observation_info.is_cold = not is_active
elif notification_type == NotificationType.BOT_SPEAKING:
# 处理机器人说话通知 (按需实现)
self.observation_info.is_typing = False
self.observation_info.last_bot_speak_time = time.time()
elif notification_type == NotificationType.USER_SPEAKING:
# 处理用户说话通知
self.observation_info.is_typing = False
self.observation_info.last_user_speak_time = time.time()
elif notification_type == NotificationType.MESSAGE_DELETED:
# 处理消息删除通知
message_id = data.get("message_id")
# 从 unprocessed_messages 中移除被删除的消息
original_count = len(self.observation_info.unprocessed_messages)
self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
]
if len(self.observation_info.unprocessed_messages) < original_count:
logger.info(f"[私聊][{self.private_name}]移除了未处理的消息 (ID: {message_id})")
elif notification_type == NotificationType.USER_JOINED:
# 处理用户加入通知 (如果适用私聊场景)
user_id = data.get("user_id")
if user_id:
self.observation_info.active_users.add(str(user_id)) # 确保是字符串
elif notification_type == NotificationType.USER_LEFT:
# 处理用户离开通知 (如果适用私聊场景)
user_id = data.get("user_id")
if user_id:
self.observation_info.active_users.discard(str(user_id)) # 确保是字符串
elif notification_type == NotificationType.ERROR:
# 处理错误通知
error_msg = data.get("error", "未提供错误信息")
logger.error(f"[私聊][{self.private_name}]收到错误通知: {error_msg}")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]处理通知时发生错误: {e}")
logger.error(traceback.format_exc()) # 打印详细堆栈信息
# @dataclass <-- 这个,不需要了(递黄瓜)
class ObservationInfo:
"""决策信息类用于收集和管理来自chat_observer的通知信息 (手动实现 __init__)"""
# 类型提示保留,可用于文档和静态分析
private_name: str
chat_history: List[Dict[str, Any]]
chat_history_str: str
unprocessed_messages: List[Dict[str, Any]]
active_users: Set[str]
last_bot_speak_time: Optional[float]
last_user_speak_time: Optional[float]
last_message_time: Optional[float]
last_message_id: Optional[str]
last_message_content: str
last_message_sender: Optional[str]
bot_id: Optional[str]
chat_history_count: int
new_messages_count: int
cold_chat_start_time: Optional[float]
cold_chat_duration: float
is_typing: bool
is_cold_chat: bool
changed: bool
chat_observer: Optional[ChatObserver]
handler: Optional[ObservationInfoHandler]
def __init__(self, private_name: str):
"""
手动初始化 ObservationInfo 的所有实例变量
"""
# 接收的参数
self.private_name: str = private_name
# data_list
self.chat_history: List[Dict[str, Any]] = []
self.chat_history_str: str = ""
self.unprocessed_messages: List[Dict[str, Any]] = []
self.active_users: Set[str] = set()
# data
self.last_bot_speak_time: Optional[float] = None
self.last_user_speak_time: Optional[float] = None
self.last_message_time: Optional[float] = None
self.last_message_id: Optional[str] = None
self.last_message_content: str = ""
self.last_message_sender: Optional[str] = None
self.bot_id: Optional[str] = None
self.chat_history_count: int = 0
self.new_messages_count: int = 0
self.cold_chat_start_time: Optional[float] = None
self.cold_chat_duration: float = 0.0
# state
self.is_typing: bool = False
self.is_cold_chat: bool = False
self.changed: bool = False
# 关联对象
self.chat_observer: Optional[ChatObserver] = None
self.handler: ObservationInfoHandler = ObservationInfoHandler(self, self.private_name)
def bind_to_chat_observer(self, chat_observer: ChatObserver):
"""绑定到指定的chat_observer
Args:
chat_observer: 要绑定的 ChatObserver 实例
"""
if self.chat_observer:
logger.warning(f"[私聊][{self.private_name}]尝试重复绑定 ChatObserver")
return
self.chat_observer = chat_observer
try:
if not self.handler: # 确保 handler 已经被创建
logger.error(f"[私聊][{self.private_name}] 尝试绑定时 handler 未初始化!")
self.chat_observer = None # 重置,防止后续错误
return
# 注册关心的通知类型
self.chat_observer.notification_manager.register_handler(
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
)
self.chat_observer.notification_manager.register_handler(
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
)
# 可以根据需要注册更多通知类型
# self.chat_observer.notification_manager.register_handler(
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
# )
logger.info(f"[私聊][{self.private_name}]成功绑定到 ChatObserver")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]绑定到 ChatObserver 时出错: {e}")
self.chat_observer = None # 绑定失败,重置
def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定"""
if (
self.chat_observer and hasattr(self.chat_observer, "notification_manager") and self.handler
): # 增加 handler 检查
try:
self.chat_observer.notification_manager.unregister_handler(
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
)
self.chat_observer.notification_manager.unregister_handler(
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
)
# 如果注册了其他类型,也要在这里注销
# self.chat_observer.notification_manager.unregister_handler(
# target="observation_info", notification_type=NotificationType.MESSAGE_DELETED, handler=self.handler
# )
logger.info(f"[私聊][{self.private_name}]成功从 ChatObserver 解绑")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]从 ChatObserver 解绑时出错: {e}")
finally: # 确保 chat_observer 被重置
self.chat_observer = None
else:
logger.warning(f"[私聊][{self.private_name}]尝试解绑时 ChatObserver 不存在、无效或 handler 未设置")
# 修改update_from_message 接收 UserInfo 对象
async def update_from_message(self, message: Dict[str, Any], user_info: Optional[UserInfo]):
"""从消息更新信息
Args:
message: 消息数据字典
user_info: 解析后的 UserInfo 对象 (可能为 None)
"""
message_time = message.get("time")
message_id = message.get("message_id")
processed_text = message.get("processed_plain_text", "")
# 只有在新消息到达时才更新 last_message 相关信息
if message_time and message_time > (self.last_message_time or 0):
self.last_message_time = message_time
self.last_message_id = message_id
self.last_message_content = processed_text
# 重置冷场计时器
self.is_cold_chat = False
self.cold_chat_start_time = None
self.cold_chat_duration = 0.0
if user_info:
sender_id = str(user_info.user_id) # 确保是字符串
self.last_message_sender = sender_id
# 更新发言时间
if sender_id == self.bot_id:
self.last_bot_speak_time = message_time
else:
self.last_user_speak_time = message_time
self.active_users.add(sender_id) # 用户发言则认为其活跃
else:
logger.warning(
f"[私聊][{self.private_name}]处理消息更新时缺少有效的 UserInfo 对象, message_id: {message_id}"
)
self.last_message_sender = None # 发送者未知
# 将原始消息字典添加到未处理列表
self.unprocessed_messages.append(message)
self.new_messages_count = len(self.unprocessed_messages) # 直接用列表长度
# logger.debug(f"[私聊][{self.private_name}]消息更新: last_time={self.last_message_time}, new_count={self.new_messages_count}")
self.update_changed() # 标记状态已改变
else:
# 如果消息时间戳不是最新的,可能不需要处理,或者记录一个警告
pass
# logger.warning(f"[私聊][{self.private_name}]收到过时或无效时间戳的消息: ID={message_id}, time={message_time}")
def update_changed(self):
"""标记状态已改变,并重置标记"""
# logger.debug(f"[私聊][{self.private_name}]状态标记为已改变 (changed=True)")
self.changed = True
async def update_cold_chat_status(self, is_cold: bool, current_time: float):
"""更新冷场状态
Args:
is_cold: 是否处于冷场状态
current_time: 当前时间戳
"""
if is_cold != self.is_cold_chat: # 仅在状态变化时更新
self.is_cold_chat = is_cold
if is_cold:
# 进入冷场状态
self.cold_chat_start_time = (
self.last_message_time or current_time
) # 从最后消息时间开始算,或从当前时间开始
logger.info(f"[私聊][{self.private_name}]进入冷场状态,开始时间: {self.cold_chat_start_time}")
else:
# 结束冷场状态
if self.cold_chat_start_time:
self.cold_chat_duration = current_time - self.cold_chat_start_time
logger.info(f"[私聊][{self.private_name}]结束冷场状态,持续时间: {self.cold_chat_duration:.2f}")
self.cold_chat_start_time = None # 重置开始时间
self.update_changed() # 状态变化,标记改变
# 即使状态没变,如果是冷场状态,也更新持续时间
if self.is_cold_chat and self.cold_chat_start_time:
self.cold_chat_duration = current_time - self.cold_chat_start_time
def get_active_duration(self) -> float:
"""获取当前活跃时长 (距离最后一条消息的时间)
Returns:
float: 最后一条消息到现在的时长
"""
if not self.last_message_time:
return 0.0
return time.time() - self.last_message_time
def get_user_response_time(self) -> Optional[float]:
"""获取用户最后响应时间 (距离用户最后发言的时间)
Returns:
Optional[float]: 用户最后发言到现在的时长如果没有用户发言则返回None
"""
if not self.last_user_speak_time:
return None
return time.time() - self.last_user_speak_time
def get_bot_response_time(self) -> Optional[float]:
"""获取机器人最后响应时间 (距离机器人最后发言的时间)
Returns:
Optional[float]: 机器人最后发言到现在的时长如果没有机器人发言则返回None
"""
if not self.last_bot_speak_time:
return None
return time.time() - self.last_bot_speak_time
async def clear_unprocessed_messages(self):
"""将未处理消息移入历史记录,并更新相关状态"""
if not self.unprocessed_messages:
return # 没有未处理消息,直接返回
# logger.debug(f"[私聊][{self.private_name}]处理 {len(self.unprocessed_messages)} 条未处理消息...")
# 将未处理消息添加到历史记录中 (确保历史记录有长度限制,避免无限增长)
max_history_len = 100 # 示例最多保留100条历史记录
self.chat_history.extend(self.unprocessed_messages)
if len(self.chat_history) > max_history_len:
self.chat_history = self.chat_history[-max_history_len:]
# 更新历史记录字符串 (只使用最近一部分生成例如20条)
history_slice_for_str = self.chat_history[-20:]
try:
self.chat_history_str = await build_readable_messages(
history_slice_for_str,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0, # read_mark 可能需要根据逻辑调整
)
except Exception as e:
logger.error(f"[私聊][{self.private_name}]构建聊天记录字符串时出错: {e}")
self.chat_history_str = "[构建聊天记录出错]" # 提供错误提示
# 清空未处理消息列表和计数
# cleared_count = len(self.unprocessed_messages)
self.unprocessed_messages.clear()
self.new_messages_count = 0
# self.has_unread_messages = False # 这个状态可以通过 new_messages_count 判断
self.chat_history_count = len(self.chat_history) # 更新历史记录总数
# logger.debug(f"[私聊][{self.private_name}]已处理 {cleared_count} 条消息,当前历史记录 {self.chat_history_count} 条。")
self.update_changed() # 状态改变

View File

@ -0,0 +1,345 @@
from typing import List, Tuple, TYPE_CHECKING
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from .chat_observer import ChatObserver
from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality
from .conversation_info import ConversationInfo
from .observation_info import ObservationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
from rich.traceback import install
install(extra_lines=3)
if TYPE_CHECKING:
pass
logger = get_module_logger("pfc")
def _calculate_similarity(goal1: str, goal2: str) -> float:
"""简单计算两个目标之间的相似度
这里使用一个简单的实现实际可以使用更复杂的文本相似度算法
Args:
goal1: 第一个目标
goal2: 第二个目标
Returns:
float: 相似度得分 (0-1)
"""
# 简单实现:检查重叠字数比例
words1 = set(goal1)
words2 = set(goal2)
overlap = len(words1.intersection(words2))
total = len(words1.union(words2))
return overlap / total if total > 0 else 0
class GoalAnalyzer:
"""对话目标分析器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
# 多目标存储结构
self.goals = [] # 存储多个目标
self.max_goals = 3 # 同时保持的最大目标数量
self.current_goal_and_reason = None
async def analyze_goal(self, conversation_info: ConversationInfo, observation_info: ObservationInfo):
"""分析对话历史并设定目标
Args:
conversation_info: 对话信息
observation_info: 观察信息
Returns:
Tuple[str, str, str]: (目标, 方法, 原因)
"""
# 构建对话目标
goals_str = ""
if conversation_info.goal_list:
for goal_reason in conversation_info.goal_list:
if isinstance(goal_reason, dict):
goal = goal_reason.get("goal", "目标内容缺失")
reasoning = goal_reason.get("reasoning", "没有明确原因")
else:
goal = str(goal_reason)
reasoning = "没有明确原因"
goal_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
goals_str += goal_str
else:
goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
goals_str = f"目标:{goal},产生该对话目标的原因:{reasoning}\n"
# 获取聊天历史记录
chat_history_text = observation_info.chat_history_str
if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages
new_messages_str = await build_readable_messages(
new_messages_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
# await observation_info.clear_unprocessed_messages()
persona_text = f"你的名字是{self.name}{self.personality_info}"
# 构建action历史文本
action_history_list = conversation_info.done_action
action_history_text = "你之前做的事情是:"
for action in action_history_list:
action_history_text += f"{action}\n"
prompt = f"""{persona_text}。现在你在参与一场QQ聊天请分析以下聊天记录并根据你的性格特征确定多个明确的对话目标。
这些目标应该反映出对话的不同方面和意图
{action_history_text}
当前对话目标
{goals_str}
聊天记录
{chat_history_text}
请分析当前对话并确定最适合的对话目标你可以
1. 保持现有目标不变
2. 修改现有目标
3. 添加新目标
4. 删除不再相关的目标
5. 如果你想结束对话请设置一个目标目标goal为"结束对话"原因reasoning为你希望结束对话
请以JSON数组格式输出当前的所有对话目标每个目标包含以下字段
1. goal: 对话目标简短的一句话
2. reasoning: 对话原因为什么设定这个目标简要解释
输出格式示例
[
{{
"goal": "回答用户关于Python编程的具体问题",
"reasoning": "用户提出了关于Python的技术问题需要专业且准确的解答"
}},
{{
"goal": "回答用户关于python安装的具体问题",
"reasoning": "用户提出了关于Python的技术问题需要专业且准确的解答"
}}
]"""
logger.debug(f"[私聊][{self.private_name}]发送到LLM的提示词: {prompt}")
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
except Exception as e:
logger.error(f"[私聊][{self.private_name}]分析对话目标时出错: {str(e)}")
content = ""
# 使用改进后的get_items_from_json函数处理JSON数组
success, result = get_items_from_json(
content,
self.private_name,
"goal",
"reasoning",
required_types={"goal": str, "reasoning": str},
allow_array=True,
)
if success:
# 判断结果是单个字典还是字典列表
if isinstance(result, list):
# 清空现有目标列表并添加新目标
conversation_info.goal_list = []
for item in result:
conversation_info.goal_list.append(item)
# 返回第一个目标作为当前主要目标(如果有)
if result:
first_goal = result[0]
return first_goal.get("goal", ""), "", first_goal.get("reasoning", "")
else:
# 单个目标的情况
conversation_info.goal_list.append(result)
return goal, "", reasoning
# 如果解析失败,返回默认值
return "", "", ""
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表
Args:
new_goal: 新的目标
method: 实现目标的方法
reasoning: 目标的原因
"""
# 检查新目标是否与现有目标相似
for i, (existing_goal, _, _) in enumerate(self.goals):
if _calculate_similarity(new_goal, existing_goal) > 0.7: # 相似度阈值
# 更新现有目标
self.goals[i] = (new_goal, method, reasoning)
# 将此目标移到列表前面(最主要的位置)
self.goals.insert(0, self.goals.pop(i))
return
# 添加新目标到列表前面
self.goals.insert(0, (new_goal, method, reasoning))
# 限制目标数量
if len(self.goals) > self.max_goals:
self.goals.pop() # 移除最老的目标
async def get_all_goals(self) -> List[Tuple[str, str, str]]:
"""获取所有当前目标
Returns:
List[Tuple[str, str, str]]: 目标列表每项为(目标, 方法, 原因)
"""
return self.goals.copy()
async def get_alternative_goals(self) -> List[Tuple[str, str, str]]:
"""获取除了当前主要目标外的其他备选目标
Returns:
List[Tuple[str, str, str]]: 备选目标列表
"""
if len(self.goals) <= 1:
return []
return self.goals[1:].copy()
async def analyze_conversation(self, goal, reasoning):
messages = self.chat_observer.get_cached_messages()
chat_history_text = await build_readable_messages(
messages,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
persona_text = f"你的名字是{self.name}{self.personality_info}"
# ===> Persona 文本构建结束 <===
# --- 修改 Prompt 字符串,使用 persona_text ---
prompt = f"""{persona_text}。现在你在参与一场QQ聊天
当前对话目标{goal}
产生该对话目标的原因{reasoning}
请分析以下聊天记录并根据你的性格特征评估该目标是否已经达到或者你是否希望停止该次对话
聊天记录
{chat_history_text}
请以JSON格式输出包含以下字段
1. goal_achieved: 对话目标是否已经达到true/false
2. stop_conversation: 是否希望停止该次对话true/false
3. reason: 为什么希望停止该次对话简要解释
输出格式示例
{{
"goal_achieved": true,
"stop_conversation": false,
"reason": "虽然目标已达成,但对话仍然有继续的价值"
}}"""
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"[私聊][{self.private_name}]LLM原始返回内容: {content}")
# 尝试解析JSON
success, result = get_items_from_json(
content,
self.private_name,
"goal_achieved",
"stop_conversation",
"reason",
required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str},
)
if not success:
logger.error(f"[私聊][{self.private_name}]无法解析对话分析结果JSON")
return False, False, "解析结果失败"
goal_achieved = result["goal_achieved"]
stop_conversation = result["stop_conversation"]
reason = result["reason"]
return goal_achieved, stop_conversation, reason
except Exception as e:
logger.error(f"[私聊][{self.private_name}]分析对话状态时出错: {str(e)}")
return False, False, f"分析出错: {str(e)}"
# 先注释掉,万一以后出问题了还能开回来(((
# class DirectMessageSender:
# """直接发送消息到平台的发送器"""
# def __init__(self, private_name: str):
# self.logger = get_module_logger("direct_sender")
# self.storage = MessageStorage()
# self.private_name = private_name
# async def send_via_ws(self, message: MessageSending) -> None:
# try:
# await global_api.send_message(message)
# except Exception as e:
# raise ValueError(f"未找到平台:{message.message_info.platform} 的url配置请检查配置文件") from e
# async def send_message(
# self,
# chat_stream: ChatStream,
# content: str,
# reply_to_message: Optional[Message] = None,
# ) -> None:
# """直接发送消息到平台
# Args:
# chat_stream: 聊天流
# content: 消息内容
# reply_to_message: 要回复的消息
# """
# # 构建消息对象
# message_segment = Seg(type="text", data=content)
# bot_user_info = UserInfo(
# user_id=global_config.BOT_QQ,
# user_nickname=global_config.BOT_NICKNAME,
# platform=chat_stream.platform,
# )
# message = MessageSending(
# message_id=f"dm{round(time.time(), 2)}",
# chat_stream=chat_stream,
# bot_user_info=bot_user_info,
# sender_info=reply_to_message.message_info.user_info if reply_to_message else None,
# message_segment=message_segment,
# reply=reply_to_message,
# is_head=True,
# is_emoji=False,
# thinking_start_time=time.time(),
# )
# # 处理消息
# await message.process()
# _message_json = message.to_dict()
# # 发送消息
# try:
# await self.send_via_ws(message)
# await self.storage.store_message(message, chat_stream)
# logger.success(f"[私聊][{self.private_name}]PFC消息已发送: {content}")
# except Exception as e:
# logger.error(f"[私聊][{self.private_name}]PFC消息发送失败: {str(e)}")

View File

@ -0,0 +1,85 @@
from typing import List, Tuple
from src.common.logger import get_module_logger
from src.plugins.memory_system.Hippocampus import HippocampusManager
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from ..chat.message import Message
from ..knowledge.knowledge_lib import qa_manager
from ..utils.chat_message_builder import build_readable_messages
logger = get_module_logger("knowledge_fetcher")
class KnowledgeFetcher:
"""知识调取器"""
def __init__(self, private_name: str):
self.llm = LLMRequest(
model=global_config.llm_normal,
temperature=global_config.llm_normal["temp"],
max_tokens=1000,
request_type="knowledge_fetch",
)
self.private_name = private_name
def _lpmm_get_knowledge(self, query: str) -> str:
"""获取相关知识
Args:
query: 查询内容
Returns:
str: 构造好的,带相关度的知识
"""
logger.debug(f"[私聊][{self.private_name}]正在从LPMM知识库中获取知识")
try:
knowledge_info = qa_manager.get_knowledge(query)
logger.debug(f"[私聊][{self.private_name}]LPMM知识库查询结果: {knowledge_info:150}")
return knowledge_info
except Exception as e:
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
return "未找到匹配的知识"
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
"""获取相关知识
Args:
query: 查询内容
chat_history: 聊天历史
Returns:
Tuple[str, str]: (获取的知识, 知识来源)
"""
# 构建查询上下文
chat_history_text = await build_readable_messages(
chat_history,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
# 从记忆中获取相关知识
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=f"{query}\n{chat_history_text}",
max_memory_num=3,
max_memory_length=2,
max_depth=3,
fast_retrieval=False,
)
knowledge_text = ""
sources_text = "无记忆匹配" # 默认值
if related_memory:
sources = []
for memory in related_memory:
knowledge_text += memory[1] + "\n"
sources.append(f"记忆片段{memory[0]}")
knowledge_text = knowledge_text.strip()
sources_text = "".join(sources)
knowledge_text += "\n现在有以下**知识**可供参考:\n "
knowledge_text += self._lpmm_get_knowledge(query)
knowledge_text += "\n请记住这些**知识**,并根据**知识**回答问题。\n"
return knowledge_text or "未找到相关知识", sources_text or "无记忆匹配"

View File

@ -0,0 +1,115 @@
import time
from typing import Dict, Optional
from src.common.logger import get_module_logger
from .conversation import Conversation
import traceback
logger = get_module_logger("pfc_manager")
class PFCManager:
"""PFC对话管理器负责管理所有对话实例"""
# 单例模式
_instance = None
# 会话实例管理
_instances: Dict[str, Conversation] = {}
_initializing: Dict[str, bool] = {}
@classmethod
def get_instance(cls) -> "PFCManager":
"""获取管理器单例
Returns:
PFCManager: 管理器实例
"""
if cls._instance is None:
cls._instance = PFCManager()
return cls._instance
async def get_or_create_conversation(self, stream_id: str, private_name: str) -> Optional[Conversation]:
"""获取或创建对话实例
Args:
stream_id: 聊天流ID
private_name: 私聊名称
Returns:
Optional[Conversation]: 对话实例创建失败则返回None
"""
# 检查是否已经有实例
if stream_id in self._initializing and self._initializing[stream_id]:
logger.debug(f"[私聊][{private_name}]会话实例正在初始化中: {stream_id}")
return None
if stream_id in self._instances and self._instances[stream_id].should_continue:
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
return self._instances[stream_id]
if stream_id in self._instances:
instance = self._instances[stream_id]
if (
hasattr(instance, "ignore_until_timestamp")
and instance.ignore_until_timestamp
and time.time() < instance.ignore_until_timestamp
):
logger.debug(f"[私聊][{private_name}]会话实例当前处于忽略状态: {stream_id}")
# 返回 None 阻止交互。或者可以返回实例但标记它被忽略了喵?
# 还是返回 None 吧喵。
return None
# 检查 should_continue 状态
if instance.should_continue:
logger.debug(f"[私聊][{private_name}]使用现有会话实例: {stream_id}")
return instance
# else: 实例存在但不应继续
try:
# 创建新实例
logger.info(f"[私聊][{private_name}]创建新的对话实例: {stream_id}")
self._initializing[stream_id] = True
# 创建实例
conversation_instance = Conversation(stream_id, private_name)
self._instances[stream_id] = conversation_instance
# 启动实例初始化
await self._initialize_conversation(conversation_instance)
except Exception as e:
logger.error(f"[私聊][{private_name}]创建会话实例失败: {stream_id}, 错误: {e}")
return None
return conversation_instance
async def _initialize_conversation(self, conversation: Conversation):
"""初始化会话实例
Args:
conversation: 要初始化的会话实例
"""
stream_id = conversation.stream_id
private_name = conversation.private_name
try:
logger.info(f"[私聊][{private_name}]开始初始化会话实例: {stream_id}")
# 启动初始化流程
await conversation._initialize()
# 标记初始化完成
self._initializing[stream_id] = False
logger.info(f"[私聊][{private_name}]会话实例 {stream_id} 初始化完成")
except Exception as e:
logger.error(f"[私聊][{private_name}]管理器初始化会话实例失败: {stream_id}, 错误: {e}")
logger.error(f"[私聊][{private_name}]{traceback.format_exc()}")
# 清理失败的初始化
async def get_conversation(self, stream_id: str) -> Optional[Conversation]:
"""获取已存在的会话实例
Args:
stream_id: 聊天流ID
Returns:
Optional[Conversation]: 会话实例不存在则返回None
"""
return self._instances.get(stream_id)

View File

@ -0,0 +1,23 @@
from enum import Enum
from typing import Literal
class ConversationState(Enum):
"""对话状态"""
INIT = "初始化"
RETHINKING = "重新思考"
ANALYZING = "分析历史"
PLANNING = "规划目标"
GENERATING = "生成回复"
CHECKING = "检查回复"
SENDING = "发送消息"
FETCHING = "获取知识"
WAITING = "等待"
LISTENING = "倾听"
ENDED = "结束"
JUDGING = "判断"
IGNORED = "屏蔽"
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]

View File

@ -0,0 +1,127 @@
import json
import re
from typing import Dict, Any, Optional, Tuple, List, Union
from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils")
def get_items_from_json(
content: str,
private_name: str,
*items: str,
default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None,
allow_array: bool = True,
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""从文本中提取JSON内容并获取指定字段
Args:
content: 包含JSON的文本
private_name: 私聊名称
*items: 要提取的字段名
default_values: 字段的默认值格式为 {字段名: 默认值}
required_types: 字段的必需类型格式为 {字段名: 类型}
allow_array: 是否允许解析JSON数组
Returns:
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
"""
content = content.strip()
result = {}
# 设置默认值
if default_values:
result.update(default_values)
# 首先尝试解析为JSON数组
if allow_array:
try:
# 尝试找到文本中的JSON数组
array_pattern = r"\[[\s\S]*\]"
array_match = re.search(array_pattern, content)
if array_match:
array_content = array_match.group()
json_array = json.loads(array_content)
# 确认是数组类型
if isinstance(json_array, list):
# 验证数组中的每个项目是否包含所有必需字段
valid_items = []
for item in json_array:
if not isinstance(item, dict):
continue
# 检查是否有所有必需字段
if all(field in item for field in items):
# 验证字段类型
if required_types:
type_valid = True
for field, expected_type in required_types.items():
if field in item and not isinstance(item[field], expected_type):
type_valid = False
break
if not type_valid:
continue
# 验证字符串字段不为空
string_valid = True
for field in items:
if isinstance(item[field], str) and not item[field].strip():
string_valid = False
break
if not string_valid:
continue
valid_items.append(item)
if valid_items:
return True, valid_items
except json.JSONDecodeError:
logger.debug(f"[私聊][{private_name}]JSON数组解析失败尝试解析单个JSON对象")
except Exception as e:
logger.debug(f"[私聊][{private_name}]尝试解析JSON数组时出错: {str(e)}")
# 尝试解析JSON对象
try:
json_data = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
json_data = json.loads(json_match.group())
except json.JSONDecodeError:
logger.error(f"[私聊][{private_name}]提取的JSON内容解析失败")
return False, result
else:
logger.error(f"[私聊][{private_name}]无法在返回内容中找到有效的JSON")
return False, result
# 提取字段
for item in items:
if item in json_data:
result[item] = json_data[item]
# 验证必需字段
if not all(item in result for item in items):
logger.error(f"[私聊][{private_name}]JSON缺少必要字段实际内容: {json_data}")
return False, result
# 验证字段类型
if required_types:
for field, expected_type in required_types.items():
if field in result and not isinstance(result[field], expected_type):
logger.error(f"[私聊][{private_name}]{field} 必须是 {expected_type.__name__} 类型")
return False, result
# 验证字符串字段不为空
for field in items:
if isinstance(result[field], str) and not result[field].strip():
logger.error(f"[私聊][{private_name}]{field} 不能为空")
return False, result
return True, result

View File

@ -0,0 +1,183 @@
import json
from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from .chat_observer import ChatObserver
from maim_message import UserInfo
logger = get_module_logger("reply_checker")
class ReplyChecker:
"""回复检查器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
)
self.name = global_config.BOT_NICKNAME
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.max_retries = 3 # 最大重试次数
async def check(
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查生成的回复是否合适
Args:
reply: 生成的回复
goal: 对话目标
chat_history: 对话历史记录
chat_history_text: 对话历史记录文本
retry_count: 当前重试次数
Returns:
Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划)
"""
# 不再从 observer 获取,直接使用传入的 chat_history
# messages = self.chat_observer.get_cached_messages(limit=20)
try:
# 筛选出最近由 Bot 自己发送的消息
bot_messages = []
for msg in reversed(chat_history):
user_info = UserInfo.from_dict(msg.get("user_info", {}))
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
bot_messages.append(msg.get("processed_plain_text", ""))
if len(bot_messages) >= 2: # 只和最近的两条比较
break
# 进行比较
if bot_messages:
# 可以用简单比较,或者更复杂的相似度库 (如 difflib)
# 简单比较:是否完全相同
if reply == bot_messages[0]: # 和最近一条完全一样
logger.warning(
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息完全相同: '{reply}'"
)
return (
False,
"被逻辑检查拒绝:回复内容与你上一条发言完全相同,可以选择深入话题或寻找其它话题或等待",
True,
) # 不合适,需要返回至决策层
# 2. 相似度检查 (如果精确匹配未通过)
import difflib # 导入 difflib 库
# 计算编辑距离相似度ratio() 返回 0 到 1 之间的浮点数
similarity_ratio = difflib.SequenceMatcher(None, reply, bot_messages[0]).ratio()
logger.debug(f"[私聊][{self.private_name}]ReplyChecker - 相似度: {similarity_ratio:.2f}")
# 设置一个相似度阈值
similarity_threshold = 0.9
if similarity_ratio > similarity_threshold:
logger.warning(
f"[私聊][{self.private_name}]ReplyChecker 检测到回复与上一条 Bot 消息高度相似 (相似度 {similarity_ratio:.2f}): '{reply}'"
)
return (
False,
f"被逻辑检查拒绝:回复内容与你上一条发言高度相似 (相似度 {similarity_ratio:.2f}),可以选择深入话题或寻找其它话题或等待。",
True,
)
except Exception as e:
import traceback
logger.error(f"[私聊][{self.private_name}]检查回复时出错: 类型={type(e)}, 值={e}")
logger.error(f"[私聊][{self.private_name}]{traceback.format_exc()}") # 打印详细的回溯信息
prompt = f"""你是一个聊天逻辑检查器,请检查以下回复或消息是否合适:
当前对话目标{goal}
最新的对话记录
{chat_history_text}
待检查的消息
{reply}
请结合聊天记录检查以下几点
1. 这条消息是否依然符合当前对话目标和实现方式
2. 这条消息是否与最新的对话记录保持一致性
3. 是否存在重复发言或重复表达同质内容尤其是只是换一种方式表达了相同的含义
4. 这条消息是否包含违规内容例如血腥暴力政治敏感等
5. 这条消息是否以发送者的角度发言不要让发送者自己回复自己的消息
6. 这条消息是否通俗易懂
7. 这条消息是否有些多余例如在对方没有回复的情况下依然连续多次消息轰炸尤其是已经连续发送3条信息的情况这很可能不合理需要着重判断
8. 这条消息是否使用了完全没必要的修辞
9. 这条消息是否逻辑通顺
10. 这条消息是否太过冗长了通常私聊的每条消息长度在20字以内除非特殊情况
11. 在连续多次发送消息的情况下这条消息是否衔接自然会不会显得奇怪例如连续两条消息中部分内容重叠
请以JSON格式输出包含以下字段
1. suitable: 是否合适 (true/false)
2. reason: 原因说明
3. need_replan: 是否需要重新决策 (true/false)当你认为此时已经不适合发消息需要规划其它行动时设为true
输出格式示例
{{
"suitable": true,
"reason": "回复符合要求,虽然有可能略微偏离目标,但是整体内容流畅得体",
"need_replan": false
}}
注意请严格按照JSON格式输出不要包含任何其他内容"""
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"[私聊][{self.private_name}]检查回复的原始返回: {content}")
# 清理内容尝试提取JSON部分
content = content.strip()
try:
# 尝试直接解析
result = json.loads(content)
except json.JSONDecodeError:
# 如果直接解析失败尝试查找和提取JSON部分
import re
json_pattern = r"\{[^{}]*\}"
json_match = re.search(json_pattern, content)
if json_match:
try:
result = json.loads(json_match.group())
except json.JSONDecodeError:
# 如果JSON解析失败尝试从文本中提取结果
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
reason = content[:100] if content else "无法解析响应"
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
return is_suitable, reason, need_replan
else:
# 如果找不到JSON从文本中判断
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
reason = content[:100] if content else "无法解析响应"
need_replan = "重新规划" in content.lower() or "目标不适合" in content.lower()
return is_suitable, reason, need_replan
# 验证JSON字段
suitable = result.get("suitable", None)
reason = result.get("reason", "未提供原因")
need_replan = result.get("need_replan", False)
# 如果suitable字段是字符串转换为布尔值
if isinstance(suitable, str):
suitable = suitable.lower() == "true"
# 如果suitable字段不存在或不是布尔值从reason中判断
if suitable is None:
suitable = "不合适" not in reason.lower() and "违规" not in reason.lower()
# 如果不合适且未达到最大重试次数,返回需要重试
if not suitable and retry_count < self.max_retries:
return False, reason, False
# 如果不合适且已达到最大重试次数,返回需要重新规划
if not suitable and retry_count >= self.max_retries:
return False, f"多次重试后仍不合适: {reason}", True
return suitable, reason, need_replan
except Exception as e:
logger.error(f"[私聊][{self.private_name}]检查回复时出错: {e}")
# 如果出错且已达到最大重试次数,建议重新规划
if retry_count >= self.max_retries:
return False, "多次检查失败,建议重新规划", True
return False, f"检查过程出错,建议重试: {str(e)}", False

View File

@ -0,0 +1,228 @@
from typing import Tuple, List, Dict, Any
from src.common.logger import get_module_logger
from ..models.utils_model import LLMRequest
from ...config.config import global_config
from .chat_observer import ChatObserver
from .reply_checker import ReplyChecker
from src.individuality.individuality import Individuality
from .observation_info import ObservationInfo
from .conversation_info import ConversationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
logger = get_module_logger("reply_generator")
# --- 定义 Prompt 模板 ---
# Prompt for direct_reply (首次回复)
PROMPT_DIRECT_REPLY = """{persona_text}。现在你在参与一场QQ私聊请根据以下信息生成一条回复
当前对话目标{goals_str}
{knowledge_info_str}
最近的聊天记录
{chat_history_text}
请根据上述信息结合聊天记录回复对方该回复应该
1. 符合对话目标""的角度发言不要自己与自己对话
2. 符合你的性格特征和身份细节
3. 通俗易懂自然流畅像正常聊天一样简短通常20字以内除非特殊情况
4. 可以适当利用相关知识但不要生硬引用
5. 自然得体结合聊天记录逻辑合理且没有重复表达同质内容
请注意把握聊天内容不要回复的太有条理可以有个性请分清""和对方说的话不要把""说的话当做对方说的话这是你自己说的话
可以回复得自然随意自然一些就像真人一样注意把握聊天内容整体风格可以平和简短不要刻意突出自身学科背景不要说你说过的话可以简短多简短都可以但是避免冗长
请你注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )
请直接输出回复内容不需要任何额外格式"""
# Prompt for send_new_message (追问/补充)
PROMPT_SEND_NEW_MESSAGE = """{persona_text}。现在你在参与一场QQ私聊**刚刚你已经发送了一条或多条消息**,现在请根据以下信息再发一条新消息:
当前对话目标{goals_str}
{knowledge_info_str}
最近的聊天记录
{chat_history_text}
请根据上述信息结合聊天记录继续发一条新消息例如对之前消息的补充深入话题或追问等等该消息应该
1. 符合对话目标""的角度发言不要自己与自己对话
2. 符合你的性格特征和身份细节
3. 通俗易懂自然流畅像正常聊天一样简短通常20字以内除非特殊情况
4. 可以适当利用相关知识但不要生硬引用
5. 跟之前你发的消息自然的衔接逻辑合理且没有重复表达同质内容或部分重叠内容
请注意把握聊天内容不用太有条理可以有个性请分清""和对方说的话不要把""说的话当做对方说的话这是你自己说的话
这条消息可以自然随意自然一些就像真人一样注意把握聊天内容整体风格可以平和简短不要刻意突出自身学科背景不要说你说过的话可以简短多简短都可以但是避免冗长
请你注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出消息内容
不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )
请直接输出回复内容不需要任何额外格式"""
# Prompt for say_goodbye (告别语生成)
PROMPT_FAREWELL = """{persona_text}。你在参与一场 QQ 私聊,现在对话似乎已经结束,你决定再发一条最后的消息来圆满结束。
最近的聊天记录
{chat_history_text}
请根据上述信息结合聊天记录构思一条**简短自然符合你人设**的最后的消息
这条消息应该
1. 从你自己的角度发言
2. 符合你的性格特征和身份细节
3. 通俗易懂自然流畅通常很简短
4. 自然地为这场对话画上句号避免开启新话题或显得冗长刻意
请像真人一样随意自然**简洁是关键**
不要输出多余内容包括前后缀冒号引号括号表情包at或@等
请直接输出最终的告别消息内容不需要任何额外格式"""
class ReplyGenerator:
"""回复生成器"""
def __init__(self, stream_id: str, private_name: str):
self.llm = LLMRequest(
model=global_config.llm_PFC_chat,
temperature=global_config.llm_PFC_chat["temp"],
max_tokens=300,
request_type="reply_generation",
)
self.personality_info = Individuality.get_instance().get_prompt(x_person=2, level=3)
self.name = global_config.BOT_NICKNAME
self.private_name = private_name
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.reply_checker = ReplyChecker(stream_id, private_name)
# 修改 generate 方法签名,增加 action_type 参数
async def generate(
self, observation_info: ObservationInfo, conversation_info: ConversationInfo, action_type: str
) -> str:
"""生成回复
Args:
observation_info: 观察信息
conversation_info: 对话信息
action_type: 当前执行的动作类型 ('direct_reply' 'send_new_message')
Returns:
str: 生成的回复
"""
# 构建提示词
logger.debug(
f"[私聊][{self.private_name}]开始生成回复 (动作类型: {action_type}):当前目标: {conversation_info.goal_list}"
)
# --- 构建通用 Prompt 参数 ---
# (这部分逻辑基本不变)
# 构建对话目标 (goals_str)
goals_str = ""
if conversation_info.goal_list:
for goal_reason in conversation_info.goal_list:
if isinstance(goal_reason, dict):
goal = goal_reason.get("goal", "目标内容缺失")
reasoning = goal_reason.get("reasoning", "没有明确原因")
else:
goal = str(goal_reason)
reasoning = "没有明确原因"
goal = str(goal) if goal is not None else "目标内容缺失"
reasoning = str(reasoning) if reasoning is not None else "没有明确原因"
goals_str += f"- 目标:{goal}\n 原因:{reasoning}\n"
else:
goals_str = "- 目前没有明确对话目标\n" # 简化无目标情况
# --- 新增:构建知识信息字符串 ---
knowledge_info_str = "【供参考的相关知识和记忆】\n" # 稍微改下标题,表明是供参考
try:
# 检查 conversation_info 是否有 knowledge_list 并且不为空
if hasattr(conversation_info, "knowledge_list") and conversation_info.knowledge_list:
# 最多只显示最近的 5 条知识
recent_knowledge = conversation_info.knowledge_list[-5:]
for i, knowledge_item in enumerate(recent_knowledge):
if isinstance(knowledge_item, dict):
query = knowledge_item.get("query", "未知查询")
knowledge = knowledge_item.get("knowledge", "无知识内容")
source = knowledge_item.get("source", "未知来源")
# 只取知识内容的前 2000 个字
knowledge_snippet = knowledge[:2000] + "..." if len(knowledge) > 2000 else knowledge
knowledge_info_str += (
f"{i + 1}. 关于 '{query}' (来源: {source}): {knowledge_snippet}\n" # 格式微调,更简洁
)
else:
knowledge_info_str += f"{i + 1}. 发现一条格式不正确的知识记录。\n"
if not recent_knowledge:
knowledge_info_str += "- 暂无。\n" # 更简洁的提示
else:
knowledge_info_str += "- 暂无。\n"
except AttributeError:
logger.warning(f"[私聊][{self.private_name}]ConversationInfo 对象可能缺少 knowledge_list 属性。")
knowledge_info_str += "- 获取知识列表时出错。\n"
except Exception as e:
logger.error(f"[私聊][{self.private_name}]构建知识信息字符串时出错: {e}")
knowledge_info_str += "- 处理知识列表时出错。\n"
# 获取聊天历史记录 (chat_history_text)
chat_history_text = observation_info.chat_history_str
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
new_messages_list = observation_info.unprocessed_messages
new_messages_str = await build_readable_messages(
new_messages_list,
replace_bot_name=True,
merge_messages=False,
timestamp_mode="relative",
read_mark=0.0,
)
chat_history_text += f"\n--- 以下是 {observation_info.new_messages_count} 条新消息 ---\n{new_messages_str}"
elif not chat_history_text:
chat_history_text = "还没有聊天记录。"
# 构建 Persona 文本 (persona_text)
persona_text = f"你的名字是{self.name}{self.personality_info}"
# --- 选择 Prompt ---
if action_type == "send_new_message":
prompt_template = PROMPT_SEND_NEW_MESSAGE
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_SEND_NEW_MESSAGE (追问生成)")
elif action_type == "say_goodbye": # 处理告别动作
prompt_template = PROMPT_FAREWELL
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_FAREWELL (告别语生成)")
else: # 默认使用 direct_reply 的 prompt (包括 'direct_reply' 或其他未明确处理的类型)
prompt_template = PROMPT_DIRECT_REPLY
logger.info(f"[私聊][{self.private_name}]使用 PROMPT_DIRECT_REPLY (首次/非连续回复生成)")
# --- 格式化最终的 Prompt ---
prompt = prompt_template.format(
persona_text=persona_text,
goals_str=goals_str,
chat_history_text=chat_history_text,
knowledge_info_str=knowledge_info_str,
)
# --- 调用 LLM 生成 ---
logger.debug(f"[私聊][{self.private_name}]发送到LLM的生成提示词:\n------\n{prompt}\n------")
try:
content, _ = await self.llm.generate_response_async(prompt)
logger.debug(f"[私聊][{self.private_name}]生成的回复: {content}")
# 移除旧的检查新消息逻辑,这应该由 conversation 控制流处理
return content
except Exception as e:
logger.error(f"[私聊][{self.private_name}]生成回复时出错: {e}")
return "抱歉,我现在有点混乱,让我重新思考一下..."
# check_reply 方法保持不变
async def check_reply(
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_str: str, retry_count: int = 0
) -> Tuple[bool, str, bool]:
"""检查回复是否合适
(此方法逻辑保持不变)
"""
return await self.reply_checker.check(reply, goal, chat_history, chat_history_str, retry_count)

View File

@ -0,0 +1,79 @@
from src.common.logger import get_module_logger
from .chat_observer import ChatObserver
from .conversation_info import ConversationInfo
# from src.individuality.individuality import Individuality # 不再需要
from ...config.config import global_config
import time
import asyncio
logger = get_module_logger("waiter")
# --- 在这里设定你想要的超时时间(秒) ---
# 例如: 120 秒 = 2 分钟
DESIRED_TIMEOUT_SECONDS = 300
class Waiter:
"""等待处理类"""
def __init__(self, stream_id: str, private_name: str):
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
self.name = global_config.BOT_NICKNAME
self.private_name = private_name
# self.wait_accumulated_time = 0 # 不再需要累加计时
async def wait(self, conversation_info: ConversationInfo) -> bool:
"""等待用户新消息或超时"""
wait_start_time = time.time()
logger.info(f"[私聊][{self.private_name}]进入常规等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
while True:
# 检查是否有新消息
if self.chat_observer.new_message_after(wait_start_time):
logger.info(f"[私聊][{self.private_name}]等待结束,收到新消息")
return False # 返回 False 表示不是超时
# 检查是否超时
elapsed_time = time.time() - wait_start_time
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
logger.info(f"[私聊][{self.private_name}]等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
wait_goal = {
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,注意可能在对方看来聊天已经结束,思考接下来要做什么",
"reasoning": "对方很久没有回复你的消息了",
}
conversation_info.goal_list.append(wait_goal)
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
return True # 返回 True 表示超时
await asyncio.sleep(5) # 每 5 秒检查一次
logger.debug(
f"[私聊][{self.private_name}]等待中..."
) # 可以考虑把这个频繁日志注释掉,只在超时或收到消息时输出
async def wait_listening(self, conversation_info: ConversationInfo) -> bool:
"""倾听用户发言或超时"""
wait_start_time = time.time()
logger.info(f"[私聊][{self.private_name}]进入倾听等待状态 (超时: {DESIRED_TIMEOUT_SECONDS} 秒)...")
while True:
# 检查是否有新消息
if self.chat_observer.new_message_after(wait_start_time):
logger.info(f"[私聊][{self.private_name}]倾听等待结束,收到新消息")
return False # 返回 False 表示不是超时
# 检查是否超时
elapsed_time = time.time() - wait_start_time
if elapsed_time > DESIRED_TIMEOUT_SECONDS:
logger.info(f"[私聊][{self.private_name}]倾听等待超过 {DESIRED_TIMEOUT_SECONDS} 秒...添加思考目标。")
wait_goal = {
# 保持 goal 文本一致
"goal": f"你等待了{elapsed_time / 60:.1f}分钟,对方似乎话说一半突然消失了,可能忙去了?也可能忘记了回复?要问问吗?还是结束对话?或继续等待?思考接下来要做什么",
"reasoning": "对方话说一半消失了,很久没有回复",
}
conversation_info.goal_list.append(wait_goal)
logger.info(f"[私聊][{self.private_name}]添加目标: {wait_goal}")
return True # 返回 True 表示超时
await asyncio.sleep(5) # 每 5 秒检查一次
logger.debug(f"[私聊][{self.private_name}]倾听等待中...") # 同上,可以考虑注释掉

View File

@ -16,7 +16,8 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.express.expression_learner import expression_learner_manager
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
@ -96,6 +97,9 @@ class BrainChatting:
self.more_plan = False
# 最近一次是否成功进行了 reply用于选择 BrainPlanner 的 Prompt
self._last_successful_reply: bool = False
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
@ -157,6 +161,7 @@ class BrainChatting:
)
async def _loopbody(self): # sourcery skip: hoist-if-from-if
# 获取最新消息(用于上下文,但不影响是否调用 observe
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
@ -165,17 +170,25 @@ class BrainChatting:
limit_mode="latest",
filter_mai=True,
filter_command=False,
filter_no_read_command=True,
filter_intercept_message_level=1,
)
# 如果有新消息,更新 last_read_time
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
await self._observe(recent_messages_list=recent_messages_list)
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
return True
# 总是执行一次思考迭代(不管有没有新消息)
# wait 动作会在其内部等待,不需要在这里处理
should_continue = await self._observe(recent_messages_list=recent_messages_list)
if not should_continue:
# 选择了 complete_talk返回 False 表示需要等待新消息
return False
# 继续下一次迭代(除非选择了 complete_talk
# 短暂等待后再继续,避免过于频繁的循环
await asyncio.sleep(0.1)
return True
async def _send_and_store_reply(
@ -240,7 +253,7 @@ class BrainChatting:
# ReflectTracker Check
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答
# -------------------------------------------------------------------------
from src.express.reflect_tracker import reflect_tracker_manager
from src.bw_learner.reflect_tracker import reflect_tracker_manager
tracker = reflect_tracker_manager.get_tracker(self.stream_id)
if tracker:
@ -253,13 +266,15 @@ class BrainChatting:
# Expression Reflection Check
# 检查是否需要提问表达反思
# -------------------------------------------------------------------------
from src.express.expression_reflector import expression_reflector_manager
from src.bw_learner.expression_reflector import expression_reflector_manager
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
asyncio.create_task(reflector.check_and_ask())
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
@ -272,14 +287,16 @@ class BrainChatting:
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 执行planner
# 获取必要信息
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
# 一次思考迭代Think - Act - Observe
# 获取聊天上下文
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
@ -290,12 +307,11 @@ class BrainChatting:
)
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=available_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
prompt_key="brain_planner_prompt_react",
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
@ -311,7 +327,10 @@ class BrainChatting:
available_actions=available_actions,
)
# 3. 并行执行所有动作
# 检查是否有 complete_talk 动作(会停止后续迭代)
has_complete_talk = any(action.action_type == "complete_talk" for action in action_to_use_info)
# 并行执行所有动作
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
@ -343,7 +362,14 @@ class BrainChatting:
else:
logger.warning(f"{self.log_prefix} 回复动作执行失败")
# 构建最终的循环信息
# 更新观察时间标记
self.action_planner.last_obs_time_mark = time.time()
# 如果选择了 complete_talk标记为完成不再继续迭代
if has_complete_talk:
logger.info(f"{self.log_prefix} 检测到 complete_talk 动作,本次思考完成")
# 构建循环信息
if reply_loop_info:
# 如果有回复信息使用回复的loop_info作为基础
loop_info = reply_loop_info
@ -369,10 +395,16 @@ class BrainChatting:
}
_reply_text = action_reply_text
# 如果选择了 complete_talk返回 False 以停止 _loopbody 的循环
# 否则返回 True让 _loopbody 继续下一次迭代
should_continue = not has_complete_talk
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
return True
# 如果选择了 complete_talk返回 False 停止循环
# 否则返回 True继续下一次思考迭代
return should_continue
async def _main_chat_loop(self):
"""主循环,持续进行计划并可能回复消息,直到被外部取消。"""
@ -380,9 +412,13 @@ class BrainChatting:
while self.running:
# 主循环
success = await self._loopbody()
await asyncio.sleep(0.1)
if not success:
break
# 选择了 complete等待新消息
logger.info(f"{self.log_prefix} 选择了 complete等待新消息...")
await self._wait_for_new_message()
# 有新消息后继续循环
continue
await asyncio.sleep(0.1)
except asyncio.CancelledError:
# 设置了关闭标志位后被取消是正常流程
logger.info(f"{self.log_prefix} 麦麦已关闭聊天")
@ -393,6 +429,33 @@ class BrainChatting:
self._loop_task = asyncio.create_task(self._main_chat_loop())
logger.error(f"{self.log_prefix} 结束了当前聊天循环")
async def _wait_for_new_message(self):
"""等待新消息到达"""
last_check_time = self.last_read_time
check_interval = 1.0 # 每秒检查一次
while self.running:
# 检查是否有新消息
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=last_check_time,
end_time=time.time(),
limit=20,
limit_mode="latest",
filter_mai=True,
filter_command=False,
filter_intercept_message_level=1,
)
# 如果有新消息,更新 last_read_time 并返回
if len(recent_messages_list) >= 1:
self.last_read_time = time.time()
logger.info(f"{self.log_prefix} 检测到新消息,恢复循环")
return
# 等待一段时间后再次检查
await asyncio.sleep(check_interval)
async def _handle_action(
self,
action: str,
@ -506,12 +569,12 @@ class BrainChatting:
"""执行单个动作的通用函数"""
try:
with Timer(f"动作{action_planner_info.action_type}", cycle_timers):
if action_planner_info.action_type == "no_reply":
# 直接处理no_reply逻辑,不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
# logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
if action_planner_info.action_type == "complete_talk":
# 直接处理complete_talk逻辑,不再通过动作系统
reason = action_planner_info.reasoning or "选择完成对话"
logger.info(f"{self.log_prefix} 选择完成对话,原因: {reason}")
# 存储no_reply信息到数据库
# 存储complete_talk信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
@ -519,9 +582,9 @@ class BrainChatting:
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_reply",
action_name="complete_talk",
)
return {"action_type": "no_reply", "success": True, "reply_text": "", "command": ""}
return {"action_type": "complete_talk", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type == "reply":
try:
@ -543,11 +606,17 @@ class BrainChatting:
)
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
return {
"action_type": "reply",
"success": False,
"reply_text": "",
"loop_info": None,
}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
@ -558,6 +627,8 @@ class BrainChatting:
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
# 标记这次循环已经成功进行了回复
self._last_successful_reply = True
return {
"action_type": "reply",
"success": True,
@ -567,7 +638,88 @@ class BrainChatting:
# 其他动作
else:
# 执行普通动作
# 内建 wait / listening不通过插件系统直接在这里处理
if action_planner_info.action_type in ["wait", "listening"]:
reason = action_planner_info.reasoning or ""
action_data = action_planner_info.action_data or {}
if action_planner_info.action_type == "wait":
# 获取等待时间(必填)
wait_seconds = action_data.get("wait_seconds")
if wait_seconds is None:
logger.warning(f"{self.log_prefix} wait 动作缺少 wait_seconds 参数,使用默认值 5 秒")
wait_seconds = 5
else:
try:
wait_seconds = float(wait_seconds)
if wait_seconds < 0:
logger.warning(f"{self.log_prefix} wait_seconds 不能为负数,使用默认值 5 秒")
wait_seconds = 5
except (ValueError, TypeError):
logger.warning(f"{self.log_prefix} wait_seconds 参数格式错误,使用默认值 5 秒")
wait_seconds = 5
logger.info(f"{self.log_prefix} 执行 wait 动作,等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason or f"等待 {wait_seconds}",
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="wait",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} wait 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {
"action_type": "wait",
"success": True,
"reply_text": "",
"command": "",
}
# listening 已合并到 wait如果遇到则转换为 wait向后兼容
elif action_planner_info.action_type == "listening":
logger.debug(f"{self.log_prefix} 检测到 listening 动作,已合并到 wait自动转换")
# 使用默认等待时间
wait_seconds = 3
logger.info(f"{self.log_prefix} 执行 listening转换为 wait动作等待 {wait_seconds}")
# 记录动作信息
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason or f"倾听并等待 {wait_seconds}",
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason, "wait_seconds": wait_seconds},
action_name="listening",
)
# 等待指定时间
await asyncio.sleep(wait_seconds)
logger.info(f"{self.log_prefix} listening 动作完成,继续下一次思考")
# 这些动作本身不产生文本回复
self._last_successful_reply = False
return {
"action_type": "listening",
"success": True,
"reply_text": "",
"command": "",
}
# 其余动作:走原有插件 Action 体系
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
@ -577,6 +729,10 @@ class BrainChatting:
thinking_id,
action_planner_info.action_message,
)
# 非 reply 类动作执行成功时,清空最近成功回复标记,让下一轮回到 initial Prompt
if success and action_planner_info.action_type != "reply":
self._last_successful_reply = False
return {
"action_type": action_planner_info.action_type,
"success": success,

View File

@ -35,12 +35,13 @@ install(extra_lines=3)
def init_prompt():
# ReAct 形式的 Planner Prompt
Prompt(
"""
{time_block}
{name_block}
你的兴趣是{interest}
{chat_context_description}以下是具体的聊天内容
**聊天内容**
{chat_content_block}
@ -57,11 +58,35 @@ reply
"reason":"回复的原因"
}}
no_reply
wait
动作描述
等待保持沉默等待对方发言
暂时不再发言等待指定时间适用于以下情况
- 你已经表达清楚一轮想给对方留出空间
- 你感觉对方的话还没说完或者自己刚刚发了好几条连续消息
- 你想要等待一定时间来让对方把话说完或者等待对方反应
- 你想保持安静专注""而不是马上回复
请你根据上下文来判断要等待多久请你灵活判断
- 如果你们交流间隔时间很短聊的很频繁不宜等待太久
- 如果你们交流间隔时间很长聊的很少可以等待较长时间
{{
"action": "no_reply",
"action": "wait",
"target_message_id":"想要作为这次等待依据的消息id通常是对方的最新消息",
"wait_seconds": 等待的秒数必填例如5 表示等待5秒,
"reason":"选择等待的原因"
}}
complete_talk
动作描述
当前聊天暂时结束了对方离开没有更多话题了
你可以使用该动作来暂时休息等待对方有新发言再继续
- 多次wait之后对方迟迟不回复消息才用
- 如果对方只是短暂不回复应该使用wait而不是complete_talk
- 聊天内容显示当前聊天已经结束或者没有新内容时候选择complete_talk
选择此动作后将不再继续循环思考直到收到对方的新消息
{{
"action": "complete_talk",
"target_message_id":"触发完成对话的消息id通常是对方的最新消息",
"reason":"选择完成对话的原因"
}}
{action_options_text}
@ -92,7 +117,7 @@ no_reply
```
""",
"brain_planner_prompt",
"brain_planner_prompt_react",
)
Prompt(
@ -123,6 +148,9 @@ class BrainPlanner:
self.last_obs_time_mark = 0.0
# 计划日志记录
self.plan_log: List[Tuple[str, float, List[ActionPlannerInfo]]] = []
def find_message_by_id(
self, message_id: str, message_id_list: List[Tuple[str, "DatabaseMessages"]]
) -> Optional["DatabaseMessages"]:
@ -152,10 +180,11 @@ class BrainPlanner:
action_planner_infos = []
try:
action = action_json.get("action", "no_reply")
action = action_json.get("action", "complete_talk")
logger.debug(f"{self.log_prefix}解析动作JSON: action={action}, json={action_json}")
reasoning = action_json.get("reason", "未提供原因")
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 非no_reply动作需要target_message_id
# 非complete_talk动作需要target_message_id
target_message = None
if target_message_id := action_json.get("target_message_id"):
@ -171,16 +200,28 @@ class BrainPlanner:
# 验证action是否可用
available_action_names = [action_name for action_name, _ in current_available_actions]
internal_action_names = ["no_reply", "reply", "wait_time"]
# 内部保留动作(不依赖插件系统)
# 注意listening 已合并到 wait 中,如果遇到 listening 则转换为 wait
internal_action_names = ["complete_talk", "reply", "wait_time", "wait", "listening"]
logger.debug(
f"{self.log_prefix}动作验证: action={action}, internal={internal_action_names}, available={available_action_names}"
)
# 将 listening 转换为 wait向后兼容
if action == "listening":
logger.debug(f"{self.log_prefix}检测到 listening 动作,已合并到 wait自动转换")
action = "wait"
if action not in internal_action_names and action not in available_action_names:
logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {available_action_names}),将强制使用 'no_reply'"
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (内部动作: {internal_action_names}, 可用插件动作: {available_action_names}),将强制使用 'complete_talk'"
)
reasoning = (
f"LLM 返回了当前不可用的动作 '{action}' (可用: {available_action_names})。原始理由: {reasoning}"
)
action = "no_reply"
action = "complete_talk"
logger.warning(f"{self.log_prefix}动作已转换为 complete_talk")
# 创建ActionPlannerInfo对象
# 将列表转换为字典格式
@ -201,7 +242,7 @@ class BrainPlanner:
available_actions_dict = dict(current_available_actions)
action_planner_infos.append(
ActionPlannerInfo(
action_type="no_reply",
action_type="complete_talk",
reasoning=f"解析单个action时出错: {e}",
action_data={},
action_message=None,
@ -218,7 +259,7 @@ class BrainPlanner:
) -> List[ActionPlannerInfo]:
# sourcery skip: use-named-expression
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作
规划器 (Planner): 使用LLM根据上下文决定做出什么动作ReAct模式
"""
# 获取聊天上下文
@ -226,7 +267,7 @@ class BrainPlanner:
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
chat_content_block, message_id_list = build_readable_messages_with_id(
@ -257,18 +298,19 @@ class BrainPlanner:
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
# 构建包含所有动作的提示词
# 构建包含所有动作的提示词:使用统一的 ReAct Prompt
prompt_key = "brain_planner_prompt_react"
# 这里不记录日志,避免重复打印,由调用方按需控制 log_prompt
prompt, message_id_list = await self.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=filtered_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
prompt_key=prompt_key,
)
# 调用LLM获取决策
actions = await self._execute_main_planner(
reasoning, actions = await self._execute_main_planner(
prompt=prompt,
message_id_list=message_id_list,
filtered_actions=filtered_actions,
@ -276,16 +318,22 @@ class BrainPlanner:
loop_start_time=loop_start_time,
)
# 记录和展示计划日志
logger.info(
f"{self.log_prefix}Planner: {reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
self.add_plan_log(reasoning, actions)
return actions
async def build_planner_prompt(
self,
is_group_chat: bool,
chat_target_info: Optional["TargetPersonInfo"],
current_available_actions: Dict[str, ActionInfo],
message_id_list: List[Tuple[str, "DatabaseMessages"]],
chat_content_block: str = "",
interest: str = "",
prompt_key: str = "brain_planner_prompt_react",
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
@ -321,7 +369,7 @@ class BrainPlanner:
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
# 获取主规划器模板并填充
planner_prompt_template = await global_prompt_manager.get_prompt_async("brain_planner_prompt")
planner_prompt_template = await global_prompt_manager.get_prompt_async(prompt_key)
prompt = planner_prompt_template.format(
time_block=time_block,
chat_context_description=chat_context_description,
@ -431,17 +479,18 @@ class BrainPlanner:
filtered_actions: Dict[str, ActionInfo],
available_actions: Dict[str, ActionInfo],
loop_start_time: float,
) -> List[ActionPlannerInfo]:
) -> Tuple[str, List[ActionPlannerInfo]]:
"""执行主规划器"""
llm_content = None
actions: List[ActionPlannerInfo] = []
extracted_reasoning = ""
try:
# 调用LLM
llm_content, (reasoning_content, _, _) = await self.planner_llm.generate_response_async(prompt=prompt)
# logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
# logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
logger.info(f"{self.log_prefix}规划器原始响应: {llm_content}")
if global_config.debug.show_planner_prompt:
logger.info(f"{self.log_prefix}规划器原始提示词: {prompt}")
@ -456,10 +505,11 @@ class BrainPlanner:
except Exception as req_e:
logger.error(f"{self.log_prefix}LLM 请求执行失败: {req_e}")
return [
extracted_reasoning = f"LLM 请求失败,模型出现问题: {req_e}"
return extracted_reasoning, [
ActionPlannerInfo(
action_type="no_reply",
reasoning=f"LLM 请求失败,模型出现问题: {req_e}",
action_type="complete_talk",
reasoning=extracted_reasoning,
action_data={},
action_message=None,
available_actions=available_actions,
@ -469,24 +519,32 @@ class BrainPlanner:
# 解析LLM响应
if llm_content:
try:
if json_objects := self._extract_json_from_markdown(llm_content):
logger.debug(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
json_objects, extracted_reasoning = self._extract_json_from_markdown(llm_content)
if json_objects:
logger.info(f"{self.log_prefix}从响应中提取到{len(json_objects)}个JSON对象")
for i, json_obj in enumerate(json_objects):
logger.info(f"{self.log_prefix}解析第{i + 1}个JSON对象: {json_obj}")
filtered_actions_list = list(filtered_actions.items())
for json_obj in json_objects:
actions.extend(self._parse_single_action(json_obj, message_id_list, filtered_actions_list))
parsed_actions = self._parse_single_action(json_obj, message_id_list, filtered_actions_list)
logger.info(f"{self.log_prefix}解析后的动作: {[a.action_type for a in parsed_actions]}")
actions.extend(parsed_actions)
else:
# 尝试解析为直接的JSON
logger.warning(f"{self.log_prefix}LLM没有返回可用动作: {llm_content}")
actions = self._create_no_reply("LLM没有返回可用动作", available_actions)
extracted_reasoning = extracted_reasoning or "LLM没有返回可用动作"
actions = self._create_complete_talk(extracted_reasoning, available_actions)
except Exception as json_e:
logger.warning(f"{self.log_prefix}解析LLM响应JSON失败 {json_e}. LLM原始输出: '{llm_content}'")
actions = self._create_no_reply(f"解析LLM响应JSON失败: {json_e}", available_actions)
extracted_reasoning = f"解析LLM响应JSON失败: {json_e}"
actions = self._create_complete_talk(extracted_reasoning, available_actions)
traceback.print_exc()
else:
actions = self._create_no_reply("规划器没有获得LLM响应", available_actions)
extracted_reasoning = "规划器没有获得LLM响应"
actions = self._create_complete_talk(extracted_reasoning, available_actions)
# 添加循环开始时间到所有非no_reply动作
# 添加循环开始时间到所有动作
for action in actions:
action.action_data = action.action_data or {}
action.action_data["loop_start_time"] = loop_start_time
@ -495,13 +553,15 @@ class BrainPlanner:
f"{self.log_prefix}规划器决定执行{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
return actions
return extracted_reasoning, actions
def _create_no_reply(self, reasoning: str, available_actions: Dict[str, ActionInfo]) -> List[ActionPlannerInfo]:
"""创建no_reply"""
def _create_complete_talk(
self, reasoning: str, available_actions: Dict[str, ActionInfo]
) -> List[ActionPlannerInfo]:
"""创建complete_talk"""
return [
ActionPlannerInfo(
action_type="no_reply",
action_type="complete_talk",
reasoning=reasoning,
action_data={},
action_message=None,
@ -509,33 +569,122 @@ class BrainPlanner:
)
]
def _extract_json_from_markdown(self, content: str) -> List[dict]:
def add_plan_log(self, reasoning: str, actions: List[ActionPlannerInfo]):
"""添加计划日志"""
self.plan_log.append((reasoning, time.time(), actions))
if len(self.plan_log) > 20:
self.plan_log.pop(0)
def _extract_json_from_markdown(self, content: str) -> Tuple[List[dict], str]:
# sourcery skip: for-append-to-extend
"""从Markdown格式的内容中提取JSON对象"""
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
reasoning_content = ""
# 使用正则表达式查找```json包裹的JSON内容
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, content, re.DOTALL)
markdown_matches = re.findall(json_pattern, content, re.DOTALL)
for match in matches:
# 提取JSON之前的内容作为推理文本
first_json_pos = len(content)
if markdown_matches:
# 找到第一个```json的位置
first_json_pos = content.find("```json")
if first_json_pos > 0:
reasoning_content = content[:first_json_pos].strip()
# 清理推理内容中的注释标记
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip()
# 处理```json包裹的JSON
for match in markdown_matches:
try:
# 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", match) # 移除单行注释
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) # 移除多行注释
if json_str := json_str.strip():
json_obj = json.loads(repair_json(json_str))
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
# 先尝试将整个块作为一个JSON对象或数组适用于多行JSON
try:
json_obj = json.loads(repair_json(json_str))
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except json.JSONDecodeError:
# 如果整个块解析失败尝试按行分割适用于多个单行JSON对象
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines:
try:
# 尝试解析每一行作为独立的JSON对象
json_obj = json.loads(repair_json(line))
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except json.JSONDecodeError:
# 单行解析失败,继续下一行
continue
except Exception as e:
logger.warning(f"解析JSON块失败: {e}, 块内容: {match[:100]}...")
logger.warning(f"{self.log_prefix}解析JSON块失败: {e}, 块内容: {match[:100]}...")
continue
return json_objects
# 如果没有找到完整的```json```块,尝试查找不完整的代码块(缺少结尾```
if not json_objects:
json_start_pos = content.find("```json")
if json_start_pos != -1:
# 找到```json之后的内容
json_content_start = json_start_pos + 7 # ```json的长度
# 提取从```json之后到内容结尾的所有内容
incomplete_json_str = content[json_content_start:].strip()
# 提取JSON之前的内容作为推理文本
if json_start_pos > 0:
reasoning_content = content[:json_start_pos].strip()
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip()
if incomplete_json_str:
try:
# 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
json_str = json_str.strip()
if json_str:
# 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split("\n") if line.strip()]
for line in lines:
try:
json_obj = json.loads(repair_json(line))
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except json.JSONDecodeError:
pass
# 如果按行解析没有成功尝试将整个块作为一个JSON对象或数组
if not json_objects:
try:
json_obj = json.loads(repair_json(json_str))
if isinstance(json_obj, dict):
json_objects.append(json_obj)
elif isinstance(json_obj, list):
for item in json_obj:
if isinstance(item, dict):
json_objects.append(item)
except Exception as e:
logger.debug(f"尝试解析不完整的JSON代码块失败: {e}")
except Exception as e:
logger.debug(f"处理不完整的JSON代码块时出错: {e}")
return json_objects, reasoning_content
init_prompt()

View File

@ -271,7 +271,7 @@ def _to_emoji_objects(data: Any) -> Tuple[List["MaiEmoji"], int]:
emoji.description = emoji_data.description
# Deserialize emotion string from DB to list
emoji.emotion = emoji_data.emotion.split(",") if emoji_data.emotion else []
emoji.emotion = emoji_data.emotion.replace("", ",").split(",") if emoji_data.emotion else []
emoji.usage_count = emoji_data.usage_count
db_last_used_time = emoji_data.last_used_time
@ -732,7 +732,7 @@ class EmojiManager:
emoji_record = Emoji.get_or_none(Emoji.emoji_hash == emoji_hash)
if emoji_record and emoji_record.emotion:
logger.info(f"[缓存命中] 从数据库获取表情包情感标签: {emoji_record.emotion[:50]}...")
return emoji_record.emotion.split(",")
return emoji_record.emotion.replace("", ",").split(",")
except Exception as e:
logger.error(f"从数据库查询表情包情感标签时出错: {e}")
@ -993,7 +993,7 @@ class EmojiManager:
)
# 处理情感列表
emotions = [e.strip() for e in emotions_text.split(",") if e.strip()]
emotions = [e.strip() for e in emotions_text.replace("", ",").split(",") if e.strip()]
# 根据情感标签数量随机选择 - 超过5个选3个超过2个选2个
if len(emotions) > 5:

View File

@ -1,163 +0,0 @@
from datetime import datetime
import time
import asyncio
from typing import Dict
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat,
build_readable_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.plugin_system.apis import frequency_api
def init_prompt():
Prompt(
"""{name_block}
{time_block}
你现在正在聊天请根据下面的聊天记录判断是否有用户觉得你的发言过于频繁或者发言过少
{message_str}
如果用户觉得你的发言过于频繁请输出"过于频繁"否则输出"正常"
如果用户觉得你的发言过少请输出"过少"否则输出"正常"
**你只能输出以下三个词之一不要输出任何其他文字解释或标点**
- 正常
- 过于频繁
- 过少
""",
"frequency_adjust_prompt",
)
logger = get_logger("frequency_control")
class FrequencyControl:
"""简化的频率控制类仅管理不同chat_id的频率值"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
# 发言频率调整值
self.talk_frequency_adjust: float = 1.0
self.last_frequency_adjust_time: float = 0.0
self.frequency_model = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="frequency.adjust"
)
# 频率调整锁,防止并发执行
self._adjust_lock = asyncio.Lock()
def get_talk_frequency_adjust(self) -> float:
"""获取发言频率调整值"""
return self.talk_frequency_adjust
def set_talk_frequency_adjust(self, value: float) -> None:
"""设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value))
async def trigger_frequency_adjust(self) -> None:
# 使用异步锁防止并发执行
async with self._adjust_lock:
# 在锁内检查,避免并发触发
current_time = time.time()
previous_adjust_time = self.last_frequency_adjust_time
msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=previous_adjust_time,
timestamp_end=current_time,
)
if current_time - previous_adjust_time < 160 or len(msg_list) <= 20:
return
# 立即更新调整时间,防止并发触发
self.last_frequency_adjust_time = current_time
try:
new_msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=previous_adjust_time,
timestamp_end=current_time,
limit=20,
limit_mode="latest",
)
message_str = build_readable_messages(
new_msg_list,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
show_actions=False,
)
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
bot_name = global_config.bot.nickname
bot_nickname = (
f",也有人叫你{','.join(global_config.bot.alias_names)}" if global_config.bot.alias_names else ""
)
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
prompt = await global_prompt_manager.format_prompt(
"frequency_adjust_prompt",
name_block=name_block,
time_block=time_block,
message_str=message_str,
)
response, (reasoning_content, _, _) = await self.frequency_model.generate_response_async(
prompt,
)
# logger.info(f"频率调整 prompt: {prompt}")
# logger.info(f"频率调整 response: {response}")
if global_config.debug.show_prompt:
logger.info(f"频率调整 prompt: {prompt}")
logger.info(f"频率调整 response: {response}")
logger.info(f"频率调整 reasoning_content: {reasoning_content}")
final_value_by_api = frequency_api.get_current_talk_value(self.chat_id)
# LLM依然输出过多内容时取消本次调整。合法最多4个字但有的模型可能会输出一些markdown换行符等需要长度宽限
if len(response) < 20:
if "过于频繁" in response:
logger.info(f"频率调整: 过于频繁,调整值到{final_value_by_api}")
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 0.8))
elif "过少" in response:
logger.info(f"频率调整: 过少,调整值到{final_value_by_api}")
self.talk_frequency_adjust = max(0.1, min(1.5, self.talk_frequency_adjust * 1.2))
except Exception as e:
logger.error(f"频率调整失败: {e}")
# 即使失败也保持时间戳更新,避免频繁重试
class FrequencyControlManager:
"""频率控制管理器,管理多个聊天流的频率控制实例"""
def __init__(self):
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
"""获取或创建指定聊天流的频率控制实例"""
if chat_id not in self.frequency_control_dict:
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
return self.frequency_control_dict[chat_id]
def remove_frequency_control(self, chat_id: str) -> bool:
"""移除指定聊天流的频率控制实例"""
if chat_id in self.frequency_control_dict:
del self.frequency_control_dict[chat_id]
return True
return False
def get_all_chat_ids(self) -> list[str]:
"""获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys())
init_prompt()
# 创建全局实例
frequency_control_manager = FrequencyControlManager()

View File

@ -0,0 +1,50 @@
from typing import Dict
from src.common.logger import get_logger
logger = get_logger("frequency_control")
class FrequencyControl:
"""简化的频率控制类仅管理不同chat_id的频率值"""
def __init__(self, chat_id: str):
self.chat_id = chat_id
# 发言频率调整值
self.talk_frequency_adjust: float = 1.0
def get_talk_frequency_adjust(self) -> float:
"""获取发言频率调整值"""
return self.talk_frequency_adjust
def set_talk_frequency_adjust(self, value: float) -> None:
"""设置发言频率调整值"""
self.talk_frequency_adjust = max(0.1, min(5.0, value))
class FrequencyControlManager:
"""频率控制管理器,管理多个聊天流的频率控制实例"""
def __init__(self):
self.frequency_control_dict: Dict[str, FrequencyControl] = {}
def get_or_create_frequency_control(self, chat_id: str) -> FrequencyControl:
"""获取或创建指定聊天流的频率控制实例"""
if chat_id not in self.frequency_control_dict:
self.frequency_control_dict[chat_id] = FrequencyControl(chat_id)
return self.frequency_control_dict[chat_id]
def remove_frequency_control(self, chat_id: str) -> bool:
"""移除指定聊天流的频率控制实例"""
if chat_id in self.frequency_control_dict:
del self.frequency_control_dict[chat_id]
return True
return False
def get_all_chat_ids(self) -> list[str]:
"""获取所有有频率控制的聊天ID"""
return list(self.frequency_control_dict.keys())
# 创建全局实例
frequency_control_manager = FrequencyControlManager()

View File

@ -16,11 +16,11 @@ from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail
from src.express.expression_learner import expression_learner_manager
from src.chat.frequency_control.frequency_control import frequency_control_manager
from src.express.reflect_tracker import reflect_tracker_manager
from src.express.expression_reflector import expression_reflector_manager
from src.jargon import extract_and_store_jargon
from src.bw_learner.expression_learner import expression_learner_manager
from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.bw_learner.reflect_tracker import reflect_tracker_manager
from src.bw_learner.expression_reflector import expression_reflector_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager
@ -29,6 +29,7 @@ from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
)
from src.chat.utils.utils import record_replyer_action_temp
from src.hippo_memorizer.chat_history_summarizer import ChatHistorySummarizer
if TYPE_CHECKING:
@ -99,7 +100,6 @@ class HeartFChatting:
self._current_cycle_detail: CycleDetail = None # type: ignore
self.last_read_time = time.time() - 2
self.no_reply_until_call = False
self.is_mute = False
@ -190,7 +190,7 @@ class HeartFChatting:
limit_mode="latest",
filter_mai=True,
filter_command=False,
filter_no_read_command=True,
filter_intercept_message_level=0,
)
# 根据连续 no_reply 次数动态调整阈值
@ -207,23 +207,6 @@ class HeartFChatting:
if len(recent_messages_list) >= threshold:
# for message in recent_messages_list:
# print(message.processed_plain_text)
# !处理no_reply_until_call逻辑
if self.no_reply_until_call:
for message in recent_messages_list:
if (
message.is_mentioned
or message.is_at
or len(recent_messages_list) >= 8
or time.time() - self.last_read_time > 600
):
self.no_reply_until_call = False
self.last_read_time = time.time()
break
# 没有提到,继续保持沉默
if self.no_reply_until_call:
# logger.info(f"{self.log_prefix} 没有提到,继续保持沉默")
await asyncio.sleep(1)
return True
self.last_read_time = time.time()
@ -303,90 +286,6 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers
async def _run_planner_without_reply(
self,
available_actions: Dict[str, ActionInfo],
cycle_timers: Dict[str, float],
) -> List[ActionPlannerInfo]:
"""执行planner但不包含reply动作用于并行执行场景提及时使用简化版提示词"""
try:
with Timer("规划器", cycle_timers):
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
is_mentioned=True, # 标记为提及时,使用简化版提示词
)
# 过滤掉reply动作虽然提及时不应该有reply但为了安全还是过滤一下
return [action for action in action_to_use_info if action.action_type != "reply"]
except Exception as e:
logger.error(f"{self.log_prefix} Planner执行失败: {e}")
traceback.print_exc()
return []
async def _generate_mentioned_reply(
self,
force_reply_message: "DatabaseMessages",
thinking_id: str,
cycle_timers: Dict[str, float],
available_actions: Dict[str, ActionInfo],
) -> Dict[str, Any]:
"""当被提及时,独立生成回复的任务"""
try:
self.questioned = False
# 重置连续 no_reply 计数
self.consecutive_no_reply_count = 0
reason = ""
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="reply",
action_reasoning=reason,
)
with Timer("提及回复生成", cycle_timers):
success, llm_response = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=force_reply_message,
available_actions=available_actions,
chosen_actions=[], # 独立回复不依赖planner的动作
reply_reason=reason,
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
reply_time_point=self.last_read_time,
)
if not success or not llm_response or not llm_response.reply_set:
logger.warning(f"{self.log_prefix} 提及回复生成失败")
return {"action_type": "reply", "success": False, "result": "提及回复生成失败", "loop_info": None}
response_set = llm_response.reply_set
selected_expressions = llm_response.selected_expressions
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=force_reply_message,
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=[], # 独立回复不依赖planner的动作
selected_expressions=selected_expressions,
)
self.last_active_time = time.time()
return {
"action_type": "reply",
"success": True,
"result": f"你回复内容{reply_text}",
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 提及回复生成异常: {e}")
traceback.print_exc()
return {"action_type": "reply", "success": False, "result": f"提及回复生成异常: {e}", "loop_info": None}
async def _observe(
self, # interest_value: float = 0.0,
recent_messages_list: Optional[List["DatabaseMessages"]] = None,
@ -412,15 +311,12 @@ class HeartFChatting:
start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
asyncio.create_task(
frequency_control_manager.get_or_create_frequency_control(self.stream_id).trigger_frequency_adjust()
)
# 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
# 添加curious检测任务 - 检测聊天记录中的矛盾、冲突或需要提问的内容
# asyncio.create_task(check_and_make_question(self.stream_id))
# 添加jargon提取任务 - 提取聊天中的黑话/俚语并入库(内部自行取消息并带冷却)
asyncio.create_task(extract_and_store_jargon(self.stream_id))
# 添加聊天内容概括任务 - 累积、打包和压缩聊天记录
# 注意后台循环已在start()中启动,这里作为额外触发点,在有思考时立即处理
# asyncio.create_task(self.chat_history_summarizer.process())
@ -438,95 +334,50 @@ class HeartFChatting:
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 如果被提及让回复生成和planner并行执行
if force_reply_message:
logger.info(f"{self.log_prefix} 检测到提及回复生成与planner并行执行")
# 执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
# 并行执行planner和回复生成
planner_task = asyncio.create_task(
self._run_planner_without_reply(
available_actions=available_actions,
cycle_timers=cycle_timers,
)
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_intercept_message_level=1,
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=available_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
)
if not continue_flag:
return False
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
force_reply_message=force_reply_message,
)
reply_task = asyncio.create_task(
self._generate_mentioned_reply(
force_reply_message=force_reply_message,
thinking_id=thinking_id,
cycle_timers=cycle_timers,
available_actions=available_actions,
)
)
# 等待两个任务完成
planner_result, reply_result = await asyncio.gather(planner_task, reply_task, return_exceptions=True)
# 处理planner结果
if isinstance(planner_result, BaseException):
logger.error(f"{self.log_prefix} Planner执行异常: {planner_result}")
action_to_use_info = []
else:
action_to_use_info = planner_result
# 处理回复结果
if isinstance(reply_result, BaseException):
logger.error(f"{self.log_prefix} 回复生成异常: {reply_result}")
reply_result = {
"action_type": "reply",
"success": False,
"result": "回复生成异常",
"loop_info": None,
}
else:
# 正常流程只执行planner
is_group_chat, chat_target_info, _ = self.action_planner.get_necessary_info()
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
chat_id=self.stream_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_no_read_command=True,
)
chat_content_block, message_id_list = build_readable_messages_with_id(
messages=message_list_before_now,
timestamp_mode="normal_no_YMD",
read_mark=self.action_planner.last_obs_time_mark,
truncate=True,
show_actions=True,
)
prompt_info = await self.action_planner.build_planner_prompt(
is_group_chat=is_group_chat,
chat_target_info=chat_target_info,
current_available_actions=available_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
)
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
)
if not continue_flag:
return False
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
action_to_use_info = await self.action_planner.plan(
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
reply_result = None
# 只在提及情况下过滤掉planner返回的reply动作提及时已有独立回复生成
if force_reply_message:
action_to_use_info = [action for action in action_to_use_info if action.action_type != "reply"]
logger.info(
f"{self.log_prefix} 决定执行{len(action_to_use_info)}个动作: {' '.join([a.action_type for a in action_to_use_info])}"
)
# 3. 并行执行所有动作不包括replyreply已经独立执行
# 3. 并行执行所有动作
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
@ -537,10 +388,6 @@ class HeartFChatting:
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
# 如果有独立的回复结果,添加到结果列表中
if reply_result:
results = list(results) + [reply_result]
# 处理执行结果
reply_loop_info = None
reply_text_from_reply = ""
@ -751,31 +598,6 @@ class HeartFChatting:
return {"action_type": "no_reply", "success": True, "result": "选择不回复", "command": ""}
elif action_planner_info.action_type == "no_reply_until_call":
# 直接当场执行no_reply_until_call逻辑
logger.info(f"{self.log_prefix} 保持沉默,直到有人直接叫的名字")
reason = action_planner_info.reasoning or "选择不回复"
# 增加连续 no_reply 计数
self.consecutive_no_reply_count += 1
self.no_reply_until_call = True
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={},
action_name="no_reply_until_call",
action_reasoning=reason,
)
return {
"action_type": "no_reply_until_call",
"success": True,
"result": "保持沉默,直到有人直接叫的名字",
"command": "",
}
elif action_planner_info.action_type == "reply":
# 直接当场执行reply逻辑
self.questioned = False
@ -784,8 +606,27 @@ class HeartFChatting:
self.consecutive_no_reply_count = 0
reason = action_planner_info.reasoning or ""
# 根据 think_mode 配置决定 think_level 的值
think_mode = global_config.chat.think_mode
if think_mode == "default":
think_level = 0
elif think_mode == "deep":
think_level = 1
elif think_mode == "dynamic":
# dynamic 模式:从 planner 返回的 action_data 中获取
think_level = action_planner_info.action_data.get("think_level", 1)
else:
# 默认使用 default 模式
think_level = 0
# 使用 action_reasoningplanner 的整体思考理由)作为 reply_reason
planner_reasoning = action_planner_info.action_reasoning or reason
record_replyer_action_temp(
chat_id=self.stream_id,
reason=reason,
think_level=think_level,
)
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
@ -807,6 +648,7 @@ class HeartFChatting:
request_type="replyer",
from_plugin=False,
reply_time_point=action_planner_info.action_data.get("loop_start_time", time.time()),
think_level=think_level,
)
if not success or not llm_response or not llm_response.reply_set:

View File

@ -7,7 +7,6 @@ from maim_message import UserInfo, Seg, GroupInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv
from src.chat.message_receive.storage import MessageStorage
@ -73,7 +72,6 @@ class ChatBot:
def __init__(self):
self.bot = None # bot 实例引用
self._started = False
self.mood_manager = mood_manager # 获取情绪管理器单例
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
async def _ensure_started(self):
@ -83,7 +81,7 @@ class ChatBot:
self._started = True
async def _process_commands_with_new_system(self, message: MessageRecv):
async def _process_commands(self, message: MessageRecv):
# sourcery skip: use-named-expression
"""使用新插件系统处理命令"""
try:
@ -115,17 +113,21 @@ class ChatBot:
try:
# 执行命令
success, response, intercept_message = await command_instance.execute()
message.is_no_read_command = bool(intercept_message)
success, response, intercept_message_level = await command_instance.execute()
message.intercept_message_level = intercept_message_level
# 记录命令执行结果
if success:
logger.info(f"命令执行成功: {command_class.__name__} (拦截: {intercept_message})")
logger.info(f"命令执行成功: {command_class.__name__} (拦截等级: {intercept_message_level})")
else:
logger.warning(f"命令执行失败: {command_class.__name__} - {response}")
# 根据命令的拦截设置决定是否继续处理消息
return True, response, not intercept_message # 找到命令根据intercept_message决定是否继续
return (
True,
response,
not bool(intercept_message_level),
) # 找到命令根据intercept_message决定是否继续
except Exception as e:
logger.error(f"执行命令时出错: {command_class.__name__} - {e}")
@ -295,7 +297,7 @@ class ChatBot:
# return
# 命令处理 - 使用新插件系统检查并处理命令
is_command, cmd_result, continue_process = await self._process_commands_with_new_system(message)
is_command, cmd_result, continue_process = await self._process_commands(message)
# 如果是命令且不需要继续处理,则直接返回
if is_command and not continue_process:

View File

@ -122,7 +122,7 @@ class MessageRecv(Message):
self.is_notify = False
self.is_command = False
self.is_no_read_command = False
self.intercept_message_level = 0
self.priority_mode = "interest"
self.priority_info = None
@ -213,6 +213,68 @@ class MessageRecv(Message):
}
"""
return ""
elif segment.type == "video_card":
# 处理视频卡片消息
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
file_name = segment.data.get("file", "未知视频")
file_size = segment.data.get("file_size", "")
url = segment.data.get("url", "")
text = f"[视频: {file_name}"
if file_size:
text += f", 大小: {file_size}字节"
text += "]"
if url:
text += f" 链接: {url}"
return text
return "[视频]"
elif segment.type == "music_card":
# 处理音乐卡片消息
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
title = segment.data.get("title", "未知歌曲")
singer = segment.data.get("singer", "")
tag = segment.data.get("tag", "") # 音乐来源,如"网易云音乐"
jump_url = segment.data.get("jump_url", "")
music_url = segment.data.get("music_url", "")
text = f"[音乐: {title}"
if singer:
text += f" - {singer}"
if tag:
text += f" ({tag})"
text += "]"
if jump_url:
text += f" 跳转链接: {jump_url}"
if music_url:
text += f" 音乐链接: {music_url}"
return text
return "[音乐]"
elif segment.type == "miniapp_card":
# 处理小程序分享卡片如B站视频分享
self.is_picid = False
self.is_emoji = False
self.is_voice = False
if isinstance(segment.data, dict):
title = segment.data.get("title", "") # 小程序名称
desc = segment.data.get("desc", "") # 内容描述
source_url = segment.data.get("source_url", "") # 原始链接
url = segment.data.get("url", "") # 小程序链接
text = "[小程序分享"
if title:
text += f" - {title}"
text += "]"
if desc:
text += f" {desc}"
if source_url:
text += f" 链接: {source_url}"
elif url:
text += f" 链接: {url}"
return text
return "[小程序分享]"
else:
return ""
except Exception as e:

View File

@ -72,7 +72,7 @@ class MessageStorage:
key_words = ""
key_words_lite = ""
selected_expressions = message.selected_expressions
is_no_read_command = False
intercept_message_level = 0
else:
filtered_display_message = ""
interest_value = message.interest_value
@ -86,7 +86,7 @@ class MessageStorage:
is_picid = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
is_no_read_command = getattr(message, "is_no_read_command", False)
intercept_message_level = getattr(message, "intercept_message_level", 0)
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
@ -138,7 +138,7 @@ class MessageStorage:
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
is_no_read_command=is_no_read_command,
intercept_message_level=intercept_message_level,
key_words=key_words,
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,

View File

@ -40,6 +40,93 @@ def is_webui_virtual_group(group_id: str) -> bool:
return group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX)
def parse_message_segments(segment) -> list:
"""解析消息段,转换为 WebUI 可用的格式
参考 NapCat 适配器的消息解析逻辑
Args:
segment: Seg 消息段对象
Returns:
list: 消息段列表每个元素为 {"type": "...", "data": ...}
"""
result = []
if segment is None:
return result
if segment.type == "seglist":
# 处理消息段列表
if segment.data:
for seg in segment.data:
result.extend(parse_message_segments(seg))
elif segment.type == "text":
# 文本消息
if segment.data:
result.append({"type": "text", "data": segment.data})
elif segment.type == "image":
# 图片消息base64
if segment.data:
result.append({"type": "image", "data": f"data:image/png;base64,{segment.data}"})
elif segment.type == "emoji":
# 表情包消息base64
if segment.data:
result.append({"type": "emoji", "data": f"data:image/gif;base64,{segment.data}"})
elif segment.type == "imageurl":
# 图片链接消息
if segment.data:
result.append({"type": "image", "data": segment.data})
elif segment.type == "face":
# 原生表情
result.append({"type": "face", "data": segment.data})
elif segment.type == "voice":
# 语音消息base64
if segment.data:
result.append({"type": "voice", "data": f"data:audio/wav;base64,{segment.data}"})
elif segment.type == "voiceurl":
# 语音链接
if segment.data:
result.append({"type": "voice", "data": segment.data})
elif segment.type == "video":
# 视频消息base64
if segment.data:
result.append({"type": "video", "data": f"data:video/mp4;base64,{segment.data}"})
elif segment.type == "videourl":
# 视频链接
if segment.data:
result.append({"type": "video", "data": segment.data})
elif segment.type == "music":
# 音乐消息
result.append({"type": "music", "data": segment.data})
elif segment.type == "file":
# 文件消息
result.append({"type": "file", "data": segment.data})
elif segment.type == "reply":
# 回复消息
result.append({"type": "reply", "data": segment.data})
elif segment.type == "forward":
# 转发消息
forward_items = []
if segment.data:
for item in segment.data:
forward_items.append(
{
"content": parse_message_segments(item.get("message_segment", {}))
if isinstance(item, dict)
else []
}
)
result.append({"type": "forward", "data": forward_items})
else:
# 未知类型,尝试作为文本处理
if segment.data:
result.append({"type": "unknown", "original_type": segment.type, "data": str(segment.data)})
return result
async def _send_message(message: MessageSending, show_log=True) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=200)
@ -50,17 +137,31 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# 检查是否是 WebUI 平台的消息,或者是 WebUI 虚拟群的消息
chat_manager, webui_platform = get_webui_chat_broadcaster()
is_webui_message = (platform == webui_platform) or is_webui_virtual_group(group_id)
if is_webui_message and chat_manager is not None:
# WebUI 聊天室消息(包括虚拟身份模式),通过 WebSocket 广播
import time
from src.config.config import global_config
# 解析消息段,获取富文本内容
message_segments = parse_message_segments(message.message_segment)
# 判断消息类型
# 如果只有一个文本段,使用简单的 text 类型
# 否则使用 rich 类型,包含完整的消息段
if len(message_segments) == 1 and message_segments[0].get("type") == "text":
message_type = "text"
segments = None
else:
message_type = "rich"
segments = message_segments
await chat_manager.broadcast(
{
"type": "bot_message",
"content": message.processed_plain_text,
"message_type": "text",
"message_type": message_type,
"segments": segments, # 富文本消息段
"timestamp": time.time(),
"group_id": group_id, # 包含群 ID 以便前端区分不同的聊天标签
"sender": {
@ -81,11 +182,70 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
return True
# 直接调用API发送消息
await get_global_api().send_message(message)
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
return True
# Fallback 逻辑: 尝试通过 API Server 发送
async def send_with_new_api(legacy_exception=None):
try:
from src.config.config import global_config
# 如果未开启 API Server直接跳过 Fallback
if not global_config.maim_message.enable_api_server:
if legacy_exception:
raise legacy_exception
return False
global_api = get_global_api()
extra_server = getattr(global_api, "extra_server", None)
if extra_server and extra_server.is_running():
# Fallback: 使用极其简单的 Platform -> API Key 映射
# 只有收到过该平台的消息,我们才知道该平台的 API Key才能回传消息
platform_map = getattr(global_api, "platform_map", {})
target_api_key = platform_map.get(platform)
if target_api_key:
# 构造 APIMessageBase
from maim_message.message import APIMessageBase, MessageDim
msg_dim = MessageDim(api_key=target_api_key, platform=platform)
api_message = APIMessageBase(
message_info=message.message_info,
message_segment=message.message_segment,
message_dim=msg_dim,
)
# 直接调用 Server 的 send_message 接口,它会自动处理路由
results = await extra_server.send_message(api_message)
# 检查是否有任何连接发送成功
if any(results.values()):
if show_log:
logger.info(
f"已通过API Server Fallback将消息 '{message_preview}' 发往平台'{platform}' (key: {target_api_key})"
)
return True
except Exception:
pass
# 如果 Fallback 失败,且存在 legacy 异常,则抛出 legacy 异常
if legacy_exception:
raise legacy_exception
return False
try:
send_result = await get_global_api().send_message(message)
# if send_result:
if show_log:
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
return True
# Legacy API 返回 False (发送失败但未报错),尝试 Fallback
# return await send_with_new_api()
except Exception as legacy_e:
# Legacy API 抛出异常,尝试 Fallback
# 如果 Fallback 也失败,将重新抛出 legacy_e
return await send_with_new_api(legacy_exception=legacy_e)
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")

View File

@ -69,7 +69,7 @@ class ActionModifier:
chat_id=self.chat_stream.stream_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 10),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
chat_content = build_readable_messages(

View File

@ -15,12 +15,15 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.planner_actions.action_manager import ActionManager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.plugin_system.base.component_types import ActionInfo, ComponentType, ActionActivationType
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.person_info.person_info import Person
if TYPE_CHECKING:
from src.common.data_models.info_data_model import TargetPersonInfo
@ -36,7 +39,6 @@ def init_prompt():
"""
{time_block}
{name_block}
你的兴趣是{interest}
{chat_context_description}以下是具体的聊天内容
**聊天内容**
{chat_content_block}
@ -46,9 +48,10 @@ reply
动作描述
1.你可以选择呼叫了你的名字但是你没有做出回应的消息进行回复
2.你可以自然的顺着正在进行的聊天内容进行回复或自然的提出一个问题
3.不要回复你自己发送的消息
4.不要单独对表情包进行回复
{{"action":"reply", "target_message_id":"消息id(m+数字)", "reason":"原因"}}
3.最好一次对一个话题进行回复免得啰嗦或者回复内容太乱
4.不要选择回复你自己发送的消息
5.不要单独对表情包进行回复
{reply_action_example}
no_reply
动作描述
@ -56,75 +59,37 @@ no_reply
控制聊天频率不要太过频繁的发言
{{"action":"no_reply"}}
{no_reply_until_call_block}
{action_options_text}
**你之前的action执行和思考记录**
{actions_before_now_block}
请选择**可选的**且符合使用条件的action并说明触发action的消息id(消息id格式:m+数字)
不要回复你自己发送的消息
先输出你的简短的选择思考理由再输出你选择的action理由不要分点精简
**动作选择要求**
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
{plan_style}
{moderation_prompt}
请选择所有符合使用要求的action动作用json格式输出```json包裹如果输出多个json每个json都要单独一行放在同一个```json代码块内你可以重复使用同一个动作或不同动作:
target_message_id为必填表示触发消息的id
请选择所有符合使用要求的action每个动作最多选择一次但是可以选择多个动作
动作用json格式输出```json包裹如果输出多个json每个json都要单独一行放在同一个```json代码块内:
**示例**
// 理由文本简短
```json
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
{{"action":"动作名", "target_message_id":"m123", .....}}
{{"action":"动作名", "target_message_id":"m456", .....}}
```""",
"planner_prompt",
)
Prompt(
"""{time_block}
{name_block}
{chat_context_description}以下是具体的聊天内容
**聊天内容**
{chat_content_block}
**可选的action**
no_reply
动作描述
没有合适的可以使用的动作不使用action
{{"action":"no_reply"}}
{action_options_text}
**你之前的action执行和思考记录**
{actions_before_now_block}
请选择**可选的**且符合使用条件的action并说明触发action的消息id(消息id格式:m+数字)
先输出你的简短的选择思考理由再输出你选择的action理由不要分点精简
**动作选择要求**
请你根据聊天内容,用户的最新消息和以下标准选择合适的动作:
1.思考**所有**的可用的action中的**每个动作**是否符合当下条件如果动作使用条件符合聊天内容就使用
2.如果相同的内容已经被执行请不要重复执行
{moderation_prompt}
请选择所有符合使用要求的action动作用json格式输出```json包裹如果输出多个json每个json都要单独一行放在同一个```json代码块内你可以重复使用同一个动作或不同动作:
**示例**
// 理由文本简短
```json
{{"action":"动作名", "target_message_id":"m123", "reason":"原因"}}
{{"action":"动作名", "target_message_id":"m456", "reason":"原因"}}
```""",
"planner_prompt_mentioned",
)
Prompt(
"""
{action_name}
动作描述{action_description}
使用条件{parallel_text}
{action_require}
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)", "reason":"原因"}}
{{"action":"{action_name}",{action_parameters}, "target_message_id":"消息id(m+数字)"}}
""",
"action_prompt",
)
@ -195,11 +160,41 @@ class ActionPlanner:
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 未找到对应消息,保持原样")
return msg_id
msg_text = (message.processed_plain_text or message.display_message or "").strip()
msg_text = (message.processed_plain_text or "").strip()
if not msg_text:
logger.warning(f"{self.log_prefix}planner理由引用 {msg_id} 的消息内容为空,保持原样")
return msg_id
# 替换 [picid:xxx] 为 [图片:描述]
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(pic_match: re.Match) -> str:
pic_id = pic_match.group(1)
description = translate_pid_to_description(pic_id)
return f"[图片:{description}]"
msg_text = re.sub(pic_pattern, replace_pic_id, msg_text)
# 替换用户引用格式:回复<aaa:bbb> 和 @<aaa:bbb>
platform = getattr(message, "user_info", None) and message.user_info.platform or getattr(message, "chat_info", None) and message.chat_info.platform or "qq"
msg_text = replace_user_references(msg_text, platform, replace_bot_name=True)
# 替换单独的 <用户名:用户ID> 格式replace_user_references 已处理回复<和@<格式)
# 匹配所有 <aaa:bbb> 格式,由于 replace_user_references 已经替换了回复<和@<格式,
# 这里匹配到的应该都是单独的格式
user_ref_pattern = r"<([^:<>]+):([^:<>]+)>"
def replace_user_ref(user_match: re.Match) -> str:
user_name = user_match.group(1)
user_id = user_match.group(2)
try:
# 检查是否是机器人自己
if user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person = Person(platform=platform, user_id=user_id)
return person.person_name or user_name
except Exception:
# 如果解析失败,使用原始昵称
return user_name
msg_text = re.sub(user_ref_pattern, replace_user_ref, msg_text)
preview = msg_text if len(msg_text) <= 100 else f"{msg_text[:97]}..."
logger.info(f"{self.log_prefix}planner理由引用 {msg_id} -> 消息({preview}")
return f"消息({msg_text}"
@ -218,11 +213,14 @@ class ActionPlanner:
try:
action = action_json.get("action", "no_reply")
original_reasoning = action_json.get("reason", "未提供原因")
reasoning = self._replace_message_ids_with_text(original_reasoning, message_id_list)
if reasoning is None:
reasoning = original_reasoning
action_data = {key: value for key, value in action_json.items() if key not in ["action", "reason"]}
# 使用 extracted_reasoning整体推理文本作为 reasoning
if extracted_reasoning:
reasoning = self._replace_message_ids_with_text(extracted_reasoning, message_id_list)
if reasoning is None:
reasoning = extracted_reasoning
else:
reasoning = "未提供原因"
action_data = {key: value for key, value in action_json.items() if key not in ["action"]}
# 非no_reply动作需要target_message_id
target_message = None
@ -248,7 +246,7 @@ class ActionPlanner:
# 验证action是否可用
available_action_names = [action_name for action_name, _ in current_available_actions]
internal_action_names = ["no_reply", "reply", "wait_time", "no_reply_until_call"]
internal_action_names = ["no_reply", "reply", "wait_time"]
if action not in internal_action_names and action not in available_action_names:
logger.warning(
@ -304,7 +302,7 @@ class ActionPlanner:
self,
available_actions: Dict[str, ActionInfo],
loop_start_time: float = 0.0,
is_mentioned: bool = False,
force_reply_message: Optional["DatabaseMessages"] = None,
) -> List[ActionPlannerInfo]:
# sourcery skip: use-named-expression
"""
@ -316,7 +314,7 @@ class ActionPlanner:
chat_id=self.chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.6),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
message_id_list: list[Tuple[str, "DatabaseMessages"]] = []
chat_content_block, message_id_list = build_readable_messages_with_id(
@ -345,11 +343,6 @@ class ActionPlanner:
logger.debug(f"{self.log_prefix}过滤后有{len(filtered_actions)}个可用动作")
# 如果是提及时且没有可用动作直接返回空列表不调用LLM以节省token
if is_mentioned and not filtered_actions:
logger.info(f"{self.log_prefix}提及时没有可用动作跳过plan调用")
return []
# 构建包含所有动作的提示词
prompt, message_id_list = await self.build_planner_prompt(
is_group_chat=is_group_chat,
@ -357,8 +350,6 @@ class ActionPlanner:
current_available_actions=filtered_actions,
chat_content_block=chat_content_block,
message_id_list=message_id_list,
interest=global_config.personality.interest,
is_mentioned=is_mentioned,
)
# 调用LLM获取决策
@ -370,6 +361,34 @@ class ActionPlanner:
loop_start_time=loop_start_time,
)
# 如果有强制回复消息,确保回复该消息
if force_reply_message:
# 检查是否已经有回复该消息的 action
has_reply_to_force_message = False
for action in actions:
if action.action_type == "reply" and action.action_message and action.action_message.message_id == force_reply_message.message_id:
has_reply_to_force_message = True
break
# 如果没有回复该消息,强制添加回复 action
if not has_reply_to_force_message:
# 移除所有 no_reply action如果有
actions = [a for a in actions if a.action_type != "no_reply"]
# 创建强制回复 action
available_actions_dict = dict(current_available_actions)
force_reply_action = ActionPlannerInfo(
action_type="reply",
reasoning="用户提及了我,必须回复该消息",
action_data={"loop_start_time": loop_start_time},
action_message=force_reply_message,
available_actions=available_actions_dict,
action_reasoning=None,
)
# 将强制回复 action 放在最前面
actions.insert(0, force_reply_action)
logger.info(f"{self.log_prefix} 检测到强制回复消息,已添加回复动作")
logger.info(
f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}"
)
@ -430,32 +449,6 @@ class ActionPlanner:
return plan_log_str
def _has_consecutive_no_reply(self, min_count: int = 3) -> bool:
"""
检查是否有连续min_count次以上的no_reply
Args:
min_count: 需要连续的最少次数默认3
Returns:
如果有连续min_count次以上no_reply返回True否则返回False
"""
consecutive_count = 0
# 从后往前遍历plan_log检查最新的连续记录
for _reasoning, _timestamp, content in reversed(self.plan_log):
if isinstance(content, list) and all(isinstance(action, ActionPlannerInfo) for action in content):
# 检查所有action是否都是no_reply
if all(action.action_type == "no_reply" for action in content):
consecutive_count += 1
if consecutive_count >= min_count:
return True
else:
# 如果遇到非no_reply的action重置计数
break
return False
async def build_planner_prompt(
self,
is_group_chat: bool,
@ -464,7 +457,6 @@ class ActionPlanner:
message_id_list: List[Tuple[str, "DatabaseMessages"]],
chat_content_block: str = "",
interest: str = "",
is_mentioned: bool = False,
) -> tuple[str, List[Tuple[str, "DatabaseMessages"]]]:
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
@ -485,47 +477,25 @@ class ActionPlanner:
)
name_block = f"你的名字是{bot_name}{bot_nickname},请注意哪些是你自己的发言。"
# 根据是否是提及时选择不同的模板
if is_mentioned:
# 提及时使用简化版提示词不需要reply、no_reply、no_reply_until_call
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt_mentioned")
prompt = planner_prompt_template.format(
time_block=time_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
actions_before_now_block=actions_before_now_block,
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
name_block=name_block,
interest=interest,
plan_style=global_config.personality.plan_style,
)
# 根据 think_mode 配置决定 reply action 的示例 JSON
if global_config.chat.think_mode == "classic":
reply_action_example = '{{"action":"reply", "target_messamge_id":"消息id(m+数字)"}}'
else:
# 正常流程使用完整版提示词
# 检查是否有连续3次以上no_reply如果有则添加no_reply_until_call选项
no_reply_until_call_block = ""
if self._has_consecutive_no_reply(min_count=3):
no_reply_until_call_block = """no_reply_until_call
动作描述
保持沉默直到有人直接叫你的名字
当前话题不感兴趣时使用或有人不喜欢你的发言时使用
当你频繁选择no_reply时使用表示话题暂时与你无关
{{"action":"no_reply_until_call"}}
"""
reply_action_example = '5.think_level表示思考深度0表示该回复不需要思考和回忆1表示该回复需要进行回忆和思考\n{{"action":"reply", "think_level":数值等级(0或1), "target_messamge_id":"消息id(m+数字)"}}'
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
time_block=time_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
actions_before_now_block=actions_before_now_block,
action_options_text=action_options_block,
no_reply_until_call_block=no_reply_until_call_block,
moderation_prompt=moderation_prompt_block,
name_block=name_block,
interest=interest,
plan_style=global_config.personality.plan_style,
)
planner_prompt_template = await global_prompt_manager.get_prompt_async("planner_prompt")
prompt = planner_prompt_template.format(
time_block=time_block,
chat_context_description=chat_context_description,
chat_content_block=chat_content_block,
actions_before_now_block=actions_before_now_block,
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
name_block=name_block,
interest=interest,
plan_style=global_config.personality.plan_style,
reply_action_example=reply_action_example,
)
return prompt, message_id_list
except Exception as e:

View File

@ -18,13 +18,12 @@ from src.chat.message_receive.uni_message_sender import UniversalMessageSender
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.utils.utils import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import global_prompt_manager
from src.mood.mood_manager import mood_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.express.expression_selector import expression_selector
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator
@ -36,7 +35,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
from src.jargon.jargon_explainer import explain_jargon_in_context
from src.bw_learner.jargon_explainer import explain_jargon_in_context
init_lpmm_prompt()
init_replyer_prompt()
@ -73,6 +72,7 @@ class DefaultReplyer:
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_time_point: Optional[float] = time.time(),
think_level: int = 1,
) -> Tuple[bool, LLMGenerationDataModel]:
# sourcery skip: merge-nested-ifs
"""
@ -98,8 +98,10 @@ class DefaultReplyer:
available_actions = {}
try:
# 3. 构建 Prompt
timing_logs = []
almost_zero_str = ""
with Timer("构建Prompt", {}): # 内部计时器,可选保留
prompt, selected_expressions = await self.build_prompt_reply_context(
prompt, selected_expressions, timing_logs, almost_zero_str = await self.build_prompt_reply_context(
extra_info=extra_info,
available_actions=available_actions,
chosen_actions=chosen_actions,
@ -107,6 +109,7 @@ class DefaultReplyer:
reply_message=reply_message,
reply_reason=reply_reason,
reply_time_point=reply_time_point,
think_level=think_level,
)
llm_response.prompt = prompt
llm_response.selected_expressions = selected_expressions
@ -135,10 +138,22 @@ class DefaultReplyer:
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
# logger.debug(f"replyer生成内容: {content}")
logger.info(f"replyer生成内容: {content}")
if global_config.debug.show_replyer_reasoning:
logger.info(f"replyer生成推理:\n{reasoning_content}")
logger.info(f"replyer生成模型: {model_name}")
# 统一输出所有日志信息使用try-except确保即使某个步骤出错也能输出
try:
# 1. 输出回复准备日志
timing_log_str = f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s" if timing_logs or almost_zero_str else "回复准备: 无计时信息"
logger.info(timing_log_str)
# 2. 输出Prompt日志
if global_config.debug.show_replyer_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
# 3. 输出模型生成内容和推理日志
logger.info(f"模型: [{model_name}][思考等级:{think_level}]生成内容: {content}")
if global_config.debug.show_replyer_reasoning and reasoning_content:
logger.info(f"模型: [{model_name}][思考等级:{think_level}]生成推理:\n{reasoning_content}")
except Exception as e:
logger.warning(f"输出日志时出错: {e}")
llm_response.content = content
llm_response.reasoning = reasoning_content
@ -162,6 +177,21 @@ class DefaultReplyer:
except Exception as llm_e:
# 精简报错信息
logger.error(f"LLM 生成失败: {llm_e}")
# 即使LLM生成失败也尝试输出已收集的日志信息
try:
# 1. 输出回复准备日志
timing_log_str = f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s" if timing_logs or almost_zero_str else "回复准备: 无计时信息"
logger.info(timing_log_str)
# 2. 输出Prompt日志
if global_config.debug.show_replyer_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
# 3. 输出模型生成失败信息
logger.info("模型生成失败,无法输出生成内容和推理")
except Exception as log_e:
logger.warning(f"输出日志时出错: {log_e}")
return False, llm_response # LLM 调用失败则无法生成回复
return True, llm_response
@ -228,7 +258,7 @@ class DefaultReplyer:
return False, llm_response
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
self, chat_history: str, target: str, reply_reason: str = "", think_level: int = 1
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@ -237,6 +267,7 @@ class DefaultReplyer:
chat_history: 聊天历史记录
target: 目标消息内容
reply_reason: planner给出的回复理由
think_level: 思考级别0/1/2
Returns:
str: 表达习惯信息字符串
@ -249,14 +280,19 @@ class DefaultReplyer:
# 使用从处理器传来的选中表达方式
# 使用模型预测选择表达方式
selected_expressions, selected_ids = await expression_selector.select_suitable_expressions(
self.chat_stream.stream_id, chat_history, max_num=8, target_message=target, reply_reason=reply_reason
self.chat_stream.stream_id,
chat_history,
max_num=8,
target_message=target,
reply_reason=reply_reason,
think_level=think_level,
)
if selected_expressions:
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habits.append(f"{expr['situation']},使用 {expr['style']}")
style_habits.append(f"{expr['situation']}{expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
@ -272,13 +308,6 @@ class DefaultReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示"""
if not global_config.mood.enable_mood:
return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@ -459,6 +488,10 @@ class DefaultReplyer:
duration = end_time - start_time
return name, result, duration
async def _build_disabled_jargon_explanation(self) -> str:
"""当关闭黑话解释时使用的占位协程避免额外的LLM调用"""
return ""
def build_chat_history_prompts(
self, message_list_before_now: List[DatabaseMessages], target_user_id: str, sender: str
) -> Tuple[str, str]:
@ -606,7 +639,7 @@ class DefaultReplyer:
# 获取基础personality
prompt_personality = global_config.personality.personality
# 检查是否需要随机替换为状态
# 检查是否需要随机替换为状态personality 本体)
if (
global_config.personality.states
and global_config.personality.state_probability > 0
@ -705,7 +738,8 @@ class DefaultReplyer:
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
reply_time_point: Optional[float] = time.time(),
) -> Tuple[str, List[int]]:
think_level: int = 1,
) -> Tuple[str, List[int], List[str], str]:
"""
构建回复器上下文
@ -751,14 +785,14 @@ class DefaultReplyer:
chat_id=chat_id,
timestamp=reply_time_point,
limit=global_config.chat.max_context_size * 1,
filter_no_read_command=True,
filter_intercept_message_level=1,
)
message_list_before_short = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=reply_time_point,
limit=int(global_config.chat.max_context_size * 0.33),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
person_list_short: List[Person] = []
@ -789,10 +823,18 @@ class DefaultReplyer:
show_actions=True,
)
# 并行执行八个构建任务(包括黑话解释)
# 根据配置决定是否启用黑话解释
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
if enable_jargon_explanation:
jargon_coroutine = explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short)
else:
jargon_coroutine = self._build_disabled_jargon_explanation()
# 并行执行八个构建任务(包括黑话解释,可配置关闭)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason, think_level=think_level),
"expression_habits",
),
self._time_and_run_task(
self.build_tool_info(chat_talking_prompt_short, sender, target, enable_tool=enable_tool), "tool_info"
@ -800,17 +842,13 @@ class DefaultReplyer:
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
self._time_and_run_task(
build_memory_retrieval_prompt(
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
chat_talking_prompt_short, sender, target, self.chat_stream, think_level=think_level
),
"memory_retrieval",
),
self._time_and_run_task(
explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short),
"jargon_explanation",
),
self._time_and_run_task(jargon_coroutine, "jargon_explanation"),
)
# 任务名称中英文映射
@ -821,7 +859,6 @@ class DefaultReplyer:
"prompt_info": "获取知识",
"actions_info": "动作信息",
"personality_prompt": "人格信息",
"mood_state_prompt": "情绪状态",
"memory_retrieval": "记忆检索",
"jargon_explanation": "黑话解释",
}
@ -839,7 +876,8 @@ class DefaultReplyer:
continue
timing_logs.append(f"{chinese_name}: {duration:.1f}s")
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
# 不再在这里输出日志,而是返回给调用者统一输出
# logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.1s")
expression_habits_block, selected_expressions = results_dict["expression_habits"]
expression_habits_block: str
@ -851,14 +889,8 @@ class DefaultReplyer:
personality_prompt: str = results_dict["personality_prompt"]
memory_retrieval: str = results_dict["memory_retrieval"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
mood_state_prompt: str = results_dict["mood_state_prompt"]
jargon_explanation: str = results_dict.get("jargon_explanation") or ""
# 从 chosen_actions 中提取 planner 的整体思考理由
planner_reasoning = ""
if global_config.chat.include_planner_reasoning and reply_reason:
# 如果没有 chosen_actions使用 reply_reason 作为备选
planner_reasoning = f"你的想法是:{reply_reason}"
planner_reasoning = f"你的想法是:{reply_reason}"
if extra_info:
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
@ -893,14 +925,31 @@ class DefaultReplyer:
chat_prompt_content = self.get_chat_prompt_for_chat(chat_id)
chat_prompt_block = f"{chat_prompt_content}\n" if chat_prompt_content else ""
# 固定使用群聊回复模板
# 根据think_level选择不同的回复模板
# think_level=0: 轻量回复(简短平淡)
# think_level=1: 中等回复(日常口语化)
if think_level == 0:
prompt_name = "replyer_prompt_0"
else: # think_level == 1 或默认
prompt_name = "replyer_prompt"
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
try:
reply_style = random.choice(list(multi_styles))
except Exception:
# 兜底:即使 multiple_reply_style 配置异常也不影响正常回复
reply_style = global_config.personality.reply_style
return await global_prompt_manager.format_prompt(
"replyer_prompt",
prompt_name,
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
bot_name=global_config.bot.nickname,
knowledge_prompt=prompt_info,
mood_state=mood_state_prompt,
# relation_info_block=relation_info,
extra_info_block=extra_info_block,
jargon_explanation=jargon_explanation,
@ -910,13 +959,13 @@ class DefaultReplyer:
dialogue_prompt=dialogue_prompt,
time_block=time_block,
reply_target_block=reply_target_block,
reply_style=global_config.personality.reply_style,
reply_style=reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
memory_retrieval=memory_retrieval,
chat_prompt=chat_prompt_block,
planner_reasoning=planner_reasoning,
), selected_expressions
), selected_expressions, timing_logs, almost_zero_str
async def build_prompt_rewrite_context(
self,
@ -926,8 +975,6 @@ class DefaultReplyer:
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
chat_stream = self.chat_stream
chat_id = chat_stream.stream_id
is_group_chat = bool(chat_stream.group_info)
sender, target = self._parse_reply_target(reply_to)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
@ -941,7 +988,7 @@ class DefaultReplyer:
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
chat_talking_prompt_half = build_readable_messages(
message_list_before_now_half,
@ -967,61 +1014,42 @@ class DefaultReplyer:
if sender and target:
# 使用预先分析的内容类型结果
if is_group_chat:
if sender:
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = (
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
# 只包含文字
reply_target_block = (
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
if sender:
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = (
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
else: # private chat
if sender:
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
elif target:
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
else:
reply_target_block = "现在,你想要回复。"
# 只包含文字
reply_target_block = (
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
else:
reply_target_block = ""
if is_group_chat:
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
else:
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
template_name = "default_expressor_prompt"
# 根据配置构建最终的 reply_style支持 multiple_reply_style 按概率随机替换
reply_style = global_config.personality.reply_style
multi_styles = getattr(global_config.personality, "multiple_reply_style", None) or []
multi_prob = getattr(global_config.personality, "multiple_probability", 0.0) or 0.0
if multi_styles and multi_prob > 0 and random.random() < multi_prob:
try:
reply_style = random.choice(list(multi_styles))
except Exception:
reply_style = global_config.personality.reply_style
return await global_prompt_manager.format_prompt(
template_name,
expression_habits_block=expression_habits_block,
@ -1034,7 +1062,7 @@ class DefaultReplyer:
reply_target_block=reply_target_block,
raw_reply=raw_reply,
reason=reason,
reply_style=global_config.personality.reply_style,
reply_style=reply_style,
keywords_reaction_prompt=keywords_reaction_prompt,
moderation_prompt=moderation_prompt_block,
)
@ -1078,10 +1106,11 @@ class DefaultReplyer:
# 直接使用已初始化的模型实例
# logger.info(f"\n{prompt}\n")
if global_config.debug.show_replyer_prompt:
logger.info(f"\n{prompt}\n")
else:
logger.debug(f"\nreplyer_Prompt:{prompt}\n")
# 不再在这里输出日志,而是返回给调用者统一输出
# if global_config.debug.show_replyer_prompt:
# logger.info(f"\n{prompt}\n")
# else:
# logger.debug(f"\nreplyer_Prompt:{prompt}\n")
content, (reasoning_content, model_name, tool_calls) = await self.express_model.generate_response_async(
prompt
@ -1090,7 +1119,7 @@ class DefaultReplyer:
# 移除 content 前后的换行符和空格
content = content.strip()
logger.info(f"使用 {model_name} 生成回复内容: {content}")
# logger.info(f"使用 {model_name} 生成回复内容: {content}")
return content, reasoning_content, model_name, tool_calls
async def get_prompt_info(self, message: str, sender: str, target: str):

View File

@ -23,9 +23,8 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
replace_user_references,
)
from src.express.expression_selector import expression_selector
from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description
from src.mood.mood_manager import mood_manager
# from src.memory_system.memory_activator import MemoryActivator
@ -34,13 +33,13 @@ from src.plugin_system.base.component_types import ActionInfo, EventType
from src.plugin_system.apis import llm_api
from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.replyer_private_prompt import init_replyer_private_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
from src.jargon.jargon_explainer import explain_jargon_in_context
from src.bw_learner.jargon_explainer import explain_jargon_in_context
init_lpmm_prompt()
init_replyer_prompt()
init_replyer_private_prompt()
init_rewrite_prompt()
init_memory_retrieval_prompt()
@ -72,6 +71,7 @@ class PrivateReplyer:
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
from_plugin: bool = True,
think_level: int = 1,
stream_id: Optional[str] = None,
reply_message: Optional[DatabaseMessages] = None,
reply_time_point: Optional[float] = time.time(),
@ -271,7 +271,7 @@ class PrivateReplyer:
logger.debug(f"使用处理器选中的{len(selected_expressions)}个表达方式")
for expr in selected_expressions:
if isinstance(expr, dict) and "situation" in expr and "style" in expr:
style_habits.append(f"{expr['situation']},使用 {expr['style']}")
style_habits.append(f"{expr['situation']}{expr['style']}")
else:
logger.debug("没有从处理器获得表达方式,将使用空的表达方式")
# 不再在replyer中进行随机选择全部交给处理器处理
@ -287,13 +287,6 @@ class PrivateReplyer:
return f"{expression_habits_title}\n{expression_habits_block}", selected_ids
async def build_mood_state_prompt(self) -> str:
"""构建情绪状态提示"""
if not global_config.mood.enable_mood:
return ""
mood_state = await mood_manager.get_mood_by_chat_id(self.chat_stream.stream_id).get_mood()
return f"你现在的心情是:{mood_state}"
async def build_tool_info(self, chat_history: str, sender: str, target: str, enable_tool: bool = True) -> str:
"""构建工具信息块
@ -474,6 +467,10 @@ class PrivateReplyer:
duration = end_time - start_time
return name, result, duration
async def _build_disabled_jargon_explanation(self) -> str:
"""当关闭黑话解释时使用的占位协程避免额外的LLM调用"""
return ""
async def build_actions_prompt(
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
) -> str:
@ -663,7 +660,7 @@ class PrivateReplyer:
chat_id=chat_id,
timestamp=time.time(),
limit=global_config.chat.max_context_size,
filter_no_read_command=True,
filter_intercept_message_level=1,
)
dialogue_prompt = build_readable_messages(
@ -678,7 +675,7 @@ class PrivateReplyer:
chat_id=chat_id,
timestamp=time.time(),
limit=int(global_config.chat.max_context_size * 0.33),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
person_list_short: List[Person] = []
@ -709,7 +706,14 @@ class PrivateReplyer:
show_actions=True,
)
# 并行执行九个构建任务(包括黑话解释)
# 根据配置决定是否启用黑话解释
enable_jargon_explanation = getattr(global_config.expression, "enable_jargon_explanation", True)
if enable_jargon_explanation:
jargon_coroutine = explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short)
else:
jargon_coroutine = self._build_disabled_jargon_explanation()
# 并行执行九个构建任务(包括黑话解释,可配置关闭)
task_results = await asyncio.gather(
self._time_and_run_task(
self.build_expression_habits(chat_talking_prompt_short, target, reply_reason), "expression_habits"
@ -721,17 +725,13 @@ class PrivateReplyer:
self._time_and_run_task(self.get_prompt_info(chat_talking_prompt_short, sender, target), "prompt_info"),
self._time_and_run_task(self.build_actions_prompt(available_actions, chosen_actions), "actions_info"),
self._time_and_run_task(self.build_personality_prompt(), "personality_prompt"),
self._time_and_run_task(self.build_mood_state_prompt(), "mood_state_prompt"),
self._time_and_run_task(
build_memory_retrieval_prompt(
chat_talking_prompt_short, sender, target, self.chat_stream, self.tool_executor
),
"memory_retrieval",
),
self._time_and_run_task(
explain_jargon_in_context(chat_id, message_list_before_short, chat_talking_prompt_short),
"jargon_explanation",
),
self._time_and_run_task(jargon_coroutine, "jargon_explanation"),
)
# 任务名称中英文映射
@ -742,7 +742,6 @@ class PrivateReplyer:
"prompt_info": "获取知识",
"actions_info": "动作信息",
"personality_prompt": "人格信息",
"mood_state_prompt": "情绪状态",
"memory_retrieval": "记忆检索",
"jargon_explanation": "黑话解释",
}
@ -770,16 +769,10 @@ class PrivateReplyer:
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
personality_prompt: str = results_dict["personality_prompt"]
mood_state_prompt: str = results_dict["mood_state_prompt"]
memory_retrieval: str = results_dict["memory_retrieval"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
jargon_explanation: str = results_dict.get("jargon_explanation") or ""
# 从 chosen_actions 中提取 planner 的整体思考理由
planner_reasoning = ""
if global_config.chat.include_planner_reasoning and reply_reason:
# 如果没有 chosen_actions使用 reply_reason 作为备选
planner_reasoning = f"你的想法是:{reply_reason}"
planner_reasoning = f"你的想法是:{reply_reason}"
if extra_info:
extra_info_block = f"以下是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策\n{extra_info}\n以上是你在回复时需要参考的信息,现在请你阅读以下内容,进行决策"
@ -814,7 +807,6 @@ class PrivateReplyer:
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
mood_state=mood_state_prompt,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
@ -837,7 +829,6 @@ class PrivateReplyer:
expression_habits_block=expression_habits_block,
tool_info_block=tool_info,
knowledge_prompt=prompt_info,
mood_state=mood_state_prompt,
relation_info_block=relation_info,
extra_info_block=extra_info_block,
identity=personality_prompt,
@ -878,7 +869,7 @@ class PrivateReplyer:
chat_id=chat_id,
timestamp=time.time(),
limit=min(int(global_config.chat.max_context_size * 0.33), 15),
filter_no_read_command=True,
filter_intercept_message_level=1,
)
chat_talking_prompt_half = build_readable_messages(
message_list_before_now_half,
@ -904,59 +895,30 @@ class PrivateReplyer:
)
if sender and target:
# 使用预先分析的内容类型结果
if is_group_chat:
if sender:
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = (
f"现在{sender}发送的图片:{pic_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
else:
# 只包含文字
reply_target_block = (
f"现在{sender}说的:{text_part}。引起了你的注意,你想要在群里发言或者回复这条消息。"
)
elif target:
reply_target_block = f"现在{target}引起了你的注意,你想要在群里发言或者回复这条消息。"
if sender:
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else:
reply_target_block = "现在,你想要在群里发言或者回复消息。"
else: # private chat
if sender:
if has_only_pics and not has_text:
# 只包含图片
reply_target_block = f"现在{sender}发送的图片:{pic_part}。引起了你的注意,针对这条消息回复。"
elif has_text and pic_part:
# 既有图片又有文字
reply_target_block = (
f"现在{sender}发送了图片:{pic_part},并说:{text_part}。引起了你的注意,针对这条消息回复。"
)
else:
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
elif target:
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
else:
reply_target_block = "现在,你想要回复。"
# 只包含文字
reply_target_block = f"现在{sender}说的:{text_part}。引起了你的注意,针对这条消息回复。"
elif target:
reply_target_block = f"现在{target}引起了你的注意,针对这条消息回复。"
else:
reply_target_block = "现在,你想要回复。"
else:
reply_target_block = ""
if is_group_chat:
chat_target_1 = await global_prompt_manager.get_prompt_async("chat_target_group1")
chat_target_2 = await global_prompt_manager.get_prompt_async("chat_target_group2")
else:
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt(
"chat_target_private1", sender_name=chat_target_name
)
chat_target_2 = await global_prompt_manager.format_prompt(
"chat_target_private2", sender_name=chat_target_name
)
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = self.chat_target_info.person_name or self.chat_target_info.user_nickname or "对方"
chat_target_1 = await global_prompt_manager.format_prompt("chat_target_private1", sender_name=chat_target_name)
chat_target_2 = await global_prompt_manager.format_prompt("chat_target_private2", sender_name=chat_target_name)
template_name = "default_expressor_prompt"

View File

@ -0,0 +1,41 @@
from src.chat.utils.prompt_builder import Prompt
def init_replyer_private_prompt():
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
{reply_target_block}
{planner_reasoning}
{identity}
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )""",
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
你现在想补充说明你刚刚自己的发言内容{target}原因是{reason}
请你根据聊天内容组织一条新回复注意{target} 是刚刚你自己的发言你要在这基础上进一步发言请按照你自己的角度来继续进行回复注意保持上下文的连贯性
{identity}
{chat_prompt}尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )
""",
"private_replyer_self_prompt",
)

View File

@ -3,8 +3,27 @@ from src.chat.utils.prompt_builder import Prompt
def init_replyer_prompt():
Prompt("正在群里聊天", "chat_target_group2")
Prompt("{sender_name}聊天", "chat_target_private2")
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在qq群里聊天下面是群里正在聊的内容其中包含聊天记录和聊天中的图片
其中标注 {bot_name}() 的发言是你自己的发言请注意区分:
{time_block}
{dialogue_prompt}
{reply_target_block}
{planner_reasoning}
{identity}
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复
尽量简短一些{keywords_reaction_prompt}
请注意把握聊天内容不要回复的太有条理
{reply_style}
请注意不要输出多余内容(包括不必要的前后缀冒号括号表情包at或 @等 )只输出发言内容就好
最好一次对一个话题进行回复免得啰嗦或者回复内容太乱
现在你说""",
"replyer_prompt_0",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
@ -18,49 +37,12 @@ def init_replyer_prompt():
{reply_target_block}
{planner_reasoning}
{identity}
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些{mood_state}
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录把握当前的话题然后给出日常且简短的回复
最好一次对一个话题进行回复免得啰嗦或者回复内容太乱
{keywords_reaction_prompt}
请注意把握聊天内容
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出一句回复内容就好
不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )
请注意不要输出多余内容(包括不必要的前后缀冒号括号at或 @等 )只输出发言内容就好
现在你说""",
"replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
{reply_target_block}
{planner_reasoning}
{identity}
{chat_prompt}你正在和{sender_name}聊天,现在请你读读之前的聊天记录然后给出日常且口语化的回复平淡一些{mood_state}
尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括前后缀冒号和引号括号表情包at或 @等 )""",
"private_replyer_prompt",
)
Prompt(
"""{knowledge_prompt}{tool_info_block}{extra_info_block}
{expression_habits_block}{memory_retrieval}{jargon_explanation}
你正在和{sender_name}聊天这是你们之前聊的内容:
{time_block}
{dialogue_prompt}
你现在想补充说明你刚刚自己的发言内容{target}原因是{reason}
请你根据聊天内容组织一条新回复注意{target} 是刚刚你自己的发言你要在这基础上进一步发言请按照你自己的角度来继续进行回复注意保持上下文的连贯性{mood_state}
{identity}
{chat_prompt}尽量简短一些{keywords_reaction_prompt}请注意把握聊天内容不要回复的太有条理可以有个性
{reply_style}
请注意不要输出多余内容(包括前后缀冒号和引号括号表情等)只输出回复内容
{moderation_prompt}不要输出多余内容(包括冒号和引号括号表情包at或 @等 )
""",
"private_replyer_self_prompt",
)

View File

@ -120,7 +120,7 @@ def get_raw_msg_by_timestamp_with_chat(
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
filter_no_read_command=False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@ -138,7 +138,7 @@ def get_raw_msg_by_timestamp_with_chat(
limit_mode=limit_mode,
filter_bot=filter_bot,
filter_command=filter_command,
filter_no_read_command=filter_no_read_command,
filter_intercept_message_level=filter_intercept_message_level,
)
@ -150,7 +150,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
filter_no_read_command=False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
"""获取在特定聊天从指定时间戳到指定时间戳的消息(包含边界),按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@ -167,7 +167,7 @@ def get_raw_msg_by_timestamp_with_chat_inclusive(
limit_mode=limit_mode,
filter_bot=filter_bot,
filter_command=filter_command,
filter_no_read_command=filter_no_read_command,
filter_intercept_message_level=filter_intercept_message_level,
)
@ -303,7 +303,7 @@ def get_raw_msg_before_timestamp(timestamp: float, limit: int = 0) -> List[Datab
def get_raw_msg_before_timestamp_with_chat(
chat_id: str, timestamp: float, limit: int = 0, filter_no_read_command: bool = False
chat_id: str, timestamp: float, limit: int = 0, filter_intercept_message_level: Optional[int] = None
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@ -311,7 +311,10 @@ def get_raw_msg_before_timestamp_with_chat(
filter_query = {"chat_id": chat_id, "time": {"$lt": timestamp}}
sort_order = [("time", 1)]
return find_messages(
message_filter=filter_query, sort=sort_order, limit=limit, filter_no_read_command=filter_no_read_command
message_filter=filter_query,
sort=sort_order,
limit=limit,
filter_intercept_message_level=filter_intercept_message_level,
)

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
from src.common.database.database_model import OnlineTime, LLMUsage, Messages, ActionRecords
from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage
from src.config.config import global_config
@ -505,13 +505,6 @@ class StatisticOutputTask(AsyncTask):
for period_key, _ in collect_period
}
# 获取bot的QQ账号
bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp
@ -537,7 +530,7 @@ class StatisticOutputTask(AsyncTask):
if not chat_id: # Should not happen if above logic is correct
continue
# Update name_mapping
# Update name_mapping(仅用于展示聊天名称)
try:
if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
@ -549,19 +542,30 @@ class StatisticOutputTask(AsyncTask):
# 重置为正确的格式
self.name_mapping[chat_id] = (chat_name, message_time_ts)
# 检查是否是bot发送的消息回复
is_bot_reply = False
if bot_qq_account and message.user_id == bot_qq_account:
is_bot_reply = True
for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
if is_bot_reply:
stats[period_key][TOTAL_REPLY_CNT] += 1
break
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1].timestamp()
for action in ActionRecords.select().where(ActionRecords.time >= action_query_start_timestamp): # type: ignore
# 仅统计已完成的 reply 动作
if action.action_name != "reply" or not action.action_done:
continue
action_time_ts = action.time
for idx, (_, period_start_dt) in enumerate(collect_period):
if action_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REPLY_CNT] += 1
break
except Exception as e:
logger.warning(f"统计 reply 动作次数失败,将回复数视为 0错误信息{e}")
return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
@ -742,7 +746,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模型分类统计:",
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@ -755,11 +759,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@ -767,7 +771,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,
@ -796,7 +800,7 @@ class StatisticOutputTask(AsyncTask):
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
output = [
"按模块分类统计:",
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
@ -809,11 +813,11 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODULE][module_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
# 计算每次回复平均值
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
@ -821,7 +825,7 @@ class StatisticOutputTask(AsyncTask):
formatted_tokens = _format_large_number(tokens)
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
output.append(
data_fmt.format(
name,

View File

@ -4,6 +4,8 @@ import time
import jieba
import json
import ast
import os
from datetime import datetime
from typing import Optional, Tuple, List, TYPE_CHECKING
@ -196,21 +198,54 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
List[str]: 分割和合并后的句子列表
"""
# 预处理:处理多余的换行符
# 1. 将连续的换行符替换为单个换行符
# 1. 将连续的换行符替换为单个换行符(保留换行符用于分割)
text = re.sub(r"\n\s*\n+", "\n", text)
# 2. 处理换行符和其他分隔符的组合
text = re.sub(r"\n\s*([,。;\s])", r"\1", text)
text = re.sub(r"([,。;\s])\s*\n", r"\1", text)
# 2. 处理换行符和其他分隔符的组合(保留换行符,删除其他分隔符)
text = re.sub(r"\n\s*([,。;\s])", r"\n\1", text)
text = re.sub(r"([,。;\s])\s*\n", r"\1\n", text)
# 处理两个汉字中间的换行符
text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text)
# 处理两个汉字中间的换行符(保留换行符,不替换为句号,让换行符强制分割)
# text = re.sub(r"([\u4e00-\u9fff])\n([\u4e00-\u9fff])", r"\1。\2", text) # 注释掉,保留换行符用于分割
len_text = len(text)
if len_text < 3:
return list(text) if random.random() < 0.01 else [text]
# 定义分隔符
separators = {"", ",", " ", "", ";"}
# 先标记哪些位置位于成对引号内部,避免在引号内部进行句子分割
# 支持的引号包括:中英文单/双引号和常见中文书名号/引号
quote_chars = {
'"',
"'",
"",
"",
"",
"",
"",
"",
"",
"",
}
inside_quote = [False] * len_text
in_quote = False
current_quote_char = ""
for idx, ch in enumerate(text):
if ch in quote_chars:
# 遇到引号时切换状态(英文引号本身开闭相同,用同一个字符表示)
if not in_quote:
in_quote = True
current_quote_char = ch
inside_quote[idx] = False
else:
# 只有遇到同一类引号才视为关闭
if ch == current_quote_char or ch in {'"', "'"} and current_quote_char in {'"', "'"}:
in_quote = False
current_quote_char = ""
inside_quote[idx] = False
else:
inside_quote[idx] = in_quote
# 定义分隔符(包含换行符)
separators = {"", ",", " ", "", ";", "\n"}
segments = []
current_segment = ""
@ -219,24 +254,42 @@ def split_into_sentences_w_remove_punctuation(text: str) -> list[str]:
while i < len(text):
char = text[i]
if char in separators:
# 检查分割条件:如果空格左右都是英文字母、数字,或数字和英文之间,则不分割(仅对空格应用此规则)
can_split = True
if 0 < i < len(text) - 1:
prev_char = text[i - 1]
next_char = text[i + 1]
# 只对空格应用"不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格"规则
if char == " ":
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum:
can_split = False
# 引号内部一律不作为分割点(包括换行)
if inside_quote[i]:
can_split = False
else:
# 换行符在不在引号内时都强制分割
if char == "\n":
can_split = True
else:
# 检查分割条件
can_split = True
# 检查分隔符左右是否有冒号(中英文),如果有则不分割
if i > 0:
prev_char = text[i - 1]
if prev_char in {":", ""}:
can_split = False
if i < len(text) - 1:
next_char = text[i + 1]
if next_char in {":", ""}:
can_split = False
# 如果左右没有冒号,再检查空格的特殊情况
if can_split and char == " " and i > 0 and i < len(text) - 1:
prev_char = text[i - 1]
next_char = text[i + 1]
# 不分割数字和数字、数字和英文、英文和数字、英文和英文之间的空格
prev_is_alnum = prev_char.isdigit() or is_english_letter(prev_char)
next_is_alnum = next_char.isdigit() or is_english_letter(next_char)
if prev_is_alnum and next_is_alnum:
can_split = False
if can_split:
# 只有当当前段不为空时才添加
if current_segment:
segments.append((current_segment, char))
# 如果当前段为空,但分隔符是空格,则也添加一个空段(保留空格)
elif char == " ":
# 如果当前段为空,但分隔符是空格或换行符,则也添加一个空段(保留分隔符
elif char in {" ", "\n"}:
segments.append(("", char))
current_segment = ""
else:
@ -641,6 +694,42 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional["TargetP
return is_group_chat, chat_target_info
def record_replyer_action_temp(chat_id: str, reason: str, think_level: int) -> None:
"""
临时记录replyer动作被选择的信息仅群聊
Args:
chat_id: 聊天ID
reason: 选择理由
think_level: 思考深度等级
"""
try:
# 确保data/temp目录存在
temp_dir = "data/temp"
os.makedirs(temp_dir, exist_ok=True)
# 创建记录数据
record_data = {
"chat_id": chat_id,
"reason": reason,
"think_level": think_level,
"timestamp": datetime.now().isoformat(),
}
# 生成文件名(使用时间戳避免冲突)
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
filename = f"replyer_action_{timestamp_str}.json"
filepath = os.path.join(temp_dir, filename)
# 写入文件
with open(filepath, "w", encoding="utf-8") as f:
json.dump(record_data, f, ensure_ascii=False, indent=2)
logger.debug(f"已记录replyer动作选择: chat_id={chat_id}, think_level={think_level}")
except Exception as e:
logger.warning(f"记录replyer动作选择失败: {e}")
def assign_message_ids(messages: List[DatabaseMessages]) -> List[Tuple[str, DatabaseMessages]]:
"""
为消息列表中的每个消息分配唯一的简短随机ID

View File

@ -130,12 +130,10 @@ class ImageManager:
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = (
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0:
logger.info(
@ -166,7 +164,7 @@ class ImageManager:
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
"""如果启用了steal_emoji且表情包未注册保存文件到data/emoji目录
Args:
image_base64: 图片的base64编码
image_hash: 图片的MD5哈希值
@ -174,7 +172,7 @@ class ImageManager:
"""
if not global_config.emoji.steal_emoji:
return
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
@ -236,12 +234,16 @@ class ImageManager:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)

View File

@ -77,7 +77,7 @@ class DatabaseMessages(BaseDataModel):
is_emoji: bool = False,
is_picid: bool = False,
is_command: bool = False,
is_no_read_command: bool = False,
intercept_message_level: int = 0,
is_notify: bool = False,
selected_expressions: Optional[str] = None,
user_id: str = "",
@ -120,7 +120,7 @@ class DatabaseMessages(BaseDataModel):
self.is_emoji = is_emoji
self.is_picid = is_picid
self.is_command = is_command
self.is_no_read_command = is_no_read_command
self.intercept_message_level = intercept_message_level
self.is_notify = is_notify
self.selected_expressions = selected_expressions
@ -188,7 +188,7 @@ class DatabaseMessages(BaseDataModel):
"is_emoji": self.is_emoji,
"is_picid": self.is_picid,
"is_command": self.is_command,
"is_no_read_command": self.is_no_read_command,
"intercept_message_level": self.intercept_message_level,
"is_notify": self.is_notify,
"selected_expressions": self.selected_expressions,
"user_id": self.user_info.user_id,

View File

@ -22,7 +22,7 @@ class MessageAndActionModel(BaseDataModel):
is_action_record: bool = field(default=False)
action_name: Optional[str] = None
is_command: bool = field(default=False)
is_no_read_command: bool = field(default=False)
intercept_message_level: int = field(default=0)
@classmethod
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
@ -37,7 +37,7 @@ class MessageAndActionModel(BaseDataModel):
display_message=message.display_message,
chat_info_platform=message.chat_info.platform,
is_command=message.is_command,
is_no_read_command=getattr(message, "is_no_read_command", False),
intercept_message_level=getattr(message, "intercept_message_level", 0),
)

View File

@ -170,7 +170,7 @@ class Messages(BaseModel):
is_emoji = BooleanField(default=False)
is_picid = BooleanField(default=False)
is_command = BooleanField(default=False)
is_no_read_command = BooleanField(default=False)
intercept_message_level = IntegerField(default=0)
is_notify = BooleanField(default=False)
selected_expressions = TextField(null=True)
@ -324,9 +324,9 @@ class Expression(BaseModel):
# new mode fields
context = TextField(null=True)
up_content = TextField(null=True)
content_list = TextField(null=True)
style_list = TextField(null=True) # 存储相似的 style格式与 content_list 相同JSON 数组)
count = IntegerField(default=1)
last_active_time = FloatField()
chat_id = TextField(index=True)
@ -593,22 +593,41 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
db.execute_sql(f"CREATE TABLE {backup_table} AS SELECT * FROM {table_name}")
logger.info(f"已创建备份表 '{backup_table}'")
# 2. 删除原表
# 2. 获取原始行数(在删除表之前)
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
logger.info(f"备份表 '{backup_table}' 包含 {original_count} 行数据")
# 3. 删除原表
db.execute_sql(f"DROP TABLE {table_name}")
logger.info(f"已删除原表 '{table_name}'")
# 3. 重新创建表(使用当前模型定义)
# 4. 重新创建表(使用当前模型定义)
db.create_tables([model])
logger.info(f"已重新创建表 '{table_name}' 使用新的约束")
# 4. 从备份表恢复数据
# 获取字段列表
# 5. 从备份表恢复数据
# 获取字段列表,排除主键字段(让数据库自动生成新的主键)
fields = list(model._meta.fields.keys())
fields_str = ", ".join(fields)
# Peewee 默认使用 'id' 作为主键字段名
# 尝试获取主键字段名,如果获取失败则默认使用 'id'
primary_key_name = "id" # 默认值
try:
if hasattr(model._meta, "primary_key") and model._meta.primary_key:
if hasattr(model._meta.primary_key, "name"):
primary_key_name = model._meta.primary_key.name
elif isinstance(model._meta.primary_key, str):
primary_key_name = model._meta.primary_key
except Exception:
pass # 如果获取失败,使用默认值 'id'
# 对于需要从 NOT NULL 改为 NULL 的字段,直接复制数据
# 对于需要从 NULL 改为 NOT NULL 的字段,需要处理 NULL 值
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
# 如果字段列表包含主键,则排除它
if primary_key_name in fields:
fields_without_pk = [f for f in fields if f != primary_key_name]
logger.info(f"排除主键字段 '{primary_key_name}',让数据库自动生成新的主键")
else:
fields_without_pk = fields
fields_str = ", ".join(fields_without_pk)
# 检查是否有字段需要从 NULL 改为 NOT NULL
null_to_notnull_fields = [
@ -621,7 +640,7 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
# 构建更复杂的 SELECT 语句来处理 NULL 值
select_fields = []
for field_name in fields:
for field_name in fields_without_pk:
if field_name in null_to_notnull_fields:
field_obj = model._meta.fields[field_name]
# 根据字段类型设置默认值
@ -642,12 +661,13 @@ def _fix_table_constraints(table_name, model, constraints_to_fix):
select_str = ", ".join(select_fields)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {select_str} FROM {backup_table}"
else:
# 没有需要处理 NULL 的字段,直接复制数据(排除主键)
insert_sql = f"INSERT INTO {table_name} ({fields_str}) SELECT {fields_str} FROM {backup_table}"
db.execute_sql(insert_sql)
logger.info(f"已从备份表恢复数据到 '{table_name}'")
# 5. 验证数据完整性
original_count = db.execute_sql(f"SELECT COUNT(*) FROM {backup_table}").fetchone()[0]
new_count = db.execute_sql(f"SELECT COUNT(*) FROM {table_name}").fetchone()[0]
if original_count == new_count:

View File

@ -20,6 +20,9 @@ PROJECT_ROOT = logger_file.parent.parent.parent.resolve()
_file_handler = None
_console_handler = None
_ws_handler = None
# 全局标志,防止重复初始化
_logging_initialized = False
_cleanup_task_started = False
def get_file_handler():
@ -869,29 +872,41 @@ def get_logger(name: Optional[str]) -> structlog.stdlib.BoundLogger:
return logger
def initialize_logging():
def initialize_logging(verbose: bool = True):
"""手动初始化日志系统确保所有logger都使用正确的配置
在应用程序的早期调用此函数确保所有模块都使用统一的日志配置
Args:
verbose: 是否输出详细的初始化信息默认为 True
Runner 进程中可以设置为 False 以避免重复的初始化日志
"""
global LOG_CONFIG
global LOG_CONFIG, _logging_initialized
# 防止重复初始化(在同一进程内)
if _logging_initialized:
return
_logging_initialized = True
LOG_CONFIG = load_log_config()
# print(LOG_CONFIG)
configure_third_party_loggers()
reconfigure_existing_loggers()
# 启动日志清理任务
start_log_cleanup_task()
start_log_cleanup_task(verbose=verbose)
# 输出初始化信息
logger = get_logger("logger")
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
# 只在 verbose=True 时输出详细的初始化信息
if verbose:
logger = get_logger("logger")
console_level = LOG_CONFIG.get("console_log_level", LOG_CONFIG.get("log_level", "INFO"))
file_level = LOG_CONFIG.get("file_log_level", LOG_CONFIG.get("log_level", "INFO"))
logger.info("日志系统已初始化:")
logger.info(f" - 控制台级别: {console_level}")
logger.info(f" - 文件级别: {file_level}")
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
logger.info("日志系统已初始化:")
logger.info(f" - 控制台级别: {console_level}")
logger.info(f" - 文件级别: {file_level}")
logger.info(" - 轮转份数: 30个文件|自动清理: 30天前的日志")
def cleanup_old_logs():
@ -924,8 +939,19 @@ def cleanup_old_logs():
logger.error(f"清理旧日志文件时出错: {e}")
def start_log_cleanup_task():
"""启动日志清理任务"""
def start_log_cleanup_task(verbose: bool = True):
"""启动日志清理任务
Args:
verbose: 是否输出启动信息默认为 True
"""
global _cleanup_task_started
# 防止重复启动清理任务
if _cleanup_task_started:
return
_cleanup_task_started = True
def cleanup_task():
while True:
@ -935,8 +961,9 @@ def start_log_cleanup_task():
cleanup_thread = threading.Thread(target=cleanup_task, daemon=True)
cleanup_thread.start()
logger = get_logger("logger")
logger.info("已启动日志清理任务将自动清理30天前的日志文件轮转份数限制: 30个文件")
if verbose:
logger = get_logger("logger")
logger.info("已启动日志清理任务将自动清理30天前的日志文件轮转份数限制: 30个文件")
def shutdown_logging():

View File

@ -15,14 +15,18 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
# 检查maim_message版本
try:
maim_message_version = importlib.metadata.version("maim_message")
version_compatible = [int(x) for x in maim_message_version.split(".")] >= [0, 3, 3]
version_int = [int(x) for x in maim_message_version.split(".")]
version_compatible = version_int >= [0, 3, 3]
# Check for API Server feature (>= 0.6.0)
has_api_server_feature = version_int >= [0, 6, 0]
except (importlib.metadata.PackageNotFoundError, ValueError):
version_compatible = False
has_api_server_feature = False
# 读取配置项
maim_message_config = global_config.maim_message
# 设置基本参数
# 设置基本参数 (Legacy Server Mode)
kwargs = {
"host": os.environ["HOST"],
"port": int(os.environ["PORT"]),
@ -39,21 +43,129 @@ def get_global_api() -> MessageServer: # sourcery skip: extract-method
if maim_message_config.auth_token and len(maim_message_config.auth_token) > 0:
kwargs["enable_token"] = True
if maim_message_config.use_custom:
# 添加WSS模式支持
del kwargs["app"]
kwargs["host"] = maim_message_config.host
kwargs["port"] = maim_message_config.port
kwargs["mode"] = maim_message_config.mode
if maim_message_config.use_wss:
if maim_message_config.cert_file:
kwargs["ssl_certfile"] = maim_message_config.cert_file
if maim_message_config.key_file:
kwargs["ssl_keyfile"] = maim_message_config.key_file
kwargs["enable_custom_uvicorn_logger"] = False
# Removed legacy custom config block (use_custom) as requested.
kwargs["enable_custom_uvicorn_logger"] = False
global_api = MessageServer(**kwargs)
if version_compatible and maim_message_config.auth_token:
for token in maim_message_config.auth_token:
global_api.add_valid_token(token)
# ---------------------------------------------------------------------
# Additional API Server Configuration (maim_message >= 6.0)
# ---------------------------------------------------------------------
enable_api_server = maim_message_config.enable_api_server
# 如果版本支持且启用了API Server则初始化额外服务器
if has_api_server_feature and enable_api_server:
try:
from maim_message.server import WebSocketServer, ServerConfig
from maim_message.message import APIMessageBase
api_logger = get_logger("maim_message_api_server")
# 1. Prepare Config
api_server_host = maim_message_config.api_server_host
api_server_port = maim_message_config.api_server_port
use_wss = maim_message_config.api_server_use_wss
server_config = ServerConfig(
host=api_server_host,
port=api_server_port,
ssl_enabled=use_wss,
ssl_certfile=maim_message_config.api_server_cert_file if use_wss else None,
ssl_keyfile=maim_message_config.api_server_key_file if use_wss else None,
)
# 2. Setup Auth Handler
async def auth_handler(metadata: dict) -> bool:
allowed_keys = maim_message_config.api_server_allowed_api_keys
# If list is empty/None, allow all (default behavior of returning True)
if not allowed_keys:
return True
api_key = metadata.get("api_key")
if api_key in allowed_keys:
return True
api_logger.warning(f"Rejected connection with invalid API Key: {api_key}")
return False
server_config.on_auth = auth_handler
# 3. Setup Message Bridge
# Initialize refined route map if not exists
if not hasattr(global_api, "platform_map"):
global_api.platform_map = {}
async def bridge_message_handler(message: APIMessageBase, metadata: dict):
# Bridge message to the main bot logic
# We convert APIMessageBase to dict to be compatible with legacy handlers
# that MainBot (ChatManager) expects.
msg_dict = message.to_dict()
# Compatibility Layer: Flatten sender_info to top-level user_info/group_info
# Legacy MessageBase expects message_info to have user_info and group_info directly.
if "message_info" in msg_dict:
msg_info = msg_dict["message_info"]
sender_info = msg_info.get("sender_info")
if sender_info:
# If direct user_info/group_info are missing, populate them from sender_info
if "user_info" not in msg_info and (ui := sender_info.get("user_info")):
msg_info["user_info"] = ui
if "group_info" not in msg_info and (gi := sender_info.get("group_info")):
msg_info["group_info"] = gi
# Route Caching Logic: Simply map platform to API Key
# This allows us to send messages back to the correct API client for this platform
try:
api_key = metadata.get("api_key")
if api_key:
platform = msg_info.get("platform")
if platform:
global_api.platform_map[platform] = api_key
except Exception as e:
api_logger.warning(f"Failed to update platform map: {e}")
# Compatibility Layer: Ensure raw_message exists (even if None) as it's part of MessageBase
if "raw_message" not in msg_dict:
msg_dict["raw_message"] = None
await global_api.process_message(msg_dict)
server_config.on_message = bridge_message_handler
# 4. Initialize Server
extra_server = WebSocketServer(config=server_config)
# 5. Patch global_api lifecycle methods to manage both servers
original_run = global_api.run
original_stop = global_api.stop
async def patched_run():
api_logger.info(f"Starting Additional API Server on {api_server_host}:{api_server_port} (WSS: {use_wss})")
# Start the extra server (non-blocking start)
await extra_server.start()
# Run the original legacy server (this usually keeps running)
await original_run()
async def patched_stop():
api_logger.info("Stopping Additional API Server...")
await extra_server.stop()
await original_stop()
global_api.run = patched_run
global_api.stop = patched_stop
# Attach for reference
global_api.extra_server = extra_server
except ImportError:
get_logger("maim_message").error("Cannot import maim_message.server components. Is maim_message >= 0.6.0 installed?")
except Exception as e:
get_logger("maim_message").error(f"Failed to initialize Additional API Server: {e}")
import traceback
get_logger("maim_message").debug(traceback.format_exc())
return global_api

View File

@ -25,7 +25,7 @@ def find_messages(
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
filter_no_read_command=False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
"""
根据提供的过滤器排序和限制条件查找消息
@ -85,8 +85,9 @@ def find_messages(
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
query = query.where(~Messages.is_command)
if filter_no_read_command:
query = query.where(~Messages.is_no_read_command)
if filter_intercept_message_level is not None:
# 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息
query = query.where(Messages.intercept_message_level <= filter_intercept_message_level)
if limit > 0:
if limit_mode == "earliest":

View File

@ -4,6 +4,7 @@ TOML 工具函数
提供 TOML 文件的格式化保存功能确保数组等元素以美观的多行格式输出
"""
import re
from typing import Any
import tomlkit
from tomlkit.items import AoT, Table, Array
@ -33,7 +34,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj
# 决定是否多行:仅在顶层且长度超过阈值时
should_multiline = (depth == 0 and len(obj) > threshold)
should_multiline = depth == 0 and len(obj) > threshold
# 如果已经是 tomlkit Array原地修改以保留注释
if isinstance(obj, Array):
@ -45,7 +46,7 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
# 普通 list转换为 tomlkit 数组
arr = tomlkit.array()
arr.multiline(should_multiline)
for item in obj:
arr.append(_format_toml_value(item, threshold, depth + 1))
return arr
@ -54,14 +55,71 @@ def _format_toml_value(obj: Any, threshold: int, depth: int = 0) -> Any:
return obj
def save_toml_with_format(data: Any, file_path: str, multiline_threshold: int = 1) -> None:
"""格式化 TOML 数据并保存到文件"""
def _update_toml_doc(target: Any, source: Any) -> None:
"""
递归合并字典 source 的值更新到 target 保留 target 的注释和格式
- 已存在的键更新值递归处理嵌套字典
- 新增的键添加到 target
- 跳过 version 字段
"""
if isinstance(source, list) or not isinstance(source, dict) or not isinstance(target, dict):
return
for key, value in source.items():
if key == "version":
continue
if key in target:
# 已存在的键:递归更新或直接赋值
target_value = target[key]
if isinstance(value, dict) and isinstance(target_value, dict):
_update_toml_doc(target_value, value)
else:
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
else:
# 新增的键:添加到 target
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
def save_toml_with_format(
data: Any, file_path: str, multiline_threshold: int = 1, preserve_comments: bool = True
) -> None:
"""
格式化 TOML 数据并保存到文件
Args:
data: 要保存的数据dict tomlkit 文档
file_path: 保存路径
multiline_threshold: 数组多行格式化阈值-1 表示不格式化
preserve_comments: 是否保留原文件的注释和格式默认 True
若为 True 且文件已存在且 data 不是 tomlkit 文档会先读取原文件再将 data 合并进去
"""
import os
from tomlkit import TOMLDocument
# 如果需要保留注释、文件存在、且 data 不是已有的 tomlkit 文档,先读取原文件再合并
if preserve_comments and os.path.exists(file_path) and not isinstance(data, TOMLDocument):
with open(file_path, "r", encoding="utf-8") as f:
doc = tomlkit.load(f)
_update_toml_doc(doc, data)
data = doc
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
output = re.sub(r"\n{3,}", "\n\n", output)
with open(file_path, "w", encoding="utf-8") as f:
tomlkit.dump(formatted, f)
f.write(output)
def format_toml_string(data: Any, multiline_threshold: int = 1) -> str:
"""格式化 TOML 数据并返回字符串"""
formatted = _format_toml_value(data, multiline_threshold) if multiline_threshold >= 0 else data
return tomlkit.dumps(formatted)
output = tomlkit.dumps(formatted)
# 规范化:将 3+ 连续空行压缩为 1 个空行,防止空行累积
return re.sub(r"\n{3,}", "\n\n", output)

View File

@ -60,6 +60,12 @@ class ModelInfo(ConfigBase):
price_out: float = field(default=0.0)
"""每M token输出价格"""
temperature: float | None = field(default=None)
"""模型级别温度(可选),会覆盖任务配置中的温度"""
max_tokens: int | None = field(default=None)
"""模型级别最大token数可选会覆盖任务配置中的max_tokens"""
force_stream_mode: bool = field(default=False)
"""是否强制使用流式输出模式"""

View File

@ -31,10 +31,9 @@ from src.config.official_configs import (
RelationshipConfig,
ToolConfig,
VoiceConfig,
MoodConfig,
MemoryConfig,
DebugConfig,
JargonConfig,
DreamConfig,
)
from .api_ada_configs import (
@ -57,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.11.6"
MMC_VERSION = "0.12.0"
def get_key_comment(toml_table, key):
@ -354,9 +353,8 @@ class Config(ConfigBase):
tool: ToolConfig
memory: MemoryConfig
debug: DebugConfig
mood: MoodConfig
voice: VoiceConfig
jargon: JargonConfig
dream: DreamConfig
@dataclass

View File

@ -43,10 +43,13 @@ class PersonalityConfig(ConfigBase):
"""人格"""
reply_style: str = ""
"""表达风格"""
"""默认表达风格"""
interest: str = ""
"""兴趣"""
multiple_reply_style: list[str] = field(default_factory=lambda: [])
"""可选的多种表达风格列表,当配置不为空时可按概率随机替换 reply_style"""
multiple_probability: float = 0.0
"""每次构建回复时,从 multiple_reply_style 中随机替换 reply_style 的概率0.0-1.0"""
plan_style: str = ""
"""说话规则,行为风格"""
@ -79,12 +82,6 @@ class ChatConfig(ConfigBase):
max_context_size: int = 18
"""上下文长度"""
interest_rate_mode: Literal["fast", "accurate"] = "fast"
"""兴趣值计算模式fast为快速计算accurate为精确计算"""
planner_size: float = 1.5
"""副规划器大小越小麦麦的动作执行能力越精细但是消耗更多token调大可以缓解429类错误"""
mentioned_bot_reply: bool = True
"""是否启用提及必回复"""
@ -117,8 +114,13 @@ class ChatConfig(ConfigBase):
时间区间支持跨夜例如 "23:00-02:00"
"""
include_planner_reasoning: bool = False
"""是否将planner推理加入replyer默认关闭不加入"""
think_mode: Literal["classic", "deep", "dynamic"] = "classic"
"""
思考模式配置
- classic: 默认think_level为0轻量回复不需要思考和回忆
- deep: 默认think_level为1深度回复需要进行回忆和思考
- dynamic: think_level由planner动态给出根据planner返回的think_level决定
"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""与 ChatStream.get_stream_id 一致地从 "platform:id:type" 生成 chat_id。"""
@ -173,7 +175,11 @@ class ChatConfig(ConfigBase):
def get_talk_value(self, chat_id: Optional[str]) -> float:
"""根据规则返回当前 chat 的动态 talk_value未匹配则回退到基础值。"""
if not self.enable_talk_value_rules or not self.talk_value_rules:
return self.talk_value
result = self.talk_value
# 防止返回0值自动转换为0.0001
if result == 0:
return 0.0000001
return result
now_min = self._now_minutes()
@ -199,7 +205,11 @@ class ChatConfig(ConfigBase):
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
result = float(value)
# 防止返回0值自动转换为0.0001
if result == 0:
return 0.0000001
return result
except Exception:
continue
@ -218,12 +228,20 @@ class ChatConfig(ConfigBase):
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
try:
return float(value)
result = float(value)
# 防止返回0值自动转换为0.0001
if result == 0:
return 0.0000001
return result
except Exception:
continue
# 3) 未命中规则返回基础值
return self.talk_value
result = self.talk_value
# 防止返回0值自动转换为0.0001
if result == 0:
return 0.0000001
return result
@dataclass
@ -244,13 +262,21 @@ class MemoryConfig(ConfigBase):
max_agent_iterations: int = 5
"""Agent最多迭代轮数最低为1"""
agent_timeout_seconds: float = 120.0
"""Agent超时时间"""
enable_jargon_detection: bool = True
"""记忆检索过程中是否启用黑话识别"""
global_memory: bool = False
"""是否允许记忆检索在聊天记录中进行全局查询忽略当前chat_id仅对 search_chat_history 等工具生效)"""
def __post_init__(self):
"""验证配置值"""
if self.max_agent_iterations < 1:
raise ValueError(f"max_agent_iterations 必须至少为1当前值: {self.max_agent_iterations}")
if self.agent_timeout_seconds <= 0:
raise ValueError(f"agent_timeout_seconds 必须大于0当前值: {self.agent_timeout_seconds}")
@dataclass
@ -260,20 +286,20 @@ class ExpressionConfig(ConfigBase):
learning_list: list[list] = field(default_factory=lambda: [])
"""
表达学习配置列表支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...]
格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...]
示例:
[
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5
["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习
["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
]
说明:
- 第一位: chat_stream_id空字符串表示全局配置
- 第二位: 是否使用学到的表达 ("enable"/"disable")
- 第三位: 是否学习表达 ("enable"/"disable")
- 第四位: 学习强度浮点数影响学习频率最短学习时间间隔 = 300/学习强度
- 第四位: 是否启用jargon学习 ("enable"/"disable")
"""
expression_groups: list[list[str]] = field(default_factory=list)
@ -296,6 +322,12 @@ class ExpressionConfig(ConfigBase):
如果列表为空则所有聊天流都可以进行表达反思前提是 reflect = true
"""
all_global_jargon: bool = False
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id。注意此功能关闭后已经记录的全局黑话不会改变需要手动删除"""
enable_jargon_explanation: bool = True
"""是否在回复前尝试对上下文中的黑话进行解释关闭可减少一次LLM调用仅影响回复前的黑话匹配与解释不影响黑话学习"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
"""
解析流配置字符串并生成对应的 chat_id
@ -331,7 +363,7 @@ class ExpressionConfig(ConfigBase):
except (ValueError, IndexError):
return None
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]:
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
"""
根据聊天流ID获取表达配置
@ -339,11 +371,11 @@ class ExpressionConfig(ConfigBase):
chat_stream_id: 聊天流ID格式为哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
"""
if not self.learning_list:
# 如果没有配置,使用默认值:启用表达,启用学习,300秒间隔
return True, True, 300
# 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习
return True, True, True
# 优先检查聊天流特定的配置
if chat_stream_id:
@ -356,10 +388,10 @@ class ExpressionConfig(ConfigBase):
if global_expression_config is not None:
return global_expression_config
# 如果都没有匹配,返回默认值
return True, True, 300
# 如果都没有匹配,返回默认值启用表达启用学习启用jargon学习
return True, True, True
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]:
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]:
"""
获取特定聊天流的表达配置
@ -367,7 +399,7 @@ class ExpressionConfig(ConfigBase):
chat_stream_id: 聊天流ID哈希值
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)如果没有配置则返回 None
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)如果没有配置则返回 None
"""
for config_item in self.learning_list:
if not config_item or len(config_item) < 4:
@ -392,19 +424,19 @@ class ExpressionConfig(ConfigBase):
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity: float = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
enable_jargon_learning: bool = config_item[3].lower() == "enable"
return use_expression, enable_learning, enable_jargon_learning # type: ignore
except (ValueError, IndexError):
continue
return None
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]:
def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]:
"""
获取全局表达配置
Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)如果没有配置则返回 None
tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)如果没有配置则返回 None
"""
for config_item in self.learning_list:
if not config_item or len(config_item) < 4:
@ -415,8 +447,8 @@ class ExpressionConfig(ConfigBase):
try:
use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3])
return use_expression, enable_learning, learning_intensity # type: ignore
enable_jargon_learning: bool = config_item[3].lower() == "enable"
return use_expression, enable_learning, enable_jargon_learning # type: ignore
except (ValueError, IndexError):
continue
@ -431,20 +463,6 @@ class ToolConfig(ConfigBase):
"""是否在聊天中启用工具"""
@dataclass
class MoodConfig(ConfigBase):
"""情绪配置类"""
enable_mood: bool = True
"""是否启用情绪系统"""
mood_update_threshold: float = 1
"""情绪更新阈值,越高,更新越慢"""
emotion_style: str = "情绪较为稳定,但遭遇特定事件的时候起伏较大"
"""情感特征,影响情绪的变化情况"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@ -639,29 +657,29 @@ class ExperimentalConfig(ConfigBase):
class MaimMessageConfig(ConfigBase):
"""maim_message配置类"""
use_custom: bool = False
"""是否使用自定义的maim_message配置"""
host: str = "127.0.0.1"
"""主机地址"""
port: int = 8090
""""端口号"""
mode: Literal["ws", "tcp"] = "ws"
"""连接模式支持ws和tcp"""
use_wss: bool = False
"""是否使用WSS安全连接"""
cert_file: str = ""
"""SSL证书文件路径仅在use_wss=True时有效"""
key_file: str = ""
"""SSL密钥文件路径仅在use_wss=True时有效"""
auth_token: list[str] = field(default_factory=lambda: [])
"""认证令牌用于API验证为空则不启用验证"""
"""认证令牌用于旧版API验证为空则不启用验证"""
enable_api_server: bool = False
"""是否启用额外的新版API Server"""
api_server_host: str = "0.0.0.0"
"""新版API Server主机地址"""
api_server_port: int = 8090
"""新版API Server端口号"""
api_server_use_wss: bool = False
"""新版API Server是否启用WSS"""
api_server_cert_file: str = ""
"""新版API Server SSL证书文件路径"""
api_server_key_file: str = ""
"""新版API Server SSL密钥文件路径"""
api_server_allowed_api_keys: list[str] = field(default_factory=lambda: [])
"""新版API Server允许的API Key列表为空则允许所有连接"""
@dataclass
@ -709,8 +727,86 @@ class LPMMKnowledgeConfig(ConfigBase):
@dataclass
class JargonConfig(ConfigBase):
"""Jargon配置类"""
class DreamConfig(ConfigBase):
"""Dream配置类"""
all_global: bool = False
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id"""
interval_minutes: int = 30
"""做梦时间间隔分钟默认30分钟"""
max_iterations: int = 20
"""做梦最大轮次默认20轮"""
first_delay_seconds: int = 60
"""程序启动后首次做梦前的延迟时间默认60秒"""
dream_time_ranges: list[str] = field(default_factory=lambda: [])
"""
做梦时间段配置列表格式["HH:MM-HH:MM", ...]
如果列表为空则表示全天允许做梦
如果配置了时间段则只有在这些时间段内才会实际执行做梦流程
时间段外调度器仍会按间隔检查但不会进入做梦流程
示例:
[
"09:00-22:00", # 白天允许做梦
"23:00-02:00", # 跨夜时间段23:00到次日02:00
]
支持跨夜区间例如 "23:00-02:00" 表示从23:00到次日02:00
"""
def _now_minutes(self) -> int:
"""返回本地时间的分钟数(0-1439)。"""
lt = time.localtime()
return lt.tm_hour * 60 + lt.tm_min
def _parse_range(self, range_str: str) -> Optional[tuple[int, int]]:
"""解析 "HH:MM-HH:MM" 到 (start_min, end_min)。"""
try:
start_str, end_str = [s.strip() for s in range_str.split("-")]
sh, sm = [int(x) for x in start_str.split(":")]
eh, em = [int(x) for x in end_str.split(":")]
return sh * 60 + sm, eh * 60 + em
except Exception:
return None
def _in_range(self, now_min: int, start_min: int, end_min: int) -> bool:
"""
判断 now_min 是否在 [start_min, end_min] 区间内
支持跨夜如果 start > end则表示跨越午夜
"""
if start_min <= end_min:
return start_min <= now_min <= end_min
# 跨夜:例如 23:00-02:00
return now_min >= start_min or now_min <= end_min
def is_in_dream_time(self) -> bool:
"""
检查当前时间是否在允许做梦的时间段内
如果 dream_time_ranges 为空则返回 True全天允许
"""
if not self.dream_time_ranges:
return True
now_min = self._now_minutes()
for time_range in self.dream_time_ranges:
if not isinstance(time_range, str):
continue
parsed = self._parse_range(time_range)
if not parsed:
continue
start_min, end_min = parsed
if self._in_range(now_min, start_min, end_min):
return True
return False
def __post_init__(self):
"""验证配置值"""
if self.interval_minutes < 1:
raise ValueError(f"interval_minutes 必须至少为1当前值: {self.interval_minutes}")
if self.max_iterations < 1:
raise ValueError(f"max_iterations 必须至少为1当前值: {self.max_iterations}")
if self.first_delay_seconds < 0:
raise ValueError(f"first_delay_seconds 不能为负数,当前值: {self.first_delay_seconds}")

View File

@ -0,0 +1,580 @@
import asyncio
import random
import time
from typing import Any, Dict, List, Optional, Tuple
from peewee import fn
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.common.database.database_model import ChatHistory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.plugin_system.apis import llm_api
from src.dream.dream_generator import generate_dream_summary
# dream 工具工厂函数
from src.dream.tools.search_chat_history_tool import make_search_chat_history
from src.dream.tools.get_chat_history_detail_tool import make_get_chat_history_detail
from src.dream.tools.delete_chat_history_tool import make_delete_chat_history
from src.dream.tools.create_chat_history_tool import make_create_chat_history
from src.dream.tools.update_chat_history_tool import make_update_chat_history
from src.dream.tools.finish_maintenance_tool import make_finish_maintenance
from src.dream.tools.search_jargon_tool import make_search_jargon
from src.dream.tools.delete_jargon_tool import make_delete_jargon
from src.dream.tools.update_jargon_tool import make_update_jargon
logger = get_logger("dream_agent")
def init_dream_prompts() -> None:
"""初始化 dream agent 的提示词"""
Prompt(
"""
你的名字是{bot_name}你现在处于"梦境维护模式dream agent"
你可以自由地在 ChatHistory 库中探索整理创建和删改记录以帮助自己在未来更好地回忆和理解对话历史
本轮要维护的聊天ID{chat_id}
本轮随机选中的起始记忆 ID{start_memory_id}
请优先以这条起始记忆为切入点先理解它的内容与上下文再决定如何在其附近进行创建新概括重写或删除等整理操作如果起始记忆为空则由你自行选择合适的切入点
你可以使用的工具包括
**ChatHistory 维护工具**
- search_chat_history根据关键词或参与人搜索该 chat_id 下的历史记忆概括列表
- get_chat_history_detail查看某条概括的详细内容
- create_chat_history根据整理后的理解创建一条新的 ChatHistory 概括记录主题概括关键词关键信息等
- update_chat_history在不改变事实的前提下重写或精炼主题概括关键词关键信息
- delete_chat_history删除明显冗余噪声错误或无意义的记录或者非常有时效性的信息或者无太多有用信息的日常互动
你也可以先用 create_chat_history 创建一条新的综合概括再对旧的冗余记录执行多次 delete_chat_history 来完成合并效果
**Jargon黑话维护工具只读禁止修改**
- search_jargon根据一个或多个关键词搜索Jargon 记录通常是含义不明确的词条或者特殊的缩写
**通用工具**
- finish_maintenance当你认为当前维护工作已经完成没有更多需要整理的内容时调用此工具来结束本次运行
**工作目标**
- 发现冗余重复或高度相似的记录并进行合并或删除
- 发现主题/概括过于含糊啰嗦或缺少关键信息的记录进行重写和精简
- summary要尽可能保持有用的信息
- 尽量保持信息的真实与可用性不要凭空捏造事实
**合并准则**
- 你可以新建一个记录然后删除旧记录来实现合并
- 如果两个或多个记录的主题相似内容是对主题不同方面的信息或讨论且信息量较少则可以合并为一条记录
- 如果两个记录冲突可以根据逻辑保留一个或者进行整合也可以采取更新的记录删除旧的记录
**轮次信息**
- 本次维护最多执行 {max_iterations}
- 每轮开始时系统会告知你当前是第几轮还剩多少轮
- 如果提前完成维护工作可以调用 finish_maintenance 工具主动结束
**每一轮的执行方式必须遵守**
- 第一步先用一小段中文自然语言写出你的思考和本轮计划例如要查什么准备怎么合并/修改
- 第二步在这段思考之后再通过工具调用来执行你的计划可以调用 0~N 个工具
- 第三步收到工具结果后在下一轮继续先写出新的思考再视情况继续调用工具
请不要在没有先写出思考的情况下直接调用工具
只输出你的思考内容或工具调用结果由系统负责真正执行工具调用
""",
name="dream_react_head_prompt",
)
class DreamTool:
"""dream 模块内部使用的简易工具封装"""
def __init__(self, name: str, description: str, parameters: List[Tuple], execute_func):
self.name = name
self.description = description
self.parameters = parameters
self.execute_func = execute_func
def get_tool_definition(self) -> Dict[str, Any]:
return {
"name": self.name,
"description": self.description,
"parameters": self.parameters,
}
async def execute(self, **kwargs) -> str:
return await self.execute_func(**kwargs)
class DreamToolRegistry:
def __init__(self) -> None:
self.tools: Dict[str, DreamTool] = {}
def register_tool(self, tool: DreamTool) -> None:
"""
注册或更新 dream 工具
注意dream agent 每个 chat_id 会重新初始化工具这里允许覆盖已有同名工具
"""
self.tools[tool.name] = tool
logger.info(f"注册/更新 dream 工具: {tool.name}")
def get_tool(self, name: str) -> Optional[DreamTool]:
return self.tools.get(name)
def get_tool_definitions(self) -> List[Dict[str, Any]]:
return [tool.get_tool_definition() for tool in self.tools.values()]
_dream_tool_registry = DreamToolRegistry()
def get_dream_tool_registry() -> DreamToolRegistry:
return _dream_tool_registry
def init_dream_tools(chat_id: str) -> None:
"""注册 dream agent 可用的 ChatHistory / Jargon 相关工具(限定在当前 chat_id 作用域内)"""
from src.llm_models.payload_content.tool_option import ToolParamType
# 通过工厂函数生成绑定当前 chat_id 的工具实现
search_chat_history = make_search_chat_history(chat_id)
get_chat_history_detail = make_get_chat_history_detail(chat_id)
delete_chat_history = make_delete_chat_history(chat_id)
create_chat_history = make_create_chat_history(chat_id)
update_chat_history = make_update_chat_history(chat_id)
finish_maintenance = make_finish_maintenance(chat_id)
search_jargon = make_search_jargon(chat_id)
delete_jargon = make_delete_jargon(chat_id)
update_jargon = make_update_jargon(chat_id)
_dream_tool_registry.register_tool(
DreamTool(
"search_chat_history",
"根据关键词或参与人查询当前 chat_id 下的 ChatHistory 概览,便于快速定位相关记忆。",
[
(
"keyword",
ToolParamType.STRING,
"关键词(可选,支持多个关键词,可用空格、逗号等分隔)。",
False,
None,
),
("participant", ToolParamType.STRING, "参与人昵称(可选)。", False, None),
],
search_chat_history,
)
)
_dream_tool_registry.register_tool(
DreamTool(
"get_chat_history_detail",
"根据 memory_id 获取单条 ChatHistory 的详细内容,包含主题、概括、关键词、关键信息等字段(不包含原文)。",
[
("memory_id", ToolParamType.INTEGER, "ChatHistory 主键 ID。", True, None),
],
get_chat_history_detail,
)
)
_dream_tool_registry.register_tool(
DreamTool(
"delete_chat_history",
"根据 memory_id 删除一条 ChatHistory 记录(请谨慎使用)。",
[
("memory_id", ToolParamType.INTEGER, "需要删除的 ChatHistory 主键 ID。", True, None),
],
delete_chat_history,
)
)
_dream_tool_registry.register_tool(
DreamTool(
"update_chat_history",
"按字段更新 ChatHistory 记录,可用于清理、重写或补充信息。",
[
("memory_id", ToolParamType.INTEGER, "需要更新的 ChatHistory 主键 ID。", True, None),
("theme", ToolParamType.STRING, "新的主题标题,如果不需要修改可不填。", False, None),
("summary", ToolParamType.STRING, "新的概括内容,如果不需要修改可不填。", False, None),
("keywords", ToolParamType.STRING, "新的关键词 JSON 字符串,如 ['关键词1','关键词2']。", False, None),
("key_point", ToolParamType.STRING, "新的关键信息 JSON 字符串,如 ['要点1','要点2']。", False, None),
],
update_chat_history,
)
)
_dream_tool_registry.register_tool(
DreamTool(
"create_chat_history",
"根据整理后的理解创建一条新的 ChatHistory 概括记录(主题、概括、关键词、关键信息等)。",
[
("theme", ToolParamType.STRING, "新的主题标题(必填)。", True, None),
("summary", ToolParamType.STRING, "新的概括内容(必填)。", True, None),
(
"keywords",
ToolParamType.STRING,
"新的关键词 JSON 字符串,如 ['关键词1','关键词2'](必填)。",
True,
None,
),
(
"key_point",
ToolParamType.STRING,
"新的关键信息 JSON 字符串,如 ['要点1','要点2'](必填)。",
True,
None,
),
("start_time", ToolParamType.STRING, "起始时间戳Unix 时间,必填)。", True, None),
("end_time", ToolParamType.STRING, "结束时间戳Unix 时间,必填)。", True, None),
],
create_chat_history,
)
)
_dream_tool_registry.register_tool(
DreamTool(
"finish_maintenance",
"结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理、合并或修改的内容时,调用此工具来主动结束本次运行。",
[
(
"reason",
ToolParamType.STRING,
"结束维护的原因说明(可选),例如 '已完成所有记录的整理''当前记录质量良好,无需进一步维护'",
False,
None,
),
],
finish_maintenance,
)
)
# ==================== Jargon 维护工具 ====================
# 注册 Jargon 工具
_dream_tool_registry.register_tool(
DreamTool(
"search_jargon",
"根据一个或多个关键词搜索当前 chat_id 相关的 Jargon 记录概览(只包含 is_jargon=True含全局 Jargon便于快速理解黑话库。",
[
("keyword", ToolParamType.STRING, "按一个或多个关键词搜索内容/含义/推断结果(必填)。", True, None),
],
search_jargon,
)
)
async def run_dream_agent_once(
chat_id: str,
max_iterations: Optional[int] = None,
start_memory_id: Optional[int] = None,
) -> None:
"""
运行一次 dream agent对指定 chat_id ChatHistory 进行最多 max_iterations 轮的整理
如果 max_iterations None则使用配置文件中的默认值
"""
if max_iterations is None:
max_iterations = global_config.dream.max_iterations
start_ts = time.time()
logger.info(f"[dream] 开始对 chat_id={chat_id} 进行 dream 维护,最多迭代 {max_iterations}")
# 初始化工具(作用域限定在当前 chat_id
init_dream_tools(chat_id)
tool_registry = get_dream_tool_registry()
tool_defs = tool_registry.get_tool_definitions()
bot_name = global_config.bot.nickname
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
head_prompt = await global_prompt_manager.format_prompt(
"dream_react_head_prompt",
bot_name=bot_name,
time_now=time_now,
chat_id=chat_id,
start_memory_id=start_memory_id if start_memory_id is not None else "无(本轮由你自由选择切入点)",
max_iterations=max_iterations,
)
conversation_messages: List[Message] = []
# 如果提供了起始记忆 ID则在对话正式开始前先把这条记忆的详细信息放入上下文
# 避免 LLM 还需要额外调用一次 get_chat_history_detail 才能看到起始记忆内容。
if start_memory_id is not None:
try:
record = ChatHistory.get_or_none(ChatHistory.id == start_memory_id)
if record:
start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time))
if record.start_time
else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
detail_text = (
f"ID={record.id}\n"
f"chat_id={record.chat_id}\n"
f"时间范围={start_time_str}{end_time_str}\n"
f"主题={record.theme or ''}\n"
f"关键词={record.keywords or ''}\n"
f"参与者={record.participants or ''}\n"
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
)
logger.debug(
f"[dream] 预加载起始记忆详情 memory_id={start_memory_id}"
f"预览: {detail_text[:200].replace(chr(10), ' ')}"
)
start_detail_builder = MessageBuilder()
start_detail_builder.set_role(RoleType.User)
start_detail_builder.add_text_content(
"【起始记忆详情】以下是本轮随机/指定的起始记忆的详细信息,供你在整理时优先参考:\n\n" + detail_text
)
conversation_messages.append(start_detail_builder.build())
else:
logger.warning(
f"[dream] 提供的 start_memory_id={start_memory_id} 未找到对应 ChatHistory 记录,"
"将不预加载起始记忆详情。"
)
except Exception as e:
logger.error(f"[dream] 预加载起始记忆详情失败 start_memory_id={start_memory_id}: {e}")
# 注意message_factory 必须是同步函数,返回消息列表(不能是 async/coroutine
def message_factory(
_client,
*,
_head_prompt: str = head_prompt,
_conversation_messages: List[Message] = conversation_messages,
) -> List[Message]:
messages: List[Message] = []
system_builder = MessageBuilder()
system_builder.set_role(RoleType.System)
system_builder.add_text_content(_head_prompt)
messages.append(system_builder.build())
messages.extend(_conversation_messages)
return messages
for iteration in range(1, max_iterations + 1):
# 在每轮开始时,添加轮次信息到对话中
remaining_rounds = max_iterations - iteration + 1
round_info_builder = MessageBuilder()
round_info_builder.set_role(RoleType.User)
round_info_builder.add_text_content(
f"【轮次信息】当前是第 {iteration}/{max_iterations} 轮,还剩 {remaining_rounds} 轮。"
)
conversation_messages.append(round_info_builder.build())
# 调用 LLM 让其决定是否要使用工具
(
success,
response,
reasoning_content,
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_defs,
request_type="dream.react",
)
if not success:
logger.error(f"[dream] 第 {iteration} 轮 LLM 调用失败: {response}")
break
# 先输出「思考」内容,再输出工具调用信息(思考文本较长,仅在 debug 下输出)
thought_log = reasoning_content or (response[:300] if response else "")
if thought_log:
logger.debug(f"[dream] 第 {iteration} 轮思考内容: {thought_log}")
logger.info(
f"[dream] 第 {iteration} 轮响应,模型={model_name},工具调用数={len(tool_calls) if tool_calls else 0}"
)
assistant_msg: Optional[Message] = None
if tool_calls:
builder = MessageBuilder()
builder.set_role(RoleType.Assistant)
if response and response.strip():
builder.add_text_content(response)
builder.set_tool_calls(tool_calls)
assistant_msg = builder.build()
elif response and response.strip():
builder = MessageBuilder()
builder.set_role(RoleType.Assistant)
builder.add_text_content(response)
assistant_msg = builder.build()
if assistant_msg:
conversation_messages.append(assistant_msg)
# 如果本轮没有工具调用,仅作为思考记录,继续下一轮
if not tool_calls:
logger.debug(f"[dream] 第 {iteration} 轮未调用任何工具,仅记录思考。")
continue
# 执行所有工具调用
tasks = []
finish_maintenance_called = False
for tc in tool_calls:
tool = tool_registry.get_tool(tc.func_name)
if not tool:
logger.warning(f"[dream] 未知工具:{tc.func_name}")
continue
# 检测是否调用了 finish_maintenance 工具
if tc.func_name == "finish_maintenance":
finish_maintenance_called = True
params = tc.args or {}
async def _run_single(t: DreamTool, p: Dict[str, Any], call_id: str, it: int):
try:
result = await t.execute(**p)
logger.debug(f"[dream] 第 {it} 轮 工具 {t.name} 执行完成")
return call_id, result
except Exception as e:
logger.error(f"[dream] 工具 {t.name} 执行失败: {e}")
return call_id, f"工具 {t.name} 执行失败: {e}"
tasks.append(_run_single(tool, params, tc.call_id, iteration))
if not tasks:
continue
tool_results = await asyncio.gather(*tasks, return_exceptions=False)
# 将工具结果作为 Tool 消息追加
for call_id, obs in tool_results:
tool_builder = MessageBuilder()
tool_builder.set_role(RoleType.Tool)
tool_builder.add_text_content(str(obs))
tool_builder.add_tool_call(call_id)
conversation_messages.append(tool_builder.build())
# 如果调用了 finish_maintenance 工具,提前结束本次运行
if finish_maintenance_called:
logger.info(f"[dream] 第 {iteration} 轮检测到 finish_maintenance 工具调用,提前结束本次维护。")
break
cost = time.time() - start_ts
logger.info(f"[dream] 对 chat_id={chat_id} 的 dream 维护结束,共迭代 {iteration} 轮,耗时 {cost:.1f}")
# 生成梦境总结
await generate_dream_summary(chat_id, conversation_messages, iteration, cost)
def _pick_random_chat_id() -> Optional[str]:
"""从 ChatHistory 中随机选择一个 chat_id用于 dream agent 本次维护
规则
- 只在 chat_id 所属的 ChatHistory 记录数 >= 10 时才会参与随机选择
- 记录数不足 10 chat_id 将被跳过不会触发做梦 react
"""
try:
# 统计每个 chat_id 的记录数,只保留记录数 >= 10 的 chat_id
rows = (
ChatHistory.select(ChatHistory.chat_id, fn.COUNT(ChatHistory.id).alias("cnt"))
.group_by(ChatHistory.chat_id)
.having(fn.COUNT(ChatHistory.id) >= 10)
.order_by(ChatHistory.chat_id)
.limit(200)
)
eligible_ids = [r.chat_id for r in rows]
if not eligible_ids:
logger.warning("[dream] ChatHistory 中暂无满足条件(记录数 >= 10的 chat_id本轮 dream 任务跳过。")
return None
chosen = random.choice(eligible_ids)
logger.info(f"[dream] 从 {len(eligible_ids)} 个满足条件的 chat_id 中随机选择:{chosen}")
return chosen
except Exception as e:
logger.error(f"[dream] 随机选择 chat_id 失败: {e}")
return None
def _pick_random_memory_for_chat(chat_id: str) -> Optional[int]:
"""
在给定 chat_id 下随机选择一条 ChatHistory 记录作为本轮整理的起始记忆
"""
try:
rows = (
ChatHistory.select(ChatHistory.id)
.where(ChatHistory.chat_id == chat_id)
.order_by(ChatHistory.start_time.asc())
.limit(200)
)
ids = [r.id for r in rows]
if not ids:
logger.warning(f"[dream] chat_id={chat_id} 下暂无 ChatHistory 记录,无法选择起始记忆。")
return None
return random.choice(ids)
except Exception as e:
logger.error(f"[dream] 在 chat_id={chat_id} 下随机选择起始记忆失败: {e}")
return None
async def run_dream_cycle_once() -> None:
"""
单次 dream 周期
- 随机选择一个 chat_id
- 在该 chat_id 下随机选择一条 ChatHistory 作为起始记忆
- 以这条起始记忆为切入点对该 chat_id 运行一次 dream agent最多 15
"""
chat_id = _pick_random_chat_id()
if not chat_id:
return
start_memory_id = _pick_random_memory_for_chat(chat_id)
await run_dream_agent_once(
chat_id=chat_id,
max_iterations=None, # 使用配置文件中的默认值
start_memory_id=start_memory_id,
)
async def start_dream_scheduler(
first_delay_seconds: Optional[int] = None,
interval_seconds: Optional[int] = None,
stop_event: Optional[asyncio.Event] = None,
) -> None:
"""
dream 调度器
- 程序启动后先等待 first_delay_seconds如果为 None则使用配置文件中的值默认 60s
- 然后每隔 interval_seconds如果为 None则使用配置文件中的值默认 30 分钟运行一次 dream agent 周期
- 如果提供 stop_event则在 stop_event set() 后优雅退出循环
"""
if first_delay_seconds is None:
first_delay_seconds = global_config.dream.first_delay_seconds
if interval_seconds is None:
interval_seconds = global_config.dream.interval_minutes * 60
logger.info(
f"[dream] dream 调度器启动:首次延迟 {first_delay_seconds}s之后每隔 {interval_seconds}s ({interval_seconds // 60} 分钟) 运行一次 dream agent"
)
try:
await asyncio.sleep(first_delay_seconds)
while True:
if stop_event is not None and stop_event.is_set():
logger.info("[dream] 收到停止事件,结束 dream 调度器循环。")
break
start_ts = time.time()
# 检查当前时间是否在允许做梦的时间段内
if not global_config.dream.is_in_dream_time():
logger.debug("[dream] 当前时间不在允许做梦的时间段内,跳过本次执行")
else:
try:
await run_dream_cycle_once()
except Exception as e:
logger.error(f"[dream] 单次 dream 周期执行异常: {e}")
elapsed = time.time() - start_ts
# 保证两次执行之间至少间隔 interval_seconds
to_sleep = max(0.0, interval_seconds - elapsed)
await asyncio.sleep(to_sleep)
except asyncio.CancelledError:
logger.info("[dream] dream 调度器任务被取消,准备退出。")
raise
# 初始化提示词
init_dream_prompts()

View File

@ -0,0 +1,203 @@
import random
from typing import List, Optional
from src.common.logger import get_logger
from src.config.config import model_config
from src.chat.utils.prompt_builder import Prompt
from src.llm_models.payload_content.message import RoleType, Message
from src.llm_models.utils_model import LLMRequest
logger = get_logger("dream_generator")
# 初始化 utils 模型用于生成梦境总结
_dream_summary_model: Optional[LLMRequest] = None
# 梦境风格列表21种
DREAM_STYLES = [
"保持诗意和想象力,自由编写",
"诗意朦胧,如薄雾笼罩的清晨",
"奇幻冒险,充满未知与探索",
"温暖怀旧,带着时光的痕迹",
"神秘悬疑,暗藏深意",
"浪漫唯美,如诗如画",
"科幻未来,科技与想象交织",
"自然清新,如山林间的微风",
"深沉哲思,引人深思",
"轻松幽默,充满趣味",
"悲伤忧郁,带着淡淡哀愁",
"激昂热烈,充满活力",
"宁静平和,如湖面般平静",
"荒诞离奇,打破常规",
"细腻温柔,如春风拂面",
"壮阔宏大,气势磅礴",
"简约纯粹,返璞归真",
"复杂多变,层次丰富",
"梦幻迷离,虚实难辨",
"现实写意,贴近生活",
"抽象概念,超越具象",
]
def get_random_dream_styles(count: int = 2) -> List[str]:
"""从梦境风格列表中随机选择指定数量的风格"""
return random.sample(DREAM_STYLES, min(count, len(DREAM_STYLES)))
def get_dream_summary_model() -> LLMRequest:
"""获取用于生成梦境总结的 utils 模型实例"""
global _dream_summary_model
if _dream_summary_model is None:
_dream_summary_model = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="dream.summary",
)
return _dream_summary_model
def init_dream_summary_prompt() -> None:
"""初始化梦境总结的提示词"""
Prompt(
"""
你刚刚完成了一次对聊天记录的记忆整理工作以下是整理过程的摘要
整理过程
{conversation_text}
请将这次整理涉及的相关信息改写为一个富有诗意和想象力的"梦境"请你仅使用具体的记忆的内容而不是整理过程编写
要求
1. 使用第一人称视角
2. 叙述直白不要复杂修辞口语化
3. 长度控制在200-800
4. 用中文输出
梦境风格
{dream_styles}
请直接输出梦境内容不要添加其他说明
""",
name="dream_summary_prompt",
)
async def generate_dream_summary(
chat_id: str,
conversation_messages: List[Message],
total_iterations: int,
time_cost: float,
) -> None:
"""生成梦境总结并输出到日志"""
try:
import json
from src.chat.utils.prompt_builder import global_prompt_manager
# 第一步:建立工具调用结果映射 (call_id -> result)
tool_results_map: dict[str, str] = {}
for msg in conversation_messages:
if msg.role == RoleType.Tool and msg.tool_call_id:
content = ""
if msg.content:
if isinstance(msg.content, list) and msg.content:
content = msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
else:
content = str(msg.content)
tool_results_map[msg.tool_call_id] = content
# 第二步:详细记录所有工具调用操作和结果到日志
tool_call_count = 0
logger.info(f"[dream][工具调用详情] 开始记录 chat_id={chat_id} 的所有工具调用操作:")
for msg in conversation_messages:
if msg.role == RoleType.Assistant and msg.tool_calls:
tool_call_count += 1
# 提取思考内容
thought_content = ""
if msg.content:
if isinstance(msg.content, list) and msg.content:
thought_content = (
msg.content[0].text if hasattr(msg.content[0], "text") else str(msg.content[0])
)
else:
thought_content = str(msg.content)
logger.info(f"[dream][工具调用详情] === 第 {tool_call_count} 组工具调用 ===")
if thought_content:
logger.info(
f"[dream][工具调用详情] 思考内容:{thought_content[:500]}{'...' if len(thought_content) > 500 else ''}"
)
# 记录每个工具调用的详细信息
for idx, tool_call in enumerate(msg.tool_calls, 1):
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
tool_call_id = tool_call.call_id
tool_result = tool_results_map.get(tool_call_id, "未找到执行结果")
# 格式化参数
try:
args_str = json.dumps(tool_args, ensure_ascii=False, indent=2) if tool_args else "无参数"
except Exception:
args_str = str(tool_args)
logger.info(f"[dream][工具调用详情] --- 工具 {idx}: {tool_name} ---")
logger.info(f"[dream][工具调用详情] 调用参数:\n{args_str}")
logger.info(f"[dream][工具调用详情] 执行结果:\n{tool_result}")
logger.info(f"[dream][工具调用详情] {'-' * 60}")
logger.info(f"[dream][工具调用详情] 共记录了 {tool_call_count} 组工具调用操作")
# 第三步:构建对话历史摘要(用于生成梦境)
conversation_summary = []
for msg in conversation_messages:
role = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
content = ""
if msg.content:
content = msg.content[0].text if isinstance(msg.content, list) and msg.content else str(msg.content)
if role == "user" and "轮次信息" in content:
# 跳过轮次信息消息
continue
if role == "assistant":
# 只保留思考内容,简化工具调用信息
if content:
# 截取前500字符避免过长
content_preview = content[:500] + ("..." if len(content) > 500 else "")
conversation_summary.append(f"[{role}] {content_preview}")
elif role == "tool":
# 工具结果,只保留关键信息
if content:
# 截取前300字符
content_preview = content[:300] + ("..." if len(content) > 300 else "")
conversation_summary.append(f"[工具执行] {content_preview}")
conversation_text = "\n".join(conversation_summary[-20:]) # 只保留最后20条消息
# 随机选择2个梦境风格
selected_styles = get_random_dream_styles(2)
dream_styles_text = "\n".join([f"{i + 1}. {style}" for i, style in enumerate(selected_styles)])
# 使用 Prompt 管理器格式化梦境生成 prompt
dream_prompt = await global_prompt_manager.format_prompt(
"dream_summary_prompt",
chat_id=chat_id,
total_iterations=total_iterations,
time_cost=time_cost,
conversation_text=conversation_text,
dream_styles=dream_styles_text,
)
# 调用 utils 模型生成梦境
summary_model = get_dream_summary_model()
dream_content, (reasoning, model_name, _) = await summary_model.generate_response_async(
dream_prompt,
max_tokens=512,
temperature=0.8,
)
if dream_content:
logger.info(f"[dream][梦境总结] 对 chat_id={chat_id} 的整理过程梦境:\n{dream_content}")
else:
logger.warning("[dream][梦境总结] 未能生成梦境总结")
except Exception as e:
logger.error(f"[dream][梦境总结] 生成梦境总结失败: {e}", exc_info=True)
init_dream_summary_prompt()

View File

@ -0,0 +1,6 @@
"""
dream agent 工具实现模块
每个工具的具体实现放在独立文件中通过 make_xxx(chat_id) 工厂函数
生成绑定到特定 chat_id 的协程函数 dream_agent.init_dream_tools 统一注册
"""

View File

@ -0,0 +1,62 @@
import time
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
logger = get_logger("dream_agent")
def make_create_chat_history(chat_id: str):
async def create_chat_history(
theme: str,
summary: str,
keywords: str,
key_point: str,
start_time: float,
end_time: float,
) -> str:
"""创建一条新的 ChatHistory 概括记录(用于整理/合并后的新记忆)"""
try:
logger.info(
f"[dream][tool] 调用 create_chat_history("
f"theme={bool(theme)}, summary={bool(summary)}, "
f"keywords={bool(keywords)}, key_point={bool(key_point)}, "
f"start_time={start_time}, end_time={end_time}) (chat_id={chat_id})"
)
now_ts = time.time()
# 将传入的 start_time/end_time如果有解析为时间戳否则回退为当前时间
def _parse_ts(value, default):
if value is None:
return default
try:
return float(value)
except (TypeError, ValueError):
return default
start_ts = _parse_ts(start_time, now_ts)
end_ts = _parse_ts(end_time, now_ts)
record = ChatHistory.create(
chat_id=chat_id,
theme=theme,
summary=summary,
keywords=keywords,
key_point=key_point,
# 对于由 dream 整理产生的新概括,时间范围优先使用工具提供的时间,否则使用当前时间占位
start_time=start_ts,
end_time=end_ts,
)
msg = (
f"已创建新的 ChatHistory 记录ID={record.id}"
f"theme={record.theme or ''}summary={'' if record.summary else ''}"
)
logger.info(f"[dream][tool] create_chat_history 完成: {msg}")
return msg
except Exception as e:
logger.error(f"create_chat_history 失败: {e}")
return f"create_chat_history 执行失败: {e}"
return create_chat_history

View File

@ -0,0 +1,25 @@
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
logger = get_logger("dream_agent")
def make_delete_chat_history(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
async def delete_chat_history(memory_id: int) -> str:
"""删除一条 chat_history 记录"""
try:
logger.info(f"[dream][tool] 调用 delete_chat_history(memory_id={memory_id})")
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
if not record:
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录,无法删除。"
logger.info(f"[dream][tool] delete_chat_history 未找到记录: {msg}")
return msg
rows = ChatHistory.delete().where(ChatHistory.id == memory_id).execute()
msg = f"已删除 ID={memory_id} 的 ChatHistory 记录,受影响行数={rows}"
logger.info(f"[dream][tool] delete_chat_history 完成: {msg}")
return msg
except Exception as e:
logger.error(f"delete_chat_history 失败: {e}")
return f"delete_chat_history 执行失败: {e}"
return delete_chat_history

View File

@ -0,0 +1,25 @@
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
logger = get_logger("dream_agent")
def make_delete_jargon(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
async def delete_jargon(jargon_id: int) -> str:
"""删除一条 Jargon 记录"""
try:
logger.info(f"[dream][tool] 调用 delete_jargon(jargon_id={jargon_id})")
record = Jargon.get_or_none(Jargon.id == jargon_id)
if not record:
msg = f"未找到 ID={jargon_id} 的 Jargon 记录,无法删除。"
logger.info(f"[dream][tool] delete_jargon 未找到记录: {msg}")
return msg
rows = Jargon.delete().where(Jargon.id == jargon_id).execute()
msg = f"已删除 ID={jargon_id} 的 Jargon 记录(内容:{record.content}),受影响行数={rows}"
logger.info(f"[dream][tool] delete_jargon 完成: {msg}")
return msg
except Exception as e:
logger.error(f"delete_jargon 失败: {e}")
return f"delete_jargon 执行失败: {e}"
return delete_jargon

View File

@ -0,0 +1,16 @@
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("dream_agent")
def make_finish_maintenance(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
async def finish_maintenance(reason: Optional[str] = None) -> str:
"""结束本次 dream 维护任务。当你认为当前 chat_id 下的维护工作已经完成,没有更多需要整理的内容时,调用此工具来结束本次运行。"""
reason_text = f",原因:{reason}" if reason else ""
msg = f"DREAM_MAINTENANCE_COMPLETE{reason_text}"
logger.info(f"[dream][tool] 调用 finish_maintenance结束本次维护{reason_text}")
return msg
return finish_maintenance

View File

@ -0,0 +1,44 @@
import time
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
logger = get_logger("dream_agent")
def make_get_chat_history_detail(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
async def get_chat_history_detail(memory_id: int) -> str:
"""获取单条 chat_history 的完整内容"""
try:
logger.info(f"[dream][tool] 调用 get_chat_history_detail(memory_id={memory_id})")
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
if not record:
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录。"
logger.info(f"[dream][tool] get_chat_history_detail 未找到记录: {msg}")
return msg
# 将时间戳转换为可读时间格式
start_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.start_time)) if record.start_time else "未知"
)
end_time_str = (
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(record.end_time)) if record.end_time else "未知"
)
result = (
f"ID={record.id}\n"
# f"chat_id={record.chat_id}\n"
f"时间范围={start_time_str}{end_time_str}\n"
f"主题={record.theme or ''}\n"
f"关键词={record.keywords or ''}\n"
f"参与者={record.participants or ''}\n"
f"概括={record.summary or ''}\n"
f"关键信息={record.key_point or ''}"
)
logger.debug(f"[dream][tool] get_chat_history_detail 成功,预览: {result[:200].replace(chr(10), ' ')}")
return result
except Exception as e:
logger.error(f"get_chat_history_detail 失败: {e}")
return f"get_chat_history_detail 执行失败: {e}"
return get_chat_history_detail

View File

@ -0,0 +1,214 @@
import json
from typing import List, Optional
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.chat.utils.utils import parse_keywords_string
logger = get_logger("dream_agent")
def make_search_chat_history(chat_id: str):
async def search_chat_history(
keyword: Optional[str] = None,
participant: Optional[str] = None,
) -> str:
"""根据关键词或参与人查询记忆返回匹配的记忆id、记忆标题theme和关键词keywordsdream 维护专用版本)"""
try:
# 检查参数
if not keyword and not participant:
return "未指定查询参数需要提供keyword或participant之一"
logger.info(
f"[dream][tool] 调用 search_chat_history(keyword={keyword}, participant={participant}) "
f"(作用域 chat_id={chat_id})"
)
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 执行查询(按时间倒序,最近的在前)
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
filtered_records: List[ChatHistory] = []
for record in records:
participant_matched = True # 如果没有participant条件默认为True
keyword_matched = True # 如果没有keyword条件默认为True
# 检查参与人匹配
if participant:
participant_matched = False
participants_list: List[str] = []
if record.participants:
try:
participants_data = (
json.loads(record.participants)
if isinstance(record.participants, str)
else record.participants
)
if isinstance(participants_data, list):
participants_list = [str(p).lower() for p in participants_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
participant_lower = participant.lower().strip()
if participant_lower and any(participant_lower in p for p in participants_list):
participant_matched = True
# 检查关键词匹配
if keyword:
keyword_matched = False
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if keywords_lower:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list: List[str] = []
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 有容错的全匹配:如果关键词数量>2允许n-1个关键词匹配否则必须全部匹配
matched_count = 0
for kw in keywords_lower:
kw_matched = (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if kw_matched:
matched_count += 1
# 计算需要匹配的关键词数量
total_keywords = len(keywords_lower)
if total_keywords > 2:
# 关键词数量>2允许n-1个关键词匹配
required_matches = total_keywords - 1
else:
# 关键词数量<=2必须全部匹配
required_matches = total_keywords
keyword_matched = matched_count >= required_matches
# 两者都匹配如果同时有participant和keyword需要两者都匹配如果只有一个条件只需要该条件匹配
matched = participant_matched and keyword_matched
if matched:
filtered_records.append(record)
if not filtered_records:
if keyword and participant:
keywords_str = "".join(parse_keywords_string(keyword) if keyword else [])
return f"未找到包含关键词'{keywords_str}'且参与人包含'{participant}'的聊天记录"
elif keyword:
keywords_list = parse_keywords_string(keyword)
keywords_str = "".join(keywords_list)
if len(keywords_list) > 2:
required_count = len(keywords_list) - 1
return f"未找到包含至少{required_count}个关键词(共{len(keywords_list)}个)'{keywords_str}'的聊天记录"
else:
return f"未找到包含所有关键词'{keywords_str}'的聊天记录"
elif participant:
return f"未找到参与人包含'{participant}'的聊天记录"
else:
return "未找到相关聊天记录"
# 如果匹配结果超过20条不返回具体记录只返回提示和所有相关关键词
if len(filtered_records) > 20:
all_keywords_set = set()
for record in filtered_records:
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
for k in keywords_data:
k_str = str(k).strip()
if k_str:
all_keywords_set.add(k_str)
except (json.JSONDecodeError, TypeError, ValueError):
continue
search_label = keyword or participant or "当前条件"
if all_keywords_set:
keywords_str = "".join(sorted(all_keywords_set))
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f'有关"{search_label}"的关键词:\n'
f"{keywords_str}"
)
else:
response_text = (
f"包含“{search_label}”的结果过多,请尝试更多关键词精确查找\n\n"
f'有关"{search_label}"的关键词信息为空'
)
logger.info(
f"[dream][tool] search_chat_history 匹配结果超过20条返回关键词汇总提示总数={len(filtered_records)}"
)
return response_text
# 构建结果文本返回id、theme和keywords最多20条
results: List[str] = []
for record in filtered_records[:20]:
result_parts: List[str] = []
# 记忆ID
result_parts.append(f"记忆ID{record.id}")
# 主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
else:
result_parts.append("主题:(无)")
# 关键词
if record.keywords:
try:
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}")
else:
result_parts.append("关键词:(无)")
except (json.JSONDecodeError, TypeError, ValueError):
result_parts.append("关键词:(无)")
else:
result_parts.append("关键词:(无)")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录"
response_text = "\n\n---\n\n".join(results)
logger.info(f"[dream][tool] search_chat_history 返回 {len(filtered_records)} 条匹配记录")
return response_text
except Exception as e:
logger.error(f"search_chat_history 失败: {e}")
return f"search_chat_history 执行失败: {e}"
return search_chat_history

View File

@ -0,0 +1,102 @@
from typing import List
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.config.config import global_config
from src.chat.utils.utils import parse_keywords_string
from src.bw_learner.learner_utils import parse_chat_id_list, chat_id_list_contains
logger = get_logger("dream_agent")
def make_search_jargon(chat_id: str):
async def search_jargon(keyword: str) -> str:
"""根据一个或多个关键词搜索当前 chat_id 相关的 Jargon 记录概览(只包含 is_jargon=True是否跨 chat_id 由 all_global 决定)"""
try:
if not keyword or not keyword.strip():
return "未指定查询关键词(参数 keyword 为必填,且不能为空)"
logger.info(f"[dream][tool] 调用 search_jargon(keyword={keyword}) (作用域 chat_id={chat_id})")
# 基础条件:只查 is_jargon=True 的记录
query = Jargon.select().where(Jargon.is_jargon)
# 根据 all_global 配置决定 chat_id 作用域
if global_config.expression.all_global_jargon:
# 开启全局黑话:只看 is_global=True 的记录,不区分 chat_id
query = query.where(Jargon.is_global)
else:
# 关闭全局黑话:后续在 Python 层按 chat_id 列表过滤(包含 is_global=True
pass
# 先按使用次数排序取一批候选,做一个安全上限
query = query.order_by(Jargon.count.desc()).limit(200)
candidates = list(query)
if not candidates:
msg = "未找到符合条件的 Jargon 记录。"
logger.info(f"[dream][tool] search_jargon 无记录: {msg}")
return msg
# 关键词为必填,因此此处必然执行关键词过滤(支持多个关键词,大小写不敏感)
keywords_list = parse_keywords_string(keyword) or []
if not keywords_list and keyword.strip():
keywords_list = [keyword.strip()]
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
# 先按关键词过滤(仅对 content 字段进行匹配)
filtered_keyword: List[Jargon] = []
for r in candidates:
content = (r.content or "").lower()
# 只要命中任意一个关键词即可视为匹配OR 逻辑)
any_matched = False
for kw in keywords_lower:
if not kw:
continue
if kw in content:
any_matched = True
break
if any_matched:
filtered_keyword.append(r)
if global_config.expression.all_global_jargon:
# 全局黑话模式:不再做 chat_id 过滤,直接使用关键词过滤结果
records = filtered_keyword
else:
# 非全局模式:仅保留全局黑话或 chat_id 列表中包含当前 chat_id 的记录
records = []
for r in filtered_keyword:
if r.is_global:
records.append(r)
continue
chat_id_list = parse_chat_id_list(r.chat_id)
if chat_id_list_contains(chat_id_list, chat_id):
records.append(r)
if not records:
scope_note = (
"(当前为全局黑话模式,仅统计 is_global=True 的条目)"
if global_config.expression.all_global_jargon
else "(当前为按 chat_id 作用域模式,仅统计全局黑话或与当前 chat_id 相关的条目)"
)
return f"未找到包含关键词'{keyword}'的 Jargon 记录{scope_note}"
lines: List[str] = []
for r in records:
is_jargon_str = "" if r.is_jargon else "" if r.is_jargon is False else "未判定"
is_global_str = "全局" if r.is_global else "非全局"
lines.append(
f"ID={r.id} | 内容={r.content} | 含义={r.meaning or ''} | "
f"chat_id={r.chat_id} | {is_global_str} | 是否黑话={is_jargon_str}"
)
result = "\n".join(lines)
logger.info(f"[dream][tool] search_jargon 返回 {len(records)} 条记录")
return result
except Exception as e:
logger.error(f"search_jargon 失败: {e}")
return f"search_jargon 执行失败: {e}"
return search_jargon

View File

@ -0,0 +1,51 @@
from typing import Any, Dict, Optional
from src.common.logger import get_logger
from src.common.database.database_model import ChatHistory
from src.plugin_system.apis import database_api
logger = get_logger("dream_agent")
def make_update_chat_history(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
async def update_chat_history(
memory_id: int,
theme: Optional[str] = None,
summary: Optional[str] = None,
keywords: Optional[str] = None,
key_point: Optional[str] = None,
) -> str:
"""按字段更新 chat_history字符串字段要求 JSON 的字段须传入已序列化的字符串)"""
try:
logger.info(
f"[dream][tool] 调用 update_chat_history(memory_id={memory_id}, "
f"theme={bool(theme)}, summary={bool(summary)}, keywords={bool(keywords)}, key_point={bool(key_point)})"
)
record = ChatHistory.get_or_none(ChatHistory.id == memory_id)
if not record:
msg = f"未找到 ID={memory_id} 的 ChatHistory 记录,无法更新。"
logger.info(f"[dream][tool] update_chat_history 未找到记录: {msg}")
return msg
data: Dict[str, Any] = {}
if theme is not None:
data["theme"] = theme
if summary is not None:
data["summary"] = summary
if keywords is not None:
data["keywords"] = keywords
if key_point is not None:
data["key_point"] = key_point
if not data:
return "未提供任何需要更新的字段。"
await database_api.db_save(ChatHistory, data=data, key_field="id", key_value=memory_id)
msg = f"已更新 ChatHistory 记录 ID={memory_id},更新字段={list(data.keys())}"
logger.info(f"[dream][tool] update_chat_history 完成: {msg}")
return msg
except Exception as e:
logger.error(f"update_chat_history 失败: {e}")
return f"update_chat_history 执行失败: {e}"
return update_chat_history

View File

@ -0,0 +1,51 @@
from typing import Any, Dict, Optional
from src.common.logger import get_logger
from src.common.database.database_model import Jargon
from src.plugin_system.apis import database_api
logger = get_logger("dream_agent")
def make_update_jargon(chat_id: str): # chat_id 目前未直接使用,预留以备扩展
async def update_jargon(
jargon_id: int,
meaning: Optional[str] = None,
is_global: Optional[bool] = None,
is_jargon: Optional[bool] = None,
content: Optional[str] = None,
) -> str:
"""按字段更新 Jargon 记录,可用于修正含义、调整全局性、标记是否为黑话等"""
try:
logger.info(
f"[dream][tool] 调用 update_jargon(jargon_id={jargon_id}, "
f"meaning={bool(meaning)}, is_global={is_global}, is_jargon={is_jargon}, content={bool(content)})"
)
record = Jargon.get_or_none(Jargon.id == jargon_id)
if not record:
msg = f"未找到 ID={jargon_id} 的 Jargon 记录,无法更新。"
logger.info(f"[dream][tool] update_jargon 未找到记录: {msg}")
return msg
data: Dict[str, Any] = {}
if meaning is not None:
data["meaning"] = meaning
if is_global is not None:
data["is_global"] = is_global
if is_jargon is not None:
data["is_jargon"] = is_jargon
if content is not None:
data["content"] = content
if not data:
return "未提供任何需要更新的字段。"
await database_api.db_save(Jargon, data=data, key_field="id", key_value=jargon_id)
msg = f"已更新 Jargon 记录 ID={jargon_id},更新字段={list(data.keys())}"
logger.info(f"[dream][tool] update_jargon 完成: {msg}")
return msg
except Exception as e:
logger.error(f"update_jargon 失败: {e}")
return f"update_jargon 执行失败: {e}"
return update_jargon

View File

@ -1,145 +0,0 @@
import re
import difflib
import random
from datetime import datetime
from typing import Optional, List, Dict
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID
content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间
count越高权重越高但最多为基础权重的5倍
如果表达已checked权重会再乘以3倍
"""
if not population:
return []
counts = []
checked_flags = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
# 获取checked状态
checked = item.get("checked", False)
checked_flags.append(bool(checked))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
base_weights = [1.0 for _ in counts]
else:
base_weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
base_weights.append(1.0 + normalized * 4.0) # 1~3
# 如果checked权重乘以3
weights = []
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
if checked:
weights.append(base_weight * 3.0)
else:
weights.append(base_weight)
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected

View File

@ -1,708 +0,0 @@
import time
import json
import os
import re
import asyncio
from typing import List, Optional, Tuple
import traceback
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages,
build_bare_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.express_utils import filter_message_content, calculate_similarity
from json_repair import repair_json
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def init_prompt() -> None:
learn_style_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
1. 只考虑文字不要考虑表情包和图片
2. 不要涉及具体的人名但是可以涉及具体名词
3. 思考有没有特殊的梗一并总结成语言风格
4. 例子仅供参考请严格根据群聊内容总结!!!
注意总结成如下格式的规律总结的内容要详细但具有概括性
例如"AAAAA"可以"BBBBB", AAAAA代表某个具体的场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字
例如
"对某件事表示十分惊叹"使用"我嘞个xxxx"
"表示讽刺的赞同,不讲道理"使用"对对对"
"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
"当涉及游戏相关时,夸赞,略带戏谑意味"使用"这么强!"
请注意不要总结你自己SELF的发言尽量保证总结内容的逻辑性
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
match_expression_context_prompt = """
**聊天内容**
{chat_str}
**从聊天内容总结的表达方式pairs**
{expression_pairs}
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果expression_pair不能有重复每个expression_pair仅输出一个最合适的context
如果找不到原句就不输出该句的匹配结果
以json格式输出
格式如下
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
...
现在请你输出匹配结果
"""
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
)
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock()
# 学习参数
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 120 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
检查是否应该触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
"""
# 检查是否允许学习
if not self.enable_learning:
return False
# 检查时间间隔
time_diff = time.time() - self.last_learning_time
if time_diff < self.min_learning_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
return False
return True
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
# 使用异步锁防止并发执行
async with self._learning_lock:
# 在锁内检查,避免并发触发
# 如果锁被持有,其他协程会等待,但等待期间条件可能已变化,所以需要再次检查
if not self.should_trigger_learning():
return
# 保存学习开始前的时间戳,用于获取消息范围
learning_start_timestamp = time.time()
previous_learning_time = self.last_learning_time
# 立即更新学习时间,防止并发触发
self.last_learning_time = learning_start_timestamp
try:
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格,传递学习开始前的时间戳
learnt_style = await self.learn_and_store(num=25, timestamp_start=previous_learning_time)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
return
async def learn_and_store(self, num: int = 10, timestamp_start: Optional[float] = None) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
Args:
num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
"""
learnt_expressions = await self.learn_expression(num, timestamp_start=timestamp_start)
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
up_content,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
return learnt_expressions
async def match_expression_context(
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
prompt = "match_expression_context_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt,
expression_pairs=expression_pairs_str,
chat_str=random_msg_match_str,
)
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 尝试提取JSON代码块如果存在
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
response = matches[0].strip()
# 移除可能的markdown代码块标记如果没有找到```json但可能有```
if not matches:
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
repaired_content = repair_json(response)
# 确保repaired_content是列表格式
if isinstance(repaired_content, str):
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, dict):
# 如果是字典,包装成列表
match_responses = [parsed_data]
elif isinstance(parsed_data, list):
match_responses = parsed_data
else:
match_responses = []
except json.JSONDecodeError:
match_responses = []
elif isinstance(repaired_content, dict):
# 如果是字典,包装成列表
match_responses = [repaired_content]
elif isinstance(repaired_content, list):
match_responses = repaired_content
else:
match_responses = []
except json.JSONDecodeError:
# 如果还是失败尝试repair_json
repaired_content = repair_json(response)
if isinstance(repaired_content, str):
parsed_data = json.loads(repaired_content)
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
else:
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
# 确保 match_responses 是一个列表
if not isinstance(match_responses, list):
if isinstance(match_responses, dict):
match_responses = [match_responses]
else:
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
return []
# 清理和规范化 match_responses 中的元素
normalized_responses = []
for item in match_responses:
if isinstance(item, dict):
# 已经是字典,直接添加
normalized_responses.append(item)
elif isinstance(item, str):
# 如果是字符串,尝试解析为 JSON
try:
parsed = json.loads(item)
if isinstance(parsed, dict):
normalized_responses.append(parsed)
elif isinstance(parsed, list):
# 如果是列表,递归处理
for sub_item in parsed:
if isinstance(sub_item, dict):
normalized_responses.append(sub_item)
else:
logger.debug(f"跳过非字典类型的子元素: {type(sub_item)}, 内容: {sub_item}")
else:
logger.debug(f"跳过无法转换为字典的字符串元素: {item}")
except (json.JSONDecodeError, TypeError):
logger.debug(f"跳过无法解析为JSON的字符串元素: {item}")
elif isinstance(item, list):
# 如果是列表,展开并处理其中的字典
for sub_item in item:
if isinstance(sub_item, dict):
normalized_responses.append(sub_item)
elif isinstance(sub_item, str):
# 尝试解析字符串
try:
parsed = json.loads(sub_item)
if isinstance(parsed, dict):
normalized_responses.append(parsed)
else:
logger.debug(f"跳过非字典类型的解析结果: {type(parsed)}, 内容: {parsed}")
except (json.JSONDecodeError, TypeError):
logger.debug(f"跳过无法解析为JSON的字符串子元素: {sub_item}")
else:
logger.debug(f"跳过非字典类型的列表元素: {type(sub_item)}, 内容: {sub_item}")
else:
logger.debug(f"跳过无法处理的元素类型: {type(item)}, 内容: {item}")
match_responses = normalized_responses
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"规范化后的 match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"规范化后的 match_responses 内容: {match_responses}")
for match_response in match_responses:
try:
# 检查 match_response 的类型(此时应该都是字典)
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response.get("context", "")
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError, TypeError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
return matched_expressions
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
"""
current_time = time.time()
# 使用传入的时间戳如果没有则使用self.last_learning_time
start_timestamp = timestamp_start if timestamp_start is not None else self.last_learning_time
# 获取上次学习之后的消息
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=start_timestamp,
timestamp_end=current_time,
limit=num,
)
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 学习用
random_msg_str: str = await build_anonymous_messages(random_msg)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
"learn_style_prompt",
chat_str=random_msg_str,
)
# print(f"random_msg_str:{random_msg_str}")
# logger.info(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return None
# logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
# 为每条消息构建精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达
continue
# 检查目标句是否为空
target_content = bare_lines[pos][1]
if not target_content:
# 目标句为空,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content:
# 上一句为空,跳过该表达
continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up:
return None
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 查找"当"和下一个引号
idx_when = line.find('"')
if idx_when == -1:
continue
idx_quote1 = idx_when + 1
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
if idx_use == -1:
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((situation, style))
return expressions
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str]] = []
removed_count = 0
for situation, style in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
if expr_obj:
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
context=context,
up_content=up_content,
current_time=current_time,
)
return
await self._create_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
async def _create_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = [situation]
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
Expression.create(
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
up_content=up_content,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = self._parse_content_list(expr_obj.content_list)
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.last_active_time = current_time
expr_obj.context = context
expr_obj.up_content = up_content
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
expr_obj.save()
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines
init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
expression_learner_manager = ExpressionLearnerManager()

View File

@ -7,6 +7,7 @@ import asyncio
import json
import time
import re
import difflib
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from dataclasses import dataclass, field
@ -30,16 +31,18 @@ HIPPO_CACHE_DIR = Path(__file__).resolve().parents[2] / "data" / "hippo_memorize
def init_prompt():
"""初始化提示词模板"""
topic_analysis_prompt = """
历史话题标题列表仅标题不含具体内容
topic_analysis_prompt = """【历史话题标题列表】(仅标题,不含具体内容):
{history_topics_block}
历史话题标题列表结束
本次聊天记录每条消息前有编号用于后续引用
{messages_block}
本次聊天记录结束
请完成以下任务
**识别话题**
1. 识别本次聊天记录中正在进行的一个或多个话题
2. 本次聊天记录的中的消息可能与历史话题有关也可能毫无关联
2. 判断历史话题标题列表中的话题是否在本次聊天记录中出现如果出现则直接使用该历史话题标题字符串
**选取消息**
@ -316,7 +319,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化
self._persist_topic_cache()
else:
@ -362,9 +367,7 @@ class ChatHistorySummarizer:
else:
time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.debug(
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
logger.debug(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
# 检查“话题检查”触发条件
should_check = False
@ -374,10 +377,10 @@ class ChatHistorySummarizer:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: 100条")
# 条件2: 距离上一次检查 > 3600 1小时,触发一次检查
elif time_since_last_check > 2400:
# 条件2: 距离上一次检查 > 3600 * 8 秒8小时且消息数量 >= 20 条,触发一次检查
elif time_since_last_check > 3600 * 8 and message_count >= 20:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 1小时")
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 8小时且消息数量达到 {message_count} 条(阈值: 20条")
if should_check:
await self._run_topic_check_and_update_cache(messages)
@ -414,7 +417,7 @@ class ChatHistorySummarizer:
# 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False
for msg in messages:
if msg.user_info.user_id == bot_user_id:
has_bot_message = True
@ -427,20 +430,63 @@ class ChatHistorySummarizer:
return
# 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
# 3. 调用 LLM 识别话题,并得到 topic -> indices
existing_topics = list(self.topic_cache.keys())
success, topic_to_indices = await self._analyze_topics_with_llm(
numbered_lines=numbered_lines,
existing_topics=existing_topics,
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys())
max_retries = 3
attempt = 0
success = False
topic_to_indices: Dict[str, List[int]] = {}
while attempt < max_retries:
attempt += 1
success, topic_to_indices = await self._analyze_topics_with_llm(
numbered_lines=numbered_lines,
existing_topics=existing_topics,
)
if success and topic_to_indices:
if attempt > 1:
logger.info(
f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}"
)
break
logger.warning(
f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败"
+ ("" if attempt >= max_retries else ",准备重试")
)
if not success or not topic_to_indices:
logger.warning(f"{self.log_prefix} 话题识别失败或无有效话题,本次检查忽略")
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
logger.error(f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃")
# 即使识别失败,也认为是一次"检查",但不更新 no_update_checks保持原状
return
# 3.5. 检查新话题是否与历史话题相似(相似度>=90%则使用历史标题)
topic_mapping = self._build_topic_mapping(topic_to_indices, similarity_threshold=0.9)
# 应用话题映射:将相似的新话题标题替换为历史话题标题
if topic_mapping:
new_topic_to_indices: Dict[str, List[int]] = {}
for new_topic, indices in topic_to_indices.items():
# 如果这个新话题需要映射到历史话题
if new_topic in topic_mapping:
historical_topic = topic_mapping[new_topic]
# 如果历史话题已经存在,合并消息索引
if historical_topic in new_topic_to_indices:
# 合并索引并去重
combined_indices = list(set(new_topic_to_indices[historical_topic] + indices))
new_topic_to_indices[historical_topic] = combined_indices
else:
new_topic_to_indices[historical_topic] = indices
else:
# 不需要映射,保持原样
new_topic_to_indices[new_topic] = indices
topic_to_indices = new_topic_to_indices
# 4. 统计哪些话题在本次检查中有新增内容
updated_topics: Set[str] = set()
@ -507,6 +553,71 @@ class ChatHistorySummarizer:
# 无论成功与否,都从缓存中删除,避免重复
self.topic_cache.pop(topic, None)
def _find_most_similar_topic(
self, new_topic: str, existing_topics: List[str], similarity_threshold: float = 0.9
) -> Optional[tuple[str, float]]:
"""
查找与给定新话题最相似的历史话题
Args:
new_topic: 新话题标题
existing_topics: 历史话题标题列表
similarity_threshold: 相似度阈值默认0.990%
Returns:
Optional[tuple[str, float]]: 如果找到相似度>=阈值的历史话题返回(历史话题标题, 相似度)
否则返回None
"""
if not existing_topics:
return None
best_match = None
best_similarity = 0.0
for existing_topic in existing_topics:
similarity = difflib.SequenceMatcher(None, new_topic, existing_topic).ratio()
if similarity > best_similarity:
best_similarity = similarity
best_match = existing_topic
# 如果相似度达到阈值,返回匹配结果
if best_match and best_similarity >= similarity_threshold:
return (best_match, best_similarity)
return None
def _build_topic_mapping(
self, topic_to_indices: Dict[str, List[int]], similarity_threshold: float = 0.9
) -> Dict[str, str]:
"""
构建新话题到历史话题的映射如果相似度>=阈值
Args:
topic_to_indices: 新话题到消息索引的映射
similarity_threshold: 相似度阈值默认0.990%
Returns:
Dict[str, str]: 新话题 -> 历史话题的映射字典
"""
existing_topics_list = list(self.topic_cache.keys())
topic_mapping: Dict[str, str] = {}
for new_topic in topic_to_indices.keys():
# 如果新话题已经在历史话题中,不需要检查
if new_topic in existing_topics_list:
continue
# 查找最相似的历史话题
result = self._find_most_similar_topic(new_topic, existing_topics_list, similarity_threshold)
if result:
historical_topic, similarity = result
topic_mapping[new_topic] = historical_topic
logger.info(
f"{self.log_prefix} 话题相似度检查: '{new_topic}' 与历史话题 '{historical_topic}' 相似度 {similarity:.2%},使用历史标题"
)
return topic_mapping
def _build_numbered_messages_for_llm(
self, messages: List[DatabaseMessages]
) -> tuple[List[str], Dict[int, str], Dict[int, str], Dict[int, Set[str]]]:
@ -589,9 +700,7 @@ class ChatHistorySummarizer:
if not numbered_lines:
return False, {}
history_topics_block = (
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt(
@ -603,8 +712,7 @@ class ChatHistorySummarizer:
try:
response, _ = await self.summarizer_llm.generate_response_async(
prompt=prompt,
temperature=0.2,
max_tokens=800,
temperature=0.3,
)
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
@ -614,17 +722,17 @@ class ChatHistorySummarizer:
json_str = None
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
# 找到JSON代码块使用第一个匹配
json_str = matches[0].strip()
else:
# 如果没有找到代码块尝试查找JSON数组的开始和结束位置
# 查找第一个 [ 和最后一个 ]
start_idx = response.find('[')
end_idx = response.rfind(']')
start_idx = response.find("[")
end_idx = response.rfind("]")
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
json_str = response[start_idx:end_idx + 1].strip()
json_str = response[start_idx : end_idx + 1].strip()
else:
# 如果还是找不到尝试直接使用整个响应移除可能的markdown标记
json_str = response.strip()
@ -921,4 +1029,3 @@ class ChatHistorySummarizer:
init_prompt()

View File

@ -28,10 +28,10 @@ class MemoryForgetTask(AsyncTask):
# logger.info("[记忆遗忘] 开始遗忘检查...")
# 执行4个阶段的遗忘检查
await self._forget_stage_1(current_time)
await self._forget_stage_2(current_time)
await self._forget_stage_3(current_time)
await self._forget_stage_4(current_time)
# await self._forget_stage_1(current_time)
# await self._forget_stage_2(current_time)
# await self._forget_stage_3(current_time)
# await self._forget_stage_4(current_time)
# logger.info("[记忆遗忘] 遗忘检查完成")
except Exception as e:

View File

@ -1,5 +0,0 @@
from .jargon_miner import extract_and_store_jargon
__all__ = [
"extract_and_store_jargon",
]

View File

@ -1,195 +0,0 @@
import json
from typing import List, Dict, Optional, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
build_readable_messages,
)
from src.chat.utils.utils import parse_platform_accounts
logger = get_logger("jargon")
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
"""
解析chat_id字段兼容旧格式字符串和新格式JSON列表
Args:
chat_id_value: 可能是字符串旧格式或JSON字符串新格式
Returns:
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
"""
if not chat_id_value:
return []
# 如果是字符串尝试解析为JSON
if isinstance(chat_id_value, str):
# 尝试解析JSON
try:
parsed = json.loads(chat_id_value)
if isinstance(parsed, list):
# 新格式:已经是列表
return parsed
elif isinstance(parsed, str):
# 解析后还是字符串,说明是旧格式
return [[parsed, 1]]
else:
# 其他类型,当作旧格式处理
return [[str(chat_id_value), 1]]
except (json.JSONDecodeError, TypeError):
# 解析失败,当作旧格式(纯字符串)
return [[str(chat_id_value), 1]]
elif isinstance(chat_id_value, list):
# 已经是列表格式
return chat_id_value
else:
# 其他类型,转换为旧格式
return [[str(chat_id_value), 1]]
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
"""
更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目
Args:
chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要更新或添加的chat_id
increment: 增加的计数默认为1
Returns:
List[List[Any]]: 更新后的chat_id列表
"""
# 查找是否已存在该chat_id
found = False
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
# 找到匹配的chat_id增加计数
if len(item) >= 2:
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
else:
item.append(increment)
found = True
break
if not found:
# 未找到,添加新条目
chat_id_list.append([target_chat_id, increment])
return chat_id_list
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
"""
检查chat_id列表中是否包含指定的chat_id
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
bool: 如果包含则返回True
"""
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return True
return False
def contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates if target)
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:
"""
构建包含中心消息上下文的段落前3条+后3条使用标准的 readable builder 输出
"""
if not messages or center_index < 0 or center_index >= len(messages):
return None
context_start = max(0, center_index - 3)
context_end = min(len(messages), center_index + 1 + 3)
context_messages = messages[context_start:context_end]
if not context_messages:
return None
try:
paragraph = build_readable_messages(
messages=context_messages,
replace_bot_name=True,
timestamp_mode="relative",
read_mark=0.0,
truncate=False,
show_actions=False,
show_pic=True,
message_id_list=None,
remove_emoji_stickers=False,
pic_single=True,
)
except Exception as e:
logger.warning(f"构建上下文段落失败: {e}")
return None
paragraph = paragraph.strip()
return paragraph or None
def is_bot_message(msg: Any) -> bool:
"""判断消息是否来自机器人自身"""
if msg is None:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
platform = (
str(getattr(msg, "user_platform", "") or getattr(getattr(msg, "user_info", None), "platform", "") or "")
.strip()
.lower()
)
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
if not platform or not user_id:
return False
platform_accounts = {}
try:
platform_accounts = parse_platform_accounts(getattr(bot_config, "platforms", []) or [])
except Exception:
platform_accounts = {}
bot_accounts: Dict[str, str] = {}
qq_account = str(getattr(bot_config, "qq_account", "") or "").strip()
if qq_account:
bot_accounts["qq"] = qq_account
telegram_account = str(getattr(bot_config, "telegram_account", "") or "").strip()
if telegram_account:
bot_accounts["telegram"] = telegram_account
for plat, account in platform_accounts.items():
if account and plat not in bot_accounts:
bot_accounts[plat] = account
bot_account = bot_accounts.get(platform)
return bool(bot_account and user_id == bot_account)

View File

@ -98,7 +98,10 @@ def _convert_messages(
content: List[Part] = []
for item in message.content:
if isinstance(item, tuple):
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
image_format = item[0].lower()
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
if image_format in ("jpg", "jpeg"):
image_format = "jpeg"
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
elif isinstance(item, str):
content.append(Part.from_text(text=item))
@ -143,10 +146,14 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
# JSON Schema要求使用"boolean"而不是"bool"
# JSON Schema 类型名称修正:
# - 布尔类型使用 "boolean" 而不是 "bool"
# - 浮点数使用 "number" 而不是 "float"
param_type_value = tool_option_param.param_type.value
if param_type_value == "bool":
param_type_value = "boolean"
elif param_type_value == "float":
param_type_value = "number"
return_dict: dict[str, Any] = {
"type": param_type_value,

View File

@ -61,10 +61,16 @@ def _convert_messages(messages: list[Message]) -> list[ChatCompletionMessagePara
content = []
for item in message.content:
if isinstance(item, tuple):
image_format = item[0].lower()
# 规范 JPEG MIME 类型后缀,统一使用 image/jpeg
if image_format in ("jpg", "jpeg"):
mime_suffix = "jpeg"
else:
mime_suffix = image_format
content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/{item[0].lower()};base64,{item[1]}"},
"image_url": {"url": f"data:image/{mime_suffix};base64,{item[1]}"},
}
)
elif isinstance(item, str):
@ -118,10 +124,14 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
:param tool_option_param: 工具参数对象
:return: 转换后的工具参数字典
"""
# JSON Schema要求使用"boolean"而不是"bool"
# JSON Schema 类型名称修正:
# - 布尔类型使用 "boolean" 而不是 "bool"
# - 浮点数使用 "number" 而不是 "float"
param_type_value = tool_option_param.param_type.value
if param_type_value == "bool":
param_type_value = "boolean"
elif param_type_value == "float":
param_type_value = "number"
return_dict: dict[str, Any] = {
"type": param_type_value,

View File

@ -49,7 +49,7 @@ class LLMRequest:
def _check_slow_request(self, time_cost: float, model_name: str) -> None:
"""检查请求是否过慢并输出警告日志
Args:
time_cost: 请求耗时
model_name: 使用的模型名称
@ -315,12 +315,30 @@ class LLMRequest:
while retry_remain > 0:
try:
if request_type == RequestType.RESPONSE:
# 温度优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
effective_temperature = temperature
if effective_temperature is None:
effective_temperature = model_info.temperature
if effective_temperature is None:
effective_temperature = (model_info.extra_params or {}).get("temperature")
if effective_temperature is None:
effective_temperature = self.model_for_task.temperature
# max_tokens 优先级:参数传入 > 模型级别配置 > extra_params > 任务配置
effective_max_tokens = max_tokens
if effective_max_tokens is None:
effective_max_tokens = model_info.max_tokens
if effective_max_tokens is None:
effective_max_tokens = (model_info.extra_params or {}).get("max_tokens")
if effective_max_tokens is None:
effective_max_tokens = self.model_for_task.max_tokens
return await client.get_response(
model_info=model_info,
message_list=(compressed_messages or message_list),
tool_options=tool_options,
max_tokens=self.model_for_task.max_tokens if max_tokens is None else max_tokens,
temperature=temperature if temperature is not None else (model_info.extra_params or {}).get("temperature", self.model_for_task.temperature),
max_tokens=effective_max_tokens,
temperature=effective_temperature,
response_format=response_format,
stream_response_handler=stream_response_handler,
async_response_parser=async_response_parser,
@ -348,7 +366,9 @@ class LLMRequest:
logger.error(f"模型 '{model_info.name}' 在多次出现空回复后仍然失败。{original_error_info}")
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}")
logger.warning(
f"模型 '{model_info.name}' 返回空回复(可重试){original_error_info}。剩余重试次数: {retry_remain}"
)
await asyncio.sleep(api_provider.retry_interval)
except NetworkConnectionError as e:
@ -376,7 +396,9 @@ class LLMRequest:
if e.status_code == 429 or e.status_code >= 500:
retry_remain -= 1
if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}")
logger.error(
f"模型 '{model_info.name}' 在遇到 {e.status_code} 错误并用尽重试次数后仍然失败。{original_error_info}"
)
raise ModelAttemptFailed(f"模型 '{model_info.name}' 重试耗尽", original_exception=e) from e
logger.warning(
@ -522,7 +544,5 @@ class LLMRequest:
if e.__cause__:
original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__)
return (
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
return f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
return ""

View File

@ -13,9 +13,9 @@ from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.common.logger import get_logger
from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager
from src.chat.knowledge import lpmm_start_up
from rich.traceback import install
# from src.api.main import start_api_server
# 导入新的插件管理器
@ -23,6 +23,7 @@ from src.plugin_system.core.plugin_manager import plugin_manager
# 导入消息API和traceback模块
from src.common.message import get_global_api
from src.dream.dream_agent import start_dream_scheduler
# 插件系统现在使用统一的插件加载器
@ -50,23 +51,11 @@ class MainSystem:
logger.info("WebUI 已禁用")
return
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
try:
from src.webui.webui_server import get_webui_server
self.webui_server = get_webui_server()
if webui_mode == "development":
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
logger.info("💡 请手动启动前端开发服务器: cd MaiBot-Dashboard && bun dev")
logger.info("💡 前端将运行在 http://localhost:7999")
else:
logger.info("✅ WebUI 生产模式已启用")
logger.info("🌐 WebUI 将运行在 http://0.0.0.0:8001")
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
except Exception as e:
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")
@ -106,7 +95,7 @@ class MainSystem:
await async_task_manager.add_task(TelemetryHeartBeatTask())
# 添加记忆遗忘任务
from src.chat.utils.memory_forget_task import MemoryForgetTask
from src.hippo_memorizer.memory_forget_task import MemoryForgetTask
await async_task_manager.add_task(MemoryForgetTask())
@ -124,11 +113,6 @@ class MainSystem:
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")
# 启动情绪管理器
if global_config.mood.enable_mood:
await mood_manager.start()
logger.info("情绪管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())
@ -159,6 +143,7 @@ class MainSystem:
try:
tasks = [
get_emoji_manager().start_periodic_check_register(),
start_dream_scheduler(),
self.app.run(),
self.server.run(),
]

View File

@ -1,6 +1,7 @@
import time
import json
import asyncio
import re
from typing import List, Dict, Any, Optional, Tuple, Set
from src.common.logger import get_logger
from src.config.config import global_config, model_config
@ -10,7 +11,7 @@ from src.common.database.database_model import ThinkingBack
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.memory_utils import parse_questions_json
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.jargon.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
from src.bw_learner.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval")
@ -77,20 +78,12 @@ def init_memory_retrieval_prompt():
问题要说明前因后果和上下文使其全面且精准
输出格式示例需要检索时
输出格式示例
```json
{{
"questions": ["张三在前几天干了什么"] #问题数组(字符串数组),如果不需要检索记忆则输出空数组[],如果需要检索则只输出包含一个问题的数组
}}
```
输出格式示例不需要检索时
```json
{{
"questions": []
}}
```
请只输出JSON对象不要输出其他内容
""",
name="memory_retrieval_question_prompt",
@ -104,17 +97,16 @@ def init_memory_retrieval_prompt():
已收集的信息
{collected_info}
**执行步骤**
**工具说明**
- 如果涉及过往事件或者查询某个过去可能提到过的概念或者某段时间发生的事件可以使用聊天记录查询工具查询过往事件
- 如果涉及人物可以使用人物信息查询工具查询人物信息
- 如果没有可靠信息且查询时间充足或者不确定查询类别也可以使用lpmm知识库查询作为辅助信息
- **如果信息不足需要使用tool说明需要查询什么并输出为纯文本说明然后调用相应工具查询可并行调用多个工具**
- **如果当前已收集的信息足够回答问题且能找到明确答案调用found_answer工具标记已找到答案**
**思考**
- 你可以对查询思路给出简短的思考思考要简短直接切入要点
- 如果信息不足你必须给出使用什么工具进行查询
- 如果信息足够你必须调用found_answer工具
- 先思考当前信息是否足够回答问题
- 如果信息不足则需要使用tool查询信息你必须给出使用什么工具进行查询
- 如果当前已收集的信息足够或信息不足确定无法找到答案你必须调用finish_search工具结束查询
""",
name="memory_retrieval_react_prompt_head",
)
@ -128,14 +120,12 @@ def init_memory_retrieval_prompt():
已收集的信息
{collected_info}
**执行步骤**
分析
- 当前信息是否足够回答问题
- **如果信息足够且能找到明确答案**在思考中直接给出答案格式为found_answer(answer="你的答案内容")
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
**重要规则**
- 你已经经过几轮查询尝试了信息搜集现在你需要总结信息选择回答问题或判断问题无法回答
- 必须严格使用检索到的信息回答问题不要编造信息
- 答案必须精简不要过多解释
- **只有在检索到明确具体的答案时才使用found_answer**
@ -146,8 +136,6 @@ def init_memory_retrieval_prompt():
)
def _log_conversation_messages(
conversation_messages: List[Message],
head_prompt: Optional[str] = None,
@ -167,7 +155,7 @@ def _log_conversation_messages(
# 如果有head_prompt先添加为第一条消息
if head_prompt:
msg_info = "========================================\n[消息 1] 角色: System 内容类型: 文本\n-----------------------------"
msg_info = "========================================\n[消息 1] 角色: System\n-----------------------------"
msg_info += f"\n{head_prompt}"
log_lines.append(msg_info)
start_idx = 2
@ -180,32 +168,24 @@ def _log_conversation_messages(
for idx, msg in enumerate(conversation_messages, start_idx):
role_name = msg.role.value if hasattr(msg.role, "value") else str(msg.role)
# 处理内容 - 显示完整内容,不截断
if isinstance(msg.content, str):
full_content = msg.content
content_type = "文本"
elif isinstance(msg.content, list):
text_parts = [item for item in msg.content if isinstance(item, str)]
image_count = len([item for item in msg.content if isinstance(item, tuple)])
full_content = "".join(text_parts) if text_parts else ""
content_type = f"混合({len(text_parts)}段文本, {image_count}张图片)"
else:
full_content = str(msg.content)
content_type = "未知"
# 构建单条消息的日志信息
msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
# msg_info = f"\n========================================\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n-----------------------------"
msg_info = (
f"\n========================================\n[消息 {idx}] 角色: {role_name}\n-----------------------------"
)
if full_content:
msg_info += f"\n{full_content}"
# if full_content:
# msg_info += f"\n{full_content}"
if msg.content:
msg_info += f"\n{msg.content}"
if msg.tool_calls:
msg_info += f"\n 工具调用: {len(msg.tool_calls)}"
for tool_call in msg.tool_calls:
msg_info += f"\n - {tool_call}"
msg_info += f"\n - {tool_call.func_name}: {json.dumps(tool_call.args, ensure_ascii=False)}"
if msg.tool_call_id:
msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
# if msg.tool_call_id:
# msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info)
@ -251,6 +231,7 @@ async def _react_agent_solve_question(
conversation_messages: List[Message] = []
first_head_prompt: Optional[str] = None # 保存第一次使用的head_prompt用于日志显示
# 正常迭代max_iterations 次(最终评估单独处理,不算在迭代中)
for iteration in range(max_iterations):
# 检查超时
if time.time() - start_time > timeout:
@ -270,7 +251,6 @@ async def _react_agent_solve_question(
# 计算剩余迭代次数
current_iteration = iteration + 1
remaining_iterations = max_iterations - current_iteration
is_final_iteration = current_iteration >= max_iterations
# 提取函数调用中参数的值,支持单引号和双引号
def extract_quoted_content(text, func_name, param_name):
@ -330,114 +310,10 @@ async def _react_agent_solve_question(
return None
# 如果是最后一次迭代使用final_prompt进行总结
if is_final_iteration:
evaluation_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_final_prompt",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
eval_success, eval_response, eval_reasoning_content, eval_model_name, eval_tool_calls = await llm_api.generate_with_model_with_tools(
evaluation_prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具
request_type="memory.react.final",
)
if not eval_success:
logger.error(f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估阶段 LLM调用失败: {eval_response}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案最终评估阶段LLM调用失败",
)
return False, "最终评估阶段LLM调用失败", thinking_steps, False
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估响应: {eval_response}"
)
# 从最终评估响应中提取found_answer或not_enough_info
found_answer_content = None
not_enough_info_reason = None
if eval_response:
found_answer_content = extract_quoted_content(eval_response, "found_answer", "answer")
if not found_answer_content:
not_enough_info_reason = extract_quoted_content(eval_response, "not_enough_info", "reason")
# 如果找到答案,返回
if found_answer_content:
eval_step = {
"iteration": iteration + 1,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"]
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{found_answer_content}",
)
return True, found_answer_content, thinking_steps, False
# 如果评估为not_enough_info返回空字符串不返回任何信息
if not_enough_info_reason:
eval_step = {
"iteration": iteration + 1,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"]
}
thinking_steps.append(eval_step)
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 最终评估阶段判断信息不足: {not_enough_info_reason}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"未找到答案:{not_enough_info_reason}",
)
return False, "", thinking_steps, False
# 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息
eval_step = {
"iteration": iteration + 1,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": "已到达最后一次迭代,无法找到答案"}}],
"observations": ["已到达最后一次迭代,无法找到答案"]
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 已到达最后一次迭代,无法找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案:已到达最后一次迭代,无法找到答案",
)
return False, "", thinking_steps, False
# 前n-1次迭代使用head_prompt决定调用哪些工具包含found_answer工具
# 正常迭代使用head_prompt决定调用哪些工具包含finish_search工具
tool_definitions = tool_registry.get_tool_definitions()
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}"
)
# tool_names = [tool_def["name"] for tool_def in tool_definitions]
# logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具: {', '.join(tool_names)} (共{len(tool_definitions)}个)")
# head_prompt应该只构建一次使用初始的collected_info后续迭代都复用同一个
if first_head_prompt is None:
@ -453,7 +329,7 @@ async def _react_agent_solve_question(
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
# 后续迭代都复用第一次构建的head_prompt
head_prompt = first_head_prompt
@ -487,15 +363,15 @@ async def _react_agent_solve_question(
request_type="memory.react",
)
logger.debug(
f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
)
# logger.info(
# f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}"
# )
if not success:
logger.error(f"ReAct Agent LLM调用失败: {response}")
break
# 注意这里会检查found_answer工具调用如果检测到found_answer工具会直接返回答案
# 注意这里会检查finish_search工具调用如果检测到finish_search工具会根据found_answer参数决定返回答案或退出查询
assistant_message: Optional[Message] = None
if tool_calls:
@ -525,8 +401,113 @@ async def _react_agent_solve_question(
# 处理工具调用
if not tool_calls:
# 如果没有工具调用,记录思考过程,继续下一轮迭代(下一轮会再次评估)
# 如果没有工具调用,检查响应文本中是否包含finish_search函数调用格式
if response and response.strip():
# 尝试从文本中解析finish_search函数调用
def parse_finish_search_from_text(text: str):
"""从文本中解析finish_search函数调用返回(found_answer, answer)元组,如果未找到则返回(None, None)"""
if not text:
return None, None
# 查找finish_search函数调用位置不区分大小写
func_pattern = "finish_search"
text_lower = text.lower()
func_pos = text_lower.find(func_pattern)
if func_pos == -1:
return None, None
# 查找函数调用的开始和结束位置
# 从func_pos开始向后查找左括号
start_pos = text.find("(", func_pos)
if start_pos == -1:
return None, None
# 查找匹配的右括号(考虑嵌套)
paren_count = 0
end_pos = start_pos
for i in range(start_pos, len(text)):
if text[i] == "(":
paren_count += 1
elif text[i] == ")":
paren_count -= 1
if paren_count == 0:
end_pos = i
break
else:
# 没有找到匹配的右括号
return None, None
# 提取函数参数部分
params_text = text[start_pos + 1 : end_pos]
# 解析found_answer参数布尔值可能是true/false/True/False
found_answer = None
found_answer_patterns = [
r"found_answer\s*=\s*true",
r"found_answer\s*=\s*True",
r"found_answer\s*=\s*false",
r"found_answer\s*=\s*False",
]
for pattern in found_answer_patterns:
match = re.search(pattern, params_text, re.IGNORECASE)
if match:
found_answer = "true" in match.group(0).lower()
break
# 解析answer参数字符串使用extract_quoted_content
answer = extract_quoted_content(text, "finish_search", "answer")
return found_answer, answer
parsed_found_answer, parsed_answer = parse_finish_search_from_text(response)
if parsed_found_answer is not None:
# 检测到finish_search函数调用格式
if parsed_found_answer:
# 找到了答案
if parsed_answer:
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": parsed_answer},
}
)
step["observations"] = ["检测到finish_search文本格式调用找到答案"]
thinking_steps.append(step)
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式找到关于问题{question}的答案: {parsed_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{parsed_answer}",
)
return True, parsed_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误继续迭代
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search文本格式found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append(
{"action_type": "finish_search", "action_params": {"found_answer": False}}
)
step["observations"] = ["检测到finish_search文本格式调用未找到答案"]
thinking_steps.append(step)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search文本格式判断未找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案通过finish_search文本格式判断未找到答案",
)
return False, "", thinking_steps, False
# 如果没有检测到finish_search格式记录思考过程继续下一轮迭代
step["observations"] = [f"思考完成,但未调用工具。响应: {response}"]
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 思考完成但未调用工具: {response}")
collected_info += f"思考: {response}"
@ -537,29 +518,60 @@ async def _react_agent_solve_question(
continue
# 处理工具调用
# 首先检查是否有found_answer工具调用如果有则立即返回不再处理其他工具
found_answer_from_tool = None
# 首先检查是否有finish_search工具调用如果有则立即返回不再处理其他工具
finish_search_found = None
finish_search_answer = None
for tool_call in tool_calls:
tool_name = tool_call.func_name
tool_args = tool_call.args or {}
if tool_name == "found_answer":
found_answer_from_tool = tool_args.get("answer", "")
if found_answer_from_tool:
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_from_tool}})
step["observations"] = ["检测到found_answer工具调用"]
if tool_name == "finish_search":
finish_search_found = tool_args.get("found_answer", False)
finish_search_answer = tool_args.get("answer", "")
if finish_search_found:
# 找到了答案
if finish_search_answer:
step["actions"].append(
{
"action_type": "finish_search",
"action_params": {"found_answer": True, "answer": finish_search_answer},
}
)
step["observations"] = ["检测到finish_search工具调用找到答案"]
thinking_steps.append(step)
logger.info(
f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具找到关于问题{question}的答案: {finish_search_answer}"
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{finish_search_answer}",
)
return True, finish_search_answer, thinking_steps, False
else:
# found_answer为True但没有提供answer视为错误
logger.warning(
f"ReAct Agent 第 {iteration + 1} 次迭代 finish_search工具found_answer为True但未提供answer"
)
else:
# 未找到答案,直接退出查询
step["actions"].append({"action_type": "finish_search", "action_params": {"found_answer": False}})
step["observations"] = ["检测到finish_search工具调用未找到答案"]
thinking_steps.append(step)
logger.debug(f"ReAct Agent 第 {iteration + 1} 次迭代 通过found_answer工具找到关于问题{question}的答案: {found_answer_from_tool}")
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 通过finish_search工具判断未找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{found_answer_from_tool}",
final_status="未找到答案通过finish_search工具判断未找到答案",
)
return True, found_answer_from_tool, thinking_steps, False
# 如果没有found_answer工具调用或者found_answer工具调用没有答案继续处理其他工具
return False, "", thinking_steps, False
# 如果没有finish_search工具调用,继续处理其他工具
tool_tasks = []
for i, tool_call in enumerate(tool_calls):
tool_name = tool_call.func_name
@ -569,8 +581,8 @@ async def _react_agent_solve_question(
f"ReAct Agent 第 {iteration + 1} 次迭代 工具调用 {i + 1}/{len(tool_calls)}: {tool_name}({tool_args})"
)
# 跳过found_answer工具调用(已经在上面处理过了)
if tool_name == "found_answer":
# 跳过finish_search工具调用(已经在上面处理过了)
if tool_name == "finish_search":
continue
# 普通工具调用
@ -634,7 +646,7 @@ async def _react_agent_solve_question(
observation_text += f"\n\n{jargon_info}"
collected_info += f"\n{jargon_info}\n"
logger.info(f"工具输出触发黑话解析: {new_concepts}")
tool_builder = MessageBuilder()
tool_builder.set_role(RoleType.Tool)
tool_builder.add_text_content(observation_text)
@ -643,26 +655,196 @@ async def _react_agent_solve_question(
thinking_steps.append(step)
# 达到最大迭代次数或超时但Agent没有明确返回found_answer
# 迭代超时应该直接视为not_enough_info而不是使用已有信息
# 只有Agent明确返回found_answer时才认为找到了答案
if collected_info:
logger.warning(
f"ReAct Agent达到最大迭代次数或超时但未明确返回found_answer。已收集信息: {collected_info[:100]}..."
)
# 正常迭代结束后,如果达到最大迭代次数或超时,执行最终评估
# 最终评估单独处理,不算在迭代中
should_do_final_evaluation = False
if is_timeout:
logger.warning("ReAct Agent超时直接视为not_enough_info")
else:
logger.warning("ReAct Agent达到最大迭代次数直接视为not_enough_info")
# React完成时输出消息列表
timeout_reason = "超时" if is_timeout else "达到最大迭代次数"
should_do_final_evaluation = True
logger.warning(f"ReAct Agent超时已迭代{iteration + 1}次,进入最终评估")
elif iteration + 1 >= max_iterations:
should_do_final_evaluation = True
logger.info(f"ReAct Agent达到最大迭代次数已迭代{iteration + 1}次),进入最终评估")
if should_do_final_evaluation:
# 获取必要变量用于最终评估
tool_registry = get_tool_registry()
bot_name = global_config.bot.nickname
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
current_iteration = iteration + 1
remaining_iterations = 0
# 提取函数调用中参数的值,支持单引号和双引号
def extract_quoted_content(text, func_name, param_name):
"""从文本中提取函数调用中参数的值,支持单引号和双引号
Args:
text: 要搜索的文本
func_name: 函数名 'found_answer'
param_name: 参数名 'answer'
Returns:
提取的参数值如果未找到则返回None
"""
if not text:
return None
# 查找函数调用位置(不区分大小写)
func_pattern = func_name.lower()
text_lower = text.lower()
func_pos = text_lower.find(func_pattern)
if func_pos == -1:
return None
# 查找参数名和等号
param_pattern = f"{param_name}="
param_pos = text_lower.find(param_pattern, func_pos)
if param_pos == -1:
return None
# 跳过参数名、等号和空白
start_pos = param_pos + len(param_pattern)
while start_pos < len(text) and text[start_pos] in " \t\n":
start_pos += 1
if start_pos >= len(text):
return None
# 确定引号类型
quote_char = text[start_pos]
if quote_char not in ['"', "'"]:
return None
# 查找匹配的结束引号(考虑转义)
end_pos = start_pos + 1
while end_pos < len(text):
if text[end_pos] == quote_char:
# 检查是否是转义的引号
if end_pos > start_pos + 1 and text[end_pos - 1] == "\\":
end_pos += 1
continue
# 找到匹配的引号
content = text[start_pos + 1 : end_pos]
# 处理转义字符
content = content.replace('\\"', '"').replace("\\'", "'").replace("\\\\", "\\")
return content
end_pos += 1
return None
# 执行最终评估
evaluation_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_final_prompt",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
(
eval_success,
eval_response,
eval_reasoning_content,
eval_model_name,
eval_tool_calls,
) = await llm_api.generate_with_model_with_tools(
evaluation_prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=[], # 最终评估阶段不提供工具
request_type="memory.react.final",
)
if not eval_success:
logger.error(f"ReAct Agent 最终评估阶段 LLM调用失败: {eval_response}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案最终评估阶段LLM调用失败",
)
return False, "最终评估阶段LLM调用失败", thinking_steps, is_timeout
if global_config.debug.show_memory_prompt:
logger.info(f"ReAct Agent 最终评估Prompt: {evaluation_prompt}")
logger.info(f"ReAct Agent 最终评估响应: {eval_response}")
# 从最终评估响应中提取found_answer或not_enough_info
found_answer_content = None
not_enough_info_reason = None
if eval_response:
found_answer_content = extract_quoted_content(eval_response, "found_answer", "answer")
if not found_answer_content:
not_enough_info_reason = extract_quoted_content(eval_response, "not_enough_info", "reason")
# 如果找到答案,返回(找到答案时,无论是否超时,都视为成功完成)
if found_answer_content:
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "found_answer", "action_params": {"answer": found_answer_content}}],
"observations": ["最终评估阶段检测到found_answer"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段找到关于问题{question}的答案: {found_answer_content}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"找到答案:{found_answer_content}",
)
return True, found_answer_content, thinking_steps, False
# 如果评估为not_enough_info返回空字符串不返回任何信息
if not_enough_info_reason:
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [{"action_type": "not_enough_info", "action_params": {"reason": not_enough_info_reason}}],
"observations": ["最终评估阶段检测到not_enough_info"],
}
thinking_steps.append(eval_step)
logger.info(f"ReAct Agent 最终评估阶段判断信息不足: {not_enough_info_reason}")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"未找到答案:{not_enough_info_reason}",
)
return False, "", thinking_steps, is_timeout
# 如果没有明确判断视为not_enough_info返回空字符串不返回任何信息
eval_step = {
"iteration": current_iteration,
"thought": f"[最终评估] {eval_response}",
"actions": [
{"action_type": "not_enough_info", "action_params": {"reason": "已到达最大迭代次数,无法找到答案"}}
],
"observations": ["已到达最大迭代次数,无法找到答案"],
}
thinking_steps.append(eval_step)
logger.info("ReAct Agent 已到达最大迭代次数,无法找到答案")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status="未找到答案:已到达最大迭代次数,无法找到答案",
)
return False, "", thinking_steps, is_timeout
# 如果正常迭代过程中提前找到答案返回,不会到达这里
# 如果正常迭代结束但没有触发最终评估(理论上不应该发生),直接返回
logger.warning("ReAct Agent正常迭代结束但未触发最终评估")
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"未找到答案:{timeout_reason}",
final_status="未找到答案:正常迭代结束",
)
return False, "", thinking_steps, is_timeout
@ -817,6 +999,7 @@ async def _process_single_question(
context: str,
initial_info: str = "",
initial_jargon_concepts: Optional[List[str]] = None,
max_iterations: Optional[int] = None,
) -> Optional[str]:
"""处理单个问题的查询
@ -841,11 +1024,15 @@ async def _process_single_question(
jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None
# 如果未指定max_iterations使用配置的默认值
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question,
chat_id=chat_id,
max_iterations=global_config.memory.max_agent_iterations,
timeout=120.0,
max_iterations=max_iterations,
timeout=global_config.memory.agent_timeout_seconds,
initial_info=question_initial_info,
initial_jargon_concepts=jargon_concepts_for_agent,
)
@ -874,7 +1061,7 @@ async def build_memory_retrieval_prompt(
sender: str,
target: str,
chat_stream,
tool_executor,
think_level: int = 1,
) -> str:
"""构建记忆检索提示
使用两段式查询第一步生成问题第二步使用ReAct Agent查询答案
@ -961,9 +1148,17 @@ async def build_memory_retrieval_prompt(
logger.info(f"无当次查询,不返回任何结果,耗时: {(end_time - start_time):.3f}")
return ""
# 第二步:并行处理所有问题(使用配置的最大迭代次数/120秒超时
max_iterations = global_config.memory.max_agent_iterations
logger.debug(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations},超时时间: 120秒")
# 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间)
base_max_iterations = global_config.memory.max_agent_iterations
# 根据think_level调整迭代次数think_level=1时不变think_level=0时减半
if think_level == 0:
max_iterations = max(1, base_max_iterations // 2) # 至少为1
else:
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(
f"问题数量: {len(questions)}think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}"
)
# 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [
@ -973,6 +1168,7 @@ async def build_memory_retrieval_prompt(
context=message,
initial_info=initial_info,
initial_jargon_concepts=concepts if enable_jargon_detection else None,
max_iterations=max_iterations,
)
for question in questions
]
@ -990,10 +1186,10 @@ async def build_memory_retrieval_prompt(
# 获取最近10分钟内已找到答案的缓存记录
cached_answers = _get_recent_found_answers(chat_id, time_window_seconds=600.0)
# 合并当前查询结果和缓存答案(去重:如果当前查询的问题在缓存中已存在,优先使用当前结果)
all_results = []
# 先添加当前查询的结果
current_questions = set()
for result in question_results:
@ -1003,7 +1199,7 @@ async def build_memory_retrieval_prompt(
if question_end != -1:
current_questions.add(result[4:question_end])
all_results.append(result)
# 添加缓存答案(排除当前查询中已存在的问题)
for cached_answer in cached_answers:
if cached_answer.startswith("问题:"):
@ -1031,4 +1227,3 @@ async def build_memory_retrieval_prompt(
except Exception as e:
logger.error(f"记忆检索时发生异常: {str(e)}")
return ""

View File

@ -17,7 +17,6 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
@ -68,6 +67,7 @@ def parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], []
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳

View File

@ -14,7 +14,7 @@ from .tool_registry import (
from .query_chat_history import register_tool as register_query_chat_history
from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info
from .found_answer import register_tool as register_found_answer
from .found_answer import register_tool as register_finish_search
from src.config.config import global_config
@ -22,7 +22,7 @@ def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_chat_history()
register_query_person_info()
register_found_answer() # 注册found_answer工具
register_finish_search() # 注册finish_search工具
if global_config.lpmm_knowledge.lpmm_mode == "agent":
register_lpmm_knowledge()

Some files were not shown because too many files have changed in this diff Show More