mirror of https://github.com/Mai-with-u/MaiBot.git
refactor of focus_chat
parent
f1373fce4a
commit
108d883675
|
|
@ -1,8 +1,8 @@
|
|||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import List, Optional, Dict, Any, Deque, Callable, Awaitable
|
||||
import random
|
||||
from typing import List, Optional, Dict, Any
|
||||
from rich.traceback import install
|
||||
|
||||
from src.config.config import global_config
|
||||
|
|
@ -10,19 +10,18 @@ from src.common.logger import get_logger
|
|||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||
from src.chat.planner_actions.planner import ActionPlanner
|
||||
from src.chat.planner_actions.action_modifier import ActionModifier
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
from src.chat.focus_chat.hfc_utils import CycleDetail
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.plugin_system.base.component_types import ChatMode
|
||||
import random
|
||||
from src.chat.focus_chat.hfc_utils import get_recent_message_stats
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
from src.plugin_system.base.component_types import ActionInfo, ChatMode
|
||||
from src.plugin_system.apis import generator_api, send_api, message_api
|
||||
from src.chat.willing.willing_manager import get_willing_manager
|
||||
from .priority_manager import PriorityManager
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||
|
||||
|
||||
ERROR_LOOP_INFO = {
|
||||
|
|
@ -107,7 +106,7 @@ class HeartFChatting:
|
|||
# 添加循环信息管理相关的属性
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
self._cycle_counter = 0
|
||||
self._current_cycle_detail: Optional[CycleDetail] = None
|
||||
self._current_cycle_detail: CycleDetail = None # type: ignore
|
||||
|
||||
self.reply_timeout_count = 0
|
||||
self.plan_timeout_count = 0
|
||||
|
|
@ -169,7 +168,7 @@ class HeartFChatting:
|
|||
def start_cycle(self):
|
||||
self._cycle_counter += 1
|
||||
self._current_cycle_detail = CycleDetail(self._cycle_counter)
|
||||
self._current_cycle_detail.thinking_id = "tid" + str(round(time.time(), 2))
|
||||
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
|
||||
cycle_timers = {}
|
||||
return cycle_timers, self._current_cycle_detail.thinking_id
|
||||
|
||||
|
|
@ -230,13 +229,15 @@ class HeartFChatting:
|
|||
async def build_reply_to_str(self, message_data: dict):
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_id = person_info_manager.get_person_id(
|
||||
message_data.get("chat_info_platform"), message_data.get("user_id")
|
||||
message_data.get("chat_info_platform"), # type: ignore
|
||||
message_data.get("user_id"), # type: ignore
|
||||
)
|
||||
person_name = await person_info_manager.get_value(person_id, "person_name")
|
||||
reply_to_str = f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
return reply_to_str
|
||||
return f"{person_name}:{message_data.get('processed_plain_text')}"
|
||||
|
||||
async def _observe(self, message_data: dict = None):
|
||||
async def _observe(self, message_data: Optional[Dict[str, Any]] = None):
|
||||
if not message_data:
|
||||
message_data = {}
|
||||
# 创建新的循环信息
|
||||
cycle_timers, thinking_id = self.start_cycle()
|
||||
|
||||
|
|
@ -339,7 +340,7 @@ class HeartFChatting:
|
|||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
if self.loop_mode == "normal":
|
||||
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id"))
|
||||
await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
|
||||
|
||||
return True
|
||||
|
||||
|
|
@ -425,39 +426,39 @@ class HeartFChatting:
|
|||
traceback.print_exc()
|
||||
return False, "", ""
|
||||
|
||||
async def shutdown(self):
|
||||
"""优雅关闭HeartFChatting实例,取消活动循环任务"""
|
||||
logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
||||
self.running = False # <-- 在开始关闭时设置标志位
|
||||
# async def shutdown(self):
|
||||
# """优雅关闭HeartFChatting实例,取消活动循环任务"""
|
||||
# logger.info(f"{self.log_prefix} 正在关闭HeartFChatting...")
|
||||
# self.running = False # <-- 在开始关闭时设置标志位
|
||||
|
||||
# 记录最终的消息统计
|
||||
if self._message_count > 0:
|
||||
logger.info(f"{self.log_prefix} 本次focus会话共发送了 {self._message_count} 条消息")
|
||||
if self._fatigue_triggered:
|
||||
logger.info(f"{self.log_prefix} 因疲惫而退出focus模式")
|
||||
# # 记录最终的消息统计
|
||||
# if self._message_count > 0:
|
||||
# logger.info(f"{self.log_prefix} 本次focus会话共发送了 {self._message_count} 条消息")
|
||||
# if self._fatigue_triggered:
|
||||
# logger.info(f"{self.log_prefix} 因疲惫而退出focus模式")
|
||||
|
||||
# 取消循环任务
|
||||
if self._loop_task and not self._loop_task.done():
|
||||
logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
|
||||
self._loop_task.cancel()
|
||||
try:
|
||||
await asyncio.wait_for(self._loop_task, timeout=1.0)
|
||||
logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
|
||||
else:
|
||||
logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
|
||||
# # 取消循环任务
|
||||
# if self._loop_task and not self._loop_task.done():
|
||||
# logger.info(f"{self.log_prefix} 正在取消HeartFChatting循环任务")
|
||||
# self._loop_task.cancel()
|
||||
# try:
|
||||
# await asyncio.wait_for(self._loop_task, timeout=1.0)
|
||||
# logger.info(f"{self.log_prefix} HeartFChatting循环任务已取消")
|
||||
# except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
# pass
|
||||
# except Exception as e:
|
||||
# logger.error(f"{self.log_prefix} 取消循环任务出错: {e}")
|
||||
# else:
|
||||
# logger.info(f"{self.log_prefix} 没有活动的HeartFChatting循环任务")
|
||||
|
||||
# 清理状态
|
||||
self.running = False
|
||||
self._loop_task = None
|
||||
# # 清理状态
|
||||
# self.running = False
|
||||
# self._loop_task = None
|
||||
|
||||
# 重置消息计数器,为下次启动做准备
|
||||
self.reset_message_count()
|
||||
# # 重置消息计数器,为下次启动做准备
|
||||
# self.reset_message_count()
|
||||
|
||||
logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
||||
# logger.info(f"{self.log_prefix} HeartFChatting关闭完成")
|
||||
|
||||
def adjust_reply_frequency(self):
|
||||
"""
|
||||
|
|
@ -549,7 +550,7 @@ class HeartFChatting:
|
|||
# 仅在未被提及或基础概率不为1时查询意愿概率
|
||||
if reply_probability < 1: # 简化逻辑,如果未提及 (reply_probability 为 0),则获取意愿概率
|
||||
# is_willing = True
|
||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id"))
|
||||
reply_probability = await self.willing_manager.get_reply_probability(message_data.get("message_id", ""))
|
||||
|
||||
additional_config = message_data.get("additional_config", {})
|
||||
if additional_config and "maimcore_reply_probability_gain" in additional_config:
|
||||
|
|
@ -570,20 +571,18 @@ class HeartFChatting:
|
|||
)
|
||||
|
||||
if random.random() < reply_probability:
|
||||
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id"))
|
||||
await self.willing_manager.before_generate_reply_handle(message_data.get("message_id", ""))
|
||||
await self._observe(message_data=message_data)
|
||||
|
||||
# 意愿管理器:注销当前message信息 (无论是否回复,只要处理过就删除)
|
||||
self.willing_manager.delete(message_data.get("message_id"))
|
||||
|
||||
return True
|
||||
self.willing_manager.delete(message_data.get("message_id", ""))
|
||||
|
||||
async def _generate_response(
|
||||
self, message_data: dict, available_actions: Optional[list], reply_to: str
|
||||
self, message_data: dict, available_actions: Optional[Dict[str, ActionInfo]], reply_to: str
|
||||
) -> Optional[list]:
|
||||
"""生成普通回复"""
|
||||
try:
|
||||
success, reply_set = await generator_api.generate_reply(
|
||||
success, reply_set, _ = await generator_api.generate_reply(
|
||||
chat_stream=self.chat_stream,
|
||||
reply_to=reply_to,
|
||||
available_actions=available_actions,
|
||||
|
|
@ -622,10 +621,9 @@ class HeartFChatting:
|
|||
await send_api.text_to_stream(
|
||||
text=data, stream_id=self.chat_stream.stream_id, reply_to=reply_to, typing=False
|
||||
)
|
||||
first_replyed = True
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=False)
|
||||
first_replyed = True
|
||||
first_replyed = True
|
||||
else:
|
||||
await send_api.text_to_stream(text=data, stream_id=self.chat_stream.stream_id, typing=True)
|
||||
reply_text += data
|
||||
|
|
|
|||
|
|
@ -1,14 +1,10 @@
|
|||
import time
|
||||
import json
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.message_repository import count_messages
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageRecv, BaseMessageInfo
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import UserInfo
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
|
@ -85,10 +81,10 @@ class CycleDetail:
|
|||
self.loop_action_info = loop_info["loop_action_info"]
|
||||
|
||||
|
||||
def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
|
||||
def get_recent_message_stats(minutes: float = 30, chat_id: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Args:
|
||||
minutes (int): 检索的分钟数,默认30分钟
|
||||
minutes (float): 检索的分钟数,默认30分钟
|
||||
chat_id (str, optional): 指定的chat_id,仅统计该chat下的消息。为None时统计全部。
|
||||
Returns:
|
||||
dict: {"bot_reply_count": int, "total_message_count": int}
|
||||
|
|
@ -98,7 +94,7 @@ def get_recent_message_stats(minutes: int = 30, chat_id: str = None) -> dict:
|
|||
start_time = now - minutes * 60
|
||||
bot_id = global_config.bot.qq_account
|
||||
|
||||
filter_base = {"time": {"$gte": start_time}}
|
||||
filter_base: Dict[str, Any] = {"time": {"$gte": start_time}}
|
||||
if chat_id is not None:
|
||||
filter_base["chat_id"] = chat_id
|
||||
|
||||
|
|
|
|||
|
|
@ -25,8 +25,7 @@ class PrioritizedMessage:
|
|||
"""
|
||||
age = time.time() - self.arrival_time
|
||||
decay_factor = math.exp(-decay_rate * age)
|
||||
priority = sum(self.interest_scores) + decay_factor
|
||||
return priority
|
||||
return sum(self.interest_scores) + decay_factor
|
||||
|
||||
def __lt__(self, other: "PrioritizedMessage") -> bool:
|
||||
"""用于堆排序的比较函数,我们想要一个最大堆,所以用 >"""
|
||||
|
|
@ -43,7 +42,7 @@ class PriorityManager:
|
|||
self.normal_queue: List[PrioritizedMessage] = [] # 普通消息队列 (最大堆)
|
||||
self.normal_queue_max_size = normal_queue_max_size
|
||||
|
||||
def add_message(self, message_data: dict, interest_score: Optional[float] = None):
|
||||
def add_message(self, message_data: dict, interest_score: float = 0):
|
||||
"""
|
||||
添加新消息到合适的队列中。
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class ActionPlanner:
|
|||
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
async def plan(self, mode: str = "focus") -> Dict[str, Any]: # sourcery skip: dict-comprehension
|
||||
async def plan(self, mode: str = "focus") -> Dict[str, Dict[str, Any]]: # sourcery skip: dict-comprehension
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue