feat:为表达方式记录更多信息

pull/1266/head
SengokuCola 2025-09-26 13:36:35 +08:00
parent e79da24c23
commit 5cc1e56904
3 changed files with 188 additions and 44 deletions

View File

@ -3,16 +3,18 @@ import random
import json
import os
from datetime import datetime
import jieba
from typing import List, Dict, Optional, Any, Tuple
import traceback
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages, build_bare_messages
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from json_repair import repair_json
from src.chat.utils.utils import get_embedding
MAX_EXPRESSION_COUNT = 300
@ -62,7 +64,7 @@ def init_prompt() -> None:
**从聊天内容总结的表达方式pairs**
{expression_pairs}
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果expression_pair不能有重复每个expression_pair仅输出一个最合适的context
如果找不到原句就不输出该句的匹配结果
以json格式输出
格式如下
@ -86,6 +88,9 @@ class ExpressionLearner:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.replyer, request_type="expression.learner"
)
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
@ -94,9 +99,9 @@ class ExpressionLearner:
self.last_learning_time: float = time.time()
# 学习参数
self.min_messages_for_learning = 25 # 触发学习所需的最少消息数
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id)
self.min_learning_interval = 300 / self.learning_intensity
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 150 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
@ -160,6 +165,7 @@ class ExpressionLearner:
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return False
def _apply_global_decay_to_database(self, current_time: float) -> None:
@ -229,19 +235,19 @@ class ExpressionLearner:
if res is None:
logger.info("没有学习到表达风格")
return []
learnt_expressions, chat_id = res
learnt_expressions = res
learnt_expressions_str = ""
for _chat_id, situation, style in learnt_expressions:
for _chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
# 按chat_id分组
chat_dict: Dict[str, List[Dict[str, Any]]] = {}
for chat_id, situation, style in learnt_expressions:
for chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions:
if chat_id not in chat_dict:
chat_dict[chat_id] = []
chat_dict[chat_id].append({"situation": situation, "style": style})
chat_dict[chat_id].append({"situation": situation, "style": style, "context": context, "context_words": context_words, "full_context": full_context, "full_context_embedding": full_context_embedding})
current_time = time.time()
@ -261,6 +267,10 @@ class ExpressionLearner:
if random.random() < 0.5:
expr_obj.situation = new_expr["situation"]
expr_obj.style = new_expr["style"]
expr_obj.context = new_expr["context"]
expr_obj.context_words = new_expr["context_words"]
expr_obj.full_context = new_expr["full_context"]
expr_obj.full_context_embedding = new_expr["full_context_embedding"]
expr_obj.count = expr_obj.count + 1
expr_obj.last_active_time = current_time
expr_obj.save()
@ -273,6 +283,10 @@ class ExpressionLearner:
chat_id=chat_id,
type="style",
create_date=current_time, # 手动设置创建日期
context=new_expr["context"],
context_words=new_expr["context_words"],
full_context=new_expr["full_context"],
full_context_embedding=new_expr["full_context_embedding"],
)
# 限制最大数量
exprs = list(
@ -287,28 +301,100 @@ class ExpressionLearner:
return learnt_expressions
async def match_expression_context(self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f"{i}. 当\"{situation}\"时,使用\"{style}\"")
expression_pairs_str = "\n".join(numbered_pairs)
prompt = "match_expression_context_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt,
expression_pairs=expression_pairs,
expression_pairs=expression_pairs_str,
chat_str=random_msg_match_str,
)
match_responses = []
# 解析所有match结果到 match_response
matched_expressions = []
for match_response in match_responses:
#exp序号
match_response["expression_pair"]
matched_expressions.append((match_response["expression_pair"], match_response["context"]))
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]:
print(f"match_expression_context_prompt: {prompt}")
print(f"random_msg_match_str: {response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith('[') and response.endswith(']'):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith('{') and not response.startswith('['):
response = '[' + response + ']'
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
repaired_content = repair_json(response)
# 确保repaired_content是列表格式
if isinstance(repaired_content, str):
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, dict):
# 如果是字典,包装成列表
match_responses = [parsed_data]
elif isinstance(parsed_data, list):
match_responses = parsed_data
else:
match_responses = []
except json.JSONDecodeError:
match_responses = []
elif isinstance(repaired_content, dict):
# 如果是字典,包装成列表
match_responses = [repaired_content]
elif isinstance(repaired_content, list):
match_responses = repaired_content
else:
match_responses = []
except json.JSONDecodeError:
# 如果还是失败尝试repair_json
repaired_content = repair_json(response)
if isinstance(repaired_content, str):
parsed_data = json.loads(repaired_content)
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
else:
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
for match_response in match_responses:
try:
# 获取表达方式序号
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response["context"]
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
return matched_expressions
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]:
"""从指定聊天流学习表达方式
Args:
@ -352,16 +438,44 @@ class ExpressionLearner:
logger.debug(f"学习{type_str}的response: {response}")
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id)
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(expressions, random_msg_match_str)
split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(matched_expressions)
split_matched_expressions_w_emb = []
full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str)
for situation, style, context, context_words in split_matched_expressions:
split_matched_expressions_w_emb.append((self.chat_id, situation, style, context, context_words, random_msg_match_str,full_context_embedding))
matched_expressions = await self.match_expression_context(expressions, random_msg_match_str)
return split_matched_expressions_w_emb
async def get_full_context_embedding(self, context: str) -> List[float]:
embedding, _ = await self.embedding_model.get_embedding(context)
return embedding
def split_expression_context(self, matched_expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str, List[str]]]:
"""
对matched_expressions中的context部分进行jieba分词
Args:
matched_expressions: 匹配到的表达方式列表每个元素为(situation, style, context)
Returns:
添加了分词结果的表达方式列表每个元素为(situation, style, context, context_words)
"""
result = []
for situation, style, context in matched_expressions:
# 使用jieba进行分词
context_words = list(jieba.cut(context))
result.append((situation, style, context, context_words))
return matched_expressions, chat_id
return result
def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]:
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组
"""
@ -388,7 +502,7 @@ class ExpressionLearner:
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((chat_id, situation, style))
expressions.append((situation, style))
return expressions

