mirror of https://github.com/Mai-with-u/MaiBot.git
fix:问题不会重复提
parent
f44857d856
commit
e2c40db366
|
|
@ -290,6 +290,8 @@ class ActionPlanner:
|
||||||
loop_start_time=loop_start_time,
|
loop_start_time=loop_start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"{self.log_prefix}Planner:{reasoning}。选择了{len(actions)}个动作: {' '.join([a.action_type for a in actions])}")
|
||||||
|
|
||||||
self.add_plan_log(reasoning, actions)
|
self.add_plan_log(reasoning, actions)
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
|
||||||
|
|
@ -341,6 +341,7 @@ class MemoryConflict(BaseModel):
|
||||||
update_time = FloatField() # 更新时间
|
update_time = FloatField() # 更新时间
|
||||||
context = TextField(null=True) # 上下文
|
context = TextField(null=True) # 上下文
|
||||||
chat_id = TextField(null=True) # 聊天ID
|
chat_id = TextField(null=True) # 聊天ID
|
||||||
|
raise_time = FloatField(null=True) # 触发次数
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
table_name = "memory_conflicts"
|
table_name = "memory_conflicts"
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,27 @@
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_before_timestamp_with_chat, build_readable_messages
|
||||||
from src.common.database.database_model import MemoryConflict
|
from src.common.database.database_model import MemoryConflict
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
class QuestionMaker:
|
class QuestionMaker:
|
||||||
def __init__(self, chat_id: str,context: str = ""):
|
def __init__(self, chat_id: str, context: str = "") -> None:
|
||||||
|
"""问题生成器。
|
||||||
|
|
||||||
|
- chat_id: 会话 ID,用于筛选该会话下的冲突记录。
|
||||||
|
- context: 额外上下文,可用于后续扩展。
|
||||||
|
|
||||||
|
用法示例:
|
||||||
|
>>> qm = QuestionMaker(chat_id="some_chat")
|
||||||
|
>>> question, chat_ctx, conflict_ctx = await qm.make_question()
|
||||||
|
"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
def get_context(self,timestamp: float = time.time()):
|
def get_context(self, timestamp: float = time.time()) -> str:
|
||||||
|
"""获取指定时间点之前的对话上下文字符串。"""
|
||||||
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
|
latest_30_msgs = get_raw_msg_before_timestamp_with_chat(
|
||||||
chat_id=self.chat_id,
|
chat_id=self.chat_id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
|
|
@ -25,21 +36,64 @@ class QuestionMaker:
|
||||||
return all_dialogue_prompt_str
|
return all_dialogue_prompt_str
|
||||||
|
|
||||||
|
|
||||||
async def get_all_conflicts(self):
|
async def get_all_conflicts(self) -> List[MemoryConflict]:
|
||||||
conflicts = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id))
|
"""获取当前会话下的所有记忆冲突记录。"""
|
||||||
|
conflicts: List[MemoryConflict] = list(MemoryConflict.select().where(MemoryConflict.chat_id == self.chat_id))
|
||||||
return conflicts
|
return conflicts
|
||||||
|
|
||||||
async def get_un_answered_conflict(self):
|
async def get_un_answered_conflict(self) -> List[MemoryConflict]:
|
||||||
|
"""获取未回答的记忆冲突记录(answer 为空)。"""
|
||||||
conflicts = await self.get_all_conflicts()
|
conflicts = await self.get_all_conflicts()
|
||||||
return [conflict for conflict in conflicts if not conflict.answer]
|
return [conflict for conflict in conflicts if not conflict.answer]
|
||||||
|
|
||||||
async def get_random_unanswered_conflict(self):
|
async def get_random_unanswered_conflict(self) -> Optional[MemoryConflict]:
|
||||||
|
"""按权重随机选取一个未回答的冲突并自增 raise_time。
|
||||||
|
|
||||||
|
选择规则:
|
||||||
|
- 若存在 `raise_time == 0` 的项:按权重抽样(0 次权重 1.0,≥1 次权重 0.05)。
|
||||||
|
- 若不存在 `raise_time == 0` 的项:仅 5% 概率返回其中任意一条,否则返回 None。
|
||||||
|
- 每次成功选中后,将该条目的 `raise_time` 自增 1 并保存。
|
||||||
|
"""
|
||||||
conflicts = await self.get_un_answered_conflict()
|
conflicts = await self.get_un_answered_conflict()
|
||||||
if not conflicts:
|
if not conflicts:
|
||||||
return None
|
return None
|
||||||
return random.choice(conflicts)
|
|
||||||
|
|
||||||
async def make_question(self):
|
# 如果没有 raise_time==0 的项,则仅有 5% 概率抽样一个
|
||||||
|
conflicts_with_zero = [c for c in conflicts if (getattr(c, "raise_time", 0) or 0) == 0]
|
||||||
|
if not conflicts_with_zero:
|
||||||
|
if random.random() >= 0.05:
|
||||||
|
return None
|
||||||
|
# 以均匀概率选择一个(此时权重都等同于 0.05,无需再按权重)
|
||||||
|
chosen_conflict = random.choice(conflicts)
|
||||||
|
else:
|
||||||
|
# 权重规则:raise_time == 0 -> 1.0;raise_time >= 1 -> 0.05
|
||||||
|
weights = []
|
||||||
|
for conflict in conflicts:
|
||||||
|
current_raise_time = getattr(conflict, "raise_time", 0) or 0
|
||||||
|
weight = 1.0 if current_raise_time == 0 else 0.05
|
||||||
|
weights.append(weight)
|
||||||
|
|
||||||
|
# 按权重随机选择
|
||||||
|
chosen_conflict = random.choices(conflicts, weights=weights, k=1)[0]
|
||||||
|
|
||||||
|
# 选中后,自增 raise_time 并保存
|
||||||
|
try:
|
||||||
|
chosen_conflict.raise_time = (getattr(chosen_conflict, "raise_time", 0) or 0) + 1
|
||||||
|
chosen_conflict.save()
|
||||||
|
except Exception:
|
||||||
|
# 静默失败不影响流程
|
||||||
|
pass
|
||||||
|
|
||||||
|
return chosen_conflict
|
||||||
|
|
||||||
|
async def make_question(self) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||||
|
"""生成一条用于询问用户的冲突问题与上下文。
|
||||||
|
|
||||||
|
返回三元组 (question, chat_context, conflict_context):
|
||||||
|
- question: 冲突文本;若本次未选中任何冲突则为 None。
|
||||||
|
- chat_context: 该冲突创建时间点前的会话上下文字符串;若无则为 None。
|
||||||
|
- conflict_context: 冲突在 DB 中存储的上下文;若无则为 None。
|
||||||
|
"""
|
||||||
conflict = await self.get_random_unanswered_conflict()
|
conflict = await self.get_random_unanswered_conflict()
|
||||||
if not conflict:
|
if not conflict:
|
||||||
return None, None, None
|
return None, None, None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue