mirror of https://github.com/Mai-with-u/MaiBot.git
增加了event_handler修改内容的方法
parent
0811cff8bf
commit
b636683fe4
|
|
@ -8,7 +8,6 @@ from typing import List, Dict, Optional, Any, Tuple
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
|
|
|
|||
|
|
@ -1,9 +1,8 @@
|
|||
import time
|
||||
from typing import Optional, Dict, List
|
||||
from typing import Optional, Dict
|
||||
from src.plugin_system.apis import message_api
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.frequency_control.talk_frequency_control import get_config_base_talk_frequency
|
||||
from src.chat.frequency_control.focus_value_control import get_config_base_focus_value
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import math
|
|||
import random
|
||||
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from collections import deque
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
|
|
@ -101,7 +100,6 @@ class HeartFChatting:
|
|||
|
||||
self.last_read_time = time.time() - 10
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||
|
||||
|
|
@ -178,7 +176,7 @@ class HeartFChatting:
|
|||
f"耗时: {self._current_cycle_detail.end_time - self._current_cycle_detail.start_time:.1f}秒" # type: ignore
|
||||
+ (f"\n详情: {'; '.join(timer_strings)}" if timer_strings else "")
|
||||
)
|
||||
|
||||
|
||||
async def caculate_interest_value(self, recent_messages_list: List["DatabaseMessages"]) -> float:
|
||||
total_interest = 0.0
|
||||
for msg in recent_messages_list:
|
||||
|
|
@ -197,10 +195,13 @@ class HeartFChatting:
|
|||
filter_mai=True,
|
||||
filter_command=True,
|
||||
)
|
||||
|
||||
|
||||
if recent_messages_list:
|
||||
self.last_read_time = time.time()
|
||||
await self._observe(interest_value=await self.caculate_interest_value(recent_messages_list),recent_messages_list=recent_messages_list)
|
||||
await self._observe(
|
||||
interest_value=await self.caculate_interest_value(recent_messages_list),
|
||||
recent_messages_list=recent_messages_list,
|
||||
)
|
||||
else:
|
||||
# Normal模式:消息数量不足,等待
|
||||
await asyncio.sleep(0.2)
|
||||
|
|
@ -257,7 +258,7 @@ class HeartFChatting:
|
|||
|
||||
return loop_info, reply_text, cycle_timers
|
||||
|
||||
async def _observe(self, interest_value: float = 0.0,recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
||||
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||
|
||||
# 使用sigmoid函数将interest_value转换为概率
|
||||
|
|
@ -274,12 +275,10 @@ class HeartFChatting:
|
|||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||
|
||||
normal_mode_probability = (
|
||||
calculate_normal_mode_probability(interest_value)
|
||||
* 2
|
||||
* self.frequency_control.get_final_talk_frequency()
|
||||
calculate_normal_mode_probability(interest_value) * 2 * self.frequency_control.get_final_talk_frequency()
|
||||
)
|
||||
|
||||
#对呼唤名字进行增幅
|
||||
|
||||
# 对呼唤名字进行增幅
|
||||
for msg in recent_messages_list:
|
||||
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
|
||||
normal_mode_probability += msg.reply_probability_boost
|
||||
|
|
@ -287,18 +286,15 @@ class HeartFChatting:
|
|||
normal_mode_probability += global_config.chat.mentioned_bot_reply
|
||||
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
|
||||
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
|
||||
|
||||
|
||||
# 根据概率决定使用直接回复
|
||||
interest_triggerd = False
|
||||
focus_triggerd = False
|
||||
|
||||
|
||||
if random.random() < normal_mode_probability:
|
||||
interest_triggerd = True
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复")
|
||||
|
||||
if s4u_config.enable_s4u:
|
||||
await send_typing()
|
||||
|
|
@ -307,21 +303,20 @@ class HeartFChatting:
|
|||
await self.expression_learner.trigger_learning_for_chat()
|
||||
|
||||
available_actions: Dict[str, ActionInfo] = {}
|
||||
|
||||
#如果兴趣度不足以激活
|
||||
if not interest_triggerd:
|
||||
#看看专注值够不够
|
||||
if random.random() < self.frequency_control.get_final_focus_value():
|
||||
#专注值足够,仍然进入正式思考
|
||||
focus_triggerd = True #都没触发,路边
|
||||
|
||||
|
||||
# 如果兴趣度不足以激活
|
||||
if not interest_triggerd:
|
||||
# 看看专注值够不够
|
||||
if random.random() < self.frequency_control.get_final_focus_value():
|
||||
# 专注值足够,仍然进入正式思考
|
||||
focus_triggerd = True # 都没触发,路边
|
||||
|
||||
# 任意一种触发都行
|
||||
if interest_triggerd or focus_triggerd:
|
||||
# 进入正式思考模式
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")
|
||||
|
||||
|
||||
# 第一步:动作检查
|
||||
try:
|
||||
await self.action_modifier.modify_actions()
|
||||
|
|
@ -353,17 +348,20 @@ class HeartFChatting:
|
|||
# actions_before_now_block=actions_before_now_block,
|
||||
message_id_list=message_id_list,
|
||||
)
|
||||
if not await events_manager.handle_mai_events(
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||
):
|
||||
)
|
||||
if not continue_flag:
|
||||
return False
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||
with Timer("规划器", cycle_timers):
|
||||
# 根据不同触发,进入不同plan
|
||||
if focus_triggerd:
|
||||
mode = ChatMode.FOCUS
|
||||
else:
|
||||
mode = ChatMode.NORMAL
|
||||
|
||||
|
||||
action_to_use_info, _ = await self.action_planner.plan(
|
||||
mode=mode,
|
||||
loop_start_time=self.last_read_time,
|
||||
|
|
@ -432,8 +430,7 @@ class HeartFChatting:
|
|||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from rich.progress import (
|
|||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import re
|
|||
import jieba
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
||||
from typing import List, Tuple, Set, Coroutine, Any
|
||||
from collections import Counter
|
||||
import traceback
|
||||
|
||||
|
|
@ -21,7 +21,6 @@ from src.common.logger import get_logger
|
|||
from src.chat.utils.utils import cut_key_words
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
) # 导入 build_readable_messages
|
||||
|
||||
|
||||
|
|
@ -1183,9 +1182,7 @@ class ParahippocampalGyrus:
|
|||
# 规范化输入为列表[str]
|
||||
if isinstance(keywords, str):
|
||||
# 支持中英文逗号、顿号、空格分隔
|
||||
parts = (
|
||||
keywords.replace(",", ",").replace("、", ",").replace(" ", ",").strip(", ")
|
||||
)
|
||||
parts = keywords.replace(",", ",").replace("、", ",").replace(" ", ",").strip(", ")
|
||||
keyword_list = [p.strip() for p in parts.split(",") if p.strip()]
|
||||
else:
|
||||
keyword_list = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import os
|
|||
import re
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
from maim_message import UserInfo, Seg
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
|
@ -169,8 +169,12 @@ class ChatBot:
|
|||
|
||||
# 处理消息内容
|
||||
await message.process()
|
||||
|
||||
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
|
||||
|
||||
_ = Person.register_person(
|
||||
platform=message.message_info.platform,
|
||||
user_id=message.message_info.user_info.user_id,
|
||||
nickname=user_info.user_nickname,
|
||||
) # type: ignore
|
||||
|
||||
await self.s4u_message_processor.process_message(message)
|
||||
|
||||
|
|
@ -220,10 +224,18 @@ class ChatBot:
|
|||
user_info = message.message_info.user_info
|
||||
if message.message_info.additional_config:
|
||||
sent_message = message.message_info.additional_config.get("echo", False)
|
||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
||||
if sent_message: # 处理上报的自身消息,更新message_id,需要ada支持上报事件
|
||||
await MessageStorage.update_message(message)
|
||||
return
|
||||
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.ON_MESSAGE_PRE_PROCESS, message
|
||||
)
|
||||
if not continue_flag:
|
||||
return
|
||||
if modified_message and modified_message._modify_flags.modify_message_segments:
|
||||
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||
|
||||
get_chat_manager().register_message(message)
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
|
|
@ -258,8 +270,11 @@ class ChatBot:
|
|||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||
return
|
||||
|
||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||
if not continue_flag:
|
||||
return
|
||||
if modified_message and modified_message._modify_flags.modify_plain_text:
|
||||
message.processed_plain_text = modified_message.plain_text
|
||||
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
|
|
|
|||
|
|
@ -202,10 +202,14 @@ class DefaultReplyer:
|
|||
from src.plugin_system.core.events_manager import events_manager
|
||||
|
||||
if not from_plugin:
|
||||
if not await events_manager.handle_mai_events(
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
||||
):
|
||||
)
|
||||
if not continue_flag:
|
||||
raise UserWarning("插件于请求前中断了内容生成")
|
||||
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||
llm_response.prompt = modified_message.llm_prompt
|
||||
prompt = str(modified_message.llm_prompt)
|
||||
|
||||
# 4. 调用 LLM 生成回复
|
||||
content = None
|
||||
|
|
@ -219,10 +223,19 @@ class DefaultReplyer:
|
|||
llm_response.reasoning = reasoning_content
|
||||
llm_response.model = model_name
|
||||
llm_response.tool_calls = tool_call
|
||||
if not from_plugin and not await events_manager.handle_mai_events(
|
||||
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||
):
|
||||
)
|
||||
if not from_plugin and not continue_flag:
|
||||
raise UserWarning("插件于请求后取消了内容生成")
|
||||
if modified_message:
|
||||
if modified_message._modify_flags.modify_llm_prompt:
|
||||
logger.warning("警告:插件在内容生成后才修改了prompt,此修改不会生效")
|
||||
llm_response.prompt = modified_message.llm_prompt # 虽然我不知道为什么在这里需要改prompt
|
||||
if modified_message._modify_flags.modify_llm_response_content:
|
||||
llm_response.content = modified_message.llm_response_content
|
||||
if modified_message._modify_flags.modify_llm_response_reasoning:
|
||||
llm_response.reasoning = modified_message.llm_response_reasoning
|
||||
except UserWarning as e:
|
||||
raise e
|
||||
except Exception as llm_e:
|
||||
|
|
@ -634,7 +647,7 @@ class DefaultReplyer:
|
|||
"""构建动作提示"""
|
||||
|
||||
action_descriptions = ""
|
||||
skip_names = ["emoji","build_memory","build_relation","reply"]
|
||||
skip_names = ["emoji", "build_memory", "build_relation", "reply"]
|
||||
if available_actions:
|
||||
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
|
||||
for action_name, action_info in available_actions.items():
|
||||
|
|
@ -671,9 +684,7 @@ class DefaultReplyer:
|
|||
else:
|
||||
bot_nickname = ""
|
||||
|
||||
prompt_personality = (
|
||||
f"{global_config.personality.personality};"
|
||||
)
|
||||
prompt_personality = f"{global_config.personality.personality};"
|
||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||
|
||||
async def build_prompt_reply_context(
|
||||
|
|
@ -809,11 +820,6 @@ class DefaultReplyer:
|
|||
|
||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if sender:
|
||||
if is_group_chat:
|
||||
reply_target_block = (
|
||||
|
|
@ -1016,7 +1022,7 @@ class DefaultReplyer:
|
|||
logger.info(f"使用模型集生成回复: {', '.join(map(str, self.express_model.model_for_task.model_list))}")
|
||||
|
||||
logger.info(f"\n{prompt}\n")
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"\n{prompt}\n")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -37,19 +37,19 @@ class BaseEventHandler(ABC):
|
|||
@abstractmethod
|
||||
async def execute(
|
||||
self, message: MaiMessages | None
|
||||
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]:
|
||||
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
|
||||
"""执行事件处理的抽象方法,子类必须实现
|
||||
Args:
|
||||
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
||||
Returns:
|
||||
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果)
|
||||
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
|
||||
"""
|
||||
raise NotImplementedError("子类必须实现 execute 方法")
|
||||
|
||||
@classmethod
|
||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||
"""获取事件处理器的信息"""
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
||||
# 从类属性读取名称,如果没有定义则使用类名自动生成S
|
||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||
if "." in name:
|
||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import copy
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
|
|
@ -7,6 +8,7 @@ from maim_message import Seg
|
|||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||
|
||||
|
||||
# 组件类型枚举
|
||||
class ComponentType(Enum):
|
||||
"""组件类型枚举"""
|
||||
|
|
@ -56,6 +58,7 @@ class EventType(Enum):
|
|||
|
||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
||||
ON_MESSAGE_PRE_PROCESS = "on_message_pre_process"
|
||||
ON_MESSAGE = "on_message"
|
||||
ON_PLAN = "on_plan"
|
||||
POST_LLM = "post_llm"
|
||||
|
|
@ -116,9 +119,9 @@ class ActionInfo(ComponentInfo):
|
|||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||
# 激活类型相关
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||
random_activation_probability: float = 0.0
|
||||
llm_judge_prompt: str = ""
|
||||
activation_keywords: List[str] = field(default_factory=list) # 激活关键词列表
|
||||
|
|
@ -154,7 +157,9 @@ class CommandInfo(ComponentInfo):
|
|||
class ToolInfo(ComponentInfo):
|
||||
"""工具组件信息"""
|
||||
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||
default_factory=list
|
||||
) # 工具参数定义
|
||||
tool_description: str = "" # 工具描述
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
@ -233,6 +238,15 @@ class PluginInfo:
|
|||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModifyFlag:
|
||||
modify_message_segments: bool = False
|
||||
modify_plain_text: bool = False
|
||||
modify_llm_prompt: bool = False
|
||||
modify_llm_response_content: bool = False
|
||||
modify_llm_response_reasoning: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaiMessages:
|
||||
"""MaiM插件消息"""
|
||||
|
|
@ -263,31 +277,129 @@ class MaiMessages:
|
|||
|
||||
llm_response_content: Optional[str] = None
|
||||
"""LLM响应内容"""
|
||||
|
||||
|
||||
llm_response_reasoning: Optional[str] = None
|
||||
"""LLM响应推理内容"""
|
||||
|
||||
|
||||
llm_response_model: Optional[str] = None
|
||||
"""LLM响应模型名称"""
|
||||
|
||||
|
||||
llm_response_tool_call: Optional[List[ToolCall]] = None
|
||||
"""LLM使用的工具调用"""
|
||||
|
||||
|
||||
action_usage: Optional[List[str]] = None
|
||||
"""使用的Action"""
|
||||
|
||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||
"""附加数据,可以存储额外信息"""
|
||||
|
||||
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.message_segments is None:
|
||||
self.message_segments = []
|
||||
|
||||
|
||||
def deepcopy(self):
|
||||
return copy.deepcopy(self)
|
||||
|
||||
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
|
||||
"""
|
||||
修改消息段列表
|
||||
|
||||
Warning:
|
||||
在生成了plain_text的情况下调用此方法,可能会导致plain_text内容与消息段不一致
|
||||
|
||||
Args:
|
||||
new_segments (List[Seg]): 新的消息段列表
|
||||
"""
|
||||
if self.plain_text and not suppress_warning:
|
||||
warnings.warn(
|
||||
"修改消息段后,plain_text可能与消息段内容不一致,建议同时更新plain_text",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.message_segments = new_segments
|
||||
self._modify_flags.modify_message_segments = True
|
||||
|
||||
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改LLM提示词
|
||||
|
||||
Warning:
|
||||
在没有生成llm_prompt的情况下调用此方法,可能会导致修改无效
|
||||
|
||||
Args:
|
||||
new_prompt (str): 新的提示词内容
|
||||
"""
|
||||
if self.llm_prompt is None and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前llm_prompt为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.llm_prompt = new_prompt
|
||||
self._modify_flags.modify_llm_prompt = True
|
||||
|
||||
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的plain_text内容
|
||||
|
||||
Warning:
|
||||
在未生成plain_text的情况下调用此方法,可能会导致plain_text为空或者修改无效
|
||||
|
||||
Args:
|
||||
new_text (str): 新的纯文本内容
|
||||
"""
|
||||
if not self.plain_text and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前plain_text为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.plain_text = new_text
|
||||
self._modify_flags.modify_plain_text = True
|
||||
|
||||
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的llm_response_content内容
|
||||
|
||||
Warning:
|
||||
在未生成llm_response_content的情况下调用此方法,可能会导致llm_response_content为空或者修改无效
|
||||
|
||||
Args:
|
||||
new_content (str): 新的LLM响应内容
|
||||
"""
|
||||
if not self.llm_response_content and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前llm_response_content为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.llm_response_content = new_content
|
||||
self._modify_flags.modify_llm_response_content = True
|
||||
|
||||
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
|
||||
"""
|
||||
修改生成的llm_response_reasoning内容
|
||||
|
||||
Warning:
|
||||
在未生成llm_response_reasoning的情况下调用此方法,可能会导致llm_response_reasoning为空或者修改无效
|
||||
|
||||
Args:
|
||||
new_reasoning (str): 新的LLM响应推理内容
|
||||
"""
|
||||
if not self.llm_response_reasoning and not suppress_warning:
|
||||
warnings.warn(
|
||||
"当前llm_response_reasoning为空,此时调用方法可能导致修改无效",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self.llm_response_reasoning = new_reasoning
|
||||
self._modify_flags.modify_llm_response_reasoning = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CustomEventHandlerResult:
|
||||
message: str = ""
|
||||
timestamp: float = 0.0
|
||||
extra_info: Optional[Dict] = None
|
||||
extra_info: Optional[Dict] = None
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ class EventsManager:
|
|||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||
stream_id: Optional[str] = None,
|
||||
action_usage: Optional[List[str]] = None,
|
||||
) -> bool:
|
||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||
"""
|
||||
处理所有事件,根据事件类型分发给订阅的处理器。
|
||||
"""
|
||||
|
|
@ -89,10 +89,10 @@ class EventsManager:
|
|||
# 2. 获取并遍历处理器
|
||||
handlers = self._events_subscribers.get(event_type, [])
|
||||
if not handlers:
|
||||
return True
|
||||
return True, None
|
||||
|
||||
current_stream_id = transformed_message.stream_id if transformed_message else None
|
||||
|
||||
modified_message: Optional[MaiMessages] = None
|
||||
for handler in handlers:
|
||||
# 3. 前置检查和配置加载
|
||||
if (
|
||||
|
|
@ -107,15 +107,19 @@ class EventsManager:
|
|||
handler.set_plugin_config(plugin_config)
|
||||
|
||||
# 4. 根据类型分发任务
|
||||
if handler.intercept_message or event_type == EventType.ON_STOP: # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
||||
if (
|
||||
handler.intercept_message or event_type == EventType.ON_STOP
|
||||
): # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
||||
# 阻塞执行,并更新 continue_flag
|
||||
should_continue = await self._dispatch_intercepting_handler(handler, event_type, transformed_message)
|
||||
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
|
||||
handler, event_type, modified_message or transformed_message
|
||||
)
|
||||
continue_flag = continue_flag and should_continue
|
||||
else:
|
||||
# 异步执行,不阻塞
|
||||
self._dispatch_handler_task(handler, event_type, transformed_message)
|
||||
|
||||
return continue_flag
|
||||
return continue_flag, modified_message
|
||||
|
||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||
|
|
@ -327,16 +331,18 @@ class EventsManager:
|
|||
except Exception as e:
|
||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
||||
|
||||
async def _dispatch_intercepting_handler(
|
||||
async def _dispatch_intercepting_handler_task(
|
||||
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
||||
) -> bool:
|
||||
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
||||
if event_type == EventType.UNKNOWN:
|
||||
raise ValueError("未知事件类型")
|
||||
if event_type not in self._history_enable_map:
|
||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||
try:
|
||||
success, continue_processing, return_message, custom_result = await handler.execute(message)
|
||||
success, continue_processing, return_message, custom_result, modified_message = await handler.execute(
|
||||
message
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
|
||||
|
|
@ -345,17 +351,17 @@ class EventsManager:
|
|||
|
||||
if self._history_enable_map[event_type] and custom_result:
|
||||
self._events_result_history[event_type].append(custom_result)
|
||||
return continue_processing
|
||||
return continue_processing, modified_message
|
||||
except KeyError:
|
||||
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
||||
return True
|
||||
return True, None
|
||||
except Exception as e:
|
||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
||||
return True # 发生异常时默认不中断其他处理
|
||||
return True, None # 发生异常时默认不中断其他处理
|
||||
|
||||
def _task_done_callback(
|
||||
self,
|
||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None]],
|
||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
||||
event_type: EventType | str,
|
||||
):
|
||||
"""任务完成回调"""
|
||||
|
|
@ -365,7 +371,7 @@ class EventsManager:
|
|||
if event_type not in self._history_enable_map:
|
||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||
try:
|
||||
success, _, result, custom_result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
||||
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
|
||||
if success:
|
||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue