修复typing问题,保证类型正确

pull/1211/head
UnCLAS-Prommer 2025-08-21 23:52:44 +08:00
parent ec500f1f5b
commit a55979164e
No known key found for this signature in database
6 changed files with 230 additions and 267 deletions

View File

@ -6,8 +6,7 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class FocusValueControl: class FocusValueControl:
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.chat_id = chat_id self.chat_id = chat_id
self.focus_value_adjust = 1 self.focus_value_adjust: float = 1
def get_current_focus_value(self) -> float: def get_current_focus_value(self) -> float:
return get_current_focus_value(self.chat_id) * self.focus_value_adjust return get_current_focus_value(self.chat_id) * self.focus_value_adjust
@ -15,7 +14,7 @@ class FocusValueControl:
class FocusValueControlManager: class FocusValueControlManager:
def __init__(self): def __init__(self):
self.focus_value_controls = {} self.focus_value_controls: dict[str, FocusValueControl] = {}
def get_focus_value_control(self, chat_id: str) -> FocusValueControl: def get_focus_value_control(self, chat_id: str) -> FocusValueControl:
if chat_id not in self.focus_value_controls: if chat_id not in self.focus_value_controls:
@ -23,7 +22,6 @@ class FocusValueControlManager:
return self.focus_value_controls[chat_id] return self.focus_value_controls[chat_id]
def get_current_focus_value(chat_id: Optional[str] = None) -> float: def get_current_focus_value(chat_id: Optional[str] = None) -> float:
""" """
根据当前时间和聊天流获取对应的 focus_value 根据当前时间和聊天流获取对应的 focus_value
@ -42,6 +40,7 @@ def get_current_focus_value(chat_id: Optional[str] = None) -> float:
return global_config.chat.focus_value return global_config.chat.focus_value
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]: def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
""" """
获取特定聊天流在当前时间的专注度 获取特定聊天流在当前时间的专注度
@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]:
return None return None
focus_value_control = FocusValueControlManager() focus_value_control = FocusValueControlManager()

View File

@ -2,10 +2,11 @@ from typing import Optional
from src.config.config import global_config from src.config.config import global_config
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
class TalkFrequencyControl: class TalkFrequencyControl:
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.chat_id = chat_id self.chat_id = chat_id
self.talk_frequency_adjust = 1 self.talk_frequency_adjust: float = 1
def get_current_talk_frequency(self) -> float: def get_current_talk_frequency(self) -> float:
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
global_frequency = get_global_frequency() global_frequency = get_global_frequency()
return global_config.chat.talk_frequency if global_frequency is None else global_frequency return global_config.chat.talk_frequency if global_frequency is None else global_frequency
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]: def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
""" """
根据时间配置列表获取当前时段的频率 根据时间配置列表获取当前时段的频率
@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str):
return None return None
def get_global_frequency() -> Optional[float]: def get_global_frequency() -> Optional[float]:
""" """
获取全局默认频率配置 获取全局默认频率配置
@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]:
return None return None
talk_frequency_control = TalkFrequencyControlManager() talk_frequency_control = TalkFrequencyControlManager()

View File

@ -30,9 +30,7 @@ def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
norm1 = np.linalg.norm(v1) norm1 = np.linalg.norm(v1)
norm2 = np.linalg.norm(v2) norm2 = np.linalg.norm(v2)
if norm1 == 0 or norm2 == 0: return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
return 0
return dot_product / (norm1 * norm2)
install(extra_lines=3) install(extra_lines=3)
@ -142,11 +140,10 @@ class MemoryGraph:
# 获取当前节点的记忆项 # 获取当前节点的记忆项
node_data = self.get_dot(topic) node_data = self.get_dot(topic)
if node_data: if node_data:
concept, data = node_data _, data = node_data
if "memory_items" in data: if "memory_items" in data:
memory_items = data["memory_items"]
# 直接使用完整的记忆内容 # 直接使用完整的记忆内容
if memory_items: if memory_items := data["memory_items"]:
first_layer_items.append(memory_items) first_layer_items.append(memory_items)
# 只在depth=2时获取第二层记忆 # 只在depth=2时获取第二层记忆
@ -154,11 +151,10 @@ class MemoryGraph:
# 获取相邻节点的记忆项 # 获取相邻节点的记忆项
for neighbor in neighbors: for neighbor in neighbors:
if node_data := self.get_dot(neighbor): if node_data := self.get_dot(neighbor):
concept, data = node_data _, data = node_data
if "memory_items" in data: if "memory_items" in data:
memory_items = data["memory_items"]
# 直接使用完整的记忆内容 # 直接使用完整的记忆内容
if memory_items: if memory_items := data["memory_items"]:
second_layer_items.append(memory_items) second_layer_items.append(memory_items)
return first_layer_items, second_layer_items return first_layer_items, second_layer_items
@ -224,26 +220,16 @@ class MemoryGraph:
# 获取话题节点数据 # 获取话题节点数据
node_data = self.G.nodes[topic] node_data = self.G.nodes[topic]
# 如果节点存在memory_items
if "memory_items" in node_data:
memory_items = node_data["memory_items"]
# 既然每个节点现在是一个完整的记忆内容,直接删除整个节点
if memory_items:
# 删除整个节点 # 删除整个节点
self.G.remove_node(topic) self.G.remove_node(topic)
# 如果节点存在memory_items
if "memory_items" in node_data:
if memory_items := node_data["memory_items"]:
return ( return (
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..." f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
if len(memory_items) > 50 if len(memory_items) > 50
else f"删除了节点 {topic} 的完整记忆: {memory_items}" else f"删除了节点 {topic} 的完整记忆: {memory_items}"
) )
else:
# 如果没有记忆项,删除该节点
self.G.remove_node(topic)
return None
else:
# 如果没有memory_items字段删除该节点
self.G.remove_node(topic)
return None return None
@ -392,9 +378,8 @@ class Hippocampus:
# 如果相似度超过阈值,获取该节点的记忆 # 如果相似度超过阈值,获取该节点的记忆
if similarity >= 0.3: # 可以调整这个阈值 if similarity >= 0.3: # 可以调整这个阈值
node_data = self.memory_graph.G.nodes[node] node_data = self.memory_graph.G.nodes[node]
memory_items = node_data.get("memory_items", "")
# 直接使用完整的记忆内容 # 直接使用完整的记忆内容
if memory_items: if memory_items := node_data.get("memory_items", ""):
memories.append((node, memory_items, similarity)) memories.append((node, memory_items, similarity))
# 按相似度降序排序 # 按相似度降序排序
@ -587,7 +572,7 @@ class Hippocampus:
unique_memories = [] unique_memories = []
for topic, memory_items, activation_value in all_memories: for topic, memory_items, activation_value in all_memories:
# memory_items现在是完整的字符串格式 # memory_items现在是完整的字符串格式
memory = memory_items if memory_items else "" memory = memory_items or ""
if memory not in seen_memories: if memory not in seen_memories:
seen_memories.add(memory) seen_memories.add(memory)
unique_memories.append((topic, memory_items, activation_value)) unique_memories.append((topic, memory_items, activation_value))
@ -599,7 +584,7 @@ class Hippocampus:
result = [] result = []
for topic, memory_items, _ in unique_memories: for topic, memory_items, _ in unique_memories:
# memory_items现在是完整的字符串格式 # memory_items现在是完整的字符串格式
memory = memory_items if memory_items else "" memory = memory_items or ""
result.append((topic, memory)) result.append((topic, memory))
logger.debug(f"选中记忆: {memory} (来自节点: {topic})") logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
@ -1471,6 +1456,7 @@ class MemoryBuilder:
self.last_processed_time: float = 0.0 self.last_processed_time: float = 0.0
def should_trigger_memory_build(self) -> bool: def should_trigger_memory_build(self) -> bool:
# sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else
"""检查是否应该触发记忆构建""" """检查是否应该触发记忆构建"""
current_time = time.time() current_time = time.time()

View File

@ -3,6 +3,7 @@ import asyncio
import json import json
import time import time
import random import random
import math
from json_repair import repair_json from json_repair import repair_json
from typing import Union, Optional from typing import Union, Optional
@ -16,6 +17,7 @@ from src.config.config import global_config, model_config
logger = get_logger("person_info") logger = get_logger("person_info")
def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id(platform: str, user_id: Union[int, str]) -> str:
"""获取唯一id""" """获取唯一id"""
if "-" in platform: if "-" in platform:
@ -24,6 +26,7 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
key = "_".join(components) key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest() return hashlib.md5(key.encode()).hexdigest()
def get_person_id_by_person_name(person_name: str) -> str: def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID""" """根据用户名获取用户ID"""
try: try:
@ -33,7 +36,8 @@ def get_person_id_by_person_name(person_name: str) -> str:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
return "" return ""
def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool:
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
if person_id: if person_id:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
return person.is_known if person else False return person.is_known if person else False
@ -49,41 +53,36 @@ def is_person_known(person_id: str = None,user_id: str = None,platform: str = No
return False return False
def get_catagory_from_memory(memory_point:str) -> str: def get_category_from_memory(memory_point: str) -> Optional[str]:
"""从记忆点中获取分类""" """从记忆点中获取分类"""
# 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类 # 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类
if not isinstance(memory_point, str): if not isinstance(memory_point, str):
return None return None
parts = memory_point.split(":", 1) parts = memory_point.split(":", 1)
if len(parts) > 1: return parts[0].strip() if len(parts) > 1 else None
return parts[0].strip()
else:
return None
def get_weight_from_memory(memory_point: str) -> float: def get_weight_from_memory(memory_point: str) -> float:
"""从记忆点中获取权重""" """从记忆点中获取权重"""
# 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重 # 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重
if not isinstance(memory_point, str): if not isinstance(memory_point, str):
return None return -math.inf
parts = memory_point.rsplit(":", 1) parts = memory_point.rsplit(":", 1)
if len(parts) > 1: if len(parts) <= 1:
return -math.inf
try: try:
return float(parts[-1].strip()) return float(parts[-1].strip())
except Exception: except Exception:
return None return -math.inf
else:
return None
def get_memory_content_from_memory(memory_point: str) -> str: def get_memory_content_from_memory(memory_point: str) -> str:
"""从记忆点中获取记忆内容""" """从记忆点中获取记忆内容"""
# 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容 # 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容
if not isinstance(memory_point, str): if not isinstance(memory_point, str):
return None return ""
parts = memory_point.split(":") parts = memory_point.split(":")
if len(parts) > 2: return ":".join(parts[1:-1]).strip() if len(parts) > 2 else ""
return ":".join(parts[1:-1]).strip()
else:
return None
def calculate_string_similarity(s1: str, s2: str) -> float: def calculate_string_similarity(s1: str, s2: str) -> float:
@ -105,7 +104,6 @@ def calculate_string_similarity(s1: str, s2: str) -> float:
# 计算Levenshtein距离 # 计算Levenshtein距离
distance = levenshtein_distance(s1, s2) distance = levenshtein_distance(s1, s2)
max_len = max(len(s1), len(s2)) max_len = max(len(s1), len(s2))
@ -113,6 +111,7 @@ def calculate_string_similarity(s1: str, s2: str) -> float:
similarity = 1 - (distance / max_len if max_len > 0 else 0) similarity = 1 - (distance / max_len if max_len > 0 else 0)
return similarity return similarity
def levenshtein_distance(s1: str, s2: str) -> int: def levenshtein_distance(s1: str, s2: str) -> int:
""" """
计算两个字符串的编辑距离 计算两个字符串的编辑距离
@ -142,6 +141,7 @@ def levenshtein_distance(s1: str, s2: str) -> int:
return previous_row[-1] return previous_row[-1]
class Person: class Person:
@classmethod @classmethod
def register_person(cls, platform: str, user_id: str, nickname: str): def register_person(cls, platform: str, user_id: str, nickname: str):
@ -248,7 +248,6 @@ class Person:
return return
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识") # raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
self.is_known = False self.is_known = False
# 初始化默认值 # 初始化默认值
@ -340,26 +339,22 @@ class Person:
return deleted_count return deleted_count
def get_all_category(self): def get_all_category(self):
category_list = [] category_list = []
for memory in self.memory_points: for memory in self.memory_points:
if memory is None: if memory is None:
continue continue
category = get_catagory_from_memory(memory) category = get_category_from_memory(memory)
if category and category not in category_list: if category and category not in category_list:
category_list.append(category) category_list.append(category)
return category_list return category_list
def get_memory_list_by_category(self, category: str): def get_memory_list_by_category(self, category: str):
memory_list = [] memory_list = []
for memory in self.memory_points: for memory in self.memory_points:
if memory is None: if memory is None:
continue continue
if get_catagory_from_memory(memory) == category: if get_category_from_memory(memory) == category:
memory_list.append(memory) memory_list.append(memory)
return memory_list return memory_list
@ -376,13 +371,13 @@ class Person:
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
if record: if record:
self.user_id = record.user_id if record.user_id else "" self.user_id = record.user_id or ""
self.platform = record.platform if record.platform else "" self.platform = record.platform or ""
self.is_known = record.is_known if record.is_known else False self.is_known = record.is_known or False
self.nickname = record.nickname if record.nickname else "" self.nickname = record.nickname or ""
self.person_name = record.person_name if record.person_name else self.nickname self.person_name = record.person_name or self.nickname
self.name_reason = record.name_reason if record.name_reason else None self.name_reason = record.name_reason or None
self.know_times = record.know_times if record.know_times else 0 self.know_times = record.know_times or 0
# 处理points字段JSON格式的列表 # 处理points字段JSON格式的列表
if record.memory_points: if record.memory_points:
@ -452,29 +447,33 @@ class Person:
try: try:
# 准备数据 # 准备数据
data = { data = {
'person_id': self.person_id, "person_id": self.person_id,
'is_known': self.is_known, "is_known": self.is_known,
'platform': self.platform, "platform": self.platform,
'user_id': self.user_id, "user_id": self.user_id,
'nickname': self.nickname, "nickname": self.nickname,
'person_name': self.person_name, "person_name": self.person_name,
'name_reason': self.name_reason, "name_reason": self.name_reason,
'know_times': self.know_times, "know_times": self.know_times,
'know_since': self.know_since, "know_since": self.know_since,
'last_know': self.last_know, "last_know": self.last_know,
'memory_points': json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points else json.dumps([], ensure_ascii=False), "memory_points": json.dumps(
'attitude_to_me': self.attitude_to_me, [point for point in self.memory_points if point is not None], ensure_ascii=False
'attitude_to_me_confidence': self.attitude_to_me_confidence, )
'friendly_value': self.friendly_value, if self.memory_points
'friendly_value_confidence': self.friendly_value_confidence, else json.dumps([], ensure_ascii=False),
'rudeness': self.rudeness, "attitude_to_me": self.attitude_to_me,
'rudeness_confidence': self.rudeness_confidence, "attitude_to_me_confidence": self.attitude_to_me_confidence,
'neuroticism': self.neuroticism, "friendly_value": self.friendly_value,
'neuroticism_confidence': self.neuroticism_confidence, "friendly_value_confidence": self.friendly_value_confidence,
'conscientiousness': self.conscientiousness, "rudeness": self.rudeness,
'conscientiousness_confidence': self.conscientiousness_confidence, "rudeness_confidence": self.rudeness_confidence,
'likeness': self.likeness, "neuroticism": self.neuroticism,
'likeness_confidence': self.likeness_confidence, "neuroticism_confidence": self.neuroticism_confidence,
"conscientiousness": self.conscientiousness,
"conscientiousness_confidence": self.conscientiousness_confidence,
"likeness": self.likeness,
"likeness_confidence": self.likeness_confidence,
} }
# 检查记录是否存在 # 检查记录是否存在
@ -513,7 +512,6 @@ class Person:
elif self.attitude_to_me > 5: elif self.attitude_to_me > 5:
attitude_info = f"{self.person_name}对你的态度较好," attitude_info = f"{self.person_name}对你的态度较好,"
if self.attitude_to_me < -8: if self.attitude_to_me < -8:
attitude_info = f"{self.person_name}对你的态度十分恶劣," attitude_info = f"{self.person_name}对你的态度十分恶劣,"
elif self.attitude_to_me < -4: elif self.attitude_to_me < -4:
@ -555,7 +553,6 @@ class Person:
class PersonInfoManager: class PersonInfoManager:
def __init__(self): def __init__(self):
self.person_name_list = {} self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try: try:
@ -581,8 +578,6 @@ class PersonInfoManager:
except Exception as e: except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
@staticmethod @staticmethod
def _extract_json_from_text(text: str) -> dict: def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""

View File

@ -1,20 +1,13 @@
import random import json
from json_repair import repair_json
from typing import Tuple from typing import Tuple
# 导入新插件系统
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
# 导入依赖的系统组件
from src.common.logger import get_logger from src.common.logger import get_logger
# 导入API模块 - 标准Python包方式
from src.plugin_system.apis import emoji_api, llm_api, message_api
# NoReplyAction已集成到heartFC_chat.py中不再需要导入
from src.config.config import global_config from src.config.config import global_config
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
import json from src.plugin_system import BaseAction, ActionActivationType
from json_repair import repair_json from src.plugin_system.apis import llm_api
logger = get_logger("relation") logger = get_logger("relation")
@ -39,10 +32,9 @@ def init_prompt():
{{ {{
"category": "分类名称" "category": "分类名称"
}} """, }} """,
"relation_category" "relation_category",
) )
Prompt( Prompt(
""" """
以下是有关{category}的现有记忆 以下是有关{category}的现有记忆
@ -73,7 +65,7 @@ def init_prompt():
现在请你根据情况选出合适的修改方式并输出json不要输出其他内容 现在请你根据情况选出合适的修改方式并输出json不要输出其他内容
""", """,
"relation_category_update" "relation_category_update",
) )
@ -98,17 +90,14 @@ class BuildRelationAction(BaseAction):
""" """
# 动作参数定义 # 动作参数定义
action_parameters = { action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"}
"person_name":"需要了解或记忆的人的名称",
"impression":"需要了解的对某人的记忆或印象"
}
# 动作使用场景 # 动作使用场景
action_require = [ action_require = [
"了解对于某人的记忆,并添加到你对对方的印象中", "了解对于某人的记忆,并添加到你对对方的印象中",
"对方与有明确提到有关其自身的事件", "对方与有明确提到有关其自身的事件",
"对方有提到其个人信息,包括喜好,身份,等等", "对方有提到其个人信息,包括喜好,身份,等等",
"对方希望你记住对方的信息" "对方希望你记住对方的信息",
] ]
# 关联类型 # 关联类型
@ -130,8 +119,6 @@ class BuildRelationAction(BaseAction):
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆") logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
return False, f"用户 {person_name} 不存在,跳过添加记忆" return False, f"用户 {person_name} 不存在,跳过添加记忆"
category_list = person.get_all_category() category_list = person.get_all_category()
if not category_list: if not category_list:
category_list_str = "无分类" category_list_str = "无分类"
@ -142,10 +129,9 @@ class BuildRelationAction(BaseAction):
"relation_category", "relation_category",
category_list=category_list_str, category_list=category_list_str,
memory_point=impression, memory_point=impression,
person_name=person.person_name person_name=person.person_name,
) )
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
else: else:
@ -162,15 +148,12 @@ class BuildRelationAction(BaseAction):
prompt, model_config=chat_model_config, request_type="relation.category" prompt, model_config=chat_model_config, request_type="relation.category"
) )
category_data = json.loads(repair_json(category)) category_data = json.loads(repair_json(category))
category = category_data.get("category", "") category = category_data.get("category", "")
if not category: if not category:
logger.warning(f"{self.log_prefix} LLM未给出分类跳过添加记忆") logger.warning(f"{self.log_prefix} LLM未给出分类跳过添加记忆")
return False, "LLM未给出分类跳过添加记忆" return False, "LLM未给出分类跳过添加记忆"
# 第二部分:更新记忆 # 第二部分:更新记忆
memory_list = person.get_memory_list_by_category(category) memory_list = person.get_memory_list_by_category(category)
@ -183,19 +166,16 @@ class BuildRelationAction(BaseAction):
memory_list_str = "" memory_list_str = ""
memory_list_id = {} memory_list_id = {}
id = 1 for id, memory in enumerate(memory_list, start=1):
for memory in memory_list:
memory_content = get_memory_content_from_memory(memory) memory_content = get_memory_content_from_memory(memory)
memory_list_str += f"{id}. {memory_content}\n" memory_list_str += f"{id}. {memory_content}\n"
memory_list_id[id] = memory memory_list_id[id] = memory
id += 1
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
"relation_category_update", "relation_category_update",
category=category, category=category,
memory_list=memory_list_str, memory_list=memory_list_str,
memory_point=impression, memory_point=impression,
person_name=person.person_name person_name=person.person_name,
) )
if global_config.debug.show_prompt: if global_config.debug.show_prompt:
@ -205,7 +185,7 @@ class BuildRelationAction(BaseAction):
chat_model_config = models.get("utils") chat_model_config = models.get("utils")
success, update_memory, _, _ = await llm_api.generate_with_model( success, update_memory, _, _ = await llm_api.generate_with_model(
prompt, model_config=chat_model_config, request_type="relation.category.update" prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
) )
update_memory_data = json.loads(repair_json(update_memory)) update_memory_data = json.loads(repair_json(update_memory))
@ -238,8 +218,6 @@ class BuildRelationAction(BaseAction):
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}") logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
return False, f"删除{person.person_name}的记忆点失败: {memory_content}" return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
return True, "关系动作执行成功" return True, "关系动作执行成功"
except Exception as e: except Exception as e:

View File

@ -2,7 +2,7 @@ from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo from src.plugin_system.base.component_types import ComponentInfo
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode from src.plugin_system.base.base_action import BaseAction, ActionActivationType
from src.plugin_system.base.config_types import ConfigField from src.plugin_system.base.config_types import ConfigField
from typing import Tuple, List, Type from typing import Tuple, List, Type