增加了event_handler修改内容的方法

pull/1228/head
UnCLAS-Prommer 2025-09-07 01:15:21 +08:00
parent 0811cff8bf
commit b636683fe4
No known key found for this signature in database
10 changed files with 215 additions and 85 deletions

View File

@ -8,7 +8,6 @@ from typing import List, Dict, Optional, Any, Tuple
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.common.data_models.database_data_model import DatabaseMessages
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages

View File

@ -1,9 +1,8 @@
import time
from typing import Optional, Dict, List
from typing import Optional, Dict
from src.plugin_system.apis import message_api
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.frequency_control.talk_frequency_control import get_config_base_talk_frequency
from src.chat.frequency_control.focus_value_control import get_config_base_focus_value

View File

@ -5,7 +5,6 @@ import math
import random
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
from rich.traceback import install
from collections import deque
from src.config.config import global_config
from src.common.logger import get_logger
@ -101,7 +100,6 @@ class HeartFChatting:
self.last_read_time = time.time() - 10
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
@ -178,7 +176,7 @@ class HeartFChatting:
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}" # type: ignore
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
)
async def caculate_interest_value(self, recent_messages_list: List["DatabaseMessages"]) -> float:
total_interest = 0.0
for msg in recent_messages_list:
@ -197,10 +195,13 @@ class HeartFChatting:
filter_mai=True,
filter_command=True,
)
if recent_messages_list:
self.last_read_time = time.time()
await self._observe(interest_value=await self.caculate_interest_value(recent_messages_list),recent_messages_list=recent_messages_list)
await self._observe(
interest_value=await self.caculate_interest_value(recent_messages_list),
recent_messages_list=recent_messages_list,
)
else:
# Normal模式消息数量不足等待
await asyncio.sleep(0.2)
@ -257,7 +258,7 @@ class HeartFChatting:
return loop_info, reply_text, cycle_timers
async def _observe(self, interest_value: float = 0.0,recent_messages_list: List["DatabaseMessages"] = []) -> bool:
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool:
reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# 使用sigmoid函数将interest_value转换为概率
@ -274,12 +275,10 @@ class HeartFChatting:
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
normal_mode_probability = (
calculate_normal_mode_probability(interest_value)
* 2
* self.frequency_control.get_final_talk_frequency()
calculate_normal_mode_probability(interest_value) * 2 * self.frequency_control.get_final_talk_frequency()
)
#对呼唤名字进行增幅
# 对呼唤名字进行增幅
for msg in recent_messages_list:
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
normal_mode_probability += msg.reply_probability_boost
@ -287,18 +286,15 @@ class HeartFChatting:
normal_mode_probability += global_config.chat.mentioned_bot_reply
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
# 根据概率决定使用直接回复
interest_triggerd = False
focus_triggerd = False
if random.random() < normal_mode_probability:
interest_triggerd = True
logger.info(
f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复"
)
logger.info(f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复")
if s4u_config.enable_s4u:
await send_typing()
@ -307,21 +303,20 @@ class HeartFChatting:
await self.expression_learner.trigger_learning_for_chat()
available_actions: Dict[str, ActionInfo] = {}
#如果兴趣度不足以激活
if not interest_triggerd:
#看看专注值够不够
if random.random() < self.frequency_control.get_final_focus_value():
#专注值足够,仍然进入正式思考
focus_triggerd = True #都没触发,路边
# 如果兴趣度不足以激活
if not interest_triggerd:
# 看看专注值够不够
if random.random() < self.frequency_control.get_final_focus_value():
# 专注值足够,仍然进入正式思考
focus_triggerd = True # 都没触发,路边
# 任意一种触发都行
if interest_triggerd or focus_triggerd:
# 进入正式思考模式
cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
# 第一步:动作检查
try:
await self.action_modifier.modify_actions()
@ -353,17 +348,20 @@ class HeartFChatting:
# actions_before_now_block=actions_before_now_block,
message_id_list=message_id_list,
)
if not await events_manager.handle_mai_events(
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
):
)
if not continue_flag:
return False
if modified_message and modified_message._modify_flags.modify_llm_prompt:
prompt_info = (modified_message.llm_prompt, prompt_info[1])
with Timer("规划器", cycle_timers):
# 根据不同触发进入不同plan
if focus_triggerd:
mode = ChatMode.FOCUS
else:
mode = ChatMode.NORMAL
action_to_use_info, _ = await self.action_planner.plan(
mode=mode,
loop_start_time=self.last_read_time,
@ -432,8 +430,7 @@ class HeartFChatting:
},
}
reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)