View File

@ -412,6 +412,9 @@ class HeartFChatting:
},
}
reply_text = action_reply_text
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)
end_time = time.time()
if end_time - start_time < global_config.chat.planner_smooth:
@ -421,8 +424,7 @@ class HeartFChatting:
await asyncio.sleep(0.1)
self.end_cycle(loop_info, cycle_timers)
self.print_cycle_info(cycle_timers)

View File

@ -417,12 +417,6 @@ def _build_readable_messages_internal(
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
# 向下兼容
if "" in content:
content = content.replace("", "")
if "" in content:
content = content.replace("", "")
# 处理图片ID
if show_pic:
content = process_pic_ids(content)
@ -862,16 +856,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str:
user_id = msg.user_info.user_id
content = msg.display_message or msg.processed_plain_text or ""
if "" in content:
content = content.replace("", "")
if "" in content:
content = content.replace("", "")
# 处理图片ID
content = process_pic_ids(content)
# if not all([platform, user_id, timestamp is not None]):
# continue
anon_name = get_anon_name(platform, user_id)
# print(f"anon_name:{anon_name}")
@ -937,3 +924,44 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set.add(person_id)
return list(person_ids_set) # 将集合转换为列表返回
async def build_bare_messages(messages: List[DatabaseMessages]) -> str:
"""
构建简化版消息字符串只包含processed_plain_text内容不考虑用户名和时间戳
Args:
messages: 消息列表
Returns:
只包含消息内容的字符串
"""
if not messages:
return ""
output_lines = []
for msg in messages:
# 获取纯文本内容
content = msg.processed_plain_text or ""
# 处理图片ID
pic_pattern = r"\[picid:[^\]]+\]"
def replace_pic_id(match):
return "[图片]"
content = re.sub(pic_pattern, replace_pic_id, content)
# 处理用户引用格式,移除回复和@标记
reply_pattern = r"回复<[^:<>]+:[^:<>]+>"
content = re.sub(reply_pattern, "回复[某人]", content)
at_pattern = r"@<[^:<>]+:[^:<>]+>"
content = re.sub(at_pattern, "@[某人]", content)
# 清理并添加到输出
content = content.strip()
if content:
output_lines.append(content)
return "\n".join(output_lines)