MaiBot/src/express/expression_learner.py

633 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import time
import json
import os
from typing import List, Optional, Tuple
import traceback
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages,
build_bare_messages,
)
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.express_utils import filter_message_content, calculate_similarity
from json_repair import repair_json
# MAX_EXPRESSION_COUNT = 300
logger = get_logger("expressor")
def init_prompt() -> None:
learn_style_prompt = """
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
1. 只考虑文字,不要考虑表情包和图片
2. 不要涉及具体的人名,但是可以涉及具体名词
3. 思考有没有特殊的梗,一并总结成语言风格
4. 例子仅供参考,请严格根据群聊内容总结!!!
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景不超过20个字。BBBBB代表对应的语言风格特定句式或表达方式不超过20个字。
例如:
"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
"表示讽刺的赞同,不讲道理"时,使用"对对对"
"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
请注意不要总结你自己SELF的发言尽量保证总结内容的逻辑性
现在请你概括
"""
Prompt(learn_style_prompt, "learn_style_prompt")
match_expression_context_prompt = """
**聊天内容**
{chat_str}
**从聊天内容总结的表达方式pairs**
{expression_pairs}
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果expression_pair不能有重复每个expression_pair仅输出一个最合适的context。
如果找不到原句,就不输出该句的匹配结果。
以json格式输出
格式如下:
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
...
现在请你输出匹配结果:
"""
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
class ExpressionLearner:
def __init__(self, chat_id: str) -> None:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
)
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习参数
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 120 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
检查是否应该触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
"""
# 检查是否允许学习
if not self.enable_learning:
return False
# 检查时间间隔
time_diff = time.time() - self.last_learning_time
if time_diff < self.min_learning_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
return False
return True
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
if not self.should_trigger_learning():
return
try:
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格
learnt_style = await self.learn_and_store(num=25)
# 更新学习时间
self.last_learning_time = time.time()
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
return
async def learn_and_store(self, num: int = 10) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
"""
learnt_expressions = await self.learn_expression(num)
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
_up_content,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
up_content,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
return learnt_expressions
async def match_expression_context(
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
prompt = "match_expression_context_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt,
expression_pairs=expression_pairs_str,
chat_str=random_msg_match_str,
)
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
repaired_content = repair_json(response)
# 确保repaired_content是列表格式
if isinstance(repaired_content, str):
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, dict):
# 如果是字典,包装成列表
match_responses = [parsed_data]
elif isinstance(parsed_data, list):
match_responses = parsed_data
else:
match_responses = []
except json.JSONDecodeError:
match_responses = []
elif isinstance(repaired_content, dict):
# 如果是字典,包装成列表
match_responses = [repaired_content]
elif isinstance(repaired_content, list):
match_responses = repaired_content
else:
match_responses = []
except json.JSONDecodeError:
# 如果还是失败尝试repair_json
repaired_content = repair_json(response)
if isinstance(repaired_content, str):
parsed_data = json.loads(repaired_content)
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
else:
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
# 确保 match_responses 是一个列表
if not isinstance(match_responses, list):
if isinstance(match_responses, dict):
match_responses = [match_responses]
else:
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
return []
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"match_responses 内容: {match_responses}")
for match_response in match_responses:
try:
# 检查 match_response 的类型
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response.get("context", "")
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError, TypeError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
return matched_expressions
async def learn_expression(self, num: int = 10) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
"""
current_time = time.time()
# 获取上次学习之后的消息
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=current_time,
limit=num,
)
# print(random_msg)
if not random_msg or random_msg == []:
return None
# 学习用
random_msg_str: str = await build_anonymous_messages(random_msg)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt(
"learn_style_prompt",
chat_str=random_msg_str,
)
# print(f"random_msg_str:{random_msg_str}")
# logger.info(f"学习{type_str}的prompt: {prompt}")
try:
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return None
# logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
expressions, random_msg_match_str
)
# 为每条消息构建精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0:
# 没有匹配到目标句或没有上一句,跳过该表达
continue
# 检查目标句是否为空
target_content = bare_lines[pos][1]
if not target_content:
# 目标句为空,跳过该表达
continue
prev_original_idx = bare_lines[pos - 1][0]
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
if not up_content:
# 上一句为空,跳过该表达
continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up:
return None
return filtered_with_up
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
"""
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容,存储为(situation, style)元组
"""
expressions: List[Tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip()
if not line:
continue
# 查找"当"和下一个引号
idx_when = line.find('"')
if idx_when == -1:
continue
idx_quote1 = idx_when + 1
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
if idx_use == -1:
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((situation, style))
return expressions
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str]] = []
removed_count = 0
for situation, style in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
expr_obj = (
Expression.select()
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
.first()
)
if expr_obj:
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
context=context,
up_content=up_content,
current_time=current_time,
)
return
await self._create_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
async def _create_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = [situation]
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
Expression.create(
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
up_content=up_content,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = self._parse_content_list(expr_obj.content_list)
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.last_active_time = current_time
expr_obj.context = context
expr_obj.up_content = up_content
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
expr_obj.save()
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表,保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines
init_prompt()
class ExpressionLearnerManager:
def __init__(self):
self.expression_learners = {}
self._ensure_expression_directories()
def get_expression_learner(self, chat_id: str) -> ExpressionLearner:
if chat_id not in self.expression_learners:
self.expression_learners[chat_id] = ExpressionLearner(chat_id)
return self.expression_learners[chat_id]
def _ensure_expression_directories(self):
"""
确保表达方式相关的目录结构存在
"""
base_dir = os.path.join("data", "expression")
directories_to_create = [
base_dir,
os.path.join(base_dir, "learnt_style"),
os.path.join(base_dir, "learnt_grammar"),
]
for directory in directories_to_create:
try:
os.makedirs(directory, exist_ok=True)
logger.debug(f"确保目录存在: {directory}")
except Exception as e:
logger.error(f"创建目录失败 {directory}: {e}")
expression_learner_manager = ExpressionLearnerManager()