mirror of https://github.com/Mai-with-u/MaiBot.git
修复typing问题,保证类型正确
parent
ec500f1f5b
commit
a55979164e
|
|
@ -4,10 +4,9 @@ from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
|||
|
||||
|
||||
class FocusValueControl:
|
||||
def __init__(self,chat_id:str):
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.focus_value_adjust = 1
|
||||
|
||||
self.focus_value_adjust: float = 1
|
||||
|
||||
def get_current_focus_value(self) -> float:
|
||||
return get_current_focus_value(self.chat_id) * self.focus_value_adjust
|
||||
|
|
@ -15,15 +14,14 @@ class FocusValueControl:
|
|||
|
||||
class FocusValueControlManager:
|
||||
def __init__(self):
|
||||
self.focus_value_controls = {}
|
||||
self.focus_value_controls: dict[str, FocusValueControl] = {}
|
||||
|
||||
def get_focus_value_control(self,chat_id:str) -> FocusValueControl:
|
||||
def get_focus_value_control(self, chat_id: str) -> FocusValueControl:
|
||||
if chat_id not in self.focus_value_controls:
|
||||
self.focus_value_controls[chat_id] = FocusValueControl(chat_id)
|
||||
return self.focus_value_controls[chat_id]
|
||||
|
||||
|
||||
|
||||
def get_current_focus_value(chat_id: Optional[str] = None) -> float:
|
||||
"""
|
||||
根据当前时间和聊天流获取对应的 focus_value
|
||||
|
|
@ -42,6 +40,7 @@ def get_current_focus_value(chat_id: Optional[str] = None) -> float:
|
|||
|
||||
return global_config.chat.focus_value
|
||||
|
||||
|
||||
def get_stream_specific_focus_value(chat_id: str) -> Optional[float]:
|
||||
"""
|
||||
获取特定聊天流在当前时间的专注度
|
||||
|
|
@ -140,4 +139,5 @@ def get_global_focus_value() -> Optional[float]:
|
|||
|
||||
return None
|
||||
|
||||
|
||||
focus_value_control = FocusValueControlManager()
|
||||
|
|
@ -2,10 +2,11 @@ from typing import Optional
|
|||
from src.config.config import global_config
|
||||
from src.chat.frequency_control.utils import parse_stream_config_to_chat_id
|
||||
|
||||
|
||||
class TalkFrequencyControl:
|
||||
def __init__(self,chat_id:str):
|
||||
def __init__(self, chat_id: str):
|
||||
self.chat_id = chat_id
|
||||
self.talk_frequency_adjust = 1
|
||||
self.talk_frequency_adjust: float = 1
|
||||
|
||||
def get_current_talk_frequency(self) -> float:
|
||||
return get_current_talk_frequency(self.chat_id) * self.talk_frequency_adjust
|
||||
|
|
@ -15,7 +16,7 @@ class TalkFrequencyControlManager:
|
|||
def __init__(self):
|
||||
self.talk_frequency_controls = {}
|
||||
|
||||
def get_talk_frequency_control(self,chat_id:str) -> TalkFrequencyControl:
|
||||
def get_talk_frequency_control(self, chat_id: str) -> TalkFrequencyControl:
|
||||
if chat_id not in self.talk_frequency_controls:
|
||||
self.talk_frequency_controls[chat_id] = TalkFrequencyControl(chat_id)
|
||||
return self.talk_frequency_controls[chat_id]
|
||||
|
|
@ -44,6 +45,7 @@ def get_current_talk_frequency(chat_id: Optional[str] = None) -> float:
|
|||
global_frequency = get_global_frequency()
|
||||
return global_config.chat.talk_frequency if global_frequency is None else global_frequency
|
||||
|
||||
|
||||
def get_time_based_frequency(time_freq_list: list[str]) -> Optional[float]:
|
||||
"""
|
||||
根据时间配置列表获取当前时段的频率
|
||||
|
|
@ -124,6 +126,7 @@ def get_stream_specific_frequency(chat_stream_id: str):
|
|||
|
||||
return None
|
||||
|
||||
|
||||
def get_global_frequency() -> Optional[float]:
|
||||
"""
|
||||
获取全局默认频率配置
|
||||
|
|
@ -141,4 +144,5 @@ def get_global_frequency() -> Optional[float]:
|
|||
|
||||
return None
|
||||
|
||||
|
||||
talk_frequency_control = TalkFrequencyControlManager()
|
||||
|
|
@ -30,9 +30,7 @@ def cosine_similarity(v1, v2):
|
|||
dot_product = np.dot(v1, v2)
|
||||
norm1 = np.linalg.norm(v1)
|
||||
norm2 = np.linalg.norm(v2)
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0
|
||||
return dot_product / (norm1 * norm2)
|
||||
return 0 if norm1 == 0 or norm2 == 0 else dot_product / (norm1 * norm2)
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
|
@ -142,11 +140,10 @@ class MemoryGraph:
|
|||
# 获取当前节点的记忆项
|
||||
node_data = self.get_dot(topic)
|
||||
if node_data:
|
||||
concept, data = node_data
|
||||
_, data = node_data
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := data["memory_items"]:
|
||||
first_layer_items.append(memory_items)
|
||||
|
||||
# 只在depth=2时获取第二层记忆
|
||||
|
|
@ -154,11 +151,10 @@ class MemoryGraph:
|
|||
# 获取相邻节点的记忆项
|
||||
for neighbor in neighbors:
|
||||
if node_data := self.get_dot(neighbor):
|
||||
concept, data = node_data
|
||||
_, data = node_data
|
||||
if "memory_items" in data:
|
||||
memory_items = data["memory_items"]
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := data["memory_items"]:
|
||||
second_layer_items.append(memory_items)
|
||||
|
||||
return first_layer_items, second_layer_items
|
||||
|
|
@ -224,27 +220,17 @@ class MemoryGraph:
|
|||
# 获取话题节点数据
|
||||
node_data = self.G.nodes[topic]
|
||||
|
||||
# 删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
# 如果节点存在memory_items
|
||||
if "memory_items" in node_data:
|
||||
memory_items = node_data["memory_items"]
|
||||
|
||||
# 既然每个节点现在是一个完整的记忆内容,直接删除整个节点
|
||||
if memory_items:
|
||||
# 删除整个节点
|
||||
self.G.remove_node(topic)
|
||||
if memory_items := node_data["memory_items"]:
|
||||
return (
|
||||
f"删除了节点 {topic} 的完整记忆: {memory_items[:50]}..."
|
||||
if len(memory_items) > 50
|
||||
else f"删除了节点 {topic} 的完整记忆: {memory_items}"
|
||||
)
|
||||
else:
|
||||
# 如果没有记忆项,删除该节点
|
||||
self.G.remove_node(topic)
|
||||
return None
|
||||
else:
|
||||
# 如果没有memory_items字段,删除该节点
|
||||
self.G.remove_node(topic)
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
# 海马体
|
||||
|
|
@ -392,9 +378,8 @@ class Hippocampus:
|
|||
# 如果相似度超过阈值,获取该节点的记忆
|
||||
if similarity >= 0.3: # 可以调整这个阈值
|
||||
node_data = self.memory_graph.G.nodes[node]
|
||||
memory_items = node_data.get("memory_items", "")
|
||||
# 直接使用完整的记忆内容
|
||||
if memory_items:
|
||||
if memory_items := node_data.get("memory_items", ""):
|
||||
memories.append((node, memory_items, similarity))
|
||||
|
||||
# 按相似度降序排序
|
||||
|
|
@ -587,7 +572,7 @@ class Hippocampus:
|
|||
unique_memories = []
|
||||
for topic, memory_items, activation_value in all_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
memory = memory_items or ""
|
||||
if memory not in seen_memories:
|
||||
seen_memories.add(memory)
|
||||
unique_memories.append((topic, memory_items, activation_value))
|
||||
|
|
@ -599,7 +584,7 @@ class Hippocampus:
|
|||
result = []
|
||||
for topic, memory_items, _ in unique_memories:
|
||||
# memory_items现在是完整的字符串格式
|
||||
memory = memory_items if memory_items else ""
|
||||
memory = memory_items or ""
|
||||
result.append((topic, memory))
|
||||
logger.debug(f"选中记忆: {memory} (来自节点: {topic})")
|
||||
|
||||
|
|
@ -1471,6 +1456,7 @@ class MemoryBuilder:
|
|||
self.last_processed_time: float = 0.0
|
||||
|
||||
def should_trigger_memory_build(self) -> bool:
|
||||
# sourcery skip: assign-if-exp, boolean-if-exp-identity, reintroduce-else
|
||||
"""检查是否应该触发记忆构建"""
|
||||
current_time = time.time()
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ import asyncio
|
|||
import json
|
||||
import time
|
||||
import random
|
||||
import math
|
||||
|
||||
from json_repair import repair_json
|
||||
from typing import Union, Optional
|
||||
|
|
@ -16,6 +17,7 @@ from src.config.config import global_config, model_config
|
|||
|
||||
logger = get_logger("person_info")
|
||||
|
||||
|
||||
def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||
"""获取唯一id"""
|
||||
if "-" in platform:
|
||||
|
|
@ -24,6 +26,7 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
|||
key = "_".join(components)
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_person_id_by_person_name(person_name: str) -> str:
|
||||
"""根据用户名获取用户ID"""
|
||||
try:
|
||||
|
|
@ -33,7 +36,8 @@ def get_person_id_by_person_name(person_name: str) -> str:
|
|||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
||||
return ""
|
||||
|
||||
def is_person_known(person_id: str = None,user_id: str = None,platform: str = None,person_name: str = None) -> bool:
|
||||
|
||||
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
|
||||
if person_id:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
return person.is_known if person else False
|
||||
|
|
@ -49,41 +53,36 @@ def is_person_known(person_id: str = None,user_id: str = None,platform: str = No
|
|||
return False
|
||||
|
||||
|
||||
def get_catagory_from_memory(memory_point:str) -> str:
|
||||
def get_category_from_memory(memory_point: str) -> Optional[str]:
|
||||
"""从记忆点中获取分类"""
|
||||
# 按照最左边的:符号进行分割,返回分割后的第一个部分作为分类
|
||||
if not isinstance(memory_point, str):
|
||||
return None
|
||||
parts = memory_point.split(":", 1)
|
||||
if len(parts) > 1:
|
||||
return parts[0].strip()
|
||||
else:
|
||||
return None
|
||||
return parts[0].strip() if len(parts) > 1 else None
|
||||
|
||||
def get_weight_from_memory(memory_point:str) -> float:
|
||||
|
||||
def get_weight_from_memory(memory_point: str) -> float:
|
||||
"""从记忆点中获取权重"""
|
||||
# 按照最右边的:符号进行分割,返回分割后的最后一个部分作为权重
|
||||
if not isinstance(memory_point, str):
|
||||
return None
|
||||
return -math.inf
|
||||
parts = memory_point.rsplit(":", 1)
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
return float(parts[-1].strip())
|
||||
except Exception:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
if len(parts) <= 1:
|
||||
return -math.inf
|
||||
try:
|
||||
return float(parts[-1].strip())
|
||||
except Exception:
|
||||
return -math.inf
|
||||
|
||||
def get_memory_content_from_memory(memory_point:str) -> str:
|
||||
|
||||
def get_memory_content_from_memory(memory_point: str) -> str:
|
||||
"""从记忆点中获取记忆内容"""
|
||||
# 按:进行分割,去掉第一段和最后一段,返回中间部分作为记忆内容
|
||||
if not isinstance(memory_point, str):
|
||||
return None
|
||||
return ""
|
||||
parts = memory_point.split(":")
|
||||
if len(parts) > 2:
|
||||
return ":".join(parts[1:-1]).strip()
|
||||
else:
|
||||
return None
|
||||
return ":".join(parts[1:-1]).strip() if len(parts) > 2 else ""
|
||||
|
||||
|
||||
def calculate_string_similarity(s1: str, s2: str) -> float:
|
||||
|
|
@ -105,7 +104,6 @@ def calculate_string_similarity(s1: str, s2: str) -> float:
|
|||
|
||||
# 计算Levenshtein距离
|
||||
|
||||
|
||||
distance = levenshtein_distance(s1, s2)
|
||||
max_len = max(len(s1), len(s2))
|
||||
|
||||
|
|
@ -113,6 +111,7 @@ def calculate_string_similarity(s1: str, s2: str) -> float:
|
|||
similarity = 1 - (distance / max_len if max_len > 0 else 0)
|
||||
return similarity
|
||||
|
||||
|
||||
def levenshtein_distance(s1: str, s2: str) -> int:
|
||||
"""
|
||||
计算两个字符串的编辑距离
|
||||
|
|
@ -142,6 +141,7 @@ def levenshtein_distance(s1: str, s2: str) -> int:
|
|||
|
||||
return previous_row[-1]
|
||||
|
||||
|
||||
class Person:
|
||||
@classmethod
|
||||
def register_person(cls, platform: str, user_id: str, nickname: str):
|
||||
|
|
@ -212,7 +212,7 @@ class Person:
|
|||
|
||||
return person
|
||||
|
||||
def __init__(self, platform: str = "", user_id: str = "",person_id: str = "",person_name: str = ""):
|
||||
def __init__(self, platform: str = "", user_id: str = "", person_id: str = "", person_name: str = ""):
|
||||
if platform == global_config.bot.platform and user_id == global_config.bot.qq_account:
|
||||
self.is_known = True
|
||||
self.person_id = get_person_id(platform, user_id)
|
||||
|
|
@ -248,7 +248,6 @@ class Person:
|
|||
return
|
||||
# raise ValueError(f"用户 {platform}:{user_id}:{person_name}:{person_id} 尚未认识")
|
||||
|
||||
|
||||
self.is_known = False
|
||||
|
||||
# 初始化默认值
|
||||
|
|
@ -261,23 +260,23 @@ class Person:
|
|||
self.memory_points = []
|
||||
|
||||
# 初始化性格特征相关字段
|
||||
self.attitude_to_me:float = 0
|
||||
self.attitude_to_me_confidence:float = 1
|
||||
self.attitude_to_me: float = 0
|
||||
self.attitude_to_me_confidence: float = 1
|
||||
|
||||
self.neuroticism:float = 5
|
||||
self.neuroticism_confidence:float = 1
|
||||
self.neuroticism: float = 5
|
||||
self.neuroticism_confidence: float = 1
|
||||
|
||||
self.friendly_value:float = 50
|
||||
self.friendly_value_confidence:float = 1
|
||||
self.friendly_value: float = 50
|
||||
self.friendly_value_confidence: float = 1
|
||||
|
||||
self.rudeness:float = 50
|
||||
self.rudeness_confidence:float = 1
|
||||
self.rudeness: float = 50
|
||||
self.rudeness_confidence: float = 1
|
||||
|
||||
self.conscientiousness:float = 50
|
||||
self.conscientiousness_confidence:float = 1
|
||||
self.conscientiousness: float = 50
|
||||
self.conscientiousness_confidence: float = 1
|
||||
|
||||
self.likeness:float = 50
|
||||
self.likeness_confidence:float = 1
|
||||
self.likeness: float = 50
|
||||
self.likeness_confidence: float = 1
|
||||
|
||||
# 从数据库加载数据
|
||||
self.load_from_database()
|
||||
|
|
@ -340,30 +339,26 @@ class Person:
|
|||
|
||||
return deleted_count
|
||||
|
||||
|
||||
|
||||
|
||||
def get_all_category(self):
|
||||
category_list = []
|
||||
for memory in self.memory_points:
|
||||
if memory is None:
|
||||
continue
|
||||
category = get_catagory_from_memory(memory)
|
||||
category = get_category_from_memory(memory)
|
||||
if category and category not in category_list:
|
||||
category_list.append(category)
|
||||
return category_list
|
||||
|
||||
|
||||
def get_memory_list_by_category(self,category:str):
|
||||
def get_memory_list_by_category(self, category: str):
|
||||
memory_list = []
|
||||
for memory in self.memory_points:
|
||||
if memory is None:
|
||||
continue
|
||||
if get_catagory_from_memory(memory) == category:
|
||||
if get_category_from_memory(memory) == category:
|
||||
memory_list.append(memory)
|
||||
return memory_list
|
||||
|
||||
def get_random_memory_by_category(self,category:str,num:int=1):
|
||||
def get_random_memory_by_category(self, category: str, num: int = 1):
|
||||
memory_list = self.get_memory_list_by_category(category)
|
||||
if len(memory_list) < num:
|
||||
return memory_list
|
||||
|
|
@ -376,13 +371,13 @@ class Person:
|
|||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
||||
|
||||
if record:
|
||||
self.user_id = record.user_id if record.user_id else ""
|
||||
self.platform = record.platform if record.platform else ""
|
||||
self.is_known = record.is_known if record.is_known else False
|
||||
self.nickname = record.nickname if record.nickname else ""
|
||||
self.person_name = record.person_name if record.person_name else self.nickname
|
||||
self.name_reason = record.name_reason if record.name_reason else None
|
||||
self.know_times = record.know_times if record.know_times else 0
|
||||
self.user_id = record.user_id or ""
|
||||
self.platform = record.platform or ""
|
||||
self.is_known = record.is_known or False
|
||||
self.nickname = record.nickname or ""
|
||||
self.person_name = record.person_name or self.nickname
|
||||
self.name_reason = record.name_reason or None
|
||||
self.know_times = record.know_times or 0
|
||||
|
||||
# 处理points字段(JSON格式的列表)
|
||||
if record.memory_points:
|
||||
|
|
@ -452,29 +447,33 @@ class Person:
|
|||
try:
|
||||
# 准备数据
|
||||
data = {
|
||||
'person_id': self.person_id,
|
||||
'is_known': self.is_known,
|
||||
'platform': self.platform,
|
||||
'user_id': self.user_id,
|
||||
'nickname': self.nickname,
|
||||
'person_name': self.person_name,
|
||||
'name_reason': self.name_reason,
|
||||
'know_times': self.know_times,
|
||||
'know_since': self.know_since,
|
||||
'last_know': self.last_know,
|
||||
'memory_points': json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points else json.dumps([], ensure_ascii=False),
|
||||
'attitude_to_me': self.attitude_to_me,
|
||||
'attitude_to_me_confidence': self.attitude_to_me_confidence,
|
||||
'friendly_value': self.friendly_value,
|
||||
'friendly_value_confidence': self.friendly_value_confidence,
|
||||
'rudeness': self.rudeness,
|
||||
'rudeness_confidence': self.rudeness_confidence,
|
||||
'neuroticism': self.neuroticism,
|
||||
'neuroticism_confidence': self.neuroticism_confidence,
|
||||
'conscientiousness': self.conscientiousness,
|
||||
'conscientiousness_confidence': self.conscientiousness_confidence,
|
||||
'likeness': self.likeness,
|
||||
'likeness_confidence': self.likeness_confidence,
|
||||
"person_id": self.person_id,
|
||||
"is_known": self.is_known,
|
||||
"platform": self.platform,
|
||||
"user_id": self.user_id,
|
||||
"nickname": self.nickname,
|
||||
"person_name": self.person_name,
|
||||
"name_reason": self.name_reason,
|
||||
"know_times": self.know_times,
|
||||
"know_since": self.know_since,
|
||||
"last_know": self.last_know,
|
||||
"memory_points": json.dumps(
|
||||
[point for point in self.memory_points if point is not None], ensure_ascii=False
|
||||
)
|
||||
if self.memory_points
|
||||
else json.dumps([], ensure_ascii=False),
|
||||
"attitude_to_me": self.attitude_to_me,
|
||||
"attitude_to_me_confidence": self.attitude_to_me_confidence,
|
||||
"friendly_value": self.friendly_value,
|
||||
"friendly_value_confidence": self.friendly_value_confidence,
|
||||
"rudeness": self.rudeness,
|
||||
"rudeness_confidence": self.rudeness_confidence,
|
||||
"neuroticism": self.neuroticism,
|
||||
"neuroticism_confidence": self.neuroticism_confidence,
|
||||
"conscientiousness": self.conscientiousness,
|
||||
"conscientiousness_confidence": self.conscientiousness_confidence,
|
||||
"likeness": self.likeness,
|
||||
"likeness_confidence": self.likeness_confidence,
|
||||
}
|
||||
|
||||
# 检查记录是否存在
|
||||
|
|
@ -513,7 +512,6 @@ class Person:
|
|||
elif self.attitude_to_me > 5:
|
||||
attitude_info = f"{self.person_name}对你的态度较好,"
|
||||
|
||||
|
||||
if self.attitude_to_me < -8:
|
||||
attitude_info = f"{self.person_name}对你的态度十分恶劣,"
|
||||
elif self.attitude_to_me < -4:
|
||||
|
|
@ -537,7 +535,7 @@ class Person:
|
|||
points_text = ""
|
||||
category_list = self.get_all_category()
|
||||
for category in category_list:
|
||||
random_memory = self.get_random_memory_by_category(category,1)[0]
|
||||
random_memory = self.get_random_memory_by_category(category, 1)[0]
|
||||
if random_memory:
|
||||
points_text = f"有关 {category} 的记忆:{get_memory_content_from_memory(random_memory)}"
|
||||
break
|
||||
|
|
@ -555,7 +553,6 @@ class Person:
|
|||
|
||||
class PersonInfoManager:
|
||||
def __init__(self):
|
||||
|
||||
self.person_name_list = {}
|
||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||
try:
|
||||
|
|
@ -581,8 +578,6 @@ class PersonInfoManager:
|
|||
except Exception as e:
|
||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
|
|
|
|||
|
|
@ -1,20 +1,13 @@
|
|||
import random
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from typing import Tuple
|
||||
|
||||
# 导入新插件系统
|
||||
from src.plugin_system import BaseAction, ActionActivationType, ChatMode
|
||||
|
||||
# 导入依赖的系统组件
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 导入API模块 - 标准Python包方式
|
||||
from src.plugin_system.apis import emoji_api, llm_api, message_api
|
||||
# NoReplyAction已集成到heartFC_chat.py中,不再需要导入
|
||||
from src.config.config import global_config
|
||||
from src.person_info.person_info import Person, get_memory_content_from_memory, get_weight_from_memory
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
import json
|
||||
from json_repair import repair_json
|
||||
from src.plugin_system import BaseAction, ActionActivationType
|
||||
from src.plugin_system.apis import llm_api
|
||||
|
||||
|
||||
logger = get_logger("relation")
|
||||
|
|
@ -39,10 +32,9 @@ def init_prompt():
|
|||
{{
|
||||
"category": "分类名称"
|
||||
}} """,
|
||||
"relation_category"
|
||||
"relation_category",
|
||||
)
|
||||
|
||||
|
||||
Prompt(
|
||||
"""
|
||||
以下是有关{category}的现有记忆:
|
||||
|
|
@ -73,7 +65,7 @@ def init_prompt():
|
|||
|
||||
现在,请你根据情况选出合适的修改方式,并输出json,不要输出其他内容:
|
||||
""",
|
||||
"relation_category_update"
|
||||
"relation_category_update",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -98,17 +90,14 @@ class BuildRelationAction(BaseAction):
|
|||
"""
|
||||
|
||||
# 动作参数定义
|
||||
action_parameters = {
|
||||
"person_name":"需要了解或记忆的人的名称",
|
||||
"impression":"需要了解的对某人的记忆或印象"
|
||||
}
|
||||
action_parameters = {"person_name": "需要了解或记忆的人的名称", "impression": "需要了解的对某人的记忆或印象"}
|
||||
|
||||
# 动作使用场景
|
||||
action_require = [
|
||||
"了解对于某人的记忆,并添加到你对对方的印象中",
|
||||
"对方与有明确提到有关其自身的事件",
|
||||
"对方有提到其个人信息,包括喜好,身份,等等",
|
||||
"对方希望你记住对方的信息"
|
||||
"对方希望你记住对方的信息",
|
||||
]
|
||||
|
||||
# 关联类型
|
||||
|
|
@ -130,8 +119,6 @@ class BuildRelationAction(BaseAction):
|
|||
logger.warning(f"{self.log_prefix} 用户 {person_name} 不存在,跳过添加记忆")
|
||||
return False, f"用户 {person_name} 不存在,跳过添加记忆"
|
||||
|
||||
|
||||
|
||||
category_list = person.get_all_category()
|
||||
if not category_list:
|
||||
category_list_str = "无分类"
|
||||
|
|
@ -142,10 +129,9 @@ class BuildRelationAction(BaseAction):
|
|||
"relation_category",
|
||||
category_list=category_list_str,
|
||||
memory_point=impression,
|
||||
person_name=person.person_name
|
||||
person_name=person.person_name,
|
||||
)
|
||||
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
logger.info(f"{self.log_prefix} 生成的LLM Prompt: {prompt}")
|
||||
else:
|
||||
|
|
@ -162,15 +148,12 @@ class BuildRelationAction(BaseAction):
|
|||
prompt, model_config=chat_model_config, request_type="relation.category"
|
||||
)
|
||||
|
||||
|
||||
|
||||
category_data = json.loads(repair_json(category))
|
||||
category = category_data.get("category", "")
|
||||
if not category:
|
||||
logger.warning(f"{self.log_prefix} LLM未给出分类,跳过添加记忆")
|
||||
return False, "LLM未给出分类,跳过添加记忆"
|
||||
|
||||
|
||||
# 第二部分:更新记忆
|
||||
|
||||
memory_list = person.get_memory_list_by_category(category)
|
||||
|
|
@ -183,19 +166,16 @@ class BuildRelationAction(BaseAction):
|
|||
|
||||
memory_list_str = ""
|
||||
memory_list_id = {}
|
||||
id = 1
|
||||
for memory in memory_list:
|
||||
for id, memory in enumerate(memory_list, start=1):
|
||||
memory_content = get_memory_content_from_memory(memory)
|
||||
memory_list_str += f"{id}. {memory_content}\n"
|
||||
memory_list_id[id] = memory
|
||||
id += 1
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
"relation_category_update",
|
||||
category=category,
|
||||
memory_list=memory_list_str,
|
||||
memory_point=impression,
|
||||
person_name=person.person_name
|
||||
person_name=person.person_name,
|
||||
)
|
||||
|
||||
if global_config.debug.show_prompt:
|
||||
|
|
@ -205,7 +185,7 @@ class BuildRelationAction(BaseAction):
|
|||
|
||||
chat_model_config = models.get("utils")
|
||||
success, update_memory, _, _ = await llm_api.generate_with_model(
|
||||
prompt, model_config=chat_model_config, request_type="relation.category.update"
|
||||
prompt, model_config=chat_model_config, request_type="relation.category.update" # type: ignore
|
||||
)
|
||||
|
||||
update_memory_data = json.loads(repair_json(update_memory))
|
||||
|
|
@ -223,7 +203,7 @@ class BuildRelationAction(BaseAction):
|
|||
# 现存或冲突记忆
|
||||
memory = memory_list_id[memory_id]
|
||||
memory_content = get_memory_content_from_memory(memory)
|
||||
del_count = person.del_memory(category,memory_content)
|
||||
del_count = person.del_memory(category, memory_content)
|
||||
|
||||
if del_count > 0:
|
||||
logger.info(f"{self.log_prefix} 删除记忆点: {memory_content}")
|
||||
|
|
@ -238,8 +218,6 @@ class BuildRelationAction(BaseAction):
|
|||
logger.warning(f"{self.log_prefix} 删除记忆点失败: {memory_content}")
|
||||
return False, f"删除{person.person_name}的记忆点失败: {memory_content}"
|
||||
|
||||
|
||||
|
||||
return True, "关系动作执行成功"
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from src.plugin_system.apis.plugin_register_api import register_plugin
|
|||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
from src.plugin_system.base.component_types import ComponentInfo
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType, ChatMode
|
||||
from src.plugin_system.base.base_action import BaseAction, ActionActivationType
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
from typing import Tuple, List, Type
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue