feat:黑化和表达不会将名字总结进去

pull/1443/head
SengokuCola 2025-12-15 00:05:15 +08:00
parent b73a748f52
commit 3db9fafe65
4 changed files with 201 additions and 15 deletions

View File

@ -3,7 +3,7 @@ import json
import os
import re
import asyncio
from typing import List, Optional, Tuple, Any, Dict
from typing import List, Optional, Tuple, Any, Dict, Callable
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest
@ -97,14 +97,14 @@ class ExpressionLearner:
async def learn_and_store(
self,
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, str, str]]:
"""
学习并存储表达方式
Args:
messages: 外部传入的消息列表必需
num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
if not messages:
return None
@ -135,6 +135,17 @@ class ExpressionLearner:
expressions, jargon_entries = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
# 过滤掉包含人物名称的表达方式
if person_name_filter:
filtered_expressions = []
for situation, style, source_id in expressions:
# 检查 situation 和 style 是否包含人物名称
if person_name_filter(situation) or person_name_filter(style):
logger.info(f"跳过包含人物名称的表达方式: situation={situation}, style={style}")
continue
filtered_expressions.append((situation, style, source_id))
expressions = filtered_expressions
# 检查表达方式数量如果超过10个则放弃本次表达学习
if len(expressions) > 10:
logger.info(f"表达方式提取数量超过10个实际{len(expressions)}个),放弃本次表达学习")
@ -147,7 +158,7 @@ class ExpressionLearner:
# 处理黑话条目,路由到 jargon_miner即使没有表达方式也要处理黑话
if jargon_entries:
await self._process_jargon_entries(jargon_entries, random_msg)
await self._process_jargon_entries(jargon_entries, random_msg, person_name_filter)
# 如果没有表达方式,直接返回
if not expressions:
@ -500,13 +511,19 @@ class ExpressionLearner:
logger.error(f"概括表达情境失败: {e}")
return None
async def _process_jargon_entries(self, jargon_entries: List[Tuple[str, str]], messages: List[Any]) -> None:
async def _process_jargon_entries(
self,
jargon_entries: List[Tuple[str, str]],
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
处理从 expression learner 提取的黑话条目路由到 jargon_miner
Args:
jargon_entries: 黑话条目列表每个元素是 (content, source_id)
messages: 消息列表用于构建上下文
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
if not jargon_entries or not messages:
return
@ -527,6 +544,11 @@ class ExpressionLearner:
logger.info(f"跳过包含机器人昵称/别名的黑话: {content}")
continue
# 检查是否包含人物名称
if person_name_filter and person_name_filter(content):
logger.info(f"跳过包含人物名称的黑话: {content}")
continue
# 解析 source_id
source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
@ -557,7 +579,7 @@ class ExpressionLearner:
return
# 调用 jargon_miner 处理这些条目
await jargon_miner.process_extracted_entries(entries)
await jargon_miner.process_extracted_entries(entries, person_name_filter)
init_prompt()

View File

@ -3,7 +3,7 @@ import json
import asyncio
import random
from collections import OrderedDict
from typing import List, Dict, Optional, Any
from typing import List, Dict, Optional, Any, Callable
from json_repair import repair_json
from peewee import fn
@ -478,12 +478,17 @@ class JargonMiner:
traceback.print_exc()
async def run_once(self, messages: List[Any]) -> None:
async def run_once(
self,
messages: List[Any],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
运行一次黑话提取
Args:
messages: 外部传入的消息列表必需
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
# 使用异步锁防止并发执行
async with self._extraction_lock:
@ -563,6 +568,11 @@ class JargonMiner:
logger.info(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
# 检查是否包含人物名称
if person_name_filter and person_name_filter(content):
logger.info(f"解析阶段跳过包含人物名称的词条: {content}")
continue
msg_id_str = str(msg_id_value or "").strip()
if not msg_id_str:
logger.warning(f"解析jargon失败msg_id缺失content={content}")
@ -723,12 +733,17 @@ class JargonMiner:
logger.error(f"JargonMiner 运行失败: {e}")
# 即使失败也保持时间戳更新,避免频繁重试
async def process_extracted_entries(self, entries: List[Dict[str, List[str]]]) -> None:
async def process_extracted_entries(
self,
entries: List[Dict[str, List[str]]],
person_name_filter: Optional[Callable[[str], bool]] = None
) -> None:
"""
处理已提取的黑话条目 expression_learner 路由过来的
Args:
entries: 黑话条目列表每个元素格式为 {"content": "...", "raw_content": [...]}
person_name_filter: 可选的过滤函数用于检查内容是否包含人物名称
"""
if not entries:
return
@ -738,6 +753,14 @@ class JargonMiner:
merged_entries: OrderedDict[str, Dict[str, List[str]]] = OrderedDict()
for entry in entries:
content_key = entry["content"]
# 检查是否包含人物名称
logger.info(f"process_extracted_entries 检查是否包含人物名称: {content_key}")
logger.info(f"person_name_filter: {person_name_filter}")
if person_name_filter and person_name_filter(content_key):
logger.info(f"process_extracted_entries 跳过包含人物名称的黑话: {content_key}")
continue
raw_list = entry.get("raw_content", []) or []
if content_key in merged_entries:
merged_entries[content_key]["raw_content"].extend(raw_list)

View File

@ -1,16 +1,34 @@
import time
import asyncio
from typing import List, Any
from typing import List, Any, Optional
from collections import OrderedDict
from dataclasses import dataclass
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.jargon_miner import miner_manager
from src.person_info.person_info import Person
logger = get_logger("bw_learner")
@dataclass
class PersonInfo:
"""参与聊天的人物信息"""
user_id: str
user_platform: str
user_nickname: str
user_cardname: Optional[str]
person_name: str
last_seen_time: float # 最后发言时间
def get_unique_key(self) -> str:
"""获取唯一标识(用于去重)"""
return f"{self.user_platform}:{self.user_id}"
class MessageRecorder:
"""
统一的消息记录器负责管理时间窗口和消息提取并将消息分发给 expression_learner jargon_miner
@ -27,6 +45,11 @@ class MessageRecorder:
# 提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
# 维护参与该chat_id的人物列表最多30个使用OrderedDict保持插入顺序
# key: f"{platform}:{user_id}", value: PersonInfo
self._person_list: OrderedDict[str, PersonInfo] = OrderedDict()
self._max_person_count = 30
# 获取 expression 和 jargon 的配置参数
self._init_parameters()
@ -111,6 +134,11 @@ class MessageRecorder:
# 按时间排序,确保顺序一致
messages = sorted(messages, key=lambda msg: msg.time or 0)
# 更新参与聊天的人物列表
self._update_person_list(messages)
logger.info(f"聊天流 {self.chat_name} 的人物列表: {self._person_list}")
logger.info(
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
@ -152,8 +180,11 @@ class MessageRecorder:
messages: 消息列表
"""
try:
# 传递消息给 ExpressionLearner必需参数
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
# 传递消息和过滤函数给 ExpressionLearner
learnt_style = await self.expression_learner.learn_and_store(
messages=messages,
person_name_filter=self.contains_person_name
)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
@ -179,14 +210,124 @@ class MessageRecorder:
messages: 消息列表
"""
try:
# 传递消息给 JargonMiner避免它重复获取
await self.jargon_miner.run_once(messages=messages)
# 传递消息和过滤函数给 JargonMiner
await self.jargon_miner.run_once(
messages=messages,
person_name_filter=self.contains_person_name
)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
def _update_person_list(self, messages: List[Any]) -> None:
"""
从消息中提取人物信息并更新人物列表
Args:
messages: 消息列表
"""
for msg in messages:
# 获取消息发送者信息
# 消息对象可能是 DatabaseMessages它有 user_info 属性
if hasattr(msg, 'user_info'):
# DatabaseMessages 类型
user_info = msg.user_info
user_id = getattr(user_info, 'user_id', None) or ''
user_platform = getattr(user_info, 'platform', None) or ''
user_nickname = getattr(user_info, 'user_nickname', None) or ''
user_cardname = getattr(user_info, 'user_cardname', None)
else:
# 直接属性访问
user_id = getattr(msg, 'user_id', None) or ''
user_platform = getattr(msg, 'user_platform', None) or ''
user_nickname = getattr(msg, 'user_nickname', None) or ''
user_cardname = getattr(msg, 'user_cardname', None)
msg_time = getattr(msg, 'time', time.time())
# 检查必要信息
if not user_id or not user_platform:
continue
# 获取 person_name
try:
person = Person(platform=user_platform, user_id=str(user_id))
person_name = person.person_name or user_nickname or (user_cardname if user_cardname else "未知用户")
except Exception as e:
logger.info(f"获取person_name失败: {e}, 使用nickname")
person_name = user_nickname or (user_cardname if user_cardname else "未知用户")
# 生成唯一key
unique_key = f"{user_platform}:{user_id}"
# 如果已存在,更新最后发言时间
if unique_key in self._person_list:
self._person_list[unique_key].last_seen_time = msg_time
# 移动到末尾(表示最近活跃)
self._person_list.move_to_end(unique_key)
else:
# 如果超过最大数量,移除最早的(最前面的)
if len(self._person_list) >= self._max_person_count:
oldest_key = next(iter(self._person_list))
del self._person_list[oldest_key]
logger.info(f"人物列表已满,移除最早的人物: {oldest_key}")
# 添加新人物
person_info = PersonInfo(
user_id=str(user_id),
user_platform=user_platform,
user_nickname=user_nickname or "",
user_cardname=user_cardname,
person_name=person_name,
last_seen_time=msg_time
)
self._person_list[unique_key] = person_info
logger.info(f"添加新人物到列表: {unique_key}, person_name={person_name}")
def contains_person_name(self, content: str) -> bool:
"""
检查内容是否包含任何参与聊天的人物的名称或昵称
Args:
content: 要检查的内容
Returns:
bool: 如果包含任何人物名称或昵称返回True
"""
if not content or not self._person_list:
return False
content_lower = content.strip().lower()
if not content_lower:
return False
# 检查所有人物
for person_info in self._person_list.values():
# 检查 person_name
if person_info.person_name:
person_name_lower = person_info.person_name.strip().lower()
if person_name_lower and person_name_lower in content_lower:
logger.debug(f"内容包含person_name: {person_info.person_name} in {content}")
return True
# 检查 user_nickname
if person_info.user_nickname:
nickname_lower = person_info.user_nickname.strip().lower()
if nickname_lower and nickname_lower in content_lower:
logger.debug(f"内容包含nickname: {person_info.user_nickname} in {content}")
return True
# 检查 user_cardname群昵称
if person_info.user_cardname:
cardname_lower = person_info.user_cardname.strip().lower()
if cardname_lower and cardname_lower in content_lower:
logger.debug(f"内容包含cardname: {person_info.user_cardname} in {content}")
return True
return False
class MessageRecorderManager:
"""MessageRecorder 管理器"""

View File

@ -36,7 +36,7 @@ def init_replyer_prompt():
{reply_target_block}
{planner_reasoning}
{identity}
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录把握当前的话题然后给出口语化回复
{chat_prompt}你正在群里聊天,现在请你读读之前的聊天记录把握当前的话题然后给出日常且简短的回复
{keywords_reaction_prompt}
请注意把握聊天内容
{reply_style}