mirror of https://github.com/Mai-with-u/MaiBot.git
feat:为表达方式记录更多信息
parent
e79da24c23
commit
5cc1e56904
|
|
@ -3,16 +3,18 @@ import random
|
|||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jieba
|
||||
from typing import List, Dict, Optional, Any, Tuple
|
||||
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config, global_config
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
|
||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages, build_bare_messages
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from json_repair import repair_json
|
||||
from src.chat.utils.utils import get_embedding
|
||||
|
||||
|
||||
MAX_EXPRESSION_COUNT = 300
|
||||
|
|
@ -62,7 +64,7 @@ def init_prompt() -> None:
|
|||
**从聊天内容总结的表达方式pairs**
|
||||
{expression_pairs}
|
||||
|
||||
请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果。
|
||||
请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果,expression_pair不能有重复,每个expression_pair仅输出一个最合适的context。
|
||||
如果找不到原句,就不输出该句的匹配结果。
|
||||
以json格式输出:
|
||||
格式如下:
|
||||
|
|
@ -86,6 +88,9 @@ class ExpressionLearner:
|
|||
self.express_learn_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
|
||||
)
|
||||
self.embedding_model: LLMRequest = LLMRequest(
|
||||
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
|
||||
)
|
||||
self.chat_id = chat_id
|
||||
self.chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
|
||||
|
|
@ -94,9 +99,9 @@ class ExpressionLearner:
|
|||
self.last_learning_time: float = time.time()
|
||||
|
||||
# 学习参数
|
||||
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
|
||||
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
|
||||
self.min_learning_interval = 300 / self.learning_intensity
|
||||
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
|
||||
self.min_learning_interval = 150 / self.learning_intensity
|
||||
|
||||
def should_trigger_learning(self) -> bool:
|
||||
"""
|
||||
|
|
@ -160,6 +165,7 @@ class ExpressionLearner:
|
|||
|
||||
except Exception as e:
|
||||
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def _apply_global_decay_to_database(self, current_time: float) -> None:
|
||||
|
|
@ -229,19 +235,19 @@ class ExpressionLearner:
|
|||
if res is None:
|
||||
logger.info("没有学习到表达风格")
|
||||
return []
|
||||
learnt_expressions, chat_id = res
|
||||
learnt_expressions = res
|
||||
learnt_expressions_str = ""
|
||||
for _chat_id, situation, style in learnt_expressions:
|
||||
for _chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
||||
# 按chat_id分组
|
||||
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
|
||||
for chat_id, situation, style in learnt_expressions:
|
||||
for chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
|
||||
if chat_id not in chat_dict:
|
||||
chat_dict[chat_id] = []
|
||||
chat_dict[chat_id].append({"situation": situation, "style": style})
|
||||
chat_dict[chat_id].append({"situation": situation, "style": style, "context": context, "context_words": context_words, "full_context": full_context, "full_context_embedding": full_context_embedding})
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
|
|
@ -261,6 +267,10 @@ class ExpressionLearner:
|
|||
if random.random() < 0.5:
|
||||
expr_obj.situation = new_expr["situation"]
|
||||
expr_obj.style = new_expr["style"]
|
||||
expr_obj.context = new_expr["context"]
|
||||
expr_obj.context_words = new_expr["context_words"]
|
||||
expr_obj.full_context = new_expr["full_context"]
|
||||
expr_obj.full_context_embedding = new_expr["full_context_embedding"]
|
||||
expr_obj.count = expr_obj.count + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.save()
|
||||
|
|
@ -273,6 +283,10 @@ class ExpressionLearner:
|
|||
chat_id=chat_id,
|
||||
type="style",
|
||||
create_date=current_time, # 手动设置创建日期
|
||||
context=new_expr["context"],
|
||||
context_words=new_expr["context_words"],
|
||||
full_context=new_expr["full_context"],
|
||||
full_context_embedding=new_expr["full_context_embedding"],
|
||||
)
|
||||
# 限制最大数量
|
||||
exprs = list(
|
||||
|
|
@ -287,28 +301,100 @@ class ExpressionLearner:
|
|||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str) -> List[Tuple[str, str, str]]:
|
||||
# 为expression_pairs逐个条目赋予编号,并构建成字符串
|
||||
numbered_pairs = []
|
||||
for i, (situation, style) in enumerate(expression_pairs, 1):
|
||||
numbered_pairs.append(f"{i}. 当\"{situation}\"时,使用\"{style}\"")
|
||||
|
||||
expression_pairs_str = "\n".join(numbered_pairs)
|
||||
|
||||
prompt = "match_expression_context_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
expression_pairs=expression_pairs,
|
||||
expression_pairs=expression_pairs_str,
|
||||
chat_str=random_msg_match_str,
|
||||
)
|
||||
|
||||
match_responses = []
|
||||
# 解析所有match结果到 match_response
|
||||
|
||||
matched_expressions = []
|
||||
for match_response in match_responses:
|
||||
#exp序号
|
||||
match_response["expression_pair"]
|
||||
matched_expressions.append((match_response["expression_pair"], match_response["context"]))
|
||||
|
||||
|
||||
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
|
||||
print(f"match_expression_context_prompt: {prompt}")
|
||||
print(f"random_msg_match_str: {response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
try:
|
||||
response = response.strip()
|
||||
# 检查是否已经是标准JSON数组格式
|
||||
if response.startswith('[') and response.endswith(']'):
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 尝试直接解析多个JSON对象
|
||||
try:
|
||||
# 如果是多个JSON对象用逗号分隔,包装成数组
|
||||
if response.startswith('{') and not response.startswith('['):
|
||||
response = '[' + response + ']'
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 使用repair_json处理响应
|
||||
repaired_content = repair_json(response)
|
||||
|
||||
# 确保repaired_content是列表格式
|
||||
if isinstance(repaired_content, str):
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [parsed_data]
|
||||
elif isinstance(parsed_data, list):
|
||||
match_responses = parsed_data
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
match_responses = []
|
||||
elif isinstance(repaired_content, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [repaired_content]
|
||||
elif isinstance(repaired_content, list):
|
||||
match_responses = repaired_content
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
# 如果还是失败,尝试repair_json
|
||||
repaired_content = repair_json(response)
|
||||
if isinstance(repaired_content, str):
|
||||
parsed_data = json.loads(repaired_content)
|
||||
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
|
||||
else:
|
||||
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
|
||||
return []
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
# 获取表达方式序号
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
|
||||
situation, style = expression_pairs[pair_index]
|
||||
context = match_response["context"]
|
||||
matched_expressions.append((situation, style, context))
|
||||
used_pair_indices.add(pair_index) # 标记该索引已使用
|
||||
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
|
||||
elif pair_index in used_pair_indices:
|
||||
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
|
||||
except (ValueError, KeyError, IndexError) as e:
|
||||
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
|
||||
continue
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
|
|
@ -352,16 +438,44 @@ class ExpressionLearner:
|
|||
|
||||
logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(expressions, random_msg_match_str)
|
||||
|
||||
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(matched_expressions)
|
||||
|
||||
split_matched_expressions_w_emb = []
|
||||
full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str)
|
||||
|
||||
for situation, style, context, context_words in split_matched_expressions:
|
||||
split_matched_expressions_w_emb.append((self.chat_id, situation, style, context, context_words, random_msg_match_str,full_context_embedding))
|
||||
|
||||
|
||||
matched_expressions = await self.match_expression_context(expressions, random_msg_match_str)
|
||||
return split_matched_expressions_w_emb
|
||||
|
||||
async def get_full_context_embedding(self, context: str) -> List[float]:
|
||||
embedding, _ = await self.embedding_model.get_embedding(context)
|
||||
return embedding
|
||||
|
||||
def split_expression_context(self, matched_expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str, List[str]]]:
|
||||
"""
|
||||
对matched_expressions中的context部分进行jieba分词
|
||||
|
||||
Args:
|
||||
matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context)
|
||||
|
||||
Returns:
|
||||
添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words)
|
||||
"""
|
||||
result = []
|
||||
for situation, style, context in matched_expressions:
|
||||
# 使用jieba进行分词
|
||||
context_words = list(jieba.cut(context))
|
||||
result.append((situation, style, context, context_words))
|
||||
|
||||
return matched_expressions, chat_id
|
||||
return result
|
||||
|
||||
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
"""
|
||||
|
|
@ -388,7 +502,7 @@ class ExpressionLearner:
|
|||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
expressions.append((chat_id, situation, style))
|
||||
expressions.append((situation, style))
|
||||
return expressions
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -412,6 +412,9 @@ class HeartFChatting:
|
|||
},
|
||||
}
|
||||
reply_text = action_reply_text
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
end_time = time.time()
|
||||
if end_time - start_time < global_config.chat.planner_smooth:
|
||||
|
|
@ -421,8 +424,7 @@ class HeartFChatting:
|
|||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
self.end_cycle(loop_info, cycle_timers)
|
||||
self.print_cycle_info(cycle_timers)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -417,12 +417,6 @@ def _build_readable_messages_internal(
|
|||
timestamp = message.time
|
||||
content = message.display_message or message.processed_plain_text or ""
|
||||
|
||||
# 向下兼容
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
if show_pic:
|
||||
content = process_pic_ids(content)
|
||||
|
|
@ -862,16 +856,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
|
|||
user_id = msg.user_info.user_id
|
||||
content = msg.display_message or msg.processed_plain_text or ""
|
||||
|
||||
if "ᶠ" in content:
|
||||
content = content.replace("ᶠ", "")
|
||||
if "ⁿ" in content:
|
||||
content = content.replace("ⁿ", "")
|
||||
|
||||
# 处理图片ID
|
||||
content = process_pic_ids(content)
|
||||
|
||||
# if not all([platform, user_id, timestamp is not None]):
|
||||
# continue
|
||||
|
||||
anon_name = get_anon_name(platform, user_id)
|
||||
# print(f"anon_name:{anon_name}")
|
||||
|
|
@ -937,3 +924,44 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
|||
person_ids_set.add(person_id)
|
||||
|
||||
return list(person_ids_set) # 将集合转换为列表返回
|
||||
|
||||
|
||||
async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
|
||||
"""
|
||||
构建简化版消息字符串,只包含processed_plain_text内容,不考虑用户名和时间戳
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
只包含消息内容的字符串
|
||||
"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
output_lines = []
|
||||
|
||||
for msg in messages:
|
||||
# 获取纯文本内容
|
||||
content = msg.processed_plain_text or ""
|
||||
|
||||
|
||||
# 处理图片ID
|
||||
pic_pattern = r"\[picid:[^\]]+\]"
|
||||
def replace_pic_id(match):
|
||||
return "[图片]"
|
||||
content = re.sub(pic_pattern, replace_pic_id, content)
|
||||
|
||||
# 处理用户引用格式,移除回复和@标记
|
||||
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
|
||||
content = re.sub(reply_pattern, "回复[某人]", content)
|
||||
|
||||
at_pattern = r"@<[^:<>]+:[^:<>]+>"
|
||||
content = re.sub(at_pattern, "@[某人]", content)
|
||||
|
||||
# 清理并添加到输出
|
||||
content = content.strip()
|
||||
if content:
|
||||
output_lines.append(content)
|
||||
|
||||
return "\n".join(output_lines)
|
||||
|
|
|
|||
Loading…
Reference in New Issue