mirror of https://github.com/Mai-with-u/MaiBot.git
action的reply_message设置为数据模型,维护typing以及增强稳定性
parent
bd83795df8
commit
82e5a710c3
|
|
@ -84,7 +84,7 @@ class ActionManager:
|
|||
log_prefix=log_prefix,
|
||||
shutting_down=shutting_down,
|
||||
plugin_config=plugin_config,
|
||||
action_message=action_message.flatten() if action_message else None,
|
||||
action_message=action_message,
|
||||
)
|
||||
|
||||
logger.debug(f"创建Action实例成功: {action_name}")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import time
|
|||
import hashlib
|
||||
import uuid
|
||||
import io
|
||||
import asyncio
|
||||
import numpy as np
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
|
@ -177,7 +176,7 @@ class ImageManager:
|
|||
emotion_prompt, temperature=0.3, max_tokens=50
|
||||
)
|
||||
|
||||
if emotion_result is None:
|
||||
if not emotion_result:
|
||||
logger.warning("LLM未能生成情感标签,使用详细描述的前几个词")
|
||||
# 降级处理:从详细描述中提取关键词
|
||||
import jieba
|
||||
|
|
|
|||
|
|
@ -156,19 +156,19 @@ class LLMRequest:
|
|||
"""
|
||||
# 请求体构建
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
|
||||
tool_built = self._build_tool_options(tools)
|
||||
|
||||
|
||||
# 模型选择
|
||||
model_info, api_provider, client = self._select_model()
|
||||
|
||||
|
||||
# 请求并处理返回值
|
||||
logger.debug(f"LLM选择耗时: {model_info.name} {time.time() - start_time}")
|
||||
|
||||
|
||||
response = await self._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
|
|
@ -179,8 +179,7 @@ class LLMRequest:
|
|||
max_tokens=max_tokens,
|
||||
tool_options=tool_built,
|
||||
)
|
||||
|
||||
|
||||
|
||||
content = response.content
|
||||
reasoning_content = response.reasoning_content or ""
|
||||
tool_calls = response.tool_calls
|
||||
|
|
@ -188,7 +187,7 @@ class LLMRequest:
|
|||
if not reasoning_content and content:
|
||||
content, extracted_reasoning = self._extract_reasoning(content)
|
||||
reasoning_content = extracted_reasoning
|
||||
|
||||
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info,
|
||||
|
|
@ -199,7 +198,7 @@ class LLMRequest:
|
|||
time_cost=time.time() - start_time,
|
||||
)
|
||||
|
||||
return content, (reasoning_content, model_info.name, tool_calls)
|
||||
return content or "", (reasoning_content, model_info.name, tool_calls)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
"""获取嵌入向量
|
||||
|
|
@ -248,11 +247,11 @@ class LLMRequest:
|
|||
)
|
||||
model_info = model_config.get_model_info(least_used_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
|
||||
|
||||
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
|
||||
logger.debug(f"选择请求模型: {model_info.name}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
self.model_usage[model_info.name] = (total_tokens, penalty, usage_penalty + 1) # 增加使用惩罚值防止连续使用
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class Person:
|
|||
self.name_reason: Optional[str] = None
|
||||
self.know_times = 0
|
||||
self.know_since = None
|
||||
self.last_know = None
|
||||
self.last_know: Optional[float] = None
|
||||
self.memory_points = []
|
||||
|
||||
# 初始化性格特征相关字段
|
||||
|
|
|
|||
|
|
@ -53,6 +53,15 @@ from .apis import (
|
|||
get_logger,
|
||||
)
|
||||
|
||||
from src.common.data_models.database_data_model import (
|
||||
DatabaseMessages,
|
||||
DatabaseUserInfo,
|
||||
DatabaseGroupInfo,
|
||||
DatabaseChatInfo,
|
||||
)
|
||||
from src.common.data_models.info_data_model import TargetPersonInfo, ActionPlannerInfo
|
||||
from src.common.data_models.llm_data_model import LLMGenerationDataModel
|
||||
|
||||
|
||||
__version__ = "2.0.0"
|
||||
|
||||
|
|
@ -103,4 +112,12 @@ __all__ = [
|
|||
# "ManifestGenerator",
|
||||
# "validate_plugin_manifest",
|
||||
# "generate_plugin_manifest",
|
||||
# 数据模型
|
||||
"DatabaseMessages",
|
||||
"DatabaseUserInfo",
|
||||
"DatabaseGroupInfo",
|
||||
"DatabaseChatInfo",
|
||||
"TargetPersonInfo",
|
||||
"ActionPlannerInfo",
|
||||
"LLMGenerationDataModel"
|
||||
]
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class BaseAction(ABC):
|
|||
chat_stream: ChatStream,
|
||||
log_prefix: str = "",
|
||||
plugin_config: Optional[dict] = None,
|
||||
action_message: Optional[dict] = None,
|
||||
action_message: Optional["DatabaseMessages"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# sourcery skip: hoist-similar-statement-from-if, merge-else-if-into-elif, move-assign-in-block, swap-if-else-branches, swap-nested-ifs
|
||||
|
|
@ -114,16 +114,13 @@ class BaseAction(ABC):
|
|||
|
||||
if self.action_message:
|
||||
self.has_action_message = True
|
||||
else:
|
||||
self.action_message = {}
|
||||
|
||||
if self.has_action_message:
|
||||
if self.action_name != "no_action":
|
||||
self.group_id = str(self.action_message.get("chat_info_group_id", None))
|
||||
self.group_name = self.action_message.get("chat_info_group_name", None)
|
||||
self.group_id = str(self.action_message.chat_info.group_info.group_id if self.action_message.chat_info.group_info else None)
|
||||
self.group_name = self.action_message.chat_info.group_info.group_name if self.action_message.chat_info.group_info else None
|
||||
|
||||
self.user_id = str(self.action_message.get("user_id", None))
|
||||
self.user_nickname = self.action_message.get("user_nickname", None)
|
||||
self.user_id = str(self.action_message.user_info.user_id)
|
||||
self.user_nickname = self.action_message.user_info.user_nickname
|
||||
if self.group_id:
|
||||
self.is_group = True
|
||||
self.target_id = self.group_id
|
||||
|
|
|
|||
Loading…
Reference in New Issue