mirror of https://github.com/Mai-with-u/MaiBot.git
更改generator的返回值为一个数据模型稳定api
parent
2d4fd08ac5
commit
1eeabe76ba
|
|
@ -679,7 +679,7 @@ class HeartFChatting:
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
success, response_set, prompt, selected_expressions = await generator_api.generate_reply(
|
success, llm_response = await generator_api.generate_reply(
|
||||||
chat_stream=self.chat_stream,
|
chat_stream=self.chat_stream,
|
||||||
reply_message=action_planner_info.action_message,
|
reply_message=action_planner_info.action_message,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
|
|
@ -688,10 +688,9 @@ class HeartFChatting:
|
||||||
enable_tool=global_config.tool.enable_tool,
|
enable_tool=global_config.tool.enable_tool,
|
||||||
request_type="replyer",
|
request_type="replyer",
|
||||||
from_plugin=False,
|
from_plugin=False,
|
||||||
return_expressions=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not success or not response_set:
|
if not success or not llm_response or not llm_response.reply_set:
|
||||||
if action_planner_info.action_message:
|
if action_planner_info.action_message:
|
||||||
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
logger.info(f"对 {action_planner_info.action_message.processed_plain_text} 的回复生成失败")
|
||||||
else:
|
else:
|
||||||
|
|
@ -701,7 +700,8 @@ class HeartFChatting:
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
|
||||||
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
|
||||||
|
response_set = llm_response.reply_set
|
||||||
|
selected_expressions = llm_response.selected_expressions
|
||||||
loop_info, reply_text, _ = await self._send_and_store_reply(
|
loop_info, reply_text, _ = await self._send_and_store_reply(
|
||||||
response_set=response_set,
|
response_set=response_set,
|
||||||
action_message=action_planner_info.action_message, # type: ignore
|
action_message=action_planner_info.action_message, # type: ignore
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from typing import List, Any, Dict, TYPE_CHECKING, Tuple
|
from typing import List, Dict, TYPE_CHECKING, Tuple
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
|
@ -161,7 +161,7 @@ class ActionModifier:
|
||||||
deactivated_actions = []
|
deactivated_actions = []
|
||||||
|
|
||||||
# 分类处理不同激活类型的actions
|
# 分类处理不同激活类型的actions
|
||||||
llm_judge_actions = {}
|
llm_judge_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
actions_to_check = list(actions_with_info.items())
|
actions_to_check = list(actions_with_info.items())
|
||||||
random.shuffle(actions_to_check)
|
random.shuffle(actions_to_check)
|
||||||
|
|
@ -218,7 +218,7 @@ class ActionModifier:
|
||||||
|
|
||||||
async def _process_llm_judge_actions_parallel(
|
async def _process_llm_judge_actions_parallel(
|
||||||
self,
|
self,
|
||||||
llm_judge_actions: Dict[str, Any],
|
llm_judge_actions: Dict[str, ActionInfo],
|
||||||
chat_content: str = "",
|
chat_content: str = "",
|
||||||
) -> Dict[str, bool]:
|
) -> Dict[str, bool]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -237,7 +237,7 @@ class ActionModifier:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
||||||
results = {}
|
results = {}
|
||||||
tasks_to_run = {}
|
tasks_to_run: Dict[str, ActionInfo] = {}
|
||||||
|
|
||||||
# 检查缓存
|
# 检查缓存
|
||||||
for action_name, action_info in llm_judge_actions.items():
|
for action_name, action_info in llm_judge_actions.items():
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ from src.mais4u.mai_think import mai_thinking_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
|
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||||
|
|
@ -162,7 +163,7 @@ class DefaultReplyer:
|
||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
reply_message: Optional[DatabaseMessages] = None,
|
reply_message: Optional[DatabaseMessages] = None,
|
||||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[List[int]]]:
|
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||||
# sourcery skip: merge-nested-ifs
|
# sourcery skip: merge-nested-ifs
|
||||||
"""
|
"""
|
||||||
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
回复器 (Replier): 负责生成回复文本的核心逻辑。
|
||||||
|
|
@ -182,6 +183,7 @@ class DefaultReplyer:
|
||||||
|
|
||||||
prompt = None
|
prompt = None
|
||||||
selected_expressions: Optional[List[int]] = None
|
selected_expressions: Optional[List[int]] = None
|
||||||
|
llm_response = LLMGenerationDataModel()
|
||||||
if available_actions is None:
|
if available_actions is None:
|
||||||
available_actions = {}
|
available_actions = {}
|
||||||
try:
|
try:
|
||||||
|
|
@ -195,10 +197,12 @@ class DefaultReplyer:
|
||||||
reply_message=reply_message,
|
reply_message=reply_message,
|
||||||
reply_reason=reply_reason,
|
reply_reason=reply_reason,
|
||||||
)
|
)
|
||||||
|
llm_response.prompt = prompt
|
||||||
|
llm_response.selected_expressions = selected_expressions
|
||||||
|
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.warning("构建prompt失败,跳过回复生成")
|
logger.warning("构建prompt失败,跳过回复生成")
|
||||||
return False, None, None, []
|
return False, llm_response
|
||||||
from src.plugin_system.core.events_manager import events_manager
|
from src.plugin_system.core.events_manager import events_manager
|
||||||
|
|
||||||
if not from_plugin:
|
if not from_plugin:
|
||||||
|
|
@ -215,12 +219,10 @@ class DefaultReplyer:
|
||||||
try:
|
try:
|
||||||
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
content, reasoning_content, model_name, tool_call = await self.llm_generate_content(prompt)
|
||||||
logger.debug(f"replyer生成内容: {content}")
|
logger.debug(f"replyer生成内容: {content}")
|
||||||
llm_response = {
|
llm_response.content = content
|
||||||
"content": content,
|
llm_response.reasoning = reasoning_content
|
||||||
"reasoning": reasoning_content,
|
llm_response.model = model_name
|
||||||
"model": model_name,
|
llm_response.tool_calls = tool_call
|
||||||
"tool_calls": tool_call,
|
|
||||||
}
|
|
||||||
if not from_plugin and not await events_manager.handle_mai_events(
|
if not from_plugin and not await events_manager.handle_mai_events(
|
||||||
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
EventType.AFTER_LLM, None, prompt, llm_response, stream_id=stream_id
|
||||||
):
|
):
|
||||||
|
|
@ -230,24 +232,23 @@ class DefaultReplyer:
|
||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
# 精简报错信息
|
# 精简报错信息
|
||||||
logger.error(f"LLM 生成失败: {llm_e}")
|
logger.error(f"LLM 生成失败: {llm_e}")
|
||||||
return False, None, prompt, selected_expressions # LLM 调用失败则无法生成回复
|
return False, llm_response # LLM 调用失败则无法生成回复
|
||||||
|
|
||||||
return True, llm_response, prompt, selected_expressions
|
return True, llm_response
|
||||||
|
|
||||||
except UserWarning as uw:
|
except UserWarning as uw:
|
||||||
raise uw
|
raise uw
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"回复生成意外失败: {e}")
|
logger.error(f"回复生成意外失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None, prompt, selected_expressions
|
return False, llm_response
|
||||||
|
|
||||||
async def rewrite_reply_with_context(
|
async def rewrite_reply_with_context(
|
||||||
self,
|
self,
|
||||||
raw_reply: str = "",
|
raw_reply: str = "",
|
||||||
reason: str = "",
|
reason: str = "",
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
return_prompt: bool = False,
|
) -> Tuple[bool, LLMGenerationDataModel]:
|
||||||
) -> Tuple[bool, Optional[str], Optional[str]]:
|
|
||||||
"""
|
"""
|
||||||
表达器 (Expressor): 负责重写和优化回复文本。
|
表达器 (Expressor): 负责重写和优化回复文本。
|
||||||
|
|
||||||
|
|
@ -260,6 +261,7 @@ class DefaultReplyer:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
|
Tuple[bool, Optional[str]]: (是否成功, 重写后的回复内容)
|
||||||
"""
|
"""
|
||||||
|
llm_response = LLMGenerationDataModel()
|
||||||
try:
|
try:
|
||||||
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
with Timer("构建Prompt", {}): # 内部计时器,可选保留
|
||||||
prompt = await self.build_prompt_rewrite_context(
|
prompt = await self.build_prompt_rewrite_context(
|
||||||
|
|
@ -267,29 +269,33 @@ class DefaultReplyer:
|
||||||
reason=reason,
|
reason=reason,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
)
|
)
|
||||||
|
llm_response.prompt = prompt
|
||||||
|
|
||||||
content = None
|
content = None
|
||||||
reasoning_content = None
|
reasoning_content = None
|
||||||
model_name = "unknown_model"
|
model_name = "unknown_model"
|
||||||
if not prompt:
|
if not prompt:
|
||||||
logger.error("Prompt 构建失败,无法生成回复。")
|
logger.error("Prompt 构建失败,无法生成回复。")
|
||||||
return False, None, None
|
return False, llm_response
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
|
content, reasoning_content, model_name, _ = await self.llm_generate_content(prompt)
|
||||||
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
|
logger.info(f"想要表达:{raw_reply}||理由:{reason}||生成回复: {content}\n")
|
||||||
|
llm_response.content = content
|
||||||
|
llm_response.reasoning = reasoning_content
|
||||||
|
llm_response.model = model_name
|
||||||
|
|
||||||
except Exception as llm_e:
|
except Exception as llm_e:
|
||||||
# 精简报错信息
|
# 精简报错信息
|
||||||
logger.error(f"LLM 生成失败: {llm_e}")
|
logger.error(f"LLM 生成失败: {llm_e}")
|
||||||
return False, None, prompt if return_prompt else None # LLM 调用失败则无法生成回复
|
return False, llm_response # LLM 调用失败则无法生成回复
|
||||||
|
|
||||||
return True, content, prompt if return_prompt else None
|
return True, llm_response
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"回复生成意外失败: {e}")
|
logger.error(f"回复生成意外失败: {e}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False, None, prompt if return_prompt else None
|
return False, llm_response
|
||||||
|
|
||||||
async def build_relation_info(self, sender: str, target: str):
|
async def build_relation_info(self, sender: str, target: str):
|
||||||
if not global_config.relationship.enable_relationship:
|
if not global_config.relationship.enable_relationship:
|
||||||
|
|
@ -375,9 +381,7 @@ class DefaultReplyer:
|
||||||
|
|
||||||
if global_config.memory.enable_instant_memory:
|
if global_config.memory.enable_instant_memory:
|
||||||
chat_history_str = build_readable_messages(
|
chat_history_str = build_readable_messages(
|
||||||
messages=chat_history,
|
messages=chat_history, replace_bot_name=True, timestamp_mode="normal"
|
||||||
replace_bot_name=True,
|
|
||||||
timestamp_mode="normal"
|
|
||||||
)
|
)
|
||||||
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str))
|
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str))
|
||||||
|
|
||||||
|
|
@ -668,16 +672,18 @@ class DefaultReplyer:
|
||||||
action_descriptions += chosen_action_descriptions
|
action_descriptions += chosen_action_descriptions
|
||||||
|
|
||||||
return action_descriptions
|
return action_descriptions
|
||||||
|
|
||||||
async def build_personality_prompt(self) -> str:
|
async def build_personality_prompt(self) -> str:
|
||||||
bot_name = global_config.bot.nickname
|
bot_name = global_config.bot.nickname
|
||||||
if global_config.bot.alias_names:
|
if global_config.bot.alias_names:
|
||||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||||
else:
|
else:
|
||||||
bot_nickname = ""
|
bot_nickname = ""
|
||||||
|
|
||||||
prompt_personality = f"{global_config.personality.personality_core};{global_config.personality.personality_side}"
|
prompt_personality = (
|
||||||
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
f"{global_config.personality.personality_core};{global_config.personality.personality_side}"
|
||||||
|
)
|
||||||
|
return f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}"
|
||||||
|
|
||||||
async def build_prompt_reply_context(
|
async def build_prompt_reply_context(
|
||||||
self,
|
self,
|
||||||
|
|
@ -875,17 +881,12 @@ class DefaultReplyer:
|
||||||
raw_reply: str,
|
raw_reply: str,
|
||||||
reason: str,
|
reason: str,
|
||||||
reply_to: str,
|
reply_to: str,
|
||||||
reply_message: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
) -> str: # sourcery skip: merge-else-if-into-elif, remove-redundant-if
|
||||||
chat_stream = self.chat_stream
|
chat_stream = self.chat_stream
|
||||||
chat_id = chat_stream.stream_id
|
chat_id = chat_stream.stream_id
|
||||||
is_group_chat = bool(chat_stream.group_info)
|
is_group_chat = bool(chat_stream.group_info)
|
||||||
|
|
||||||
if reply_message:
|
sender, target = self._parse_reply_target(reply_to)
|
||||||
sender = reply_message.get("sender", "")
|
|
||||||
target = reply_message.get("target", "")
|
|
||||||
else:
|
|
||||||
sender, target = self._parse_reply_target(reply_to)
|
|
||||||
|
|
||||||
# 添加情绪状态获取
|
# 添加情绪状态获取
|
||||||
if global_config.mood.enable_mood:
|
if global_config.mood.enable_mood:
|
||||||
|
|
@ -908,7 +909,7 @@ class DefaultReplyer:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 并行执行2个构建任务
|
# 并行执行2个构建任务
|
||||||
(expression_habits_block, _), relation_info, personality_prompt = await asyncio.gather(
|
(expression_habits_block, _), relation_info, personality_prompt = await asyncio.gather(
|
||||||
self.build_expression_habits(chat_talking_prompt_half, target),
|
self.build_expression_habits(chat_talking_prompt_half, target),
|
||||||
self.build_relation_info(sender, target),
|
self.build_relation_info(sender, target),
|
||||||
self.build_personality_prompt(),
|
self.build_personality_prompt(),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, List, Tuple, TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from . import BaseDataModel
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.llm_models.payload_content.tool_option import ToolCall
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMGenerationDataModel(BaseDataModel):
|
||||||
|
content: Optional[str] = None
|
||||||
|
reasoning: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
tool_calls: Optional[List["ToolCall"]] = None
|
||||||
|
prompt: Optional[str] = None
|
||||||
|
selected_expressions: Optional[List[int]] = None
|
||||||
|
reply_set: Optional[List[Tuple[str, Any]]] = None
|
||||||
|
|
@ -21,6 +21,7 @@ from src.plugin_system.base.component_types import ActionInfo
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.info_data_model import ActionPlannerInfo
|
from src.common.data_models.info_data_model import ActionPlannerInfo
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
|
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
@ -85,11 +86,9 @@ async def generate_reply(
|
||||||
enable_tool: bool = False,
|
enable_tool: bool = False,
|
||||||
enable_splitter: bool = True,
|
enable_splitter: bool = True,
|
||||||
enable_chinese_typo: bool = True,
|
enable_chinese_typo: bool = True,
|
||||||
return_prompt: bool = False,
|
|
||||||
request_type: str = "generator_api",
|
request_type: str = "generator_api",
|
||||||
from_plugin: bool = True,
|
from_plugin: bool = True,
|
||||||
return_expressions: bool = False,
|
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
|
|
||||||
"""生成回复
|
"""生成回复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -117,7 +116,7 @@ async def generate_reply(
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||||
return False, [], None, None
|
return False, None
|
||||||
|
|
||||||
if not extra_info and action_data:
|
if not extra_info and action_data:
|
||||||
extra_info = action_data.get("extra_info", "")
|
extra_info = action_data.get("extra_info", "")
|
||||||
|
|
@ -126,7 +125,7 @@ async def generate_reply(
|
||||||
reply_reason = action_data.get("reason", "")
|
reply_reason = action_data.get("reason", "")
|
||||||
|
|
||||||
# 调用回复器生成回复
|
# 调用回复器生成回复
|
||||||
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
|
success, llm_response = await replyer.generate_reply_with_context(
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
available_actions=available_actions,
|
available_actions=available_actions,
|
||||||
chosen_actions=chosen_actions,
|
chosen_actions=chosen_actions,
|
||||||
|
|
@ -138,43 +137,27 @@ async def generate_reply(
|
||||||
)
|
)
|
||||||
if not success:
|
if not success:
|
||||||
logger.warning("[GeneratorAPI] 回复生成失败")
|
logger.warning("[GeneratorAPI] 回复生成失败")
|
||||||
return False, [], None, None
|
return False, None
|
||||||
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
|
if content := llm_response.content:
|
||||||
if content := llm_response_dict.get("content", ""):
|
|
||||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
else:
|
else:
|
||||||
reply_set = []
|
reply_set = []
|
||||||
|
llm_response.reply_set = reply_set
|
||||||
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
|
||||||
|
|
||||||
# if return_prompt:
|
return success, llm_response
|
||||||
# if return_expressions:
|
|
||||||
# return success, reply_set, prompt, selected_expressions
|
|
||||||
# else:
|
|
||||||
# return success, reply_set, prompt, None
|
|
||||||
# else:
|
|
||||||
# if return_expressions:
|
|
||||||
# return success, reply_set, (None, selected_expressions)
|
|
||||||
# else:
|
|
||||||
# return success, reply_set, None
|
|
||||||
return (
|
|
||||||
success,
|
|
||||||
reply_set,
|
|
||||||
prompt if return_prompt else None,
|
|
||||||
selected_expressions if return_expressions else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
raise ve
|
raise ve
|
||||||
|
|
||||||
except UserWarning as uw:
|
except UserWarning as uw:
|
||||||
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
|
||||||
return False, [], None, None
|
return False, None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False, [], None, None
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
async def rewrite_reply(
|
async def rewrite_reply(
|
||||||
chat_stream: Optional[ChatStream] = None,
|
chat_stream: Optional[ChatStream] = None,
|
||||||
|
|
@ -185,9 +168,8 @@ async def rewrite_reply(
|
||||||
raw_reply: str = "",
|
raw_reply: str = "",
|
||||||
reason: str = "",
|
reason: str = "",
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
return_prompt: bool = False,
|
|
||||||
request_type: str = "generator_api",
|
request_type: str = "generator_api",
|
||||||
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str]]:
|
) -> Tuple[bool, Optional["LLMGenerationDataModel"]]:
|
||||||
"""重写回复
|
"""重写回复
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -210,7 +192,7 @@ async def rewrite_reply(
|
||||||
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
|
||||||
if not replyer:
|
if not replyer:
|
||||||
logger.error("[GeneratorAPI] 无法获取回复器")
|
logger.error("[GeneratorAPI] 无法获取回复器")
|
||||||
return False, [], None
|
return False, None
|
||||||
|
|
||||||
logger.info("[GeneratorAPI] 开始重写回复")
|
logger.info("[GeneratorAPI] 开始重写回复")
|
||||||
|
|
||||||
|
|
@ -221,29 +203,28 @@ async def rewrite_reply(
|
||||||
reply_to = reply_to or reply_data.get("reply_to", "")
|
reply_to = reply_to or reply_data.get("reply_to", "")
|
||||||
|
|
||||||
# 调用回复器重写回复
|
# 调用回复器重写回复
|
||||||
success, content, prompt = await replyer.rewrite_reply_with_context(
|
success, llm_response = await replyer.rewrite_reply_with_context(
|
||||||
raw_reply=raw_reply,
|
raw_reply=raw_reply,
|
||||||
reason=reason,
|
reason=reason,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
return_prompt=return_prompt,
|
|
||||||
)
|
)
|
||||||
reply_set = []
|
reply_set = []
|
||||||
if content:
|
if success and llm_response and (content := llm_response.content):
|
||||||
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
|
||||||
|
llm_response.reply_set = reply_set
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
logger.info(f"[GeneratorAPI] 重写回复成功,生成了 {len(reply_set)} 个回复项")
|
||||||
else:
|
else:
|
||||||
logger.warning("[GeneratorAPI] 重写回复失败")
|
logger.warning("[GeneratorAPI] 重写回复失败")
|
||||||
|
|
||||||
return success, reply_set, prompt if return_prompt else None
|
return success, llm_response
|
||||||
|
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
raise ve
|
raise ve
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
logger.error(f"[GeneratorAPI] 重写回复时出错: {e}")
|
||||||
return False, [], None
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: bool) -> List[Tuple[str, Any]]:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import List, Dict, Optional, Type, Tuple, Any, Coroutine
|
from typing import List, Dict, Optional, Type, Tuple, Any, TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.message import MessageRecv
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
@ -9,6 +9,9 @@ from src.plugin_system.base.component_types import EventType, EventHandlerInfo,
|
||||||
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
from src.plugin_system.base.base_events_handler import BaseEventHandler
|
||||||
from .global_announcement_manager import global_announcement_manager
|
from .global_announcement_manager import global_announcement_manager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||||
|
|
||||||
logger = get_logger("events_manager")
|
logger = get_logger("events_manager")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -47,7 +50,7 @@ class EventsManager:
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
message: Optional[MessageRecv] = None,
|
message: Optional[MessageRecv] = None,
|
||||||
llm_prompt: Optional[str] = None,
|
llm_prompt: Optional[str] = None,
|
||||||
llm_response: Optional[Dict[str, Any]] = None,
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
action_usage: Optional[List[str]] = None,
|
action_usage: Optional[List[str]] = None,
|
||||||
) -> Optional[MaiMessages]:
|
) -> Optional[MaiMessages]:
|
||||||
|
|
@ -97,7 +100,7 @@ class EventsManager:
|
||||||
event_type: EventType,
|
event_type: EventType,
|
||||||
message: Optional[MessageRecv] = None,
|
message: Optional[MessageRecv] = None,
|
||||||
llm_prompt: Optional[str] = None,
|
llm_prompt: Optional[str] = None,
|
||||||
llm_response: Optional[Dict[str, Any]] = None,
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
stream_id: Optional[str] = None,
|
stream_id: Optional[str] = None,
|
||||||
action_usage: Optional[List[str]] = None,
|
action_usage: Optional[List[str]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
|
|
@ -175,16 +178,16 @@ class EventsManager:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _transform_event_message(
|
def _transform_event_message(
|
||||||
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
self, message: MessageRecv, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
|
||||||
) -> MaiMessages:
|
) -> MaiMessages:
|
||||||
"""转换事件消息格式"""
|
"""转换事件消息格式"""
|
||||||
# 直接赋值部分内容
|
# 直接赋值部分内容
|
||||||
transformed_message = MaiMessages(
|
transformed_message = MaiMessages(
|
||||||
llm_prompt=llm_prompt,
|
llm_prompt=llm_prompt,
|
||||||
llm_response_content=llm_response.get("content") if llm_response else None,
|
llm_response_content=llm_response.content if llm_response else None,
|
||||||
llm_response_reasoning=llm_response.get("reasoning") if llm_response else None,
|
llm_response_reasoning=llm_response.reasoning if llm_response else None,
|
||||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
llm_response_model=llm_response.model if llm_response else None,
|
||||||
llm_response_tool_call=llm_response.get("tool_calls") if llm_response else None,
|
llm_response_tool_call=llm_response.tool_calls if llm_response else None,
|
||||||
raw_message=message.raw_message,
|
raw_message=message.raw_message,
|
||||||
additional_data=message.message_info.additional_config or {},
|
additional_data=message.message_info.additional_config or {},
|
||||||
)
|
)
|
||||||
|
|
@ -228,7 +231,7 @@ class EventsManager:
|
||||||
return transformed_message
|
return transformed_message
|
||||||
|
|
||||||
def _build_message_from_stream(
|
def _build_message_from_stream(
|
||||||
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional[Dict[str, Any]] = None
|
self, stream_id: str, llm_prompt: Optional[str] = None, llm_response: Optional["LLMGenerationDataModel"] = None
|
||||||
) -> MaiMessages:
|
) -> MaiMessages:
|
||||||
"""从流ID构建消息"""
|
"""从流ID构建消息"""
|
||||||
chat_stream = get_chat_manager().get_stream(stream_id)
|
chat_stream = get_chat_manager().get_stream(stream_id)
|
||||||
|
|
@ -240,7 +243,7 @@ class EventsManager:
|
||||||
self,
|
self,
|
||||||
stream_id: str,
|
stream_id: str,
|
||||||
llm_prompt: Optional[str] = None,
|
llm_prompt: Optional[str] = None,
|
||||||
llm_response: Optional[Dict[str, Any]] = None,
|
llm_response: Optional["LLMGenerationDataModel"] = None,
|
||||||
action_usage: Optional[List[str]] = None,
|
action_usage: Optional[List[str]] = None,
|
||||||
) -> MaiMessages:
|
) -> MaiMessages:
|
||||||
"""没有message对象时进行转换"""
|
"""没有message对象时进行转换"""
|
||||||
|
|
@ -249,10 +252,10 @@ class EventsManager:
|
||||||
return MaiMessages(
|
return MaiMessages(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
llm_prompt=llm_prompt,
|
llm_prompt=llm_prompt,
|
||||||
llm_response_content=(llm_response.get("content") if llm_response else None),
|
llm_response_content=(llm_response.content if llm_response else None),
|
||||||
llm_response_reasoning=(llm_response.get("reasoning") if llm_response else None),
|
llm_response_reasoning=(llm_response.reasoning if llm_response else None),
|
||||||
llm_response_model=llm_response.get("model") if llm_response else None,
|
llm_response_model=(llm_response.model if llm_response else None),
|
||||||
llm_response_tool_call=(llm_response.get("tool_calls") if llm_response else None),
|
llm_response_tool_call=(llm_response.tool_calls if llm_response else None),
|
||||||
is_group_message=(not (not chat_stream.group_info)),
|
is_group_message=(not (not chat_stream.group_info)),
|
||||||
is_private_message=(not chat_stream.group_info),
|
is_private_message=(not chat_stream.group_info),
|
||||||
action_usage=action_usage,
|
action_usage=action_usage,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue