From e2c40db366ab47a6657fb86a726ad83c1b14efb1 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Sun, 5 Oct 2025 17:03:13 +0800 Subject: [PATCH] =?UTF-8?q?fix=EF=BC=9A=E9=97=AE=E9=A2=98=E4=B8=8D?= =?UTF-8?q?=E4=BC=9A=E9=87=8D=E5=A4=8D=E6=8F=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/planner_actions/planner.py | 2 + src/common/database/database_model.py | 1 + src/memory_system/question_maker.py | 70 ++++++++++++++++++++++++--- 3 files changed, 65 insertions(+), 8 deletions(-) diff --git a/src/chat/planner_actions/planner.py b/src/chat/planner_actions/planner.py index a8d670db..83b5d8ad 100644 --- a/src/chat/planner_actions/planner.py +++ b/src/chat/planner_actions/planner.py @@ -290,6 +290,8 @@ class ActionPlanner: loop_start_time=loop_start_time, ) + logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}") + self.add_plan_log(reasoning, actions) return actions diff --git a/src/common/database/database_model.py b/src/common/database/database_model.py index 0efdfb7b..d3d5e2ad 100644 --- a/src/common/database/database_model.py +++ b/src/common/database/database_model.py @@ -341,6 +341,7 @@ class MemoryConflict(BaseModel): update_time = FloatField() # 更新时间 context = TextField(null=True) # 上下文 chat_id = TextField(null=True) # 聊天ID + raise_time = FloatField(null=True) # 触发次数 class Meta: table_name = "memory_conflicts" diff --git a/src/memory_system/question_maker.py b/src/memory_system/question_maker.py index 9a44d408..8dccce7e 100644 --- a/src/memory_system/question_maker.py +++ b/src/memory_system/question_maker.py @@ -1,16 +1,27 @@ import time import random +from typing import List, Optional, Tuple from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages from src.common.database.database_model import MemoryConflict from src.config.config import global_config class QuestionMaker: - def __init__(self, chat_id: str,context: str = ""): + def __init__(self, chat_id: str, context: str = "") -> None: + """问题生成器。 + + - chat_id: 会话 ID,用于筛选该会话下的冲突记录。 + - context: 额外上下文,可用于后续扩展。 + + 用法示例: + >>> qm = QuestionMaker(chat_id="some_chat") + >>> question, chat_ctx, conflict_ctx = await qm.make_question() + """ self.chat_id = chat_id self.context = context - def get_context(self,timestamp: float = time.time()): + def get_context(self, timestamp: float = time.time()) -> str: + """获取指定时间点之前的对话上下文字符串。""" latest_30_msgs = get_raw_msg_before_timestamp_with_chat( chat_id=self.chat_id, timestamp=timestamp, @@ -25,21 +36,64 @@ class QuestionMaker: return all_dialogue_prompt_str - async def get_all_conflicts(self): - conflicts = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id)) + async def get_all_conflicts(self) -> List[MemoryConflict]: + """获取当前会话下的所有记忆冲突记录。""" + conflicts: List[MemoryConflict] = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id)) return conflicts - async def get_un_answered_conflict(self): + async def get_un_answered_conflict(self) -> List[MemoryConflict]: + """获取未回答的记忆冲突记录(answer 为空)。""" conflicts = await self.get_all_conflicts() return [conflict for conflict in conflicts if not conflict.answer] - async def get_random_unanswered_conflict(self): + async def get_random_unanswered_conflict(self) -> Optional[MemoryConflict]: + """按权重随机选取一个未回答的冲突并自增 raise_time。 + + 选择规则: + - 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.05)。 + - 若不存在 `raise_time == 0` 的项:仅 5% 概率返回其中任意一条,否则返回 None。 + - 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。 + """ conflicts = await self.get_un_answered_conflict() if not conflicts: return None - return random.choice(conflicts) - async def make_question(self): + # 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个 + conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0] + if not conflicts_with_zero: + if random.random() >= 0.05: + return None + # 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重) + chosen_conflict = random.choice(conflicts) + else: + # 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.05 + weights = [] + for conflict in conflicts: + current_raise_time = getattr(conflict, "raise_time", 0) or 0 + weight = 1.0 if current_raise_time == 0 else 0.05 + weights.append(weight) + + # 按权重随机选择 + chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0] + + # 选中后,自增 raise_time 并保存 + try: + chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1 + chosen_conflict.save() + except Exception: + # 静默失败不影响流程 + pass + + return chosen_conflict + + async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]: + """生成一条用于询问用户的冲突问题与上下文。 + + 返回三元组 (question, chat_context, conflict_context): + - question: 冲突文本;若本次未选中任何冲突则为 None。 + - chat_context: 该冲突创建时间点前的会话上下文字符串;若无则为 None。 + - conflict_context: 冲突在 DB 中存储的上下文;若无则为 None。 + """ conflict = await self.get_random_unanswered_conflict() if not conflict: return None, None, None