mirror of https://github.com/Mai-with-u/MaiBot.git
增加了event_handler修改内容的方法
parent
0811cff8bf
commit
b636683fe4
|
|
@ -8,7 +8,6 @@ from typing import List, Dict, Optional, Any, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import model_config, global_config
|
from src.config.config import model_config, global_config
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,8 @@
|
||||||
import time
|
import time
|
||||||
from typing import Optional, Dict, List
|
from typing import Optional, Dict
|
||||||
from src.plugin_system.apis import message_api
|
from src.plugin_system.apis import message_api
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.frequency_control.talk_frequency_control import get_config_base_talk_frequency
|
from src.chat.frequency_control.talk_frequency_control import get_config_base_talk_frequency
|
||||||
from src.chat.frequency_control.focus_value_control import get_config_base_focus_value
|
from src.chat.frequency_control.focus_value_control import get_config_base_focus_value
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import math
|
||||||
import random
|
import random
|
||||||
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from collections import deque
|
|
||||||
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
@ -101,7 +100,6 @@ class HeartFChatting:
|
||||||
|
|
||||||
self.last_read_time = time.time() - 10
|
self.last_read_time = time.time() - 10
|
||||||
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""检查是否需要启动主循环,如果未激活则启动。"""
|
"""检查是否需要启动主循环,如果未激活则启动。"""
|
||||||
|
|
||||||
|
|
@ -200,7 +198,10 @@ class HeartFChatting:
|
||||||
|
|
||||||
if recent_messages_list:
|
if recent_messages_list:
|
||||||
self.last_read_time = time.time()
|
self.last_read_time = time.time()
|
||||||
await self._observe(interest_value=await self.caculate_interest_value(recent_messages_list),recent_messages_list=recent_messages_list)
|
await self._observe(
|
||||||
|
interest_value=await self.caculate_interest_value(recent_messages_list),
|
||||||
|
recent_messages_list=recent_messages_list,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Normal模式:消息数量不足,等待
|
# Normal模式:消息数量不足,等待
|
||||||
await asyncio.sleep(0.2)
|
await asyncio.sleep(0.2)
|
||||||
|
|
@ -257,7 +258,7 @@ class HeartFChatting:
|
||||||
|
|
||||||
return loop_info, reply_text, cycle_timers
|
return loop_info, reply_text, cycle_timers
|
||||||
|
|
||||||
async def _observe(self, interest_value: float = 0.0,recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
async def _observe(self, interest_value: float = 0.0, recent_messages_list: List["DatabaseMessages"] = []) -> bool:
|
||||||
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
reply_text = "" # 初始化reply_text变量,避免UnboundLocalError
|
||||||
|
|
||||||
# 使用sigmoid函数将interest_value转换为概率
|
# 使用sigmoid函数将interest_value转换为概率
|
||||||
|
|
@ -274,12 +275,10 @@ class HeartFChatting:
|
||||||
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
|
||||||
|
|
||||||
normal_mode_probability = (
|
normal_mode_probability = (
|
||||||
calculate_normal_mode_probability(interest_value)
|
calculate_normal_mode_probability(interest_value) * 2 * self.frequency_control.get_final_talk_frequency()
|
||||||
* 2
|
|
||||||
* self.frequency_control.get_final_talk_frequency()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
#对呼唤名字进行增幅
|
# 对呼唤名字进行增幅
|
||||||
for msg in recent_messages_list:
|
for msg in recent_messages_list:
|
||||||
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
|
if msg.reply_probability_boost is not None and msg.reply_probability_boost > 0.0:
|
||||||
normal_mode_probability += msg.reply_probability_boost
|
normal_mode_probability += msg.reply_probability_boost
|
||||||
|
|
@ -288,7 +287,6 @@ class HeartFChatting:
|
||||||
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
|
if global_config.chat.at_bot_inevitable_reply and msg.is_at:
|
||||||
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
|
normal_mode_probability += global_config.chat.at_bot_inevitable_reply
|
||||||
|
|
||||||
|
|
||||||
# 根据概率决定使用直接回复
|
# 根据概率决定使用直接回复
|
||||||
interest_triggerd = False
|
interest_triggerd = False
|
||||||
focus_triggerd = False
|
focus_triggerd = False
|
||||||
|
|
@ -296,9 +294,7 @@ class HeartFChatting:
|
||||||
if random.random() < normal_mode_probability:
|
if random.random() < normal_mode_probability:
|
||||||
interest_triggerd = True
|
interest_triggerd = True
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复")
|
||||||
f"{self.log_prefix} 有新消息,在{normal_mode_probability * 100:.0f}%概率下选择回复"
|
|
||||||
)
|
|
||||||
|
|
||||||
if s4u_config.enable_s4u:
|
if s4u_config.enable_s4u:
|
||||||
await send_typing()
|
await send_typing()
|
||||||
|
|
@ -308,13 +304,12 @@ class HeartFChatting:
|
||||||
|
|
||||||
available_actions: Dict[str, ActionInfo] = {}
|
available_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
#如果兴趣度不足以激活
|
# 如果兴趣度不足以激活
|
||||||
if not interest_triggerd:
|
if not interest_triggerd:
|
||||||
#看看专注值够不够
|
# 看看专注值够不够
|
||||||
if random.random() < self.frequency_control.get_final_focus_value():
|
if random.random() < self.frequency_control.get_final_focus_value():
|
||||||
#专注值足够,仍然进入正式思考
|
# 专注值足够,仍然进入正式思考
|
||||||
focus_triggerd = True #都没触发,路边
|
focus_triggerd = True # 都没触发,路边
|
||||||
|
|
||||||
|
|
||||||
# 任意一种触发都行
|
# 任意一种触发都行
|
||||||
if interest_triggerd or focus_triggerd:
|
if interest_triggerd or focus_triggerd:
|
||||||
|
|
@ -353,10 +348,13 @@ class HeartFChatting:
|
||||||
# actions_before_now_block=actions_before_now_block,
|
# actions_before_now_block=actions_before_now_block,
|
||||||
message_id_list=message_id_list,
|
message_id_list=message_id_list,
|
||||||
)
|
)
|
||||||
if not await events_manager.handle_mai_events(
|
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||||
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
EventType.ON_PLAN, None, prompt_info[0], None, self.chat_stream.stream_id
|
||||||
):
|
)
|
||||||
|
if not continue_flag:
|
||||||
return False
|
return False
|
||||||
|
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||||
|
prompt_info = (modified_message.llm_prompt, prompt_info[1])
|
||||||
with Timer("规划器", cycle_timers):
|
with Timer("规划器", cycle_timers):
|
||||||
# 根据不同触发,进入不同plan
|
# 根据不同触发,进入不同plan
|
||||||
if focus_triggerd:
|
if focus_triggerd:
|
||||||
|
|
@ -433,7 +431,6 @@ class HeartFChatting:
|
||||||
}
|
}
|
||||||
reply_text = action_reply_text
|
reply_text = action_reply_text
|
||||||
|
|
||||||
|
|
||||||
self.end_cycle(loop_info, cycle_timers)
|
self.end_cycle(loop_info, cycle_timers)
|
||||||
self.print_cycle_info(cycle_timers)
|
self.print_cycle_info(cycle_timers)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ from rich.progress import (
|
||||||
SpinnerColumn,
|
SpinnerColumn,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from src.chat.utils.utils import get_embedding
|
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ import re
|
||||||
import jieba
|
import jieba
|
||||||
import networkx as nx
|
import networkx as nx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import List, Tuple, Set, Coroutine, Any, Dict
|
from typing import List, Tuple, Set, Coroutine, Any
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
@ -21,7 +21,6 @@ from src.common.logger import get_logger
|
||||||
from src.chat.utils.utils import cut_key_words
|
from src.chat.utils.utils import cut_key_words
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
|
||||||
) # 导入 build_readable_messages
|
) # 导入 build_readable_messages
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1183,9 +1182,7 @@ class ParahippocampalGyrus:
|
||||||
# 规范化输入为列表[str]
|
# 规范化输入为列表[str]
|
||||||
if isinstance(keywords, str):
|
if isinstance(keywords, str):
|
||||||
# 支持中英文逗号、顿号、空格分隔
|
# 支持中英文逗号、顿号、空格分隔
|
||||||
parts = (
|
parts = keywords.replace(",", ",").replace("、", ",").replace(" ", ",").strip(", ")
|
||||||
keywords.replace(",", ",").replace("、", ",").replace(" ", ",").strip(", ")
|
|
||||||
)
|
|
||||||
keyword_list = [p.strip() for p in parts.split(",") if p.strip()]
|
keyword_list = [p.strip() for p in parts.split(",") if p.strip()]
|
||||||
else:
|
else:
|
||||||
keyword_list = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]
|
keyword_list = [k.strip() for k in keywords if isinstance(k, str) and k.strip()]
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo, Seg
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
@ -170,7 +170,11 @@ class ChatBot:
|
||||||
# 处理消息内容
|
# 处理消息内容
|
||||||
await message.process()
|
await message.process()
|
||||||
|
|
||||||
_ = Person.register_person(platform=message.message_info.platform, user_id=message.message_info.user_info.user_id,nickname=user_info.user_nickname) # type: ignore
|
_ = Person.register_person(
|
||||||
|
platform=message.message_info.platform,
|
||||||
|
user_id=message.message_info.user_info.user_id,
|
||||||
|
nickname=user_info.user_nickname,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
await self.s4u_message_processor.process_message(message)
|
await self.s4u_message_processor.process_message(message)
|
||||||
|
|
||||||
|
|
@ -220,10 +224,18 @@ class ChatBot:
|
||||||
user_info = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
if message.message_info.additional_config:
|
if message.message_info.additional_config:
|
||||||
sent_message = message.message_info.additional_config.get("echo", False)
|
sent_message = message.message_info.additional_config.get("echo", False)
|
||||||
if sent_message: # 这一段只是为了在一切处理前劫持上报的自身消息,用于更新message_id,需要ada支持上报事件,实际测试中不会对正常使用造成任何问题
|
if sent_message: # 处理上报的自身消息,更新message_id,需要ada支持上报事件
|
||||||
await MessageStorage.update_message(message)
|
await MessageStorage.update_message(message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||||
|
EventType.ON_MESSAGE_PRE_PROCESS, message
|
||||||
|
)
|
||||||
|
if not continue_flag:
|
||||||
|
return
|
||||||
|
if modified_message and modified_message._modify_flags.modify_message_segments:
|
||||||
|
message.message_segment = Seg(type="seglist", data=modified_message.message_segments)
|
||||||
|
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
|
|
@ -258,8 +270,11 @@ class ChatBot:
|
||||||
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
logger.info(f"命令处理完成,跳过后续消息处理: {cmd_result}")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not await events_manager.handle_mai_events(EventType.ON_MESSAGE, message):
|
continue_flag, modified_message = await events_manager.handle_mai_events(EventType.ON_MESSAGE, message)
|
||||||
|
if not continue_flag:
|
||||||
return
|
return
|
||||||
|
if modified_message and modified_message._modify_flags.modify_plain_text:
|
||||||
|
message.processed_plain_text = modified_message.plain_text
|
||||||
|
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
|
|
|
||||||
|
|
@ -202,10 +202,14 @@ class DefaultReplyer:
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
|
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
if not await events_manager.handle_mai_events(
|
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||||
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
EventType.POST_LLM, None, prompt, None, stream_id=stream_id
|
||||||
):
|
)
|
||||||
|
if not continue_flag:
|
||||||
raise UserWarning("插件于请求前中断了内容生成")
|
raise UserWarning("插件于请求前中断了内容生成")
|
||||||
|
if modified_message and modified_message._modify_flags.modify_llm_prompt:
|
||||||
|
llm_response.prompt = modified_message.llm_prompt
|
||||||
|
prompt = str(modified_message.llm_prompt)
|
||||||
|
|
||||||
# 4. 调用 LLM 生成回复
|
# 4. 调用 LLM 生成回复
|
||||||
content = None
|
content = None
|
||||||
|
|
@ -219,10 +223,19 @@ class DefaultReplyer:
|
||||||
llm_response.reasoning = reasoning_content
|
llm_response.reasoning = reasoning_content
|
||||||
llm_response.model = model_name
|
llm_response.model = model_name
|
||||||
llm_response.tool_calls = tool_call
|
llm_response.tool_calls = tool_call
|
||||||
if not from_plugin and not await events_manager.handle_mai_events(
|
continue_flag, modified_message = await events_manager.handle_mai_events(
|
||||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||||
):
|
)
|
||||||
|
if not from_plugin and not continue_flag:
|
||||||
raise UserWarning("插件于请求后取消了内容生成")
|
raise UserWarning("插件于请求后取消了内容生成")
|
||||||
|
if modified_message:
|
||||||
|
if modified_message._modify_flags.modify_llm_prompt:
|
||||||
|
logger.warning("警告:插件在内容生成后才修改了prompt,此修改不会生效")
|
||||||
|
llm_response.prompt = modified_message.llm_prompt # 虽然我不知道为什么在这里需要改prompt
|
||||||
|
if modified_message._modify_flags.modify_llm_response_content:
|
||||||
|
llm_response.content = modified_message.llm_response_content
|
||||||
|
if modified_message._modify_flags.modify_llm_response_reasoning:
|
||||||
|
llm_response.reasoning = modified_message.llm_response_reasoning
|
||||||
except UserWarning as e:
|
except UserWarning as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
|
|
@ -634,7 +647,7 @@ class DefaultReplyer:
|
||||||
"""构建动作提示"""
|
"""构建动作提示"""
|
||||||
|
|
||||||
action_descriptions = ""
|
action_descriptions = ""
|
||||||
skip_names = ["emoji","build_memory","build_relation","reply"]
|
skip_names = ["emoji", "build_memory", "build_relation", "reply"]
|
||||||
if available_actions:
|
if available_actions:
|
||||||
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
|
action_descriptions = "除了进行回复之外,你可以做以下这些动作,不过这些动作由另一个模型决定,:\n"
|
||||||
for action_name, action_info in available_actions.items():
|
for action_name, action_info in available_actions.items():
|
||||||
|
|
@ -671,9 +684,7 @@ class DefaultReplyer:
|
||||||
else:
|
else:
|
||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
|
|
||||||
prompt_personality = (
|
prompt_personality = f"{global_config.personality.personality};"
|
||||||
f"{global_config.personality.personality};"
|
|
||||||
)
|
|
||||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
|
|
@ -809,11 +820,6 @@ class DefaultReplyer:
|
||||||
|
|
||||||
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
moderation_prompt_block = "请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if sender:
|
if sender:
|
||||||
if is_group_chat:
|
if is_group_chat:
|
||||||
reply_target_block = (
|
reply_target_block = (
|
||||||
|
|
|
||||||
|
|
@ -37,19 +37,19 @@ class BaseEventHandler(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def execute(
|
async def execute(
|
||||||
self, message: MaiMessages | None
|
self, message: MaiMessages | None
|
||||||
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]:
|
) -> Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]:
|
||||||
"""执行事件处理的抽象方法,子类必须实现
|
"""执行事件处理的抽象方法,子类必须实现
|
||||||
Args:
|
Args:
|
||||||
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
message (MaiMessages | None): 事件消息对象,当你注册的事件为ON_START和ON_STOP时message为None
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果)
|
Tuple[bool, bool, Optional[str], Optional[CustomEventHandlerResult], Optional[MaiMessages]]: (是否执行成功, 是否需要继续处理, 可选的返回消息, 可选的自定义结果,可选的修改后消息)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError("子类必须实现 execute 方法")
|
raise NotImplementedError("子类必须实现 execute 方法")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_handler_info(cls) -> "EventHandlerInfo":
|
def get_handler_info(cls) -> "EventHandlerInfo":
|
||||||
"""获取事件处理器的信息"""
|
"""获取事件处理器的信息"""
|
||||||
# 从类属性读取名称,如果没有定义则使用类名自动生成
|
# 从类属性读取名称,如果没有定义则使用类名自动生成S
|
||||||
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
name: str = getattr(cls, "handler_name", cls.__name__.lower().replace("handler", ""))
|
||||||
if "." in name:
|
if "." in name:
|
||||||
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
logger.error(f"事件处理器名称 '{name}' 包含非法字符 '.',请使用下划线替代")
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
@ -7,6 +8,7 @@ from maim_message import Seg
|
||||||
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
from src.llm_models.payload_content.tool_option import ToolParamType as ToolParamType
|
||||||
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
from src.llm_models.payload_content.tool_option import ToolCall as ToolCall
|
||||||
|
|
||||||
|
|
||||||
# 组件类型枚举
|
# 组件类型枚举
|
||||||
class ComponentType(Enum):
|
class ComponentType(Enum):
|
||||||
"""组件类型枚举"""
|
"""组件类型枚举"""
|
||||||
|
|
@ -56,6 +58,7 @@ class EventType(Enum):
|
||||||
|
|
||||||
ON_START = "on_start" # 启动事件,用于调用按时任务
|
ON_START = "on_start" # 启动事件,用于调用按时任务
|
||||||
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
ON_STOP = "on_stop" # 停止事件,用于调用按时任务
|
||||||
|
ON_MESSAGE_PRE_PROCESS = "on_message_pre_process"
|
||||||
ON_MESSAGE = "on_message"
|
ON_MESSAGE = "on_message"
|
||||||
ON_PLAN = "on_plan"
|
ON_PLAN = "on_plan"
|
||||||
POST_LLM = "post_llm"
|
POST_LLM = "post_llm"
|
||||||
|
|
@ -116,8 +119,8 @@ class ActionInfo(ComponentInfo):
|
||||||
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
action_require: List[str] = field(default_factory=list) # 动作需求说明
|
||||||
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
associated_types: List[str] = field(default_factory=list) # 关联的消息类型
|
||||||
# 激活类型相关
|
# 激活类型相关
|
||||||
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
|
focus_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||||
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS #已弃用
|
normal_activation_type: ActionActivationType = ActionActivationType.ALWAYS # 已弃用
|
||||||
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
activation_type: ActionActivationType = ActionActivationType.ALWAYS
|
||||||
random_activation_probability: float = 0.0
|
random_activation_probability: float = 0.0
|
||||||
llm_judge_prompt: str = ""
|
llm_judge_prompt: str = ""
|
||||||
|
|
@ -154,7 +157,9 @@ class CommandInfo(ComponentInfo):
|
||||||
class ToolInfo(ComponentInfo):
|
class ToolInfo(ComponentInfo):
|
||||||
"""工具组件信息"""
|
"""工具组件信息"""
|
||||||
|
|
||||||
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(default_factory=list) # 工具参数定义
|
tool_parameters: List[Tuple[str, ToolParamType, str, bool, List[str] | None]] = field(
|
||||||
|
default_factory=list
|
||||||
|
) # 工具参数定义
|
||||||
tool_description: str = "" # 工具描述
|
tool_description: str = "" # 工具描述
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|
@ -233,6 +238,15 @@ class PluginInfo:
|
||||||
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
return [dep.get_pip_requirement() for dep in self.python_dependencies]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModifyFlag:
|
||||||
|
modify_message_segments: bool = False
|
||||||
|
modify_plain_text: bool = False
|
||||||
|
modify_llm_prompt: bool = False
|
||||||
|
modify_llm_response_content: bool = False
|
||||||
|
modify_llm_response_reasoning: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MaiMessages:
|
class MaiMessages:
|
||||||
"""MaiM插件消息"""
|
"""MaiM插件消息"""
|
||||||
|
|
@ -279,6 +293,8 @@ class MaiMessages:
|
||||||
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
additional_data: Dict[Any, Any] = field(default_factory=dict)
|
||||||
"""附加数据,可以存储额外信息"""
|
"""附加数据,可以存储额外信息"""
|
||||||
|
|
||||||
|
_modify_flags: ModifyFlag = field(default_factory=ModifyFlag)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.message_segments is None:
|
if self.message_segments is None:
|
||||||
self.message_segments = []
|
self.message_segments = []
|
||||||
|
|
@ -286,6 +302,102 @@ class MaiMessages:
|
||||||
def deepcopy(self):
|
def deepcopy(self):
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
def modify_message_segments(self, new_segments: List[Seg], suppress_warning: bool = False):
|
||||||
|
"""
|
||||||
|
修改消息段列表
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
在生成了plain_text的情况下调用此方法,可能会导致plain_text内容与消息段不一致
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_segments (List[Seg]): 新的消息段列表
|
||||||
|
"""
|
||||||
|
if self.plain_text and not suppress_warning:
|
||||||
|
warnings.warn(
|
||||||
|
"修改消息段后,plain_text可能与消息段内容不一致,建议同时更新plain_text",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.message_segments = new_segments
|
||||||
|
self._modify_flags.modify_message_segments = True
|
||||||
|
|
||||||
|
def modify_llm_prompt(self, new_prompt: str, suppress_warning: bool = False):
|
||||||
|
"""
|
||||||
|
修改LLM提示词
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
在没有生成llm_prompt的情况下调用此方法,可能会导致修改无效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_prompt (str): 新的提示词内容
|
||||||
|
"""
|
||||||
|
if self.llm_prompt is None and not suppress_warning:
|
||||||
|
warnings.warn(
|
||||||
|
"当前llm_prompt为空,此时调用方法可能导致修改无效",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.llm_prompt = new_prompt
|
||||||
|
self._modify_flags.modify_llm_prompt = True
|
||||||
|
|
||||||
|
def modify_plain_text(self, new_text: str, suppress_warning: bool = False):
|
||||||
|
"""
|
||||||
|
修改生成的plain_text内容
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
在未生成plain_text的情况下调用此方法,可能会导致plain_text为空或者修改无效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_text (str): 新的纯文本内容
|
||||||
|
"""
|
||||||
|
if not self.plain_text and not suppress_warning:
|
||||||
|
warnings.warn(
|
||||||
|
"当前plain_text为空,此时调用方法可能导致修改无效",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.plain_text = new_text
|
||||||
|
self._modify_flags.modify_plain_text = True
|
||||||
|
|
||||||
|
def modify_llm_response_content(self, new_content: str, suppress_warning: bool = False):
|
||||||
|
"""
|
||||||
|
修改生成的llm_response_content内容
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
在未生成llm_response_content的情况下调用此方法,可能会导致llm_response_content为空或者修改无效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_content (str): 新的LLM响应内容
|
||||||
|
"""
|
||||||
|
if not self.llm_response_content and not suppress_warning:
|
||||||
|
warnings.warn(
|
||||||
|
"当前llm_response_content为空,此时调用方法可能导致修改无效",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.llm_response_content = new_content
|
||||||
|
self._modify_flags.modify_llm_response_content = True
|
||||||
|
|
||||||
|
def modify_llm_response_reasoning(self, new_reasoning: str, suppress_warning: bool = False):
|
||||||
|
"""
|
||||||
|
修改生成的llm_response_reasoning内容
|
||||||
|
|
||||||
|
Warning:
|
||||||
|
在未生成llm_response_reasoning的情况下调用此方法,可能会导致llm_response_reasoning为空或者修改无效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_reasoning (str): 新的LLM响应推理内容
|
||||||
|
"""
|
||||||
|
if not self.llm_response_reasoning and not suppress_warning:
|
||||||
|
warnings.warn(
|
||||||
|
"当前llm_response_reasoning为空,此时调用方法可能导致修改无效",
|
||||||
|
UserWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self.llm_response_reasoning = new_reasoning
|
||||||
|
self._modify_flags.modify_llm_response_reasoning = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CustomEventHandlerResult:
|
class CustomEventHandlerResult:
|
||||||
message: str = ""
|
message: str = ""
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ class EventsManager:
|
||||||
llm_response: Optional["LLMGenerationDataModel"] = None,
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
action_usage: Optional[List[str]] = None,
|
action_usage: Optional[List[str]] = None,
|
||||||
) -> bool:
|
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||||
"""
|
"""
|
||||||
处理所有事件,根据事件类型分发给订阅的处理器。
|
处理所有事件,根据事件类型分发给订阅的处理器。
|
||||||
"""
|
"""
|
||||||
|
|
@ -89,10 +89,10 @@ class EventsManager:
|
||||||
# 2. 获取并遍历处理器
|
# 2. 获取并遍历处理器
|
||||||
handlers = self._events_subscribers.get(event_type, [])
|
handlers = self._events_subscribers.get(event_type, [])
|
||||||
if not handlers:
|
if not handlers:
|
||||||
return True
|
return True, None
|
||||||
|
|
||||||
current_stream_id = transformed_message.stream_id if transformed_message else None
|
current_stream_id = transformed_message.stream_id if transformed_message else None
|
||||||
|
modified_message: Optional[MaiMessages] = None
|
||||||
for handler in handlers:
|
for handler in handlers:
|
||||||
# 3. 前置检查和配置加载
|
# 3. 前置检查和配置加载
|
||||||
if (
|
if (
|
||||||
|
|
@ -107,15 +107,19 @@ class EventsManager:
|
||||||
handler.set_plugin_config(plugin_config)
|
handler.set_plugin_config(plugin_config)
|
||||||
|
|
||||||
# 4. 根据类型分发任务
|
# 4. 根据类型分发任务
|
||||||
if handler.intercept_message or event_type == EventType.ON_STOP: # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
if (
|
||||||
|
handler.intercept_message or event_type == EventType.ON_STOP
|
||||||
|
): # 让ON_STOP的所有事件处理器都发挥作用,防止还没执行即被取消
|
||||||
# 阻塞执行,并更新 continue_flag
|
# 阻塞执行,并更新 continue_flag
|
||||||
should_continue = await self._dispatch_intercepting_handler(handler, event_type, transformed_message)
|
should_continue, modified_message = await self._dispatch_intercepting_handler_task(
|
||||||
|
handler, event_type, modified_message or transformed_message
|
||||||
|
)
|
||||||
continue_flag = continue_flag and should_continue
|
continue_flag = continue_flag and should_continue
|
||||||
else:
|
else:
|
||||||
# 异步执行,不阻塞
|
# 异步执行,不阻塞
|
||||||
self._dispatch_handler_task(handler, event_type, transformed_message)
|
self._dispatch_handler_task(handler, event_type, transformed_message)
|
||||||
|
|
||||||
return continue_flag
|
return continue_flag, modified_message
|
||||||
|
|
||||||
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
async def cancel_handler_tasks(self, handler_name: str) -> None:
|
||||||
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
tasks_to_be_cancelled = self._handler_tasks.get(handler_name, [])
|
||||||
|
|
@ -327,16 +331,18 @@ class EventsManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
logger.error(f"创建事件处理器任务 {handler.handler_name} 时发生异常: {e}", exc_info=True)
|
||||||
|
|
||||||
async def _dispatch_intercepting_handler(
|
async def _dispatch_intercepting_handler_task(
|
||||||
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
self, handler: BaseEventHandler, event_type: EventType | str, message: Optional[MaiMessages] = None
|
||||||
) -> bool:
|
) -> Tuple[bool, Optional[MaiMessages]]:
|
||||||
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
"""分发并等待一个阻塞(同步)的事件处理器,返回是否应继续处理。"""
|
||||||
if event_type == EventType.UNKNOWN:
|
if event_type == EventType.UNKNOWN:
|
||||||
raise ValueError("未知事件类型")
|
raise ValueError("未知事件类型")
|
||||||
if event_type not in self._history_enable_map:
|
if event_type not in self._history_enable_map:
|
||||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||||
try:
|
try:
|
||||||
success, continue_processing, return_message, custom_result = await handler.execute(message)
|
success, continue_processing, return_message, custom_result, modified_message = await handler.execute(
|
||||||
|
message
|
||||||
|
)
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
|
logger.error(f"EventHandler {handler.handler_name} 执行失败: {return_message}")
|
||||||
|
|
@ -345,17 +351,17 @@ class EventsManager:
|
||||||
|
|
||||||
if self._history_enable_map[event_type] and custom_result:
|
if self._history_enable_map[event_type] and custom_result:
|
||||||
self._events_result_history[event_type].append(custom_result)
|
self._events_result_history[event_type].append(custom_result)
|
||||||
return continue_processing
|
return continue_processing, modified_message
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
logger.error(f"事件 {event_type} 注册的历史记录启用情况与实际不符合")
|
||||||
return True
|
return True, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
logger.error(f"EventHandler {handler.handler_name} 发生异常: {e}", exc_info=True)
|
||||||
return True # 发生异常时默认不中断其他处理
|
return True, None # 发生异常时默认不中断其他处理
|
||||||
|
|
||||||
def _task_done_callback(
|
def _task_done_callback(
|
||||||
self,
|
self,
|
||||||
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None]],
|
task: asyncio.Task[Tuple[bool, bool, str | None, CustomEventHandlerResult | None, MaiMessages | None]],
|
||||||
event_type: EventType | str,
|
event_type: EventType | str,
|
||||||
):
|
):
|
||||||
"""任务完成回调"""
|
"""任务完成回调"""
|
||||||
|
|
@ -365,7 +371,7 @@ class EventsManager:
|
||||||
if event_type not in self._history_enable_map:
|
if event_type not in self._history_enable_map:
|
||||||
raise ValueError(f"事件类型 {event_type} 未注册")
|
raise ValueError(f"事件类型 {event_type} 未注册")
|
||||||
try:
|
try:
|
||||||
success, _, result, custom_result = task.result() # 忽略是否继续的标志,因为消息本身未被拦截
|
success, _, result, custom_result, _ = task.result() # 忽略是否继续的标志和消息的修改,因为消息本身未被拦截
|
||||||
if success:
|
if success:
|
||||||
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
logger.debug(f"事件处理任务 {task_name} 已成功完成: {result}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue