Merge branch 'MaiM-with-u:dev' into dev

pull/1207/head
Windpicker-owo 2025-08-22 12:06:36 +08:00 committed by GitHub
commit 9f3906de37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 903 additions and 815 deletions

View File

@ -3,13 +3,13 @@ import time
import traceback
import math
import random
from typing import List, Optional, Dict, Any, Tuple
from typing import List, Optional, Dict, Any, Tuple, TYPE_CHECKING
from rich.traceback import install
from collections import deque
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.timer_calculator import Timer
@ -24,12 +24,15 @@ from src.chat.frequency_control.focus_value_control import focus_value_control
from src.chat.express.expression_learner import expression_learner_manager
from src.person_info.relationship_builder_manager import relationship_builder_manager
from src.person_info.person_info import Person
from src.plugin_system.base.component_types import ChatMode, EventType
from src.plugin_system.base.component_types import ChatMode, EventType, ActionInfo
from src.plugin_system.core import events_manager
from src.plugin_system.apis import generator_api, send_api, message_api, database_api
from src.mais4u.mai_think import mai_thinking_manager
from src.mais4u.s4u_config import s4u_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
ERROR_LOOP_INFO = {
"loop_plan_info": {
@ -100,7 +103,7 @@ class HeartFChatting:
self.reply_timeout_count = 0
self.plan_timeout_count = 0
self.last_read_time = time.time() - 1
self.last_read_time = time.time() - 10
self.focus_energy = 1
self.no_action_consecutive = 0
@ -141,7 +144,7 @@ class HeartFChatting:
except asyncio.CancelledError:
logger.info(f"{self.log_prefix} HeartFChatting: 结束了聊天")
def start_cycle(self):
def start_cycle(self) -> Tuple[Dict[str, float], str]:
self._cycle_counter += 1
self._current_cycle_detail = CycleDetail(self._cycle_counter)
self._current_cycle_detail.thinking_id = f"tid{str(round(time.time(), 2))}"
@ -172,7 +175,8 @@ class HeartFChatting:
action_type = action_result.get("action_type", "未知动作")
elif isinstance(action_result, list) and action_result:
# 新格式action_result是actions列表
action_type = action_result[0].get("action_type", "未知动作")
# TODO: 把这里写明白
action_type = action_result[0].action_type or "未知动作"
elif isinstance(loop_plan_info, list) and loop_plan_info:
# 直接是actions列表的情况
action_type = loop_plan_info[0].get("action_type", "未知动作")
@ -207,7 +211,7 @@ class HeartFChatting:
logger.info(f"{self.log_prefix} 兴趣度充足,等待新消息")
self.focus_energy = 1
async def _should_process_messages(self, new_message: List[DatabaseMessages]) -> tuple[bool, float]:
async def _should_process_messages(self, new_message: List["DatabaseMessages"]) -> tuple[bool, float]:
"""
判断是否应该处理消息
@ -265,7 +269,7 @@ class HeartFChatting:
return False, 0.0
async def _loopbody(self):
recent_messages_dict = message_api.get_messages_by_time_in_chat(
recent_messages_list = message_api.get_messages_by_time_in_chat(
chat_id=self.stream_id,
start_time=self.last_read_time,
end_time=time.time(),
@ -275,7 +279,7 @@ class HeartFChatting:
filter_command=True,
)
# 统一的消息处理逻辑
should_process, interest_value = await self._should_process_messages(recent_messages_dict)
should_process, interest_value = await self._should_process_messages(recent_messages_list)
if should_process:
self.last_read_time = time.time()
@ -290,11 +294,11 @@ class HeartFChatting:
async def _send_and_store_reply(
self,
response_set,
action_message,
action_message: "DatabaseMessages",
cycle_timers: Dict[str, float],
thinking_id,
actions,
selected_expressions: List[int] = None,
selected_expressions: Optional[List[int]] = None,
) -> Tuple[Dict[str, Any], str, Dict[str, float]]:
with Timer("回复发送", cycle_timers):
reply_text = await self._send_response(
@ -304,11 +308,11 @@ class HeartFChatting:
)
# 获取 platform如果不存在则从 chat_stream 获取,如果还是 None 则使用默认值
platform = action_message.get("chat_info_platform")
platform = action_message.chat_info.platform
if platform is None:
platform = getattr(self.chat_stream, "platform", "unknown")
person = Person(platform=platform, user_id=action_message.get("user_id", ""))
person = Person(platform=platform, user_id=action_message.user_info.user_id)
person_name = person.person_name
action_prompt_display = f"你对{person_name}进行了回复:{reply_text}"
@ -353,9 +357,13 @@ class HeartFChatting:
k = 2.0 # 控制曲线陡峭程度
x0 = 1.0 # 控制曲线中心点
return 1.0 / (1.0 + math.exp(-k * (interest_val - x0)))
normal_mode_probability = calculate_normal_mode_probability(interest_value) * 2 * self.talk_frequency_control.get_current_talk_frequency()
normal_mode_probability = (
calculate_normal_mode_probability(interest_value)
* 2
* self.talk_frequency_control.get_current_talk_frequency()
)
# 根据概率决定使用哪种模式
if random.random() < normal_mode_probability:
mode = ChatMode.NORMAL
@ -383,17 +391,17 @@ class HeartFChatting:
except Exception as e:
logger.error(f"{self.log_prefix} 记忆构建失败: {e}")
available_actions: Dict[str, ActionInfo] = {}
if random.random() > self.focus_value_control.get_current_focus_value() and mode == ChatMode.FOCUS:
# 如果激活度没有激活并且聊天活跃度低有可能不进行plan相当于不在电脑前不进行认真思考
actions = [
{
"action_type": "no_action",
"reasoning": "专注不足",
"action_data": {},
}
action_to_use_info = [
ActionPlannerInfo(
action_type="no_action",
reasoning="专注不足",
action_data={},
)
]
else:
available_actions = {}
# 第一步:动作修改
with Timer("动作修改", cycle_timers):
try:
@ -414,105 +422,19 @@ class HeartFChatting:
):
return False
with Timer("规划器", cycle_timers):
actions, _ = await self.action_planner.plan(
action_to_use_info, _ = await self.action_planner.plan(
mode=mode,
loop_start_time=self.last_read_time,
available_actions=available_actions,
)
# 3. 并行执行所有动作
async def execute_action(action_info, actions):
"""执行单个动作的通用函数"""
try:
if action_info["action_type"] == "no_action":
# 直接处理no_action逻辑不再通过动作系统
reason = action_info.get("reasoning", "选择不回复")
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_info["action_type"] != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_info["action_type"],
action_info["reasoning"],
action_info["action_data"],
cycle_timers,
thinking_id,
action_info["action_message"],
)
return {
"action_type": action_info["action_type"],
"success": success,
"reply_text": reply_text,
"command": command,
}
else:
try:
success, response_set, prompt_selected_expressions = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_info["action_message"],
available_actions=available_actions,
choosen_actions=actions,
reply_reason=action_info.get("reasoning", ""),
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
return_expressions=True,
)
if prompt_selected_expressions and len(prompt_selected_expressions) > 1:
_, selected_expressions = prompt_selected_expressions
else:
selected_expressions = []
if not success or not response_set:
logger.info(
f"{action_info['action_message'].get('processed_plain_text')} 的回复生成失败"
)
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
loop_info, reply_text, cycle_timers_reply = await self._send_and_store_reply(
response_set=response_set,
action_message=action_info["action_message"],
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=actions,
selected_expressions=selected_expressions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_info["action_type"],
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e),
}
action_tasks = [asyncio.create_task(execute_action(action, actions)) for action in actions]
action_tasks = [
asyncio.create_task(
self._execute_action(action, action_to_use_info, thinking_id, available_actions, cycle_timers)
)
for action in action_to_use_info
]
# 并行执行所有任务
results = await asyncio.gather(*action_tasks, return_exceptions=True)
@ -529,7 +451,7 @@ class HeartFChatting:
logger.error(f"{self.log_prefix} 动作执行异常: {result}")
continue
_cur_action = actions[i]
_cur_action = action_to_use_info[i]
if result["action_type"] != "reply":
action_success = result["success"]
action_reply_text = result["reply_text"]
@ -558,7 +480,7 @@ class HeartFChatting:
# 没有回复信息构建纯动作的loop_info
loop_info = {
"loop_plan_info": {
"action_result": actions,
"action_result": action_to_use_info,
},
"loop_action_info": {
"action_taken": action_success,
@ -578,7 +500,7 @@ class HeartFChatting:
# await self.willing_manager.after_generate_reply_handle(message_data.get("message_id", ""))
action_type = actions[0]["action_type"] if actions else "no_action"
action_type = action_to_use_info[0].action_type if action_to_use_info else "no_action"
# 管理no_action计数器当执行了非no_action动作时重置计数器
if action_type != "no_action":
@ -620,7 +542,7 @@ class HeartFChatting:
action_data: dict,
cycle_timers: Dict[str, float],
thinking_id: str,
action_message: dict,
action_message: Optional["DatabaseMessages"] = None,
) -> tuple[bool, str, str]:
"""
处理规划动作使用动作工厂创建相应的动作处理器
@ -672,8 +594,8 @@ class HeartFChatting:
async def _send_response(
self,
reply_set,
message_data,
selected_expressions: List[int] = None,
message_data: "DatabaseMessages",
selected_expressions: Optional[List[int]] = None,
) -> str:
new_message_count = message_api.count_new_messages(
chat_id=self.chat_stream.stream_id, start_time=self.last_read_time, end_time=time.time()
@ -710,3 +632,97 @@ class HeartFChatting:
reply_text += data
return reply_text
async def _execute_action(
self,
action_planner_info: ActionPlannerInfo,
chosen_action_plan_infos: List[ActionPlannerInfo],
thinking_id: str,
available_actions: Dict[str, ActionInfo],
cycle_timers: Dict[str, float],
):
"""执行单个动作的通用函数"""
try:
if action_planner_info.action_type == "no_action":
# 直接处理no_action逻辑不再通过动作系统
reason = action_planner_info.reasoning or "选择不回复"
logger.info(f"{self.log_prefix} 选择不回复,原因: {reason}")
# 存储no_action信息到数据库
await database_api.store_action_info(
chat_stream=self.chat_stream,
action_build_into_prompt=False,
action_prompt_display=reason,
action_done=True,
thinking_id=thinking_id,
action_data={"reason": reason},
action_name="no_action",
)
return {"action_type": "no_action", "success": True, "reply_text": "", "command": ""}
elif action_planner_info.action_type != "reply":
# 执行普通动作
with Timer("动作执行", cycle_timers):
success, reply_text, command = await self._handle_action(
action_planner_info.action_type,
action_planner_info.reasoning or "",
action_planner_info.action_data or {},
cycle_timers,
thinking_id,
action_planner_info.action_message,
)
return {
"action_type": action_planner_info.action_type,
"success": success,
"reply_text": reply_text,
"command": command,
}
else:
try:
success, response_set, prompt, selected_expressions = await generator_api.generate_reply(
chat_stream=self.chat_stream,
reply_message=action_planner_info.action_message,
available_actions=available_actions,
chosen_actions=chosen_action_plan_infos,
reply_reason=action_planner_info.reasoning or "",
enable_tool=global_config.tool.enable_tool,
request_type="replyer",
from_plugin=False,
return_expressions=True,
)
if not success or not response_set:
if action_planner_info.action_message:
logger.info(f"{action_planner_info.action_message.processed_plain_text} 的回复生成失败")
else:
logger.info("回复生成失败")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
except asyncio.CancelledError:
logger.debug(f"{self.log_prefix} 并行执行:回复生成任务已被取消")
return {"action_type": "reply", "success": False, "reply_text": "", "loop_info": None}
loop_info, reply_text, _ = await self._send_and_store_reply(
response_set=response_set,
action_message=action_planner_info.action_message, # type: ignore
cycle_timers=cycle_timers,
thinking_id=thinking_id,
actions=chosen_action_plan_infos,
selected_expressions=selected_expressions,
)
return {
"action_type": "reply",
"success": True,
"reply_text": reply_text,
"loop_info": loop_info,
}
except Exception as e:
logger.error(f"{self.log_prefix} 执行动作时出错: {e}")
logger.error(f"{self.log_prefix} 错误信息: {traceback.format_exc()}")
return {
"action_type": action_planner_info.action_type,
"success": False,
"reply_text": "",
"loop_info": None,
"error": str(e),
}

View File

@ -303,4 +303,4 @@ init_prompt()
try:
expression_selector = ExpressionSelector()
except Exception as e:
print(f"ExpressionSelector初始化失败: {e}")
logger.error(f"ExpressionSelector初始化失败: {e}")

View File

@ -4,44 +4,43 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class FocusValueControl:
def __init__(self,chat_id:str):
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.focus_value_adjust = 1
self.focus_value_adjust: float = 1
def get_current_focus_value(self) -> float:
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
class FocusValueControlManager:
def __init__(self):
self.focus_value_controls = {}
def get_focus_value_control(self,chat_id:str) -> FocusValueControl:
self.focus_value_controls: dict[str, FocusValueControl] = {}
def get_focus_value_control(self, chat_id: str) -> FocusValueControl:
if chat_id not in self.focus_value_controls:
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
return self.focus_value_controls[chat_id]
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
"""
根据当前时间和聊天流获取对应的 focus_value
"""
if not global_config.chat.focus_value_adjust:
return global_config.chat.focus_value
if chat_id:
stream_focus_value = get_stream_specific_focus_value(chat_id)
if stream_focus_value is not None:
return stream_focus_value
global_focus_value = get_global_focus_value()
if global_focus_value is not None:
return global_focus_value
return global_config.chat.focus_value
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
"""
获取特定聊天流在当前时间的专注度
@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]:
return None
focus_value_control = FocusValueControlManager()
focus_value_control = FocusValueControlManager()

View File

@ -2,20 +2,21 @@ from typing import Optional
from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class TalkFrequencyControl:
def __init__(self,chat_id:str):
def __init__(self, chat_id: str):
self.chat_id = chat_id
self.talk_frequency_adjust = 1
self.talk_frequency_adjust: float = 1
def get_current_talk_frequency(self) -> float:
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
class TalkFrequencyControlManager:
def __init__(self):
self.talk_frequency_controls = {}
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl:
def get_talk_frequency_control(self, chat_id: str) -> TalkFrequencyControl:
if chat_id not in self.talk_frequency_controls:
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
return self.talk_frequency_controls[chat_id]
@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
global_frequency = get_global_frequency()
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
"""
根据时间配置列表获取当前时段的频率
@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str):
return None
def get_global_frequency() -> Optional[float]:
"""
获取全局默认频率配置
@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]:
return None
talk_frequency_control = TalkFrequencyControlManager()
talk_frequency_control = TalkFrequencyControlManager()

View File

@ -12,7 +12,7 @@ from src.chat.message_receive.storage import MessageStorage
from src.chat.heart_flow.heartflow import heartflow
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.timer_calculator import Timer
from src.chat.utils.chat_message_builder import replace_user_references_sync
from src.chat.utils.chat_message_builder import replace_user_references
from src.common.logger import get_logger
from src.mood.mood_manager import mood_manager
from src.person_info.person_info import Person
@ -131,7 +131,7 @@ class HeartFCMessageReceiver:
processed_plain_text = re.sub(picid_pattern, "[图片]", message.processed_plain_text)
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references_sync(
processed_plain_text = replace_user_references(
processed_plain_text,
message.message_info.platform, # type: ignore
replace_bot_name=True

View File

@ -0,0 +1,82 @@
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
from src.config.config import global_config
import os
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端")
# 初始化Embedding库
embed_manager = EmbeddingManager()
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
except Exception as e:
logger.warning(f"此消息不会影响正常使用从文件加载Embedding库时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager()
logger.info("正在从文件加载KG")
try:
kg_manager.load_from_file()
except Exception as e:
logger.warning(f"此消息不会影响正常使用从文件加载KG时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
logger.info(f"KG边数量{len(kg_manager.graph.get_edge_list())}")
# 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes:
# 使用与EmbeddingStore中一致的命名空间格式
key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
global qa_manager
# 问答系统(用于知识库)
qa_manager = QAManager(
embed_manager,
kg_manager,
)
# # 记忆激活(用于记忆库)
# global inspire_manager
# inspire_manager = MemoryActiveManager(
# embed_manager,
# llm_client_list[global_config["embedding"]["provider"]],
# )
else:
logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误

View File

@ -5,7 +5,7 @@ from typing import List, Union
from .global_logger import logger
from . import prompt_template
from .knowledge_lib import INVALID_ENTITY
from . import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json

View File

@ -1,80 +0,0 @@
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
from src.config.config import global_config
import os
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:
logger.info("正在初始化Mai-LPMM")
logger.info("创建LLM客户端")
# 初始化Embedding库
embed_manager = EmbeddingManager()
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
except Exception as e:
logger.warning(f"此消息不会影响正常使用从文件加载Embedding库时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager()
logger.info("正在从文件加载KG")
try:
kg_manager.load_from_file()
except Exception as e:
logger.warning(f"此消息不会影响正常使用从文件加载KG时{e}")
# logger.warning("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
logger.info(f"KG边数量{len(kg_manager.graph.get_edge_list())}")
# 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes:
# 使用与EmbeddingStore中一致的命名空间格式
key = f"paragraph-{pg_hash}"
if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}")
# 问答系统(用于知识库)
qa_manager = QAManager(
embed_manager,
kg_manager,
)
# # 记忆激活(用于记忆库)
# inspire_manager = MemoryActiveManager(
# embed_manager,
# llm_client_list[global_config["embedding"]["provider"]],
# )
else:
logger.info("LPMM知识库已禁用跳过初始化")
# 创建空的占位符对象,避免导入错误

View File

@ -4,7 +4,7 @@ import glob
from typing import Any, Dict, List
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
from . import INVALID_ENTITY, ROOT_PATH, DATA_PATH
# from src.manager.local_store_manager import local_storage

View File

@ -60,7 +60,7 @@ class QAManager:
for res in relation_search_res:
if store_item := self.embed_manager.relation_embedding_store.store.get(res[0]):
rel_str = store_item.str
print(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
logger.info(f"找到相关关系,相似度:{(res[1] * 100):.2f}% - {rel_str}")
# TODO: 使用LLM过滤三元组结果
# logger.info(f"LLM过滤三元组用时{time.time() - part_start_time:.2f}s")
@ -94,7 +94,7 @@ class QAManager:
for res in result:
raw_paragraph = self.embed_manager.paragraphs_embedding_store.store[res[0]].str
print(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
logger.info(f"找到相关文段,相关系数:{res[1]:.8f}\n{raw_paragraph}\n\n")
return result, ppr_node_weights

View File

@ -30,9 +30,7 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0:
return 0
return dot_product / (norm1 * norm2)
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
install(extra_lines=3)
@ -142,11 +140,10 @@ class MemoryGraph:
# 获取当前节点的记忆项
node_data = self.get_dot(topic)
if node_data:
concept, data = node_data
_, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
# 直接使用完整的记忆内容
if memory_items:
if memory_items := data["memory_items"]:
first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆
@ -154,11 +151,10 @@ class MemoryGraph:
# 获取相邻节点的记忆项
for neighbor in neighbors:
if node_data := self.get_dot(neighbor):
concept, data = node_data
_, data = node_data
if "memory_items" in data:
memory_items = data["memory_items"]
# 直接使用完整的记忆内容
if memory_items:
if memory_items := data["memory_items"]:
second_layer_items.append(memory_items)
return first_layer_items, second_layer_items
@ -224,27 +220,17 @@ class MemoryGraph:
# 获取话题节点数据
node_data = self.G.nodes[topic]
# 删除整个节点
self.G.remove_node(topic)
# 如果节点存在memory_items
if "memory_items" in node_data:
memory_items = node_data["memory_items"]
# 既然每个节点现在是一个完整的记忆内容,直接删除整个节点
if memory_items:
# 删除整个节点
self.G.remove_node(topic)
if memory_items := node_data["memory_items"]:
return (
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
if len(memory_items) > 50
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
)
else:
# 如果没有记忆项,删除该节点
self.G.remove_node(topic)
return None
else:
# 如果没有memory_items字段删除该节点
self.G.remove_node(topic)
return None
return None
# 海马体
@ -392,9 +378,8 @@ class Hippocampus:
# 如果相似度超过阈值,获取该节点的记忆
if similarity >= 0.3: # 可以调整这个阈值
node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容
if memory_items:
if memory_items := node_data.get("memory_items", ""):
memories.append((node, memory_items, similarity))
# 按相似度降序排序
@ -411,7 +396,7 @@ class Hippocampus:
如果为False使用LLM提取关键词速度较慢但更准确
"""
if not text:
return []
return [], []
# 使用LLM提取关键词 - 根据详细文本长度分布优化topic_num计算
text_length = len(text)
@ -587,7 +572,7 @@ class Hippocampus:
unique_memories = []
for topic, memory_items, activation_value in all_memories:
# memory_items现在是完整的字符串格式
memory = memory_items if memory_items else ""
memory = memory_items or ""
if memory not in seen_memories:
seen_memories.add(memory)
unique_memories.append((topic, memory_items, activation_value))
@ -599,7 +584,7 @@ class Hippocampus:
result = []
for topic, memory_items, _ in unique_memories:
# memory_items现在是完整的字符串格式
memory = memory_items if memory_items else ""
memory = memory_items or ""
result.append((topic, memory))
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
@ -1435,13 +1420,11 @@ class HippocampusManager:
if not self._initialized:
raise RuntimeError("HippocampusManager 尚未初始化,请先调用 initialize 方法")
try:
response, keywords, keywords_lite = await self._hippocampus.get_activate_from_text(
text, max_depth, fast_retrieval
)
return await self._hippocampus.get_activate_from_text(text, max_depth, fast_retrieval)
except Exception as e:
logger.error(f"文本产生激活值失败: {e}")
logger.error(traceback.format_exc())
return 0.0, [], []
return 0.0, [], []
def get_memory_from_keyword(self, keyword: str, max_depth: int = 2) -> list:
"""从关键词获取相关记忆的公共接口"""
@ -1473,6 +1456,7 @@ class MemoryBuilder:
self.last_processed_time: float = 0.0
def should_trigger_memory_build(self) -> bool:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else
"""检查是否应该触发记忆构建"""
current_time = time.time()

View File

@ -11,7 +11,7 @@ from datetime import datetime, timedelta
from src.llm_models.utils_model import LLMRequest
from src.common.logger import get_logger
from src.common.database.database_model import Memory # Peewee Models导入
from src.config.config import model_config
from src.config.config import model_config, global_config
logger = get_logger(__name__)
@ -42,7 +42,7 @@ class InstantMemory:
request_type="memory.summary",
)
async def if_need_build(self, text):
async def if_need_build(self, text: str):
prompt = f"""
请判断以下内容中是否有值得记忆的信息如果有请输出1否则输出0
{text}
@ -51,8 +51,9 @@ class InstantMemory:
try:
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
if global_config.debug.show_prompt:
print(prompt)
print(response)
return "1" in response
except Exception as e:
@ -94,7 +95,7 @@ class InstantMemory:
logger.error(f"构建记忆出现错误:{str(e)} {traceback.format_exc()}")
return None
async def create_and_store_memory(self, text):
async def create_and_store_memory(self, text: str):
if_need = await self.if_need_build(text)
if if_need:
logger.info(f"需要记忆:{text}")
@ -126,24 +127,25 @@ class InstantMemory:
from json_repair import repair_json
prompt = f"""
请根据以下发言内容判断是否需要提取记忆
{target}
请用json格式输出包含以下字段
其中time的要求是
可以选择具体日期时间格式为YYYY-MM-DD HH:MM:SS或者大致时间格式为YYYY-MM-DD
可以选择相对时间例如今天昨天前天5天前1个月前
可以选择留空进行模糊搜索
{{
"need_memory": 1,
"keywords": "希望获取的记忆关键词,用/划分",
"time": "希望获取的记忆大致时间"
}}
请只输出json格式不要输出其他多余内容
"""
请根据以下发言内容判断是否需要提取记忆
{target}
请用json格式输出包含以下字段
其中time的要求是
可以选择具体日期时间格式为YYYY-MM-DD HH:MM:SS或者大致时间格式为YYYY-MM-DD
可以选择相对时间例如今天昨天前天5天前1个月前
可以选择留空进行模糊搜索
{{
"need_memory": 1,
"keywords": "希望获取的记忆关键词,用/划分",
"time": "希望获取的记忆大致时间"
}}
请只输出json格式不要输出其他多余内容
"""
try:
response, _ = await self.summary_model.generate_response_async(prompt, temperature=0.5)
print(prompt)
print(response)
if global_config.debug.show_prompt:
print(prompt)
print(response)
if not response:
return None
try:

View File

@ -145,7 +145,7 @@ class ChatBot:
logger.error(f"处理命令时出错: {e}")
return False, None, True # 出错时继续处理消息
async def hanle_notice_message(self, message: MessageRecv):
async def handle_notice_message(self, message: MessageRecv):
if message.message_info.message_id == "notice":
message.is_notify = True
logger.info("notice消息")
@ -212,7 +212,7 @@ class ChatBot:
# logger.debug(str(message_data))
message = MessageRecv(message_data)
if await self.hanle_notice_message(message):
if await self.handle_notice_message(message):
# return
pass

View File

@ -115,7 +115,7 @@ class MessageRecv(Message):
self.priority_mode = "interest"
self.priority_info = None
self.interest_value: float = None # type: ignore
self.key_words = []
self.key_words_lite = []
@ -213,9 +213,9 @@ class MessageRecvS4U(MessageRecv):
self.is_screen = False
self.is_internal = False
self.voice_done = None
self.chat_info = None
async def process(self) -> None:
self.processed_plain_text = await self._process_message_segments(self.message_segment)
@ -420,7 +420,7 @@ class MessageSending(MessageProcessBase):
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None,
selected_expressions:List[int] = None,
selected_expressions: Optional[List[int]] = None,
):
# 调用父类初始化
super().__init__(
@ -445,7 +445,7 @@ class MessageSending(MessageProcessBase):
self.display_message = display_message
self.interest_value = 0.0
self.selected_expressions = selected_expressions
def build_reply(self):

View File

@ -2,6 +2,7 @@ from typing import Dict, Optional, Type
from src.chat.message_receive.chat_stream import ChatStream
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.plugin_system.core.component_registry import component_registry
from src.plugin_system.base.component_types import ComponentType, ActionInfo
from src.plugin_system.base.base_action import BaseAction
@ -37,7 +38,7 @@ class ActionManager:
chat_stream: ChatStream,
log_prefix: str,
shutting_down: bool = False,
action_message: Optional[dict] = None,
action_message: Optional[DatabaseMessages] = None,
) -> Optional[BaseAction]:
"""
创建动作处理器实例
@ -83,7 +84,7 @@ class ActionManager:
log_prefix=log_prefix,
shutting_down=shutting_down,
plugin_config=plugin_config,
action_message=action_message,
action_message=action_message.flatten() if action_message else None,
)
logger.debug(f"创建Action实例成功: {action_name}")
@ -123,4 +124,4 @@ class ActionManager:
"""恢复到默认动作集"""
actions_to_restore = list(self._using_actions.keys())
self._using_actions = component_registry.get_default_actions()
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")
logger.debug(f"恢复动作集: 从 {actions_to_restore} 恢复到默认动作集 {list(self._using_actions.keys())}")

View File

@ -1,7 +1,7 @@
import json
import time
import traceback
from typing import Dict, Any, Optional, Tuple, List
from typing import Dict, Optional, Tuple, List
from rich.traceback import install
from datetime import datetime
from json_repair import repair_json
@ -9,6 +9,8 @@ from json_repair import repair_json
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
@ -97,7 +99,9 @@ class ActionPlanner:
self.plan_retry_count = 0
self.max_plan_retries = 3
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
def find_message_by_id(
self, message_id: str, message_id_list: List[DatabaseMessages]
) -> Optional[DatabaseMessages]:
# sourcery skip: use-next
"""
根据message_id从message_id_list中查找对应的原始消息
@ -110,37 +114,37 @@ class ActionPlanner:
找到的原始消息字典如果未找到则返回None
"""
for item in message_id_list:
if item.get("id") == message_id:
return item.get("message")
if item.message_id == message_id:
return item
return None
def get_latest_message(self, message_id_list: list) -> Optional[Dict[str, Any]]:
def get_latest_message(self, message_id_list: List[DatabaseMessages]) -> Optional[DatabaseMessages]:
"""
获取消息列表中的最新消息
Args:
message_id_list: 消息ID列表格式为[{'id': str, 'message': dict}, ...]
Returns:
最新的消息字典如果列表为空则返回None
"""
return message_id_list[-1].get("message") if message_id_list else None
return message_id_list[-1] if message_id_list else None
async def plan(
self,
mode: ChatMode = ChatMode.FOCUS,
loop_start_time:float = 0.0,
loop_start_time: float = 0.0,
available_actions: Optional[Dict[str, ActionInfo]] = None,
) -> Tuple[List[Dict[str, Any]], Optional[Dict[str, Any]]]:
) -> Tuple[List[ActionPlannerInfo], Optional[DatabaseMessages]]:
"""
规划器 (Planner): 使用LLM根据上下文决定做出什么动作
"""
action = "no_action" # 默认动作
reasoning = "规划器初始化默认"
action: str = "no_action" # 默认动作
reasoning: str = "规划器初始化默认"
action_data = {}
current_available_actions: Dict[str, ActionInfo] = {}
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
target_message: Optional[DatabaseMessages] = None # 初始化target_message变量
prompt: str = ""
message_id_list: list = []
@ -208,19 +212,21 @@ class ActionPlanner:
# 如果获取的target_message为None输出warning并重新plan
if target_message is None:
self.plan_retry_count += 1
logger.warning(f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}")
logger.warning(
f"{self.log_prefix}无法找到target_message_id '{target_message_id}' 对应的消息,重试次数: {self.plan_retry_count}/{self.max_plan_retries}"
)
# 仍有重试次数
if self.plan_retry_count < self.max_plan_retries:
# 递归重新plan
return await self.plan(mode, loop_start_time, available_actions)
logger.error(f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message")
logger.error(
f"{self.log_prefix}连续{self.max_plan_retries}次plan获取target_message失败选择最新消息作为target_message"
)
target_message = self.get_latest_message(message_id_list)
self.plan_retry_count = 0 # 重置计数器
else:
logger.warning(f"{self.log_prefix}动作'{action}'缺少target_message_id")
if action != "no_action" and action != "reply" and action not in current_available_actions:
logger.warning(
f"{self.log_prefix}LLM 返回了当前不可用或无效的动作: '{action}' (可用: {list(current_available_actions.keys())}),将强制使用 'no_action'"
@ -244,38 +250,37 @@ class ActionPlanner:
if mode == ChatMode.NORMAL and action in current_available_actions:
is_parallel = current_available_actions[action].parallel_action
action_data["loop_start_time"] = loop_start_time
actions = [
{
"action_type": action,
"reasoning": reasoning,
"action_data": action_data,
"action_message": target_message,
"available_actions": available_actions,
}
ActionPlannerInfo(
action_type=action,
reasoning=reasoning,
action_data=action_data,
action_message=target_message,
available_actions=available_actions,
)
]
if action != "reply" and is_parallel:
actions.append({
"action_type": "reply",
"action_message": target_message,
"available_actions": available_actions
})
actions.append(
ActionPlannerInfo(
action_type="reply",
action_message=target_message,
available_actions=available_actions,
)
)
return actions,target_message
return actions, target_message
async def build_planner_prompt(
self,
is_group_chat: bool, # Now passed as argument
chat_target_info: Optional[dict], # Now passed as argument
current_available_actions: Dict[str, ActionInfo],
refresh_time :bool = False,
refresh_time: bool = False,
mode: ChatMode = ChatMode.FOCUS,
) -> tuple[str, list]: # sourcery skip: use-join
) -> tuple[str, List[DatabaseMessages]]: # sourcery skip: use-join
"""构建 Planner LLM 的提示词 (获取模板并填充数据)"""
try:
message_list_before_now = get_raw_msg_before_timestamp_with_chat(
@ -305,13 +310,12 @@ class ActionPlanner:
actions_before_now_block = f"你刚刚选择并执行过的action是\n{actions_before_now_block}"
if refresh_time:
self.last_obs_time_mark = time.time()
mentioned_bonus = ""
if global_config.chat.mentioned_bot_inevitable_reply:
mentioned_bonus = "\n- 有人提到你"
if global_config.chat.at_bot_inevitable_reply:
mentioned_bonus = "\n- 有人提到你或者at你"
if mode == ChatMode.FOCUS:
no_action_block = """
@ -332,7 +336,7 @@ class ActionPlanner:
"""
chat_context_description = "你现在正在一个群聊中"
chat_target_name = None
chat_target_name = None
if not is_group_chat and chat_target_info:
chat_target_name = (
chat_target_info.get("person_name") or chat_target_info.get("user_nickname") or "对方"
@ -388,7 +392,7 @@ class ActionPlanner:
action_options_text=action_options_block,
moderation_prompt=moderation_prompt_block,
identity_block=identity_block,
plan_style = global_config.personality.plan_style
plan_style=global_config.personality.plan_style,
)
return prompt, message_id_list
except Exception as e:

View File

@ -9,6 +9,7 @@ from datetime import datetime
from src.mais4u.mai_think import mai_thinking_manager
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.config.config import global_config, model_config
from src.individuality.individuality import get_individuality
from src.llm_models.utils_model import LLMRequest
@ -21,7 +22,7 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import (
build_readable_messages,
get_raw_msg_before_timestamp_with_chat,
replace_user_references_sync,
replace_user_references,
)
from src.chat.express.expression_selector import expression_selector
from src.chat.memory_system.memory_activator import MemoryActivator
@ -157,12 +158,12 @@ class DefaultReplyer:
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
from_plugin: bool = True,
stream_id: Optional[str] = None,
reply_message: Optional[Dict[str, Any]] = None,
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], List[Dict[str, Any]]]:
reply_message: Optional[DatabaseMessages] = None,
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str], Optional[List[int]]]:
# sourcery skip: merge-nested-ifs
"""
回复器 (Replier): 负责生成回复文本的核心逻辑
@ -181,7 +182,7 @@ class DefaultReplyer:
"""
prompt = None
selected_expressions = None
selected_expressions: Optional[List[int]] = None
if available_actions is None:
available_actions = {}
try:
@ -374,7 +375,12 @@ class DefaultReplyer:
)
if global_config.memory.enable_instant_memory:
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history))
chat_history_str = build_readable_messages(
messages=chat_history,
replace_bot_name=True,
timestamp_mode="normal"
)
asyncio.create_task(self.instant_memory.create_and_store_memory(chat_history_str))
instant_memory = await self.instant_memory.get_memory(target)
logger.info(f"即时记忆:{instant_memory}")
@ -527,7 +533,7 @@ class DefaultReplyer:
Returns:
Tuple[str, str]: (核心对话prompt, 背景对话prompt)
"""
core_dialogue_list = []
core_dialogue_list: List[DatabaseMessages] = []
bot_id = str(global_config.bot.qq_account)
# 过滤消息分离bot和目标用户的对话 vs 其他用户的对话
@ -559,7 +565,7 @@ class DefaultReplyer:
if core_dialogue_list:
# 检查最新五条消息中是否包含bot自己说的消息
latest_5_messages = core_dialogue_list[-5:] if len(core_dialogue_list) >= 5 else core_dialogue_list
has_bot_message = any(str(msg.get("user_id")) == bot_id for msg in latest_5_messages)
has_bot_message = any(str(msg.user_info.user_id) == bot_id for msg in latest_5_messages)
# logger.info(f"最新五条消息:{latest_5_messages}")
# logger.info(f"最新五条消息中是否包含bot自己说的消息{has_bot_message}")
@ -634,7 +640,7 @@ class DefaultReplyer:
return mai_think
async def build_actions_prompt(
self, available_actions, choosen_actions: Optional[List[Dict[str, Any]]] = None
self, available_actions: Dict[str, ActionInfo], chosen_actions_info: Optional[List[ActionPlannerInfo]] = None
) -> str:
"""构建动作提示"""
@ -646,20 +652,21 @@ class DefaultReplyer:
action_descriptions += f"- {action_name}: {action_description}\n"
action_descriptions += "\n"
choosen_action_descriptions = ""
if choosen_actions:
for action in choosen_actions:
action_name = action.get("action_type", "unknown_action")
chosen_action_descriptions = ""
if chosen_actions_info:
for action_plan_info in chosen_actions_info:
action_name = action_plan_info.action_type
if action_name == "reply":
continue
action_description = action.get("reason", "无描述")
reasoning = action.get("reasoning", "无原因")
if action := available_actions.get(action_name):
action_description = action.description or "无描述"
reasoning = action_plan_info.reasoning or "无原因"
choosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
chosen_action_descriptions += f"- {action_name}: {action_description},原因:{reasoning}\n"
if choosen_action_descriptions:
if chosen_action_descriptions:
action_descriptions += "根据聊天情况,另一个模型决定在回复的同时做以下这些动作:\n"
action_descriptions += choosen_action_descriptions
action_descriptions += chosen_action_descriptions
return action_descriptions
@ -668,9 +675,9 @@ class DefaultReplyer:
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
chosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List[ActionPlannerInfo]] = None,
enable_tool: bool = True,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional[DatabaseMessages] = None,
) -> Tuple[str, List[int]]:
"""
构建回复器上下文
@ -694,11 +701,11 @@ class DefaultReplyer:
platform = chat_stream.platform
if reply_message:
user_id = reply_message.get("user_id", "")
user_id = reply_message.user_info.user_id
person = Person(platform=platform, user_id=user_id)
person_name = person.person_name or user_id
sender = person_name
target = reply_message.get("processed_plain_text")
target = reply_message.processed_plain_text
else:
person_name = "用户"
sender = "用户"
@ -710,7 +717,7 @@ class DefaultReplyer:
else:
mood_prompt = ""
target = replace_user_references_sync(target, chat_stream.platform, replace_bot_name=True)
target = replace_user_references(target, chat_stream.platform, replace_bot_name=True)
message_list_before_now_long = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
@ -774,11 +781,13 @@ class DefaultReplyer:
logger.info(f"回复准备: {'; '.join(timing_logs)}; {almost_zero_str} <0.01s")
expression_habits_block, selected_expressions = results_dict["expression_habits"]
relation_info = results_dict["relation_info"]
memory_block = results_dict["memory_block"]
tool_info = results_dict["tool_info"]
prompt_info = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info = results_dict["actions_info"]
expression_habits_block: str
selected_expressions: List[int]
relation_info: str = results_dict["relation_info"]
memory_block: str = results_dict["memory_block"]
tool_info: str = results_dict["tool_info"]
prompt_info: str = results_dict["prompt_info"] # 直接使用格式化后的结果
actions_info: str = results_dict["actions_info"]
keywords_reaction_prompt = await self.build_keywords_reaction_prompt(target)
if extra_info:

View File

@ -19,8 +19,8 @@ install(extra_lines=3)
logger = get_logger("chat_message_builder")
def replace_user_references_sync(
content: str,
def replace_user_references(
content: Optional[str],
platform: str,
name_resolver: Optional[Callable[[str, str], str]] = None,
replace_bot_name: bool = True,
@ -38,6 +38,8 @@ def replace_user_references_sync(
Returns:
str: 处理后的内容字符串
"""
if not content:
return ""
if name_resolver is None:
def default_resolver(platform: str, user_id: str) -> str:
@ -93,80 +95,6 @@ def replace_user_references_sync(
return content
async def replace_user_references_async(
content: str,
platform: str,
name_resolver: Optional[Callable[[str, str], Any]] = None,
replace_bot_name: bool = True,
) -> str:
"""
替换内容中的用户引用格式包括回复<aaa:bbb>@<aaa:bbb>格式
Args:
content: 要处理的内容字符串
platform: 平台标识
name_resolver: 名称解析函数接收(platform, user_id)参数返回用户名称
如果为None则使用默认的person_info_manager
replace_bot_name: 是否将机器人的user_id替换为"机器人昵称(你)"
Returns:
str: 处理后的内容字符串
"""
if name_resolver is None:
async def default_resolver(platform: str, user_id: str) -> str:
# 检查是否是机器人自己
if replace_bot_name and user_id == global_config.bot.qq_account:
return f"{global_config.bot.nickname}(你)"
person = Person(platform=platform, user_id=user_id)
return person.person_name or user_id # type: ignore
name_resolver = default_resolver
# 处理回复<aaa:bbb>格式
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
match = re.search(reply_pattern, content)
if match:
aaa = match.group(1)
bbb = match.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
reply_person_name = f"{global_config.bot.nickname}(你)"
else:
reply_person_name = await name_resolver(platform, bbb) or aaa
content = re.sub(reply_pattern, f"回复 {reply_person_name}", content, count=1)
except Exception:
# 如果解析失败,使用原始昵称
content = re.sub(reply_pattern, f"回复 {aaa}", content, count=1)
# 处理@<aaa:bbb>格式
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
at_matches = list(re.finditer(at_pattern, content))
if at_matches:
new_content = ""
last_end = 0
for m in at_matches:
new_content += content[last_end : m.start()]
aaa = m.group(1)
bbb = m.group(2)
try:
# 检查是否是机器人自己
if replace_bot_name and bbb == global_config.bot.qq_account:
at_person_name = f"{global_config.bot.nickname}(你)"
else:
at_person_name = await name_resolver(platform, bbb) or aaa
new_content += f"@{at_person_name}"
except Exception:
# 如果解析失败,使用原始昵称
new_content += f"@{aaa}"
last_end = m.end()
new_content += content[last_end:]
content = new_content
return content
def get_raw_msg_by_timestamp(timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"):
"""
获取从指定时间戳到指定时间戳的消息按时间升序排序返回消息列表
@ -498,7 +426,7 @@ def _build_readable_messages_internal(
person_name = f"{global_config.bot.nickname}(你)"
# 使用独立函数处理用户引用格式
if content := replace_user_references_sync(content, platform, replace_bot_name=replace_bot_name):
if content := replace_user_references(content, platform, replace_bot_name=replace_bot_name):
detailed_messages_raw.append((timestamp, person_name, content, False))
if not detailed_messages_raw:
@ -658,7 +586,10 @@ async def build_readable_messages_with_list(
允许通过参数控制格式化行为
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
convert_DatabaseMessages_to_MessageAndActionModel(messages), replace_bot_name, timestamp_mode, truncate
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
replace_bot_name,
timestamp_mode,
truncate,
)
if pic_mapping_info := build_pic_mapping_info(pic_id_mapping):
@ -725,19 +656,7 @@ def build_readable_messages(
if not messages:
return ""
copy_messages: List[MessageAndActionModel] = [
MessageAndActionModel(
msg.time,
msg.user_info.user_id,
msg.user_info.platform,
msg.user_info.user_nickname,
msg.user_info.user_cardname,
msg.processed_plain_text,
msg.display_message,
msg.chat_info.platform,
)
for msg in messages
]
copy_messages: List[MessageAndActionModel] = [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages]
if show_actions and copy_messages:
# 获取所有消息的时间范围
@ -942,7 +861,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
except Exception:
return "?"
content = replace_user_references_sync(content, platform, anon_name_resolver, replace_bot_name=False)
content = replace_user_references(content, platform, anon_name_resolver, replace_bot_name=False)
header = f"{anon_name}"
output_lines.append(header)
@ -996,22 +915,3 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回
def convert_DatabaseMessages_to_MessageAndActionModel(message: List[DatabaseMessages]) -> List[MessageAndActionModel]:
"""
DatabaseMessages 列表转换为 MessageAndActionModel 列表
"""
return [
MessageAndActionModel(
time=msg.time,
user_id=msg.user_info.user_id,
user_platform=msg.user_info.platform,
user_nickname=msg.user_info.user_nickname,
user_cardname=msg.user_info.user_cardname,
processed_plain_text=msg.processed_plain_text,
display_message=msg.display_message,
chat_info_platform=msg.chat_info.platform,
)
for msg in message
]

View File

@ -11,7 +11,6 @@ from collections import Counter
from typing import Optional, Tuple, Dict, List, Any
from src.common.logger import get_logger
from src.common.data_models.info_data_model import TargetPersonInfo
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.message_repository import find_messages, count_messages
from src.config.config import global_config, model_config
@ -641,6 +640,8 @@ def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
platform: str = chat_stream.platform
user_id: str = user_info.user_id # type: ignore
from src.common.data_models.info_data_model import TargetPersonInfo # 解决循环导入问题
# Initialize target_info with basic info
target_info = TargetPersonInfo(
platform=platform,

View File

@ -1,4 +1,4 @@
from typing import Optional, Any
from typing import Optional, Any, Dict
from dataclasses import dataclass, field
from . import BaseDataModel
@ -157,3 +157,42 @@ class DatabaseMessages(BaseDataModel):
# assert isinstance(self.interest_value, float) or self.interest_value is None, (
# "interest_value must be a float or None"
# )
def flatten(self) -> Dict[str, Any]:
"""
将消息数据模型转换为字典格式便于存储或传输
"""
return {
"message_id": self.message_id,
"time": self.time,
"chat_id": self.chat_id,
"reply_to": self.reply_to,
"interest_value": self.interest_value,
"key_words": self.key_words,
"key_words_lite": self.key_words_lite,
"is_mentioned": self.is_mentioned,
"processed_plain_text": self.processed_plain_text,
"display_message": self.display_message,
"priority_mode": self.priority_mode,
"priority_info": self.priority_info,
"additional_config": self.additional_config,
"is_emoji": self.is_emoji,
"is_picid": self.is_picid,
"is_command": self.is_command,
"is_notify": self.is_notify,
"selected_expressions": self.selected_expressions,
"user_id": self.user_info.user_id,
"user_nickname": self.user_info.user_nickname,
"user_cardname": self.user_info.user_cardname,
"user_platform": self.user_info.platform,
"chat_info_group_id": self.group_info.group_id if self.group_info else None,
"chat_info_group_name": self.group_info.group_name if self.group_info else None,
"chat_info_group_platform": self.group_info.group_platform if self.group_info else None,
"chat_info_stream_id": self.chat_info.stream_id,
"chat_info_platform": self.chat_info.platform,
"chat_info_create_time": self.chat_info.create_time,
"chat_info_last_active_time": self.chat_info.last_active_time,
"chat_info_user_platform": self.chat_info.user_info.platform,
"chat_info_user_id": self.chat_info.user_info.user_id,
"chat_info_user_nickname": self.chat_info.user_info.user_nickname,
"chat_info_user_cardname": self.chat_info.user_info.user_cardname,
}

View File

@ -1,12 +1,25 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Dict, TYPE_CHECKING
from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
from src.plugin_system.base.component_types import ActionInfo
@dataclass
class TargetPersonInfo(BaseDataModel):
platform: str = field(default_factory=str)
user_id: str = field(default_factory=str)
user_nickname: str = field(default_factory=str)
person_id: Optional[str] = None
person_name: Optional[str] = None
person_name: Optional[str] = None
@dataclass
class ActionPlannerInfo(BaseDataModel):
action_type: str = field(default_factory=str)
reasoning: Optional[str] = None
action_data: Optional[Dict] = None
action_message: Optional["DatabaseMessages"] = None
available_actions: Optional[Dict[str, "ActionInfo"]] = None

View File

@ -1,10 +1,15 @@
from typing import Optional
from typing import Optional, TYPE_CHECKING
from dataclasses import dataclass, field
from . import BaseDataModel
if TYPE_CHECKING:
from .database_data_model import DatabaseMessages
@dataclass
class MessageAndActionModel(BaseDataModel):
chat_id: str = field(default_factory=str)
time: float = field(default_factory=float)
user_id: str = field(default_factory=str)
user_platform: str = field(default_factory=str)
@ -15,3 +20,17 @@ class MessageAndActionModel(BaseDataModel):
chat_info_platform: str = field(default_factory=str)
is_action_record: bool = field(default=False)
action_name: Optional[str] = None
@classmethod
def from_DatabaseMessages(cls, message: "DatabaseMessages"):
return cls(
chat_id=message.chat_id,
time=message.time,
user_id=message.user_info.user_id,
user_platform=message.user_info.platform,
user_nickname=message.user_info.user_nickname,
user_cardname=message.user_info.user_cardname,
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
chat_info_platform=message.chat_info.platform,
)

View File

@ -47,10 +47,13 @@ logger = get_logger("Gemini客户端")
# gemini_thinking参数默认范围
# 不同模型的思考预算范围配置
THINKING_BUDGET_LIMITS = {
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
"gemini-2.5-flash": {"min": 1, "max": 24576, "can_disable": True},
"gemini-2.5-flash-lite": {"min": 512, "max": 24576, "can_disable": True},
"gemini-2.5-pro": {"min": 128, "max": 32768, "can_disable": False},
}
# 思维预算特殊值
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
gemini_safe_settings = [
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.BLOCK_NONE),
@ -91,9 +94,7 @@ def _convert_messages(
for item in message.content:
if isinstance(item, tuple):
image_format = "jpeg" if item[0].lower() == "jpg" else item[0].lower()
content.append(
Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}")
)
content.append(Part.from_bytes(data=base64.b64decode(item[1]), mime_type=f"image/{image_format}"))
elif isinstance(item, str):
content.append(Part.from_text(text=item))
else:
@ -336,47 +337,40 @@ class GeminiClient(BaseClient):
api_key=api_provider.api_key,
) # 这里和openai不一样gemini会自己决定自己是否需要retry
# 思维预算特殊值
THINKING_BUDGET_AUTO = -1 # 自动调整思考预算,由模型决定
THINKING_BUDGET_DISABLED = 0 # 禁用思考预算(如果模型允许禁用)
@staticmethod
def clamp_thinking_budget(tb: int, model_id: str):
def clamp_thinking_budget(tb: int, model_id: str) -> int:
"""
按模型限制思考预算范围仅支持指定的模型支持带数字后缀的新版本
"""
limits = None
matched_key = None
# 优先尝试精确匹配
if model_id in THINKING_BUDGET_LIMITS:
limits = THINKING_BUDGET_LIMITS[model_id]
matched_key = model_id
else:
# 按 key 长度倒序,保证更长的(更具体的,如 -lite优先
sorted_keys = sorted(THINKING_BUDGET_LIMITS.keys(), key=len, reverse=True)
for key in sorted_keys:
# 必须满足:完全等于 或者 前缀匹配(带 "-" 边界)
if model_id == key or model_id.startswith(key + "-"):
limits = THINKING_BUDGET_LIMITS[key]
matched_key = key
break
if model_id == key or model_id.startswith(f"{key}-"):
limits = THINKING_BUDGET_LIMITS[key]
break
# 特殊值处理
if tb == GeminiClient.THINKING_BUDGET_AUTO:
return GeminiClient.THINKING_BUDGET_AUTO
if tb == GeminiClient.THINKING_BUDGET_DISABLED:
if tb == THINKING_BUDGET_AUTO:
return THINKING_BUDGET_AUTO
if tb == THINKING_BUDGET_DISABLED:
if limits and limits.get("can_disable", False):
return GeminiClient.THINKING_BUDGET_DISABLED
return limits["min"] if limits else GeminiClient.THINKING_BUDGET_AUTO
return THINKING_BUDGET_DISABLED
return limits["min"] if limits else THINKING_BUDGET_AUTO
# 已知模型裁剪到范围
if limits:
return max(limits["min"], min(tb, limits["max"]))
return max(limits["min"], min(tb, limits["max"]))
# 未知模型,返回动态模式
logger.warning(f"模型 {model_id} 未在 THINKING_BUDGET_LIMITS 中定义,将使用动态模式 tb=-1 兼容。")
return GeminiClient.THINKING_BUDGET_AUTO
return THINKING_BUDGET_AUTO
async def get_response(
self,
@ -424,15 +418,13 @@ class GeminiClient(BaseClient):
# 将tool_options转换为Gemini API所需的格式
tools = _convert_tool_options(tool_options) if tool_options else None
tb = GeminiClient.THINKING_BUDGET_AUTO
#空处理
tb = THINKING_BUDGET_AUTO
# 空处理
if extra_params and "thinking_budget" in extra_params:
try:
tb = int(extra_params["thinking_budget"])
except (ValueError, TypeError):
logger.warning(
f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}"
)
logger.warning(f"无效的 thinking_budget 值 {extra_params['thinking_budget']},将使用默认动态模式 {tb}")
# 裁剪到模型支持的范围
tb = self.clamp_thinking_budget(tb, model_info.model_identifier)

View File

@ -13,6 +13,7 @@ from src.common.logger import get_logger
from src.individuality.individuality import get_individuality, Individuality
from src.common.server import get_global_server, Server
from src.mood.mood_manager import mood_manager
from src.chat.knowledge import lpmm_start_up
from rich.traceback import install
from src.migrate_helper.migrate import check_and_run_migrations
# from src.api.main import start_api_server
@ -83,6 +84,9 @@ class MainSystem:
# 启动API服务器
# start_api_server()
# logger.info("API服务器启动成功")
# 启动LPMM
lpmm_start_up()
# 加载所有actions包括默认的和插件的
plugin_manager.load_all_plugins()
@ -104,7 +108,6 @@ class MainSystem:
logger.info("情绪管理器初始化成功")
# 初始化聊天管理器
await get_chat_manager()._initialize()
asyncio.create_task(get_chat_manager()._auto_save_task())

View File

@ -3,6 +3,7 @@ import asyncio
import json
import time
import random
import math
from json_repair import repair_json
from typing import Union, Optional
@ -16,6 +17,7 @@ from src.config.config import global_config, model_config
logger = get_logger("person_info")
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id"""
if "-" in platform:
@ -24,6 +26,7 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
@ -33,7 +36,8 @@ def get_person_id_by_person_name(person_name: str) -> str:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return ""
def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool:
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
if person_id:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
return person.is_known if person else False
@ -47,89 +51,84 @@ def is_person_known(person_id: str = None,user_id: str = None,platform: str = No
return person.is_known if person else False
else:
return False
def get_catagory_from_memory(memory_point:str) -> str:
def get_category_from_memory(memory_point: str) -> Optional[str]:
"""从记忆点中获取分类"""
# 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类
if not isinstance(memory_point, str):
return None
parts = memory_point.split(":", 1)
if len(parts) > 1:
return parts[0].strip()
else:
return None
def get_weight_from_memory(memory_point:str) -> float:
return parts[0].strip() if len(parts) > 1 else None
def get_weight_from_memory(memory_point: str) -> float:
"""从记忆点中获取权重"""
# 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重
if not isinstance(memory_point, str):
return None
return -math.inf
parts = memory_point.rsplit(":", 1)
if len(parts) > 1:
try:
return float(parts[-1].strip())
except Exception:
return None
else:
return None
def get_memory_content_from_memory(memory_point:str) -> str:
if len(parts) <= 1:
return -math.inf
try:
return float(parts[-1].strip())
except Exception:
return -math.inf
def get_memory_content_from_memory(memory_point: str) -> str:
"""从记忆点中获取记忆内容"""
# 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容
if not isinstance(memory_point, str):
return None
return ""
parts = memory_point.split(":")
if len(parts) > 2:
return ":".join(parts[1:-1]).strip()
else:
return None
return ":".join(parts[1:-1]).strip() if len(parts) > 2 else ""
def calculate_string_similarity(s1: str, s2: str) -> float:
"""
计算两个字符串的相似度
Args:
s1: 第一个字符串
s2: 第二个字符串
Returns:
float: 相似度范围0-11表示完全相同
"""
if s1 == s2:
return 1.0
if not s1 or not s2:
return 0.0
# 计算Levenshtein距离
distance = levenshtein_distance(s1, s2)
max_len = max(len(s1), len(s2))
# 计算相似度1 - (编辑距离 / 最大长度)
similarity = 1 - (distance / max_len if max_len > 0 else 0)
return similarity
def levenshtein_distance(s1: str, s2: str) -> int:
"""
计算两个字符串的编辑距离
Args:
s1: 第一个字符串
s2: 第二个字符串
Returns:
int: 编辑距离
"""
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
if len(s2) == 0:
return len(s1)
previous_row = range(len(s2) + 1)
for i, c1 in enumerate(s1):
current_row = [i + 1]
@ -139,44 +138,45 @@ def levenshtein_distance(s1: str, s2: str) -> int:
substitutions = previous_row[j] + (c1 != c2)
current_row.append(min(insertions, deletions, substitutions))
previous_row = current_row
return previous_row[-1]
class Person:
@classmethod
def register_person(cls, platform: str, user_id: str, nickname: str):
"""
注册新用户的类方法
必须输入 platformuser_id nickname 参数
Args:
platform: 平台名称
user_id: 用户ID
nickname: 用户昵称
Returns:
Person: 新注册的Person实例
"""
if not platform or not user_id or not nickname:
logger.error("注册用户失败platform、user_id 和 nickname 都是必需参数")
return None
# 生成唯一的person_id
person_id = get_person_id(platform, user_id)
if is_person_known(person_id=person_id):
logger.debug(f"用户 {nickname} 已存在")
return Person(person_id=person_id)
# 创建Person实例
person = cls.__new__(cls)
# 设置基本属性
person.person_id = person_id
person.platform = platform
person.user_id = user_id
person.nickname = nickname
# 初始化默认值
person.is_known = True # 注册后立即标记为已认识
person.person_name = nickname # 使用nickname作为初始person_name
@ -185,34 +185,34 @@ class Person:
person.know_since = time.time()
person.last_know = time.time()
person.memory_points = []
# 初始化性格特征相关字段
person.attitude_to_me = 0
person.attitude_to_me_confidence = 1
person.neuroticism = 5
person.neuroticism_confidence = 1
person.friendly_value = 50
person.friendly_value_confidence = 1
person.rudeness = 50
person.rudeness_confidence = 1
person.conscientiousness = 50
person.conscientiousness_confidence = 1
person.likeness = 50
person.likeness_confidence = 1
# 同步到数据库
person.sync_to_database()
logger.info(f"成功注册新用户:{person_id},平台:{platform},昵称:{nickname}")
return person
def __init__(self, platform: str = "", user_id: str = "",person_id: str = "",person_name: str = ""):
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
self.is_known = True
self.person_id = get_person_id(platform, user_id)
@ -221,10 +221,10 @@ class Person:
self.nickname = global_config.bot.nickname
self.person_name = global_config.bot.nickname
return
self.user_id = ""
self.platform = ""
if person_id:
self.person_id = person_id
elif person_name:
@ -232,7 +232,7 @@ class Person:
if not self.person_id:
self.is_known = False
logger.warning(f"根据用户名 {person_name} 获取用户ID时不存在用户{person_name}")
return
return
elif platform and user_id:
self.person_id = get_person_id(platform, user_id)
self.user_id = user_id
@ -240,17 +240,16 @@ class Person:
else:
logger.error("Person 初始化失败,缺少必要参数")
raise ValueError("Person 初始化失败,缺少必要参数")
if not is_person_known(person_id=self.person_id):
self.is_known = False
logger.debug(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.person_name = f"未知用户{self.person_id[:4]}"
return
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.is_known = False
# 初始化默认值
self.nickname = ""
self.person_name: Optional[str] = None
@ -259,47 +258,47 @@ class Person:
self.know_since = None
self.last_know = None
self.memory_points = []
# 初始化性格特征相关字段
self.attitude_to_me:float = 0
self.attitude_to_me_confidence:float = 1
self.neuroticism:float = 5
self.neuroticism_confidence:float = 1
self.friendly_value:float = 50
self.friendly_value_confidence:float = 1
self.rudeness:float = 50
self.rudeness_confidence:float = 1
self.conscientiousness:float = 50
self.conscientiousness_confidence:float = 1
self.likeness:float = 50
self.likeness_confidence:float = 1
self.attitude_to_me: float = 0
self.attitude_to_me_confidence: float = 1
self.neuroticism: float = 5
self.neuroticism_confidence: float = 1
self.friendly_value: float = 50
self.friendly_value_confidence: float = 1
self.rudeness: float = 50
self.rudeness_confidence: float = 1
self.conscientiousness: float = 50
self.conscientiousness_confidence: float = 1
self.likeness: float = 50
self.likeness_confidence: float = 1
# 从数据库加载数据
self.load_from_database()
def del_memory(self, category: str, memory_content: str, similarity_threshold: float = 0.95):
"""
删除指定分类和记忆内容的记忆点
Args:
category: 记忆分类
memory_content: 要删除的记忆内容
similarity_threshold: 相似度阈值默认0.9595%
Returns:
int: 删除的记忆点数量
"""
if not self.memory_points:
return 0
deleted_count = 0
memory_points_to_keep = []
for memory_point in self.memory_points:
# 跳过None值
if memory_point is None:
@ -310,80 +309,76 @@ class Person:
# 格式不正确,保留原样
memory_points_to_keep.append(memory_point)
continue
memory_category = parts[0].strip()
memory_text = parts[1].strip()
memory_weight = parts[2].strip()
# 检查分类是否匹配
if memory_category != category:
memory_points_to_keep.append(memory_point)
continue
# 计算记忆内容的相似度
similarity = calculate_string_similarity(memory_content, memory_text)
# 如果相似度达到阈值,则删除(不添加到保留列表)
if similarity >= similarity_threshold:
deleted_count += 1
logger.debug(f"删除记忆点: {memory_point} (相似度: {similarity:.4f})")
else:
memory_points_to_keep.append(memory_point)
# 更新memory_points
self.memory_points = memory_points_to_keep
# 同步到数据库
if deleted_count > 0:
self.sync_to_database()
logger.info(f"成功删除 {deleted_count} 个记忆点,分类: {category}")
return deleted_count
def get_all_category(self):
category_list = []
for memory in self.memory_points:
if memory is None:
continue
category = get_catagory_from_memory(memory)
category = get_category_from_memory(memory)
if category and category not in category_list:
category_list.append(category)
return category_list
def get_memory_list_by_category(self,category:str):
def get_memory_list_by_category(self, category: str):
memory_list = []
for memory in self.memory_points:
if memory is None:
continue
if get_catagory_from_memory(memory) == category:
if get_category_from_memory(memory) == category:
memory_list.append(memory)
return memory_list
def get_random_memory_by_category(self,category:str,num:int=1):
def get_random_memory_by_category(self, category: str, num: int = 1):
memory_list = self.get_memory_list_by_category(category)
if len(memory_list) < num:
return memory_list
return random.sample(memory_list, num)
def load_from_database(self):
"""从数据库加载个人信息数据"""
try:
# 查询数据库中的记录
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
if record:
self.user_id = record.user_id if record.user_id else ""
self.platform = record.platform if record.platform else ""
self.is_known = record.is_known if record.is_known else False
self.nickname = record.nickname if record.nickname else ""
self.person_name = record.person_name if record.person_name else self.nickname
self.name_reason = record.name_reason if record.name_reason else None
self.know_times = record.know_times if record.know_times else 0
self.user_id = record.user_id or ""
self.platform = record.platform or ""
self.is_known = record.is_known or False
self.nickname = record.nickname or ""
self.person_name = record.person_name or self.nickname
self.name_reason = record.name_reason or None
self.know_times = record.know_times or 0
# 处理points字段JSON格式的列表
if record.memory_points:
try:
@ -398,53 +393,53 @@ class Person:
self.memory_points = []
else:
self.memory_points = []
# 加载性格特征相关字段
if record.attitude_to_me and not isinstance(record.attitude_to_me, str):
self.attitude_to_me = record.attitude_to_me
if record.attitude_to_me_confidence is not None:
self.attitude_to_me_confidence = float(record.attitude_to_me_confidence)
if record.friendly_value is not None:
self.friendly_value = float(record.friendly_value)
if record.friendly_value_confidence is not None:
self.friendly_value_confidence = float(record.friendly_value_confidence)
if record.rudeness is not None:
self.rudeness = float(record.rudeness)
if record.rudeness_confidence is not None:
self.rudeness_confidence = float(record.rudeness_confidence)
if record.neuroticism and not isinstance(record.neuroticism, str):
self.neuroticism = float(record.neuroticism)
if record.neuroticism_confidence is not None:
self.neuroticism_confidence = float(record.neuroticism_confidence)
if record.conscientiousness is not None:
self.conscientiousness = float(record.conscientiousness)
if record.conscientiousness_confidence is not None:
self.conscientiousness_confidence = float(record.conscientiousness_confidence)
if record.likeness is not None:
self.likeness = float(record.likeness)
if record.likeness_confidence is not None:
self.likeness_confidence = float(record.likeness_confidence)
logger.debug(f"已从数据库加载用户 {self.person_id} 的信息")
else:
self.sync_to_database()
logger.info(f"用户 {self.person_id} 在数据库中不存在,使用默认值并创建")
except Exception as e:
logger.error(f"从数据库加载用户 {self.person_id} 信息时出错: {e}")
# 出错时保持默认值
def sync_to_database(self):
"""将所有属性同步回数据库"""
if not self.is_known:
@ -452,34 +447,38 @@ class Person:
try:
# 准备数据
data = {
'person_id': self.person_id,
'is_known': self.is_known,
'platform': self.platform,
'user_id': self.user_id,
'nickname': self.nickname,
'person_name': self.person_name,
'name_reason': self.name_reason,
'know_times': self.know_times,
'know_since': self.know_since,
'last_know': self.last_know,
'memory_points': json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points else json.dumps([], ensure_ascii=False),
'attitude_to_me': self.attitude_to_me,
'attitude_to_me_confidence': self.attitude_to_me_confidence,
'friendly_value': self.friendly_value,
'friendly_value_confidence': self.friendly_value_confidence,
'rudeness': self.rudeness,
'rudeness_confidence': self.rudeness_confidence,
'neuroticism': self.neuroticism,
'neuroticism_confidence': self.neuroticism_confidence,
'conscientiousness': self.conscientiousness,
'conscientiousness_confidence': self.conscientiousness_confidence,
'likeness': self.likeness,
'likeness_confidence': self.likeness_confidence,
"person_id": self.person_id,
"is_known": self.is_known,
"platform": self.platform,
"user_id": self.user_id,
"nickname": self.nickname,
"person_name": self.person_name,
"name_reason": self.name_reason,
"know_times": self.know_times,
"know_since": self.know_since,
"last_know": self.last_know,
"memory_points": json.dumps(
[point for point in self.memory_points if point is not None], ensure_ascii=False
)
if self.memory_points
else json.dumps([], ensure_ascii=False),
"attitude_to_me": self.attitude_to_me,
"attitude_to_me_confidence": self.attitude_to_me_confidence,
"friendly_value": self.friendly_value,
"friendly_value_confidence": self.friendly_value_confidence,
"rudeness": self.rudeness,
"rudeness_confidence": self.rudeness_confidence,
"neuroticism": self.neuroticism,
"neuroticism_confidence": self.neuroticism_confidence,
"conscientiousness": self.conscientiousness,
"conscientiousness_confidence": self.conscientiousness_confidence,
"likeness": self.likeness,
"likeness_confidence": self.likeness_confidence,
}
# 检查记录是否存在
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
if record:
# 更新现有记录
for field, value in data.items():
@ -491,10 +490,10 @@ class Person:
# 创建新记录
PersonInfo.create(**data)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
except Exception as e:
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
def build_relationship(self):
if not self.is_known:
return ""
@ -505,22 +504,21 @@ class Person:
nickname_str = f"(ta在{self.platform}上的昵称是{self.nickname})"
relation_info = ""
attitude_info = ""
if self.attitude_to_me:
if self.attitude_to_me > 8:
attitude_info = f"{self.person_name}对你的态度十分好,"
elif self.attitude_to_me > 5:
attitude_info = f"{self.person_name}对你的态度较好,"
if self.attitude_to_me < -8:
attitude_info = f"{self.person_name}对你的态度十分恶劣,"
elif self.attitude_to_me < -4:
attitude_info = f"{self.person_name}对你的态度不好,"
elif self.attitude_to_me < 0:
attitude_info = f"{self.person_name}对你的态度一般,"
neuroticism_info = ""
if self.neuroticism:
if self.neuroticism > 8:
@ -533,29 +531,28 @@ class Person:
neuroticism_info = f"{self.person_name}的情绪比较稳定,"
else:
neuroticism_info = f"{self.person_name}的情绪非常稳定,毫无波动"
points_text = ""
category_list = self.get_all_category()
for category in category_list:
random_memory = self.get_random_memory_by_category(category,1)[0]
random_memory = self.get_random_memory_by_category(category, 1)[0]
if random_memory:
points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}"
break
points_info = ""
if points_text:
points_info = f"你还记得有关{self.person_name}的最近记忆:{points_text}"
if not (nickname_str or attitude_info or neuroticism_info or points_info):
return ""
relation_info = f"{self.person_name}:{nickname_str}{attitude_info}{neuroticism_info}{points_info}"
return relation_info
class PersonInfoManager:
def __init__(self):
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
@ -580,8 +577,6 @@ class PersonInfoManager:
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod
def _extract_json_from_text(text: str) -> dict:
@ -717,6 +712,6 @@ class PersonInfoManager:
person.sync_to_database()
self.person_name_list[person_id] = unique_nickname
return {"nickname": unique_nickname, "reason": "使用用户原始昵称作为默认值"}
person_info_manager = PersonInfoManager()

View File

@ -3,7 +3,8 @@ import traceback
import os
import pickle
import random
from typing import List, Dict, Any
import asyncio
from typing import List, Dict, Any, TYPE_CHECKING
from src.config.config import global_config
from src.common.logger import get_logger
from src.person_info.relationship_manager import get_relationship_manager
@ -15,7 +16,9 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
num_new_messages_since,
)
import asyncio
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("relationship_builder")
@ -429,7 +432,7 @@ class RelationshipBuilder:
if dropped_count > 0:
logger.debug(f"{person_id} 随机丢弃了 {dropped_count} / {original_segment_count} 个消息段")
processed_messages = []
processed_messages: List["DatabaseMessages"] = []
# 对筛选后的消息段进行排序,确保时间顺序
segments_to_process.sort(key=lambda x: x["start_time"])
@ -449,17 +452,18 @@ class RelationshipBuilder:
# 如果 processed_messages 不为空,说明这不是第一个被处理的消息段,在消息列表前添加间隔标识
if processed_messages:
# 创建一个特殊的间隔消息
gap_message = {
"time": start_time - 0.1, # 稍微早于段开始时间
"user_id": "system",
"user_platform": "system",
"user_nickname": "系统",
"user_cardname": "",
"display_message": f"...(中间省略一些消息){start_date} 之后的消息如下...",
"is_action_record": True,
"chat_info_platform": segment_messages[0].chat_info.platform or "",
"chat_id": chat_id,
}
gap_message = DatabaseMessages(
time=start_time - 0.1,
user_id="system",
user_platform="system",
user_nickname="系统",
user_cardname="",
display_message=f"...(中间省略一些消息){start_date} 之后的消息如下...",
is_action_record=True,
chat_info_platform=segment_messages[0].chat_info.platform or "",
chat_id=chat_id,
)
processed_messages.append(gap_message)
# 添加该段的所有消息
@ -467,11 +471,11 @@ class RelationshipBuilder:
if processed_messages:
# 按时间排序所有消息(包括间隔标识)
processed_messages.sort(key=lambda x: x["time"])
processed_messages.sort(key=lambda x: x.time)
logger.debug(f"{person_id} 获取到总共 {len(processed_messages)} 条消息(包含间隔标识)用于印象更新")
relationship_manager = get_relationship_manager()
build_frequency = 0.3 * global_config.relationship.relation_frequency
if random.random() < build_frequency:
# 调用原有的更新方法

View File

@ -3,16 +3,18 @@ import traceback
from json_repair import repair_json
from datetime import datetime
from typing import List
from typing import List, TYPE_CHECKING
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
from src.chat.utils.chat_message_builder import build_readable_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from .person_info import Person
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("relation")
@ -177,7 +179,7 @@ class RelationshipManager:
return person
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List[DatabaseMessages]):
async def update_person_impression(self, person_id, timestamp, bot_engaged_messages: List["DatabaseMessages"]):
"""更新用户印象
Args:
@ -192,8 +194,6 @@ class RelationshipManager:
# nickname = person.nickname
know_times: float = person.know_times
user_messages = bot_engaged_messages
# 匿名化消息
# 创建用户名称映射
name_mapping = {}
@ -201,13 +201,14 @@ class RelationshipManager:
user_count = 1
# 遍历消息,构建映射
for msg in user_messages:
for msg in bot_engaged_messages:
if msg.user_info.user_id == "system":
continue
try:
user_id = msg.user_info.user_id
platform = msg.chat_info.platform
assert isinstance(user_id, str) and isinstance(platform, str)
assert user_id, "用户ID不能为空"
assert platform, "平台不能为空"
msg_person = Person(user_id=user_id, platform=platform)
except Exception as e:
@ -233,7 +234,7 @@ class RelationshipManager:
current_user = chr(ord(current_user) + 1)
readable_messages = build_readable_messages(
messages=user_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
messages=bot_engaged_messages, replace_bot_name=True, timestamp_mode="normal_no_YMD", truncate=True
)
for original_name, mapped_name in name_mapping.items():

View File

@ -9,7 +9,7 @@
"""
import traceback
from typing import Tuple, Any, Dict, List, Optional
from typing import Tuple, Any, Dict, List, Optional, TYPE_CHECKING
from rich.traceback import install
from src.common.logger import get_logger
from src.chat.replyer.default_generator import DefaultReplyer
@ -18,6 +18,10 @@ from src.chat.utils.utils import process_llm_response
from src.chat.replyer.replyer_manager import replyer_manager
from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING:
from src.common.data_models.info_data_model import ActionPlannerInfo
from src.common.data_models.database_data_model import DatabaseMessages
install(extra_lines=3)
logger = get_logger("generator_api")
@ -73,11 +77,11 @@ async def generate_reply(
chat_stream: Optional[ChatStream] = None,
chat_id: Optional[str] = None,
action_data: Optional[Dict[str, Any]] = None,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
extra_info: str = "",
reply_reason: str = "",
available_actions: Optional[Dict[str, ActionInfo]] = None,
choosen_actions: Optional[List[Dict[str, Any]]] = None,
chosen_actions: Optional[List["ActionPlannerInfo"]] = None,
enable_tool: bool = False,
enable_splitter: bool = True,
enable_chinese_typo: bool = True,
@ -85,7 +89,7 @@ async def generate_reply(
request_type: str = "generator_api",
from_plugin: bool = True,
return_expressions: bool = False,
) -> Tuple[bool, List[Tuple[str, Any]], Optional[Tuple[str, List[Dict[str, Any]]]]]:
) -> Tuple[bool, List[Tuple[str, Any]], Optional[str], Optional[List[int]]]:
"""生成回复
Args:
@ -96,7 +100,7 @@ async def generate_reply(
extra_info: 额外信息用于补充上下文
reply_reason: 回复原因
available_actions: 可用动作
choosen_actions: 已选动作
chosen_actions: 已选动作
enable_tool: 是否启用工具调用
enable_splitter: 是否启用消息分割器
enable_chinese_typo: 是否启用错字生成器
@ -110,16 +114,14 @@ async def generate_reply(
try:
# 获取回复器
logger.debug("[GeneratorAPI] 开始生成回复")
replyer = get_replyer(
chat_stream, chat_id, request_type=request_type
)
replyer = get_replyer(chat_stream, chat_id, request_type=request_type)
if not replyer:
logger.error("[GeneratorAPI] 无法获取回复器")
return False, [], None
return False, [], None, None
if not extra_info and action_data:
extra_info = action_data.get("extra_info", "")
if not reply_reason and action_data:
reply_reason = action_data.get("reason", "")
@ -127,7 +129,7 @@ async def generate_reply(
success, llm_response_dict, prompt, selected_expressions = await replyer.generate_reply_with_context(
extra_info=extra_info,
available_actions=available_actions,
chosen_actions=choosen_actions,
chosen_actions=chosen_actions,
enable_tool=enable_tool,
reply_message=reply_message,
reply_reason=reply_reason,
@ -136,7 +138,7 @@ async def generate_reply(
)
if not success:
logger.warning("[GeneratorAPI] 回复生成失败")
return False, [], None
return False, [], None, None
assert llm_response_dict is not None, "llm_response_dict不应为None" # 虽然说不会出现llm_response为空的情况
if content := llm_response_dict.get("content", ""):
reply_set = process_human_text(content, enable_splitter, enable_chinese_typo)
@ -144,28 +146,34 @@ async def generate_reply(
reply_set = []
logger.debug(f"[GeneratorAPI] 回复生成成功,生成了 {len(reply_set)} 个回复项")
if return_prompt:
if return_expressions:
return success, reply_set, (prompt, selected_expressions)
else:
return success, reply_set, prompt
else:
if return_expressions:
return success, reply_set, (None, selected_expressions)
else:
return success, reply_set, None
# if return_prompt:
# if return_expressions:
# return success, reply_set, prompt, selected_expressions
# else:
# return success, reply_set, prompt, None
# else:
# if return_expressions:
# return success, reply_set, (None, selected_expressions)
# else:
# return success, reply_set, None
return (
success,
reply_set,
prompt if return_prompt else None,
selected_expressions if return_expressions else None,
)
except ValueError as ve:
raise ve
except UserWarning as uw:
logger.warning(f"[GeneratorAPI] 中断了生成: {uw}")
return False, [], None
return False, [], None, None
except Exception as e:
logger.error(f"[GeneratorAPI] 生成回复时出错: {e}")
logger.error(traceback.format_exc())
return False, [], None
return False, [], None, None
async def rewrite_reply(

View File

@ -21,15 +21,17 @@
import traceback
import time
from typing import Optional, Union, Dict, Any, List
from src.common.logger import get_logger
from typing import Optional, Union, Dict, Any, List, TYPE_CHECKING
# 导入依赖
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.message_receive.message import MessageSending, MessageRecv
from maim_message import Seg, UserInfo
from src.config.config import global_config
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("send_api")
@ -46,10 +48,10 @@ async def _send_to_target(
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
show_log: bool = True,
selected_expressions:List[int] = None,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定目标发送消息的内部实现
@ -70,7 +72,7 @@ async def _send_to_target(
if set_reply and not reply_message:
logger.warning("[SendAPI] 使用引用回复,但未提供回复消息")
return False
if show_log:
logger.debug(f"[SendAPI] 发送{message_type}消息到 {stream_id}")
@ -98,13 +100,13 @@ async def _send_to_target(
message_segment = Seg(type=message_type, data=content) # type: ignore
if reply_message:
anchor_message = message_dict_to_message_recv(reply_message)
anchor_message = message_dict_to_message_recv(reply_message.flatten())
if anchor_message:
anchor_message.update_chat_stream(target_stream)
assert anchor_message.message_info.user_info, "用户信息缺失"
reply_to_platform_id = (
f"{anchor_message.message_info.platform}:{anchor_message.message_info.user_info.user_id}"
)
)
else:
reply_to_platform_id = ""
anchor_message = None
@ -192,12 +194,11 @@ def message_dict_to_message_recv(message_dict: Dict[str, Any]) -> Optional[Messa
}
message_recv = MessageRecv(message_dict_recv)
logger.info(f"[SendAPI] 找到匹配的回复消息,发送者: {message_dict.get('user_nickname', '')}")
return message_recv
# =============================================================================
# 公共API函数 - 预定义类型的发送函数
# =============================================================================
@ -208,9 +209,9 @@ async def text_to_stream(
stream_id: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
selected_expressions:List[int] = None,
selected_expressions: Optional[List[int]] = None,
) -> bool:
"""向指定流发送文本消息
@ -237,7 +238,13 @@ async def text_to_stream(
)
async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def emoji_to_stream(
emoji_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送表情包
Args:
@ -248,10 +255,25 @@ async def emoji_to_stream(emoji_base64: str, stream_id: str, storage_message: bo
Returns:
bool: 是否发送成功
"""
return await _send_to_target("emoji", emoji_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message)
return await _send_to_target(
"emoji",
emoji_base64,
stream_id,
"",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def image_to_stream(image_base64: str, stream_id: str, storage_message: bool = True, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def image_to_stream(
image_base64: str,
stream_id: str,
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送图片
Args:
@ -262,11 +284,25 @@ async def image_to_stream(image_base64: str, stream_id: str, storage_message: bo
Returns:
bool: 是否发送成功
"""
return await _send_to_target("image", image_base64, stream_id, "", typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message)
return await _send_to_target(
"image",
image_base64,
stream_id,
"",
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
async def command_to_stream(
command: Union[str, dict], stream_id: str, storage_message: bool = True, display_message: str = "", set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
command: Union[str, dict],
stream_id: str,
storage_message: bool = True,
display_message: str = "",
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""向指定流发送命令
@ -279,7 +315,14 @@ async def command_to_stream(
bool: 是否发送成功
"""
return await _send_to_target(
"command", command, stream_id, display_message, typing=False, storage_message=storage_message, set_reply=set_reply,reply_message=reply_message
"command",
command,
stream_id,
display_message,
typing=False,
storage_message=storage_message,
set_reply=set_reply,
reply_message=reply_message,
)
@ -289,7 +332,7 @@ async def custom_to_stream(
stream_id: str,
display_message: str = "",
typing: bool = False,
reply_message: Optional[Dict[str, Any]] = None,
reply_message: Optional["DatabaseMessages"] = None,
set_reply: bool = False,
storage_message: bool = True,
show_log: bool = True,

View File

@ -2,13 +2,15 @@ import time
import asyncio
from abc import ABC, abstractmethod
from typing import Tuple, Optional, Dict, Any
from typing import Tuple, Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
from src.plugin_system.base.component_types import ActionActivationType, ChatMode, ActionInfo, ComponentType
from src.plugin_system.base.component_types import ActionActivationType, ActionInfo, ComponentType
from src.plugin_system.apis import send_api, database_api, message_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_action")
@ -206,7 +208,11 @@ class BaseAction(ABC):
return False, f"等待新消息失败: {str(e)}"
async def send_text(
self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None, typing: bool = False
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
typing: bool = False,
) -> bool:
"""发送文本消息
@ -229,7 +235,9 @@ class BaseAction(ABC):
typing=typing,
)
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_emoji(
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送表情包
Args:
@ -242,9 +250,13 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.emoji_to_stream(emoji_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message)
return await send_api.emoji_to_stream(
emoji_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
)
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_image(
self, image_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送图片
Args:
@ -257,9 +269,18 @@ class BaseAction(ABC):
logger.error(f"{self.log_prefix} 缺少聊天ID")
return False
return await send_api.image_to_stream(image_base64, self.chat_id,set_reply=set_reply,reply_message=reply_message)
return await send_api.image_to_stream(
image_base64, self.chat_id, set_reply=set_reply, reply_message=reply_message
)
async def send_custom(self, message_type: str, content: str, typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_custom(
self,
message_type: str,
content: str,
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送自定义类型消息
Args:
@ -308,7 +329,13 @@ class BaseAction(ABC):
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送命令消息

View File

@ -1,10 +1,13 @@
from abc import ABC, abstractmethod
from typing import Dict, Tuple, Optional, Any
from typing import Dict, Tuple, Optional, TYPE_CHECKING
from src.common.logger import get_logger
from src.plugin_system.base.component_types import CommandInfo, ComponentType
from src.chat.message_receive.message import MessageRecv
from src.plugin_system.apis import send_api
if TYPE_CHECKING:
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger("base_command")
@ -84,7 +87,13 @@ class BaseCommand(ABC):
return current
async def send_text(self, content: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool:
async def send_text(
self,
content: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送回复消息
Args:
@ -100,10 +109,22 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.text_to_stream(text=content, stream_id=chat_stream.stream_id, set_reply=set_reply,reply_message=reply_message,storage_message=storage_message)
return await send_api.text_to_stream(
text=content,
stream_id=chat_stream.stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
async def send_type(
self, message_type: str, content: str, display_message: str = "", typing: bool = False, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
self,
message_type: str,
content: str,
display_message: str = "",
typing: bool = False,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送指定类型的回复消息到当前聊天环境
@ -134,7 +155,13 @@ class BaseCommand(ABC):
)
async def send_command(
self, command_name: str, args: Optional[dict] = None, display_message: str = "", storage_message: bool = True,set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None
self,
command_name: str,
args: Optional[dict] = None,
display_message: str = "",
storage_message: bool = True,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
) -> bool:
"""发送命令消息
@ -177,7 +204,9 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 发送命令时出错: {e}")
return False
async def send_emoji(self, emoji_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None) -> bool:
async def send_emoji(
self, emoji_base64: str, set_reply: bool = False, reply_message: Optional["DatabaseMessages"] = None
) -> bool:
"""发送表情包
Args:
@ -191,9 +220,17 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.emoji_to_stream(emoji_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message)
return await send_api.emoji_to_stream(
emoji_base64, chat_stream.stream_id, set_reply=set_reply, reply_message=reply_message
)
async def send_image(self, image_base64: str, set_reply: bool = False,reply_message: Optional[Dict[str, Any]] = None,storage_message: bool = True) -> bool:
async def send_image(
self,
image_base64: str,
set_reply: bool = False,
reply_message: Optional["DatabaseMessages"] = None,
storage_message: bool = True,
) -> bool:
"""发送图片
Args:
@ -207,7 +244,13 @@ class BaseCommand(ABC):
logger.error(f"{self.log_prefix} 缺少聊天流或stream_id")
return False
return await send_api.image_to_stream(image_base64, chat_stream.stream_id,set_reply=set_reply,reply_message=reply_message,storage_message=storage_message)
return await send_api.image_to_stream(
image_base64,
chat_stream.stream_id,
set_reply=set_reply,
reply_message=reply_message,
storage_message=storage_message,
)
@classmethod
def get_command_info(cls) -> "CommandInfo":

View File

@ -2,7 +2,7 @@ from typing import Dict, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.knowledge import qa_manager
from src.plugin_system import BaseTool, ToolParamType
logger = get_logger("lpmm_get_knowledge_tool")

View File

@ -1,20 +1,13 @@
import random
import json
from json_repair import repair_json
from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
# 导入依赖的系统组件
from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import json
from json_repair import repair_json
from src.plugin_system import BaseAction, ActionActivationType
from src.plugin_system.apis import llm_api
logger = get_logger("relation")
@ -39,10 +32,9 @@ def init_prompt():
{{
"category": "分类名称"
}} """,
"relation_category"
"relation_category",
)
Prompt(
"""
以下是有关{category}的现有记忆
@ -73,7 +65,7 @@ def init_prompt():
现在请你根据情况选出合适的修改方式并输出json不要输出其他内容
""",
"relation_category_update"
"relation_category_update",
)
@ -98,17 +90,14 @@ class BuildRelationAction(BaseAction):
"""
# 动作参数定义
action_parameters = {
"person_name":"需要了解或记忆的人的名称",
"impression":"需要了解的对某人的记忆或印象"
}
action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"}
# 动作使用场景
action_require = [
"了解对于某人的记忆,并添加到你对对方的印象中",
"对方与有明确提到有关其自身的事件",
"对方有提到其个人信息,包括喜好,身份,等等",
"对方希望你记住对方的信息"
"对方希望你记住对方的信息",
]
# 关联类型
@ -129,9 +118,7 @@ class BuildRelationAction(BaseAction):
if not person.is_known:
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
return False, f"用户 {person_name} 不存在,跳过添加记忆"
category_list = person.get_all_category()
if not category_list:
category_list_str = "无分类"
@ -142,9 +129,8 @@ class BuildRelationAction(BaseAction):
"relation_category",
category_list=category_list_str,
memory_point=impression,
person_name=person.person_name
person_name=person.person_name,
)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
@ -161,84 +147,76 @@ class BuildRelationAction(BaseAction):
success, category, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category"
)
category_data = json.loads(repair_json(category))
category = category_data.get("category", "")
if not category:
logger.warning(f"{self.log_prefix} LLM未给出分类跳过添加记忆")
return False, "LLM未给出分类跳过添加记忆"
# 第二部分:更新记忆
memory_list = person.get_memory_list_by_category(category)
if not memory_list:
logger.info(f"{self.log_prefix} {person.person_name}{category} 的记忆为空,进行创建")
person.memory_points.append(f"{category}:{impression}:1.0")
person.sync_to_database()
return True, f"未找到分类为{category}的记忆点,进行添加"
memory_list_str = ""
memory_list_id = {}
id = 1
for memory in memory_list:
for id, memory in enumerate(memory_list, start=1):
memory_content = get_memory_content_from_memory(memory)
memory_list_str += f"{id}. {memory_content}\n"
memory_list_id[id] = memory
id += 1
prompt = await global_prompt_manager.format_prompt(
"relation_category_update",
category=category,
memory_list=memory_list_str,
memory_point=impression,
person_name=person.person_name
person_name=person.person_name,
)
if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
else:
logger.debug(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
chat_model_config = models.get("utils")
chat_model_config = models.get("utils")
success, update_memory, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category.update"
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
)
update_memory_data = json.loads(repair_json(update_memory))
new_memory = update_memory_data.get("new_memory", "")
memory_id = update_memory_data.get("memory_id", "")
integrate_memory = update_memory_data.get("integrate_memory", "")
if new_memory:
# 新记忆
person.memory_points.append(f"{category}:{new_memory}:1.0")
person.sync_to_database()
return True, f"{person.person_name}新增记忆点: {new_memory}"
elif memory_id and integrate_memory:
# 现存或冲突记忆
memory = memory_list_id[memory_id]
memory_content = get_memory_content_from_memory(memory)
del_count = person.del_memory(category,memory_content)
del_count = person.del_memory(category, memory_content)
if del_count > 0:
logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}")
memory_weight = get_weight_from_memory(memory)
person.memory_points.append(f"{category}:{integrate_memory}:{memory_weight + 1.0}")
person.sync_to_database()
return True, f"更新{person.person_name}的记忆点: {memory_content} -> {integrate_memory}"
else:
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
return True, "关系动作执行成功"
@ -248,4 +226,4 @@ class BuildRelationAction(BaseAction):
# 还缺一个关系的太多遗忘和对应的提取
init_prompt()
init_prompt()

View File

@ -2,7 +2,7 @@ from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
from src.plugin_system.base.base_action import BaseAction, ActionActivationType
from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type

View File

@ -1,5 +1,5 @@
[inner]
version = "1.3.0"
version = "1.3.1"
# 配置文件版本号迭代规则同bot_config.toml