From d2f98145da40ecbccd307d0319f99a6b75add001 Mon Sep 17 00:00:00 2001 From: UnCLAS-Prommer Date: Sun, 24 Aug 2025 00:11:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9ActionRecord=E4=B8=BA?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/planner_actions/planner.py | 6 ++-- src/chat/utils/chat_message_builder.py | 27 ++++++++++++----- src/common/data_models/database_data_model.py | 30 +++++++++++++++++++ src/plugin_system/apis/database_api.py | 4 +-- 4 files changed, 54 insertions(+), 13 deletions(-) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index 713458d6..55473c0d 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -28,7 +28,7 @@ from src.plugin_system.core.component_registry import component_registry if TYPE_CHECKING: from src.common.data_models.info_data_model import TargetPersonInfo - from src.common.data_models.database_data_model import DatabaseMessages + from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords logger = get_logger("planner") @@ -272,12 +272,12 @@ class ActionPlanner: # 只保留action_type在action_list中的ActionPlannerInfo action_names_in_list = [name for name, _ in action_list] # actions_before_now是List[Dict[str, Any]]格式,需要提取action_type字段 - filtered_actions = [] + filtered_actions: List["DatabaseActionRecords"] = [] for action_record in actions_before_now: # print(action_record) # print(action_record['action_name']) # print(action_names_in_list) - action_type = action_record["action_name"] + action_type = action_record.action_name if action_type in action_names_in_list: filtered_actions.append(action_record) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 9e035529..2dbb19a1 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -8,7 +8,7 @@ from rich.traceback import install from src.config.config import global_config from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages -from src.common.data_models.database_data_model import DatabaseMessages +from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords from src.common.data_models.message_data_model import MessageAndActionModel from src.common.database.database_model import ActionRecords from src.common.database.database_model import Images @@ -183,7 +183,7 @@ def get_actions_by_timestamp_with_chat( timestamp_end: float = time.time(), limit: int = 0, limit_mode: str = "latest", -) -> List[Dict[str, Any]]: +) -> List[DatabaseActionRecords]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" query = ActionRecords.select().where( (ActionRecords.chat_id == chat_id) @@ -196,14 +196,25 @@ def get_actions_by_timestamp_with_chat( query = query.order_by(ActionRecords.time.desc()).limit(limit) # 获取后需要反转列表,以保持最终输出为时间升序 actions = list(query) - return [action.__data__ for action in reversed(actions)] + actions.reverse() else: # earliest query = query.order_by(ActionRecords.time.asc()).limit(limit) else: query = query.order_by(ActionRecords.time.asc()) actions = list(query) - return [action.__data__ for action in actions] + return [DatabaseActionRecords( + action_id=action.action_id, + time=action.time, + action_name=action.action_name, + action_data=action.action_data, + action_done=action.action_done, + action_build_into_prompt=action.action_build_into_prompt, + action_prompt_display=action.action_prompt_display, + chat_id=action.chat_id, + chat_info_stream_id=action.chat_info_stream_id, + chat_info_platform=action.chat_info_platform, + ) for action in actions] def get_actions_by_timestamp_with_chat_inclusive( @@ -533,7 +544,7 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: return "\n".join(mapping_lines) -def build_readable_actions(actions: List[Dict[str, Any]],mode:str="relative") -> str: +def build_readable_actions(actions: List[DatabaseActionRecords],mode:str="relative") -> str: """ 将动作列表转换为可读的文本格式。 格式: 在()分钟前,你使用了(action_name),具体内容是:(action_prompt_display) @@ -554,13 +565,13 @@ def build_readable_actions(actions: List[Dict[str, Any]],mode:str="relative") -> # sorted_actions = sorted(actions, key=lambda x: x.get("time", 0), reverse=True) for action in actions: - action_time = action.get("time", current_time) - action_name = action.get("action_name", "未知动作") + action_time = action.time or current_time + action_name = action.action_name or "未知动作" # action_reason = action.get(action_data") if action_name in ["no_action", "no_action"]: continue - action_prompt_display = action.get("action_prompt_display", "无具体内容") + action_prompt_display = action.action_prompt_display or "无具体内容" time_diff_seconds = current_time - action_time if mode == "relative": diff --git a/src/common/data_models/database_data_model.py b/src/common/data_models/database_data_model.py index 1f671890..b752cbb7 100644 --- a/src/common/data_models/database_data_model.py +++ b/src/common/data_models/database_data_model.py @@ -1,3 +1,4 @@ +import json from typing import Optional, Any, Dict from dataclasses import dataclass, field @@ -196,3 +197,32 @@ class DatabaseMessages(BaseDataModel): "chat_info_user_nickname": self.chat_info.user_info.user_nickname, "chat_info_user_cardname": self.chat_info.user_info.user_cardname, } + +@dataclass(init=False) +class DatabaseActionRecords(BaseDataModel): + def __init__( + self, + action_id: str, + time: float, + action_name: str, + action_data: str, + action_done: bool, + action_build_into_prompt: bool, + action_prompt_display: str, + chat_id: str, + chat_info_stream_id: str, + chat_info_platform: str, + ): + self.action_id = action_id + self.time = time + self.action_name = action_name + if isinstance(action_data, str): + self.action_data = json.loads(action_data) + else: + raise ValueError("action_data must be a JSON string") + self.action_done = action_done + self.action_build_into_prompt = action_build_into_prompt + self.action_prompt_display = action_prompt_display + self.chat_id = chat_id + self.chat_info_stream_id = chat_info_stream_id + self.chat_info_platform = chat_info_platform \ No newline at end of file diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index 8b253806..be087914 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -8,6 +8,8 @@ """ import traceback +import time +import json from typing import Dict, List, Any, Union, Type, Optional from src.common.logger import get_logger from peewee import Model, DoesNotExist @@ -337,8 +339,6 @@ async def store_action_info( ) """ try: - import time - import json from src.common.database.database_model import ActionRecords # 构建动作记录数据