diff --git a/src/chat/express/expression_learner.py b/src/chat/express/expression_learner.py index 528dc15e..79fa433e 100644 --- a/src/chat/express/expression_learner.py +++ b/src/chat/express/expression_learner.py @@ -3,16 +3,18 @@ import random import json import os from datetime import datetime - +import jieba from typing import List, Dict, Optional, Any, Tuple - +import traceback from src.common.logger import get_logger from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config -from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages +from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages, build_bare_messages from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager +from json_repair import repair_json +from src.chat.utils.utils import get_embedding MAX_EXPRESSION_COUNT = 300 @@ -62,7 +64,7 @@ def init_prompt() -> None: **从聊天内容总结的表达方式pairs** {expression_pairs} -请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果。 +请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果,expression_pair不能有重复,每个expression_pair仅输出一个最合适的context。 如果找不到原句,就不输出该句的匹配结果。 以json格式输出: 格式如下: @@ -86,6 +88,9 @@ class ExpressionLearner: self.express_learn_model: LLMRequest = LLMRequest( model_set=model_config.model_task_config.replyer, request_type="expression.learner" ) + self.embedding_model: LLMRequest = LLMRequest( + model_set=model_config.model_task_config.embedding, request_type="expression.embedding" + ) self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id @@ -94,9 +99,9 @@ class ExpressionLearner: self.last_learning_time: float = time.time() # 学习参数 - self.min_messages_for_learning = 25 # 触发学习所需的最少消息数 _, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(self.chat_id) - self.min_learning_interval = 300 / self.learning_intensity + self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数 + self.min_learning_interval = 150 / self.learning_intensity def should_trigger_learning(self) -> bool: """ @@ -160,6 +165,7 @@ class ExpressionLearner: except Exception as e: logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") + traceback.print_exc() return False def _apply_global_decay_to_database(self, current_time: float) -> None: @@ -229,19 +235,19 @@ class ExpressionLearner: if res is None: logger.info("没有学习到表达风格") return [] - learnt_expressions, chat_id = res + learnt_expressions = res learnt_expressions_str = "" - for _chat_id, situation, style in learnt_expressions: + for _chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions: learnt_expressions_str += f"{situation}->{style}\n" logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") # 按chat_id分组 chat_dict: Dict[str, List[Dict[str, Any]]] = {} - for chat_id, situation, style in learnt_expressions: + for chat_id, situation, style, context, context_words, full_context, full_context_embedding in learnt_expressions: if chat_id not in chat_dict: chat_dict[chat_id] = [] - chat_dict[chat_id].append({"situation": situation, "style": style}) + chat_dict[chat_id].append({"situation": situation, "style": style, "context": context, "context_words": context_words, "full_context": full_context, "full_context_embedding": full_context_embedding}) current_time = time.time() @@ -261,6 +267,10 @@ class ExpressionLearner: if random.random() < 0.5: expr_obj.situation = new_expr["situation"] expr_obj.style = new_expr["style"] + expr_obj.context = new_expr["context"] + expr_obj.context_words = new_expr["context_words"] + expr_obj.full_context = new_expr["full_context"] + expr_obj.full_context_embedding = new_expr["full_context_embedding"] expr_obj.count = expr_obj.count + 1 expr_obj.last_active_time = current_time expr_obj.save() @@ -273,6 +283,10 @@ class ExpressionLearner: chat_id=chat_id, type="style", create_date=current_time, # 手动设置创建日期 + context=new_expr["context"], + context_words=new_expr["context_words"], + full_context=new_expr["full_context"], + full_context_embedding=new_expr["full_context_embedding"], ) # 限制最大数量 exprs = list( @@ -287,28 +301,100 @@ class ExpressionLearner: return learnt_expressions async def match_expression_context(self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str) -> List[Tuple[str, str, str]]: + # 为expression_pairs逐个条目赋予编号,并构建成字符串 + numbered_pairs = [] + for i, (situation, style) in enumerate(expression_pairs, 1): + numbered_pairs.append(f"{i}. 当\"{situation}\"时,使用\"{style}\"") + + expression_pairs_str = "\n".join(numbered_pairs) + prompt = "match_expression_context_prompt" prompt = await global_prompt_manager.format_prompt( prompt, - expression_pairs=expression_pairs, + expression_pairs=expression_pairs_str, chat_str=random_msg_match_str, ) - match_responses = [] - # 解析所有match结果到 match_response - - matched_expressions = [] - for match_response in match_responses: - #exp序号 - match_response["expression_pair"] - matched_expressions.append((match_response["expression_pair"], match_response["context"])) - - - - response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3) - async def learn_expression(self, num: int = 10) -> Optional[Tuple[List[Tuple[str, str, str]], str]]: + print(f"match_expression_context_prompt: {prompt}") + print(f"random_msg_match_str: {response}") + + # 解析JSON响应 + match_responses = [] + try: + response = response.strip() + # 检查是否已经是标准JSON数组格式 + if response.startswith('[') and response.endswith(']'): + match_responses = json.loads(response) + else: + # 尝试直接解析多个JSON对象 + try: + # 如果是多个JSON对象用逗号分隔,包装成数组 + if response.startswith('{') and not response.startswith('['): + response = '[' + response + ']' + match_responses = json.loads(response) + else: + # 使用repair_json处理响应 + repaired_content = repair_json(response) + + # 确保repaired_content是列表格式 + if isinstance(repaired_content, str): + try: + parsed_data = json.loads(repaired_content) + if isinstance(parsed_data, dict): + # 如果是字典,包装成列表 + match_responses = [parsed_data] + elif isinstance(parsed_data, list): + match_responses = parsed_data + else: + match_responses = [] + except json.JSONDecodeError: + match_responses = [] + elif isinstance(repaired_content, dict): + # 如果是字典,包装成列表 + match_responses = [repaired_content] + elif isinstance(repaired_content, list): + match_responses = repaired_content + else: + match_responses = [] + except json.JSONDecodeError: + # 如果还是失败,尝试repair_json + repaired_content = repair_json(response) + if isinstance(repaired_content, str): + parsed_data = json.loads(repaired_content) + match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data] + else: + match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content] + + except (json.JSONDecodeError, Exception) as e: + logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}") + return [] + + matched_expressions = [] + used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引 + + for match_response in match_responses: + try: + # 获取表达方式序号 + pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引 + + # 检查索引是否有效且未被使用过 + if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices: + situation, style = expression_pairs[pair_index] + context = match_response["context"] + matched_expressions.append((situation, style, context)) + used_pair_indices.add(pair_index) # 标记该索引已使用 + logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}") + elif pair_index in used_pair_indices: + logger.debug(f"跳过重复的表达方式 {pair_index + 1}") + except (ValueError, KeyError, IndexError) as e: + logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}") + continue + + return matched_expressions + + async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, List[str], str, List[float]]]]: """从指定聊天流学习表达方式 Args: @@ -352,16 +438,44 @@ class ExpressionLearner: logger.debug(f"学习{type_str}的response: {response}") - expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response, chat_id) + expressions: List[Tuple[str, str]] = self.parse_expression_response(response) + + matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(expressions, random_msg_match_str) + + split_matched_expressions: List[Tuple[str, str, str, List[str]]] = self.split_expression_context(matched_expressions) + + split_matched_expressions_w_emb = [] + full_context_embedding: List[float] = await self.get_full_context_embedding(random_msg_match_str) + + for situation, style, context, context_words in split_matched_expressions: + split_matched_expressions_w_emb.append((self.chat_id, situation, style, context, context_words, random_msg_match_str,full_context_embedding)) - matched_expressions = await self.match_expression_context(expressions, random_msg_match_str) + return split_matched_expressions_w_emb + + async def get_full_context_embedding(self, context: str) -> List[float]: + embedding, _ = await self.embedding_model.get_embedding(context) + return embedding + + def split_expression_context(self, matched_expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str, List[str]]]: + """ + 对matched_expressions中的context部分进行jieba分词 + Args: + matched_expressions: 匹配到的表达方式列表,每个元素为(situation, style, context) + Returns: + 添加了分词结果的表达方式列表,每个元素为(situation, style, context, context_words) + """ + result = [] + for situation, style, context in matched_expressions: + # 使用jieba进行分词 + context_words = list(jieba.cut(context)) + result.append((situation, style, context, context_words)) - return matched_expressions, chat_id + return result - def parse_expression_response(self, response: str, chat_id: str) -> List[Tuple[str, str, str]]: + def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: """ 解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组 """ @@ -388,7 +502,7 @@ class ExpressionLearner: if idx_quote4 == -1: continue style = line[idx_quote3 + 1 : idx_quote4] - expressions.append((chat_id, situation, style)) + expressions.append((situation, style)) return expressions diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index bb36b102..88909b2b 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -412,6 +412,9 @@ class HeartFChatting: }, } reply_text = action_reply_text + + self.end_cycle(loop_info, cycle_timers) + self.print_cycle_info(cycle_timers) end_time = time.time() if end_time - start_time < global_config.chat.planner_smooth: @@ -421,8 +424,7 @@ class HeartFChatting: await asyncio.sleep(0.1) - self.end_cycle(loop_info, cycle_timers) - self.print_cycle_info(cycle_timers) + diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index 1bd72c85..5488dc9f 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -417,12 +417,6 @@ def _build_readable_messages_internal( timestamp = message.time content = message.display_message or message.processed_plain_text or "" - # 向下兼容 - if "ᶠ" in content: - content = content.replace("ᶠ", "") - if "ⁿ" in content: - content = content.replace("ⁿ", "") - # 处理图片ID if show_pic: content = process_pic_ids(content) @@ -862,16 +856,9 @@ async def build_anonymous_messages(messages: List[DatabaseMessages]) -> str: user_id = msg.user_info.user_id content = msg.display_message or msg.processed_plain_text or "" - if "ᶠ" in content: - content = content.replace("ᶠ", "") - if "ⁿ" in content: - content = content.replace("ⁿ", "") - # 处理图片ID content = process_pic_ids(content) - # if not all([platform, user_id, timestamp is not None]): - # continue anon_name = get_anon_name(platform, user_id) # print(f"anon_name:{anon_name}") @@ -937,3 +924,44 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set.add(person_id) return list(person_ids_set) # 将集合转换为列表返回 + + +async def build_bare_messages(messages: List[DatabaseMessages]) -> str: + """ + 构建简化版消息字符串,只包含processed_plain_text内容,不考虑用户名和时间戳 + + Args: + messages: 消息列表 + + Returns: + 只包含消息内容的字符串 + """ + if not messages: + return "" + + output_lines = [] + + for msg in messages: + # 获取纯文本内容 + content = msg.processed_plain_text or "" + + + # 处理图片ID + pic_pattern = r"\[picid:[^\]]+\]" + def replace_pic_id(match): + return "[图片]" + content = re.sub(pic_pattern, replace_pic_id, content) + + # 处理用户引用格式,移除回复和@标记 + reply_pattern = r"回复<[^:<>]+:[^:<>]+>" + content = re.sub(reply_pattern, "回复[某人]", content) + + at_pattern = r"@<[^:<>]+:[^:<>]+>" + content = re.sub(at_pattern, "@[某人]", content) + + # 清理并添加到输出 + content = content.strip() + if content: + output_lines.append(content) + + return "\n".join(output_lines)