View File

@ -25,7 +25,6 @@ from rich.progress import (
SpinnerColumn,
TextColumn,
)
from src.chat.utils.utils import get_embedding
from src.config.config import global_config

View File

@ -7,7 +7,7 @@ import re
import jieba
import networkx as nx
import numpy as np
from typing import List, Tuple, Set, Coroutine, Any, Dict
from typing import List, Tuple, Set, Coroutine, Any
from collections import Counter
import traceback
@ -21,7 +21,6 @@ from src.common.logger import get_logger
from src.chat.utils.utils import cut_key_words
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_by_timestamp_with_chat_inclusive,
) # 导入 build_readable_messages
@ -1183,9 +1182,7 @@ class ParahippocampalGyrus:
# 规范化输入为列表[str]
if isinstance(keywords, str):
# 支持中英文逗号、顿号、空格分隔
parts = (
keywords.replace("", ",").replace("", ",").replace(" ", ",").strip(", ")
)
parts = keywords.replace("", ",").replace("", ",").replace(" ", ",").strip(", ")
keyword_list = [p.strip() for p in parts.split(",") if p.strip()]
else:
keyword_list = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]

View File

@ -3,7 +3,7 @@ import os
import re
from typing import Dict, Any, Optional
from maim_message import UserInfo
from maim_message import UserInfo, Seg
from src.common.logger import get_logger
from src.config.config import global_config
@ -169,8 +169,12 @@ class ChatBot:
# 处理消息内容
await message.process()
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
_ = Person.register_person(
platform=message.message_info.platform,
user_id=message.message_info.user_info.user_id,
nickname=user_info.user_nickname,
) # type: ignore
await self.s4u_message_processor.process_message(message)
@ -220,10 +224,18 @@ class ChatBot:
user_info = message.message_info.user_info
if message.message_info.additional_config:
sent_message = message.message_info.additional_config.get("echo", False)
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
if sent_message: # 处理上报的自身消息更新message_id需要ada支持上报事件
await MessageStorage.update_message(message)
return
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.ON_MESSAGE_PRE_PROCESS, message
)
if not continue_flag:
return
if modified_message and modified_message._modify_flags.modify_message_segments:
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
@ -258,8 +270,11 @@ class ChatBot:
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
return
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
if not continue_flag:
return
if modified_message and modified_message._modify_flags.modify_plain_text:
message.processed_plain_text = modified_message.plain_text
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:

View File

