mirror of https://github.com/Mai-with-u/MaiBot.git
将PFC加回来,修复一大堆PFC的神秘报错
parent
77725ba9d8
commit
7bdd394bf0
|
|
@ -1,12 +1,12 @@
|
|||
import time
|
||||
from typing import Tuple, Optional # 增加了 Optional
|
||||
from typing import Tuple, Optional, Dict, Any # 增加了 Optional
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .observation_info import ObservationInfo
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
|
@ -108,13 +108,11 @@ class ActionPlanner:
|
|||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_action_planner,
|
||||
temperature=global_config.llm_PFC_action_planner["temp"],
|
||||
max_tokens=1500,
|
||||
model_set=model_config.model_task_config.planner,
|
||||
request_type="action_planning",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
# self.action_planner_info = ActionPlannerInfo() # 移除未使用的变量
|
||||
|
|
@ -131,7 +129,7 @@ class ActionPlanner:
|
|||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.BOT_NICKNAME
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
# 修改 plan 方法签名,增加 last_successful_reply_action 参数
|
||||
|
|
@ -152,13 +150,13 @@ class ActionPlanner:
|
|||
Tuple[str, str]: (行动类型, 行动原因)
|
||||
"""
|
||||
# --- 获取 Bot 上次发言时间信息 ---
|
||||
# (这部分逻辑不变)
|
||||
time_since_last_bot_message_info = ""
|
||||
try:
|
||||
bot_id = str(global_config.BOT_QQ)
|
||||
if hasattr(observation_info, "chat_history") and observation_info.chat_history:
|
||||
for i in range(len(observation_info.chat_history) - 1, -1, -1):
|
||||
msg = observation_info.chat_history[i]
|
||||
bot_id = str(global_config.bot.qq_account)
|
||||
chat_history = getattr(observation_info, "chat_history", None)
|
||||
if chat_history and len(chat_history) > 0:
|
||||
for i in range(len(chat_history) - 1, -1, -1):
|
||||
msg = chat_history[i]
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
sender_info = msg.get("user_info", {})
|
||||
|
|
@ -173,14 +171,11 @@ class ActionPlanner:
|
|||
break
|
||||
else:
|
||||
logger.debug(
|
||||
f"[私聊][{self.private_name}]Observation info chat history is empty or not available for bot time check."
|
||||
f"[私聊][{self.private_name}]聊天历史为空或尚未加载,跳过 Bot 发言时间检查。"
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning(
|
||||
f"[私聊][{self.private_name}]ObservationInfo object might not have chat_history attribute yet for bot time check."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
logger.debug(f"[私聊][{self.private_name}]获取 Bot 上次发言时间时出错: {e}")
|
||||
|
||||
|
||||
# --- 获取超时提示信息 ---
|
||||
# (这部分逻辑不变)
|
||||
|
|
@ -288,10 +283,11 @@ class ActionPlanner:
|
|||
if hasattr(observation_info, "new_messages_count") and observation_info.new_messages_count > 0:
|
||||
if hasattr(observation_info, "unprocessed_messages") and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
# Convert dict format to DatabaseMessages objects
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,16 +2,43 @@ import time
|
|||
import asyncio
|
||||
import traceback
|
||||
from typing import Optional, Dict, Any, List
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages
|
||||
from maim_message import UserInfo
|
||||
from ...config.config import global_config
|
||||
from src.config.config import global_config
|
||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||
from .message_storage import MongoDBMessageStorage
|
||||
from rich.traceback import install
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_module_logger("chat_observer")
|
||||
logger = get_logger("chat_observer")
|
||||
|
||||
|
||||
def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||
"""Convert Peewee Message model to dict for PFC compatibility
|
||||
|
||||
Args:
|
||||
message: Peewee Messages model instance
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Message dictionary
|
||||
"""
|
||||
return {
|
||||
"message_id": message.message_id,
|
||||
"time": message.time,
|
||||
"chat_id": message.chat_id,
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
"processed_plain_text": message.processed_plain_text,
|
||||
"display_message": message.display_message,
|
||||
"is_mentioned": message.is_mentioned,
|
||||
"is_command": message.is_command,
|
||||
# Add user_info dict for compatibility with existing code
|
||||
"user_info": {
|
||||
"user_id": message.user_id,
|
||||
"user_nickname": message.user_nickname,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ChatObserver:
|
||||
|
|
@ -49,12 +76,8 @@ class ChatObserver:
|
|||
|
||||
self.stream_id = stream_id
|
||||
self.private_name = private_name
|
||||
self.message_storage = MongoDBMessageStorage()
|
||||
|
||||
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID
|
||||
self.last_message_read: Optional[Dict[str, Any]] = None
|
||||
self.last_message_time: float = time.time()
|
||||
|
||||
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||
|
|
@ -86,7 +109,10 @@ class ChatObserver:
|
|||
"""
|
||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||
|
||||
new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time)
|
||||
new_message_exists = Messages.select().where(
|
||||
(Messages.chat_id == self.stream_id) &
|
||||
(Messages.time > self.last_check_time)
|
||||
).exists()
|
||||
|
||||
if new_message_exists:
|
||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||
|
|
@ -157,41 +183,7 @@ class ChatObserver:
|
|||
)
|
||||
return has_new
|
||||
|
||||
def get_message_history(
|
||||
self,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
limit: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取消息历史
|
||||
|
||||
Args:
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
limit: 限制返回消息数量
|
||||
user_id: 指定用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
filtered_messages = self.message_history
|
||||
|
||||
if start_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] >= start_time]
|
||||
|
||||
if end_time is not None:
|
||||
filtered_messages = [m for m in filtered_messages if m["time"] <= end_time]
|
||||
|
||||
if user_id is not None:
|
||||
filtered_messages = [
|
||||
m for m in filtered_messages if UserInfo.from_dict(m.get("user_info", {})).user_id == user_id
|
||||
]
|
||||
|
||||
if limit is not None:
|
||||
filtered_messages = filtered_messages[-limit:]
|
||||
|
||||
return filtered_messages
|
||||
|
||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||
"""获取新消息
|
||||
|
|
@ -199,7 +191,12 @@ class ChatObserver:
|
|||
Returns:
|
||||
List[Dict[str, Any]]: 新消息列表
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time)
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == self.stream_id) &
|
||||
(Messages.time > self.last_message_time)
|
||||
).order_by(Messages.time.asc())
|
||||
|
||||
new_messages = [_message_to_dict(msg) for msg in query]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]
|
||||
|
|
@ -218,7 +215,14 @@ class ChatObserver:
|
|||
Returns:
|
||||
List[Dict[str, Any]]: 最多5条消息
|
||||
"""
|
||||
new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point)
|
||||
query = Messages.select().where(
|
||||
(Messages.chat_id == self.stream_id) &
|
||||
(Messages.time < time_point)
|
||||
).order_by(Messages.time.desc()).limit(5)
|
||||
|
||||
messages = list(query)
|
||||
messages.reverse() # 需要按时间正序排列
|
||||
new_messages = [_message_to_dict(msg) for msg in messages]
|
||||
|
||||
if new_messages:
|
||||
self.last_message_read = new_messages[-1]["message_id"]
|
||||
|
|
@ -319,7 +323,7 @@ class ChatObserver:
|
|||
for msg in messages:
|
||||
try:
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if user_info.user_id == global_config.BOT_QQ:
|
||||
if user_info.user_id == global_config.bot.qq_account:
|
||||
self.update_bot_speak_time(msg["time"])
|
||||
else:
|
||||
self.update_user_speak_time(msg["time"])
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import datetime
|
|||
|
||||
# from .message_storage import MongoDBMessageStorage
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.common.data_models import transform_class_to_dict
|
||||
|
||||
# from src.config.config import global_config
|
||||
from typing import Dict, Any, Optional
|
||||
|
|
@ -89,33 +90,53 @@ class Conversation:
|
|||
timestamp=time.time(),
|
||||
limit=30, # 加载最近30条作为初始上下文,可以调整
|
||||
)
|
||||
chat_talking_prompt = await build_readable_messages(
|
||||
chat_talking_prompt = build_readable_messages(
|
||||
initial_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
if initial_messages:
|
||||
# 将 DatabaseMessages 列表转换为 PFC 期望的 dict 格式(保持嵌套结构)
|
||||
initial_messages_dict: list[dict] = []
|
||||
for msg in initial_messages:
|
||||
msg_dict = {
|
||||
"message_id": msg.message_id,
|
||||
"time": msg.time,
|
||||
"chat_id": msg.chat_id,
|
||||
"processed_plain_text": msg.processed_plain_text,
|
||||
"display_message": msg.display_message,
|
||||
"is_mentioned": msg.is_mentioned,
|
||||
"is_command": msg.is_command,
|
||||
"user_info": {
|
||||
"user_id": msg.user_info.user_id if msg.user_info else "",
|
||||
"user_nickname": msg.user_info.user_nickname if msg.user_info else "",
|
||||
"user_cardname": msg.user_info.user_cardname if msg.user_info else None,
|
||||
"platform": msg.user_info.platform if msg.user_info else "",
|
||||
}
|
||||
}
|
||||
initial_messages_dict.append(msg_dict)
|
||||
|
||||
# 将加载的消息填充到 ObservationInfo 的 chat_history
|
||||
self.observation_info.chat_history = initial_messages
|
||||
self.observation_info.chat_history = initial_messages_dict
|
||||
self.observation_info.chat_history_str = chat_talking_prompt + "\n"
|
||||
self.observation_info.chat_history_count = len(initial_messages)
|
||||
self.observation_info.chat_history_count = len(initial_messages_dict)
|
||||
|
||||
# 更新 ObservationInfo 中的时间戳等信息
|
||||
last_msg = initial_messages[-1]
|
||||
self.observation_info.last_message_time = last_msg.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg.get("user_info", {}))
|
||||
last_msg_dict: dict = initial_messages_dict[-1]
|
||||
self.observation_info.last_message_time = last_msg_dict.get("time")
|
||||
last_user_info = UserInfo.from_dict(last_msg_dict.get("user_info", {}))
|
||||
self.observation_info.last_message_sender = last_user_info.user_id
|
||||
self.observation_info.last_message_content = last_msg.get("processed_plain_text", "")
|
||||
self.observation_info.last_message_content = last_msg_dict.get("processed_plain_text", "")
|
||||
|
||||
logger.info(
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
f"[私聊][{self.private_name}]成功加载 {len(initial_messages_dict)} 条初始聊天记录。最后一条消息时间: {self.observation_info.last_message_time}"
|
||||
)
|
||||
|
||||
# 让 ChatObserver 从加载的最后一条消息之后开始同步
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg # 更新 observer 的最后读取记录
|
||||
if self.observation_info.last_message_time:
|
||||
self.chat_observer.last_message_time = self.observation_info.last_message_time
|
||||
self.chat_observer.last_message_read = last_msg_dict # 更新 observer 的最后读取记录
|
||||
else:
|
||||
logger.info(f"[私聊][{self.private_name}]没有找到初始聊天记录。")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@ from typing import Optional
|
|||
|
||||
class ConversationInfo:
|
||||
def __init__(self):
|
||||
self.done_action = []
|
||||
self.goal_list = []
|
||||
self.knowledge_list = []
|
||||
self.memory_list = []
|
||||
self.done_action: list = []
|
||||
self.goal_list: list = []
|
||||
self.knowledge_list: list = []
|
||||
self.memory_list: list = []
|
||||
self.last_successful_reply_action: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from typing import Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import Message, MessageSending
|
||||
from maim_message import UserInfo, Seg
|
||||
|
|
@ -11,7 +11,7 @@ from rich.traceback import install
|
|||
install(extra_lines=3)
|
||||
|
||||
|
||||
logger = get_module_logger("message_sender")
|
||||
logger = get_logger("message_sender")
|
||||
|
||||
|
||||
class DirectMessageSender:
|
||||
|
|
@ -40,8 +40,8 @@ class DirectMessageSender:
|
|||
|
||||
# 获取麦麦的信息
|
||||
bot_user_info = UserInfo(
|
||||
user_id=global_config.BOT_QQ,
|
||||
user_nickname=global_config.BOT_NICKNAME,
|
||||
user_id=global_config.bot.qq_account,
|
||||
user_nickname=global_config.bot.nickname,
|
||||
platform=chat_stream.platform,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,119 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any
|
||||
from src.common.database import db
|
||||
|
||||
|
||||
class MessageStorage(ABC):
|
||||
"""消息存储接口"""
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""获取指定消息ID之后的所有消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
message: 消息
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""获取指定时间点之前的消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
time_point: 时间戳
|
||||
limit: 最大消息数量
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: 消息列表
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
"""检查是否有新消息
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
after_time: 时间戳
|
||||
|
||||
Returns:
|
||||
bool: 是否有新消息
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class MongoDBMessageStorage(MessageStorage):
|
||||
"""MongoDB消息存储实现"""
|
||||
|
||||
async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": message_time}}
|
||||
# print(f"storage_check_message: {message_time}")
|
||||
|
||||
return list(db.messages.find(query).sort("time", 1))
|
||||
|
||||
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
query = {"chat_id": chat_id, "time": {"$lt": time_point}}
|
||||
|
||||
messages = list(db.messages.find(query).sort("time", -1).limit(limit))
|
||||
|
||||
# 将消息按时间正序排列
|
||||
messages.reverse()
|
||||
return messages
|
||||
|
||||
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
query = {"chat_id": chat_id, "time": {"$gt": after_time}}
|
||||
|
||||
return db.messages.find_one(query) is not None
|
||||
|
||||
|
||||
# # 创建一个内存消息存储实现,用于测试
|
||||
# class InMemoryMessageStorage(MessageStorage):
|
||||
# """内存消息存储实现,主要用于测试"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
# if chat_id not in self.messages:
|
||||
# return []
|
||||
|
||||
# messages = self.messages[chat_id]
|
||||
# if not message_id:
|
||||
# return messages
|
||||
|
||||
# # 找到message_id的索引
|
||||
# try:
|
||||
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
|
||||
# return messages[index + 1:]
|
||||
# except StopIteration:
|
||||
# return []
|
||||
|
||||
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
# if chat_id not in self.messages:
|
||||
# return []
|
||||
|
||||
# messages = [
|
||||
# m for m in self.messages[chat_id]
|
||||
# if m["time"] < time_point
|
||||
# ]
|
||||
|
||||
# return messages[-limit:]
|
||||
|
||||
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||
# if chat_id not in self.messages:
|
||||
# return False
|
||||
|
||||
# return any(m["time"] > after_time for m in self.messages[chat_id])
|
||||
|
||||
# # 测试辅助方法
|
||||
# def add_message(self, chat_id: str, message: Dict[str, Any]):
|
||||
# """添加测试消息"""
|
||||
# if chat_id not in self.messages:
|
||||
# self.messages[chat_id] = []
|
||||
# self.messages[chat_id].append(message)
|
||||
# self.messages[chat_id].sort(key=lambda m: m["time"])
|
||||
|
|
@ -1,13 +1,40 @@
|
|||
from typing import List, Optional, Dict, Any, Set
|
||||
from maim_message import UserInfo
|
||||
import time
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .chat_states import NotificationHandler, NotificationType, Notification
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseUserInfo
|
||||
import traceback # 导入 traceback 用于调试
|
||||
|
||||
logger = get_module_logger("observation_info")
|
||||
logger = get_logger("observation_info")
|
||||
|
||||
|
||||
def dict_to_database_message(msg_dict: Dict[str, Any]) -> DatabaseMessages:
|
||||
"""Convert PFC dict format to DatabaseMessages object
|
||||
|
||||
Args:
|
||||
msg_dict: Message in PFC dict format with nested user_info
|
||||
|
||||
Returns:
|
||||
DatabaseMessages object compatible with build_readable_messages()
|
||||
"""
|
||||
user_info_dict: Dict[str, Any] = msg_dict.get("user_info", {})
|
||||
|
||||
return DatabaseMessages(
|
||||
message_id=msg_dict.get("message_id", ""),
|
||||
time=msg_dict.get("time", 0.0),
|
||||
chat_id=msg_dict.get("chat_id", ""),
|
||||
processed_plain_text=msg_dict.get("processed_plain_text", ""),
|
||||
display_message=msg_dict.get("display_message", ""),
|
||||
is_mentioned=msg_dict.get("is_mentioned", False),
|
||||
is_command=msg_dict.get("is_command", False),
|
||||
user_id=user_info_dict.get("user_id", ""),
|
||||
user_nickname=user_info_dict.get("user_nickname", ""),
|
||||
user_cardname=user_info_dict.get("user_cardname"),
|
||||
user_platform=user_info_dict.get("platform", ""),
|
||||
)
|
||||
|
||||
|
||||
class ObservationInfoHandler(NotificationHandler):
|
||||
|
|
@ -366,10 +393,11 @@ class ObservationInfo:
|
|||
# 更新历史记录字符串 (只使用最近一部分生成,例如20条)
|
||||
history_slice_for_str = self.chat_history[-20:]
|
||||
try:
|
||||
self.chat_history_str = await build_readable_messages(
|
||||
history_slice_for_str,
|
||||
# Convert dict format to DatabaseMessages objects
|
||||
db_messages = [dict_to_database_message(m) for m in history_slice_for_str]
|
||||
self.chat_history_str = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0, # read_mark 可能需要根据逻辑调整
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
from typing import List, Tuple, TYPE_CHECKING
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .pfc_utils import get_items_from_json
|
||||
from .conversation_info import ConversationInfo
|
||||
from .observation_info import ObservationInfo
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from rich.traceback import install
|
||||
|
||||
|
|
@ -15,7 +15,7 @@ install(extra_lines=3)
|
|||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
logger = get_module_logger("pfc")
|
||||
logger = get_logger("pfc")
|
||||
|
||||
|
||||
def _calculate_similarity(goal1: str, goal2: str) -> float:
|
||||
|
|
@ -43,12 +43,12 @@ class GoalAnalyzer:
|
|||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="conversation_goal"
|
||||
model_set=model_config.model_task_config.planner, request_type="conversation_goal"
|
||||
)
|
||||
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.nick_name = global_config.BOT_ALIAS_NAMES
|
||||
self.name = global_config.bot.nickname
|
||||
self.nick_name = global_config.bot.alias_names
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
|
||||
|
|
@ -105,10 +105,10 @@ class GoalAnalyzer:
|
|||
|
||||
if observation_info.new_messages_count > 0:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
|
@ -189,7 +189,9 @@ class GoalAnalyzer:
|
|||
else:
|
||||
# 单个目标的情况
|
||||
conversation_info.goal_list.append(result)
|
||||
return goal, "", reasoning
|
||||
goal_value = result.get("goal", "")
|
||||
reasoning_value = result.get("reasoning", "")
|
||||
return goal_value, "", reasoning_value
|
||||
|
||||
# 如果解析失败,返回默认值
|
||||
return "", "", ""
|
||||
|
|
@ -238,10 +240,10 @@ class GoalAnalyzer:
|
|||
|
||||
async def analyze_conversation(self, goal, reasoning):
|
||||
messages = self.chat_observer.get_cached_messages()
|
||||
chat_history_text = await build_readable_messages(
|
||||
messages,
|
||||
db_messages = [dict_to_database_message(m) for m in messages]
|
||||
chat_history_text = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
from typing import List, Tuple
|
||||
from src.common.logger import get_module_logger
|
||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from typing import List, Tuple, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
# NOTE: HippocampusManager doesn't exist in v0.12.2 - memory system was redesigned
|
||||
# from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.message_receive.message import Message
|
||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||
from src.chat.knowledge import qa_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
from src.chat.brain_chat.PFC.observation_info import dict_to_database_message
|
||||
|
||||
logger = get_module_logger("knowledge_fetcher")
|
||||
logger = get_logger("knowledge_fetcher")
|
||||
|
||||
|
||||
class KnowledgeFetcher:
|
||||
|
|
@ -15,10 +17,7 @@ class KnowledgeFetcher:
|
|||
|
||||
def __init__(self, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_normal,
|
||||
temperature=global_config.llm_normal["temp"],
|
||||
max_tokens=1000,
|
||||
request_type="knowledge_fetch",
|
||||
model_set=model_config.model_task_config.utils
|
||||
)
|
||||
self.private_name = private_name
|
||||
|
||||
|
|
@ -41,42 +40,46 @@ class KnowledgeFetcher:
|
|||
logger.error(f"[私聊][{self.private_name}]LPMM知识库搜索工具执行失败: {str(e)}")
|
||||
return "未找到匹配的知识"
|
||||
|
||||
async def fetch(self, query: str, chat_history: List[Message]) -> Tuple[str, str]:
|
||||
async def fetch(self, query: str, chat_history: List[Dict[str, Any]]) -> Tuple[str, str]:
|
||||
"""获取相关知识
|
||||
|
||||
Args:
|
||||
query: 查询内容
|
||||
chat_history: 聊天历史
|
||||
chat_history: 聊天历史 (PFC dict format)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str]: (获取的知识, 知识来源)
|
||||
"""
|
||||
# 构建查询上下文
|
||||
chat_history_text = await build_readable_messages(
|
||||
chat_history,
|
||||
db_messages = [dict_to_database_message(m) for m in chat_history]
|
||||
chat_history_text = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
||||
# 从记忆中获取相关知识
|
||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
text=f"{query}\n{chat_history_text}",
|
||||
max_memory_num=3,
|
||||
max_memory_length=2,
|
||||
max_depth=3,
|
||||
fast_retrieval=False,
|
||||
)
|
||||
# NOTE: Hippocampus memory system was redesigned in v0.12.2
|
||||
# The old get_memory_from_text API no longer exists
|
||||
# For now, we'll skip the memory retrieval part and only use LPMM knowledge
|
||||
# TODO: Integrate with new memory system if needed
|
||||
knowledge_text = ""
|
||||
sources_text = "无记忆匹配" # 默认值
|
||||
if related_memory:
|
||||
sources = []
|
||||
for memory in related_memory:
|
||||
knowledge_text += memory[1] + "\n"
|
||||
sources.append(f"记忆片段{memory[0]}")
|
||||
knowledge_text = knowledge_text.strip()
|
||||
sources_text = ",".join(sources)
|
||||
|
||||
# # 从记忆中获取相关知识 (DISABLED - old Hippocampus API)
|
||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||
# text=f"{query}\n{chat_history_text}",
|
||||
# max_memory_num=3,
|
||||
# max_memory_length=2,
|
||||
# max_depth=3,
|
||||
# fast_retrieval=False,
|
||||
# )
|
||||
# if related_memory:
|
||||
# sources = []
|
||||
# for memory in related_memory:
|
||||
# knowledge_text += memory[1] + "\n"
|
||||
# sources.append(f"记忆片段{memory[0]}")
|
||||
# knowledge_text = knowledge_text.strip()
|
||||
# sources_text = ",".join(sources)
|
||||
|
||||
knowledge_text += "\n现在有以下**知识**可供参考:\n "
|
||||
knowledge_text += self._lpmm_get_knowledge(query)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import time
|
||||
from typing import Dict, Optional
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from .conversation import Conversation
|
||||
import traceback
|
||||
|
||||
logger = get_module_logger("pfc_manager")
|
||||
logger = get_logger("pfc_manager")
|
||||
|
||||
|
||||
class PFCManager:
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import json
|
||||
import re
|
||||
from typing import Dict, Any, Optional, Tuple, List, Union
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_module_logger("pfc_utils")
|
||||
logger = get_logger("pfc_utils")
|
||||
|
||||
|
||||
def get_items_from_json(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,13 @@
|
|||
import json
|
||||
import random
|
||||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
from .chat_observer import ChatObserver
|
||||
from maim_message import UserInfo
|
||||
|
||||
logger = get_module_logger("reply_checker")
|
||||
logger = get_logger("reply_checker")
|
||||
|
||||
|
||||
class ReplyChecker:
|
||||
|
|
@ -14,13 +15,30 @@ class ReplyChecker:
|
|||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_reply_checker, temperature=0.50, max_tokens=1000, request_type="reply_check"
|
||||
model_set=model_config.model_task_config.utils,
|
||||
request_type="reply_check"
|
||||
)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.max_retries = 3 # 最大重试次数
|
||||
|
||||
def _get_personality_prompt(self) -> str:
|
||||
"""获取个性提示信息"""
|
||||
prompt_personality = global_config.personality.personality
|
||||
|
||||
# 检查是否需要随机替换为状态
|
||||
if (
|
||||
global_config.personality.states
|
||||
and global_config.personality.state_probability > 0
|
||||
and random.random() < global_config.personality.state_probability
|
||||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
async def check(
|
||||
self, reply: str, goal: str, chat_history: List[Dict[str, Any]], chat_history_text: str, retry_count: int = 0
|
||||
) -> Tuple[bool, str, bool]:
|
||||
|
|
@ -43,7 +61,7 @@ class ReplyChecker:
|
|||
bot_messages = []
|
||||
for msg in reversed(chat_history):
|
||||
user_info = UserInfo.from_dict(msg.get("user_info", {}))
|
||||
if str(user_info.user_id) == str(global_config.BOT_QQ): # 确保比较的是字符串
|
||||
if str(user_info.user_id) == str(global_config.bot.qq_account):
|
||||
bot_messages.append(msg.get("processed_plain_text", ""))
|
||||
if len(bot_messages) >= 2: # 只和最近的两条比较
|
||||
break
|
||||
|
|
@ -129,7 +147,7 @@ class ReplyChecker:
|
|||
content = content.strip()
|
||||
try:
|
||||
# 尝试直接解析
|
||||
result = json.loads(content)
|
||||
result: dict = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
# 如果直接解析失败,尝试查找和提取JSON部分
|
||||
import re
|
||||
|
|
@ -138,7 +156,7 @@ class ReplyChecker:
|
|||
json_match = re.search(json_pattern, content)
|
||||
if json_match:
|
||||
try:
|
||||
result = json.loads(json_match.group())
|
||||
result: dict = json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
# 如果JSON解析失败,尝试从文本中提取结果
|
||||
is_suitable = "不合适" not in content.lower() and "违规" not in content.lower()
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import Tuple, List, Dict, Any
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.config.config import global_config, model_config
|
||||
import random
|
||||
from .chat_observer import ChatObserver
|
||||
from .reply_checker import ReplyChecker
|
||||
from .observation_info import ObservationInfo
|
||||
from .observation_info import ObservationInfo, dict_to_database_message
|
||||
from .conversation_info import ConversationInfo
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages
|
||||
|
||||
|
|
@ -87,13 +87,11 @@ class ReplyGenerator:
|
|||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.llm = LLMRequest(
|
||||
model=global_config.llm_PFC_chat,
|
||||
temperature=global_config.llm_PFC_chat["temp"],
|
||||
max_tokens=300,
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
request_type="reply_generation",
|
||||
)
|
||||
self.personality_info = self._get_personality_prompt()
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.reply_checker = ReplyChecker(stream_id, private_name)
|
||||
|
|
@ -110,7 +108,7 @@ class ReplyGenerator:
|
|||
):
|
||||
prompt_personality = random.choice(global_config.personality.states)
|
||||
|
||||
bot_name = global_config.BOT_NICKNAME
|
||||
bot_name = global_config.bot.nickname
|
||||
return f"你的名字是{bot_name},你{prompt_personality};"
|
||||
|
||||
# 修改 generate 方法签名,增加 action_type 参数
|
||||
|
|
@ -188,10 +186,10 @@ class ReplyGenerator:
|
|||
chat_history_text = observation_info.chat_history_str
|
||||
if observation_info.new_messages_count > 0 and observation_info.unprocessed_messages:
|
||||
new_messages_list = observation_info.unprocessed_messages
|
||||
new_messages_str = await build_readable_messages(
|
||||
new_messages_list,
|
||||
db_messages = [dict_to_database_message(m) for m in new_messages_list]
|
||||
new_messages_str = build_readable_messages(
|
||||
db_messages,
|
||||
replace_bot_name=True,
|
||||
merge_messages=False,
|
||||
timestamp_mode="relative",
|
||||
read_mark=0.0,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from src.common.logger import get_module_logger
|
||||
from src.common.logger import get_logger
|
||||
from .chat_observer import ChatObserver
|
||||
from .conversation_info import ConversationInfo
|
||||
|
||||
|
|
@ -7,7 +7,7 @@ from src.config.config import global_config
|
|||
import time
|
||||
import asyncio
|
||||
|
||||
logger = get_module_logger("waiter")
|
||||
logger = get_logger("waiter")
|
||||
|
||||
# --- 在这里设定你想要的超时时间(秒) ---
|
||||
# 例如: 120 秒 = 2 分钟
|
||||
|
|
@ -19,7 +19,7 @@ class Waiter:
|
|||
|
||||
def __init__(self, stream_id: str, private_name: str):
|
||||
self.chat_observer = ChatObserver.get_instance(stream_id, private_name)
|
||||
self.name = global_config.BOT_NICKNAME
|
||||
self.name = global_config.bot.nickname
|
||||
self.private_name = private_name
|
||||
# self.wait_accumulated_time = 0 # 不再需要累加计时
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
|||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
|
||||
from src.chat.brain_chat.PFC.pfc_manager import PFCManager
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, events_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
|
|
@ -73,6 +74,7 @@ class ChatBot:
|
|||
self.bot = None # bot 实例引用
|
||||
self._started = False
|
||||
self.heartflow_message_receiver = HeartFCMessageReceiver() # 新增
|
||||
self.pfc_manager = PFCManager.get_instance() # PFC管理器
|
||||
|
||||
async def _ensure_started(self):
|
||||
"""确保所有任务已启动"""
|
||||
|
|
@ -81,6 +83,23 @@ class ChatBot:
|
|||
|
||||
self._started = True
|
||||
|
||||
async def _create_pfc_chat(self, message: MessageRecv):
|
||||
"""创建或获取PFC对话实例
|
||||
|
||||
Args:
|
||||
message: 消息对象
|
||||
"""
|
||||
try:
|
||||
chat_id = str(message.chat_stream.stream_id)
|
||||
private_name = str(message.message_info.user_info.user_nickname)
|
||||
|
||||
logger.debug(f"[私聊][{private_name}]创建或获取PFC对话: {chat_id}")
|
||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建PFC聊天失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def _process_commands(self, message: MessageRecv):
|
||||
# sourcery skip: use-named-expression
|
||||
"""使用新插件系统处理命令"""
|
||||
|
|
@ -324,7 +343,16 @@ class ChatBot:
|
|||
template_group_name = None
|
||||
|
||||
async def preprocess():
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
# 根据聊天类型路由消息
|
||||
if group_info is None:
|
||||
# 私聊消息 -> PFC系统
|
||||
logger.debug("[私聊]检测到私聊消息,路由到PFC系统")
|
||||
await MessageStorage.store_message(message, chat)
|
||||
await self._create_pfc_chat(message)
|
||||
else:
|
||||
# 群聊消息 -> HeartFlow系统
|
||||
logger.debug("[群聊]检测到群聊消息,路由到HeartFlow系统")
|
||||
await self.heartflow_message_receiver.process_message(message)
|
||||
|
||||
if template_group_name:
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
|
|
|
|||
|
|
@ -162,6 +162,44 @@ class ModelConfig(ConfigBase):
|
|||
return super().model_post_init(context)
|
||||
|
||||
|
||||
def get_model_info_by_name(model_config: ModelConfig, model_name: str) -> ModelInfo:
|
||||
"""根据模型名称获取模型信息
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig实例
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
ModelInfo: 模型信息
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到指定模型
|
||||
"""
|
||||
for model in model_config.models:
|
||||
if model.name == model_name:
|
||||
return model
|
||||
raise ValueError(f"未找到名为 '{model_name}' 的模型")
|
||||
|
||||
|
||||
def get_provider_by_name(model_config: ModelConfig, provider_name: str) -> APIProvider:
|
||||
"""根据提供商名称获取提供商信息
|
||||
|
||||
Args:
|
||||
model_config: ModelConfig实例
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
APIProvider: API提供商信息
|
||||
|
||||
Raises:
|
||||
ValueError: 未找到指定提供商
|
||||
"""
|
||||
for provider in model_config.api_providers:
|
||||
if provider.name == provider_name:
|
||||
return provider
|
||||
raise ValueError(f"未找到名为 '{provider_name}' 的API提供商")
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""总配置管理类"""
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Tuple, List, Dict, Optional, Callable, Any, Set
|
|||
import traceback
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.config.config import model_config, get_model_info_by_name, get_provider_by_name
|
||||
from src.config.model_configs import APIProvider, ModelInfo, TaskConfig
|
||||
from .payload_content.message import MessageBuilder, Message
|
||||
from .payload_content.resp_format import RespFormat
|
||||
|
|
@ -296,8 +296,8 @@ class LLMRequest:
|
|||
key=lambda k: available_models[k][0] + available_models[k][1] * 300 + available_models[k][2] * 1000,
|
||||
)
|
||||
|
||||
model_info = model_config.get_model_info(selected_model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
model_info = get_model_info_by_name(model_config, selected_model_name)
|
||||
api_provider = get_provider_by_name(model_config, model_info.api_provider)
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
logger.debug(f"选择请求模型: {model_info.name} (策略: {strategy})")
|
||||
|
|
|
|||
Loading…
Reference in New Issue