@ -202,10 +202,14 @@ class DefaultReplyer:
from src.plugin_system.core.events_manager import events_manager
if not from_plugin:
if not await events_manager.handle_mai_events(
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
):
)
if not continue_flag:
raise UserWarning("插件于请求前中断了内容生成")
if modified_message and modified_message._modify_flags.modify_llm_prompt:
llm_response.prompt = modified_message.llm_prompt
prompt = str(modified_message.llm_prompt)
# 4. 调用 LLM 生成回复
content = None
@ -219,10 +223,19 @@ class DefaultReplyer:
llm_response.reasoning = reasoning_content
llm_response.model = model_name
llm_response.tool_calls = tool_call
if not from_plugin and not await events_manager.handle_mai_events(
continue_flag, modified_message = await events_manager.handle_mai_events(
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
):
)
if not from_plugin and not continue_flag:
raise UserWarning("插件于请求后取消了内容生成")
if modified_message:
if modified_message._modify_flags.modify_llm_prompt:
logger.warning("警告插件在内容生成后才修改了prompt此修改不会生效")
llm_response.prompt = modified_message.llm_prompt # 虽然我不知道为什么在这里需要改prompt
if modified_message._modify_flags.modify_llm_response_content:
llm_response.content = modified_message.llm_response_content
if modified_message._modify_flags.modify_llm_response_reasoning:
llm_response.reasoning = modified_message.llm_response_reasoning
except UserWarning as e:
raise e
except Exception as llm_e:
@ -634,7 +647,7 @@ class DefaultReplyer:
"""构建动作提示"""
action_descriptions = ""
skip_names = ["emoji","build_memory","build_relation","reply"]
skip_names = ["emoji", "build_memory", "build_relation", "reply"]
if available_actions:
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
for action_name, action_info in available_actions.items():
@ -671,9 +684,7 @@ class DefaultReplyer:
else:
bot_nickname = ""
prompt_personality = (
f"{global_config.personality.personality};"
)
prompt_personality = f"{global_config.personality.personality};"
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
async def build_prompt_reply_context(
@ -809,11 +820,6 @@ class DefaultReplyer:
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
if sender:
if is_group_chat:
reply_target_block = (
@ -1016,7 +1022,7 @@ class DefaultReplyer:
logger.info(f"使用模型集生成回复: {', '.join(map(str, self.express_model.model_for_task.model_list))}")
logger.info(f"\n{prompt}\n")
if global_config.debug.show_prompt:
logger.info(f"\n{prompt}\n")
else:

View File

@ -37,19 +37,19 @@ class BaseEventHandler(ABC):
@abstractmethod
async def execute(
self, message: MaiMessages | None
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]:
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
"""执行事件处理的抽象方法,子类必须实现
Args:
message (MaiMessages | None): 事件消息对象当你注册的事件为ON_START和ON_STOP时message为None
Returns:
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果)
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果可选的修改后消息)
"""
raise NotImplementedError("子类必须实现 execute 方法")
@classmethod
def get_handler_info(cls) -> "EventHandlerInfo":
"""获取事件处理器的信息"""
# 从类属性读取名称,如果没有定义则使用类名自动生成
# 从类属性读取名称,如果没有定义则使用类名自动生成S
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
if "." in name:
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")

View File

@ -1,4 +1,5 @@
import copy
import warnings
from enum import Enum
from typing import Dict, Any, List, Optional, Tuple
from dataclasses import dataclass, field
@ -7,6 +8,7 @@ from maim_message import Seg
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
# 组件类型枚举
class ComponentType(Enum):
"""组件类型枚举"""
@ -56,6 +58,7 @@ class EventType(Enum):
ON_START = "on_start" # 启动事件,用于调用按时任务
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
ON_MESSAGE_PRE_PROCESS = "on_message_pre_process"
ON_MESSAGE = "on_message"
ON_PLAN = "on_plan"
POST_LLM = "post_llm"
@ -116,9 +119,9 @@ class ActionInfo(ComponentInfo):
action_require: List[str] = field(default_factory=list) # 动作需求说明
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
# 激活类型相关
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
activation_type: ActionActivationType = ActionActivationType.ALWAYS
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
activation_type: ActionActivationType = ActionActivationType.ALWAYS
random_activation_probability: float = 0.0
llm_judge_prompt: str = ""
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
@ -154,7 +157,9 @@ class CommandInfo(ComponentInfo):
class ToolInfo(ComponentInfo):
"""工具组件信息"""
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
default_factory=list
) # 工具参数定义
tool_description: str = "" # 工具描述
def __post_init__(self):
@ -233,6 +238,15 @@ class PluginInfo:
return [dep.get_pip_requirement() for dep in self.python_dependencies]
@dataclass
class ModifyFlag:
modify_message_segments: bool = False
modify_plain_text: bool = False
modify_llm_prompt: bool = False
modify_llm_response_content: bool = False
modify_llm_response_reasoning: bool = False
@dataclass
class MaiMessages:
"""MaiM插件消息"""
@ -263,31 +277,129 @@ class MaiMessages:
llm_response_content: Optional[str] = None
"""LLM响应内容"""
llm_response_reasoning: Optional[str] = None
"""LLM响应推理内容"""
llm_response_model: Optional[str] = None
"""LLM响应模型名称"""
llm_response_tool_call: Optional[List[ToolCall]] = None
"""LLM使用的工具调用"""
action_usage: Optional[List[str]] = None
"""使用的Action"""
additional_data: Dict[Any, Any] = field(default_factory=dict)
"""附加数据,可以存储额外信息"""
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []
def deepcopy(self):
return copy.deepcopy(self)
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
"""
修改消息段列表
Warning:
在生成了plain_text的情况下调用此方法可能会导致plain_text内容与消息段不一致
Args:
new_segments (List[Seg]): 新的消息段列表
"""
if self.plain_text and not suppress_warning:
warnings.warn(
"修改消息段后plain_text可能与消息段内容不一致建议同时更新plain_text",
UserWarning,
stacklevel=2,
)
self.message_segments = new_segments
self._modify_flags.modify_message_segments = True
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
"""
修改LLM提示词
Warning:
在没有生成llm_prompt的情况下调用此方法可能会导致修改无效
Args:
new_prompt (str): 新的提示词内容
"""
if self.llm_prompt is None and not suppress_warning:
warnings.warn(
"当前llm_prompt为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.llm_prompt = new_prompt
self._modify_flags.modify_llm_prompt = True
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
"""
修改生成的plain_text内容
Warning:
在未生成plain_text的情况下调用此方法可能会导致plain_text为空或者修改无效
Args:
new_text (str): 新的纯文本内容
"""
if not self.plain_text and not suppress_warning:
warnings.warn(
"当前plain_text为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.plain_text = new_text
self._modify_flags.modify_plain_text = True
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
"""
修改生成的llm_response_content内容
Warning:
在未生成llm_response_content的情况下调用此方法可能会导致llm_response_content为空或者修改无效
Args:
new_content (str): 新的LLM响应内容
"""
if not self.llm_response_content and not suppress_warning:
warnings.warn(
"当前llm_response_content为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.llm_response_content = new_content
self._modify_flags.modify_llm_response_content = True
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
"""
修改生成的llm_response_reasoning内容
Warning:
在未生成llm_response_reasoning的情况下调用此方法可能会导致llm_response_reasoning为空或者修改无效
Args:
new_reasoning (str): 新的LLM响应推理内容
"""
if not self.llm_response_reasoning and not suppress_warning:
warnings.warn(
"当前llm_response_reasoning为空此时调用方法可能导致修改无效",
UserWarning,
stacklevel=2,
)
self.llm_response_reasoning = new_reasoning
self._modify_flags.modify_llm_response_reasoning = True
@dataclass
class CustomEventHandlerResult:
message: str = ""
timestamp: float = 0.0
extra_info: Optional[Dict] = None
extra_info: Optional[Dict] = None

View File

@ -71,7 +71,7 @@ class EventsManager:
llm_response: Optional["LLMGenerationDataModel"] = None,
stream_id: Optional[str] = None,
action_usage: Optional[List[str]] = None,
) -> bool:
) -> Tuple[bool, Optional[MaiMessages]]:
"""
处理所有事件根据事件类型分发给订阅的处理器
"""
@ -89,10 +89,10 @@ class EventsManager:
# 2. 获取并遍历处理器
handlers = self._events_subscribers.get(event_type, [])
if not handlers:
return True
return True, None
current_stream_id = transformed_message.stream_id if transformed_message else None
modified_message: Optional[MaiMessages] = None
for handler in handlers:
# 3. 前置检查和配置加载
if (
@ -107,15 +107,19 @@ class EventsManager:
handler.set_plugin_config(plugin_config)
# 4. 根据类型分发任务
if handler.intercept_message or event_type == EventType.ON_STOP: # 让ON_STOP的所有事件处理器都发挥作用防止还没执行即被取消
if (
handler.intercept_message or event_type == EventType.ON_STOP
): # 让ON_STOP的所有事件处理器都发挥作用防止还没执行即被取消
# 阻塞执行,并更新 continue_flag
should_continue = await self._dispatch_intercepting_handler(handler, event_type, transformed_message)
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
handler, event_type, modified_message or transformed_message
)
continue_flag = continue_flag and should_continue
else:
# 异步执行,不阻塞
self._dispatch_handler_task(handler, event_type, transformed_message)
return continue_flag
return continue_flag, modified_message
async def cancel_handler_tasks(self, handler_name: str) -> None:
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
@ -327,16 +331,18 @@ class EventsManager:
except Exception as e:
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
async def _dispatch_intercepting_handler(
async def _dispatch_intercepting_handler_task(
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
) -> bool:
) -> Tuple[bool, Optional[MaiMessages]]:
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
if event_type == EventType.UNKNOWN:
raise ValueError("未知事件类型")
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
success, continue_processing, return_message, custom_result = await handler.execute(message)
success, continue_processing, return_message, custom_result, modified_message = await handler.execute(
message
)
if not success:
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
@ -345,17 +351,17 @@ class EventsManager:
if self._history_enable_map[event_type] and custom_result:
self._events_result_history[event_type].append(custom_result)
return continue_processing
return continue_processing, modified_message
except KeyError:
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
return True
return True, None
except Exception as e:
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
return True # 发生异常时默认不中断其他处理
return True, None # 发生异常时默认不中断其他处理
def _task_done_callback(
self,
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None]],
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
event_type: EventType | str,
):
"""任务完成回调"""
@ -365,7 +371,7 @@ class EventsManager:
if event_type not in self._history_enable_map:
raise ValueError(f"事件类型 {event_type} 未注册")
try:
success, _, result, custom_result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
if success:
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
else: