feat:复用jargon和expression的部分代码,代码层面合并,合并配置项

缓解bot重复学习自身表达的问题
缓解单字黑话推断时消耗过高的问题
修复count过高时推断过长的问题
移除表达方式学习强度配置
pull/1421/head
SengokuCola 2025-12-07 14:28:30 +08:00
parent 717b18be1e
commit 2e31fa2055
20 changed files with 587 additions and 469 deletions

View File

@ -3,19 +3,17 @@ import json
import os import os
import re import re
import asyncio import asyncio
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Any
import traceback
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages, build_anonymous_messages,
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.express_utils import filter_message_content from src.bw_learner.learner_utils import filter_message_content, is_bot_message
from json_repair import repair_json from json_repair import repair_json
@ -26,15 +24,14 @@ logger = get_logger("expressor")
def init_prompt() -> None: def init_prompt() -> None:
learn_style_prompt = """{chat_str} learn_style_prompt = """{chat_str}
你的名字是{bot_name},现在请你请从上面这段群聊中用户的语言风格和说话方式
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
每一行消息前面的方括号中的数字 [1][2]是该行消息的唯一编号请在输出中引用这些编号来标注表达方式的来源行
1. 只考虑文字不要考虑表情包和图片 1. 只考虑文字不要考虑表情包和图片
2. 不要涉及具体的人名但是可以涉及具体名词 2. 不要总结SELF的发言
3. 思考有没有特殊的梗一并总结成语言风格 3. 不要涉及具体的人名也不要涉及具体名词
4. 例子仅供参考请严格根据群聊内容总结!!! 4. 思考有没有特殊的梗一并总结成语言风格
5. 例子仅供参考请严格根据群聊内容总结!!!
注意总结成如下格式的规律总结的内容要详细但具有概括性 注意总结成如下格式的规律总结的内容要详细但具有概括性
例如"AAAAA"可以"BBBBB", AAAAA代表某个具体的场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字 例如"AAAAA"可以"BBBBB", AAAAA代表某个场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字
请严格以 JSON 数组的形式输出结果每个元素为一个对象结构如下注意字段名 请严格以 JSON 数组的形式输出结果每个元素为一个对象结构如下注意字段名
[ [
@ -45,10 +42,6 @@ def init_prompt() -> None:
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}}, {{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}},
] ]
请注意
- 不要总结你自己SELF的发言尽量保证总结内容的逻辑性
- 请只针对最重要的若干条表达方式进行总结避免输出太多重复或相似的条目
其中 其中
- situation表示在什么情境下的简短概括不超过20个字 - situation表示在什么情境下的简短概括不超过20个字
- style表示对应的语言风格或常用表达不超过20个字 - style表示对应的语言风格或常用表达不超过20个字
@ -69,170 +62,36 @@ class ExpressionLearner:
self.summary_model: LLMRequest = LLMRequest( self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary" model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
) )
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
self.chat_id = chat_id self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次学习时间
self.last_learning_time: float = time.time()
# 学习锁,防止并发执行学习任务 # 学习锁,防止并发执行学习任务
self._learning_lock = asyncio.Lock() self._learning_lock = asyncio.Lock()
# 学习参数 async def learn_and_store(
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat( self,
self.chat_id messages: List[Any],
) ) -> List[Tuple[str, str, str]]:
# 防止除以零如果学习强度为0或负数使用最小值0.0001
if self.learning_intensity <= 0:
logger.warning(f"学习强度为 {self.learning_intensity},已自动调整为 0.0001 以避免除以零错误")
self.learning_intensity = 0.0000001
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 120 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
检查是否应该触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否应该触发学习
"""
# 检查是否允许学习
if not self.enable_learning:
return False
# 检查时间间隔
time_diff = time.time() - self.last_learning_time
if time_diff < self.min_learning_interval:
return False
# 检查消息数量(只检查指定聊天流的消息)
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_learning:
return False
return True
async def trigger_learning_for_chat(self):
"""
为指定聊天流触发学习
Args:
chat_id: 聊天流ID
Returns:
bool: 是否成功触发学习
"""
# 使用异步锁防止并发执行
async with self._learning_lock:
# 在锁内检查,避免并发触发
# 如果锁被持有,其他协程会等待,但等待期间条件可能已变化,所以需要再次检查
if not self.should_trigger_learning():
return
# 保存学习开始前的时间戳,用于获取消息范围
learning_start_timestamp = time.time()
previous_learning_time = self.last_learning_time
# 立即更新学习时间,防止并发触发
self.last_learning_time = learning_start_timestamp
try:
logger.info(f"在聊天流 {self.chat_name} 学习表达方式")
# 学习语言风格,传递学习开始前的时间戳
learnt_style = await self.learn_and_store(num=25, timestamp_start=previous_learning_time)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}")
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
return
async def learn_and_store(self, num: int = 10, timestamp_start: Optional[float] = None) -> List[Tuple[str, str, str]]:
""" """
学习并存储表达方式 学习并存储表达方式
Args: Args:
messages: 外部传入的消息列表必需
num: 学习数量 num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
""" """
learnt_expressions = await self.learn_expression(num, timestamp_start=timestamp_start) if not messages:
if learnt_expressions is None:
logger.info("没有学习到表达风格")
return []
# 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
return learnt_expressions
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str]]]:
"""从指定聊天流学习表达方式
Args:
num: 学习数量
timestamp_start: 学习开始的时间戳如果为None则使用self.last_learning_time
"""
current_time = time.time()
# 使用传入的时间戳如果没有则使用self.last_learning_time
start_timestamp = timestamp_start if timestamp_start is not None else self.last_learning_time
# 获取上次学习之后的消息
random_msg = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=start_timestamp,
timestamp_end=current_time,
limit=num,
)
# print(random_msg)
if not random_msg or random_msg == []:
return None return None
random_msg = messages
# 学习用(开启行编号,便于溯源) # 学习用(开启行编号,便于溯源)
random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True) random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
prompt: str = await global_prompt_manager.format_prompt( prompt: str = await global_prompt_manager.format_prompt(
"learn_style_prompt", "learn_style_prompt",
bot_name=global_config.bot.nickname,
chat_str=random_msg_str, chat_str=random_msg_str,
) )
@ -269,16 +128,50 @@ class ExpressionLearner:
# 当前行的原始内容 # 当前行的原始内容
current_msg = random_msg[line_index] current_msg = random_msg[line_index]
# 过滤掉从bot自己发言中提取到的表达方式
if is_bot_message(current_msg):
continue
context = filter_message_content(current_msg.processed_plain_text or "") context = filter_message_content(current_msg.processed_plain_text or "")
if not context: if not context:
continue continue
filtered_expressions.append((situation, style, context)) filtered_expressions.append((situation, style, context))
learnt_expressions = filtered_expressions
if not filtered_expressions: if learnt_expressions is None:
return None logger.info("没有学习到表达风格")
return []
return filtered_expressions # 展示学到的表达方式
learnt_expressions_str = ""
for (
situation,
style,
_context,
) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
current_time = time.time()
# 存储到数据库 Expression 表
for (
situation,
style,
context,
) in learnt_expressions:
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
current_time=current_time,
)
return learnt_expressions
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
""" """
@ -356,9 +249,9 @@ class ExpressionLearner:
if in_string: if in_string:
# 在字符串值内部,将中文引号替换为转义的英文引号 # 在字符串值内部,将中文引号替换为转义的英文引号
if char == '"': # 中文左引号 if char == '"': # 中文左引号 U+201C
result.append('\\"') result.append('\\"')
elif char == '"': # 中文右引号 elif char == '"': # 中文右引号 U+201D
result.append('\\"') result.append('\\"')
else: else:
result.append(char) result.append(char)

View File

@ -10,7 +10,7 @@ from src.config.config import global_config, model_config
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression from src.common.database.database_model import Expression
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.express.express_utils import weighted_sample from src.bw_learner.learner_utils import weighted_sample
logger = get_logger("expression_selector") logger = get_logger("expression_selector")

View File

@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.config.config import model_config, global_config from src.config.config import model_config, global_config
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.jargon.jargon_miner import search_jargon from src.bw_learner.jargon_miner import search_jargon
from src.jargon.jargon_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains
logger = get_logger("jargon") logger = get_logger("jargon")
@ -82,7 +82,7 @@ class JargonExplainer:
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
# 根据all_global配置决定查询逻辑 # 根据all_global配置决定查询逻辑
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 开启all_global只查询is_global=True的记录 # 开启all_global只查询is_global=True的记录
query = query.where(Jargon.is_global) query = query.where(Jargon.is_global)
else: else:
@ -107,7 +107,7 @@ class JargonExplainer:
continue continue
# 检查chat_id如果all_global=False # 检查chat_id如果all_global=False
if not global_config.jargon.all_global: if not global_config.expression.all_global_jargon:
if jargon.is_global: if jargon.is_global:
# 全局黑话,包含 # 全局黑话,包含
pass pass
@ -181,7 +181,7 @@ class JargonExplainer:
content = entry["content"] content = entry["content"]
# 根据是否开启全局黑话,决定查询方式 # 根据是否开启全局黑话,决定查询方式
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 开启全局黑话查询所有is_global=True的记录 # 开启全局黑话查询所有is_global=True的记录
results = search_jargon( results = search_jargon(
keyword=content, keyword=content,
@ -265,7 +265,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
return [] return []
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
query = query.where(Jargon.is_global) query = query.where(Jargon.is_global)
query = query.order_by(Jargon.count.desc()) query = query.order_by(Jargon.count.desc())
@ -277,7 +277,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
if not content: if not content:
continue continue
if not global_config.jargon.all_global and not jargon.is_global: if not global_config.expression.all_global_jargon and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id) chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id): if not chat_id_list_contains(chat_id_list, chat_id):
continue continue

View File

@ -1,6 +1,7 @@
import time import time
import json import json
import asyncio import asyncio
import random
from collections import OrderedDict from collections import OrderedDict
from typing import List, Dict, Optional, Any from typing import List, Dict, Optional, Any
from json_repair import repair_json from json_repair import repair_json
@ -16,7 +17,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive, get_raw_msg_by_timestamp_with_chat_inclusive,
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.jargon.jargon_utils import ( from src.bw_learner.learner_utils import (
is_bot_message, is_bot_message,
build_context_paragraph, build_context_paragraph,
contains_bot_self_name, contains_bot_self_name,
@ -29,6 +30,29 @@ from src.jargon.jargon_utils import (
logger = get_logger("jargon") logger = get_logger("jargon")
def _is_single_char_jargon(content: str) -> bool:
"""
判断是否是单字黑话单个汉字英文或数字
Args:
content: 词条内容
Returns:
bool: 如果是单字黑话返回True否则返回False
"""
if not content or len(content) != 1:
return False
char = content[0]
# 判断是否是单个汉字、单个英文字母或单个数字
return (
'\u4e00' <= char <= '\u9fff' or # 汉字
'a' <= char <= 'z' or # 小写字母
'A' <= char <= 'Z' or # 大写字母
'0' <= char <= '9' # 数字
)
def _init_prompt() -> None: def _init_prompt() -> None:
prompt_str = """ prompt_str = """
**聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID** **聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID**
@ -36,11 +60,9 @@ def _init_prompt() -> None:
请从上面这段聊天内容中提取"可能是黑话"的候选项黑话/俚语/网络缩写/口头禅 请从上面这段聊天内容中提取"可能是黑话"的候选项黑话/俚语/网络缩写/口头禅
- 必须为对话中真实出现过的短词或短语 - 必须为对话中真实出现过的短词或短语
- 必须是你无法理解含义的词语没有明确含义的词语 - 必须是你无法理解含义的词语没有明确含义的词语请不要选择有明确含义或者含义清晰的词语
- 请不要选择有明确含义或者含义清晰的词语
- 排除人名@表情包/图片中的内容纯标点常规功能词如的啊等 - 排除人名@表情包/图片中的内容纯标点常规功能词如的啊等
- 每个词条长度建议 2-8 个字符不强制尽量短小 - 每个词条长度建议 2-8 个字符不强制尽量短小
- 合并重复项去重
黑话必须为以下几种类型 黑话必须为以下几种类型
- 由字母构成的汉语拼音首字母的简写词例如nbyydsxswl - 由字母构成的汉语拼音首字母的简写词例如nbyydsxswl
@ -67,12 +89,14 @@ def _init_inference_prompts() -> None:
{content} {content}
**词条出现的上下文其中的{bot_name}的发言内容是你自己的发言** **词条出现的上下文其中的{bot_name}的发言内容是你自己的发言**
{raw_content_list} {raw_content_list}
{previous_meaning_section}
请根据上下文推断"{content}"这个词条的含义 请根据上下文推断"{content}"这个词条的含义
- 如果这是一个黑话俚语或网络用语请推断其含义 - 如果这是一个黑话俚语或网络用语请推断其含义
- 如果含义明确常规词汇也请说明 - 如果含义明确常规词汇也请说明
- {bot_name} 的发言内容可能包含错误请不要参考其发言内容 - {bot_name} 的发言内容可能包含错误请不要参考其发言内容
- 如果上下文信息不足无法推断含义请设置 no_info true - 如果上下文信息不足无法推断含义请设置 no_info true
{previous_meaning_instruction}
JSON 格式输出 JSON 格式输出
{{ {{
@ -166,10 +190,6 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
class JargonMiner: class JargonMiner:
def __init__(self, chat_id: str) -> None: def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id self.chat_id = chat_id
self.last_learning_time: float = time.time()
# 频率控制,可按需调整
self.min_messages_for_learning: int = 30
self.min_learning_interval: float = 60
self.llm = LLMRequest( self.llm = LLMRequest(
model_set=model_config.model_task_config.utils, model_set=model_config.model_task_config.utils,
@ -200,6 +220,10 @@ class JargonMiner:
if not key: if not key:
return return
# 单字黑话(单个汉字、英文或数字)不记录到缓存
if _is_single_char_jargon(key):
return
if key in self.cache: if key in self.cache:
self.cache.move_to_end(key) self.cache.move_to_end(key)
else: else:
@ -272,13 +296,37 @@ class JargonMiner:
logger.warning(f"jargon {content} 没有raw_content跳过推断") logger.warning(f"jargon {content} 没有raw_content跳过推断")
return return
# 获取当前count和上一次的meaning
current_count = jargon_obj.count or 0
previous_meaning = jargon_obj.meaning or ""
# 当count为24, 60时随机移除一半的raw_content项目
if current_count in [24, 60] and len(raw_content_list) > 1:
# 计算要保留的数量至少保留1个
keep_count = max(1, len(raw_content_list) // 2)
raw_content_list = random.sample(raw_content_list, keep_count)
logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目")
# 步骤1: 基于raw_content和content推断 # 步骤1: 基于raw_content和content推断
raw_content_text = "\n".join(raw_content_list) raw_content_text = "\n".join(raw_content_list)
# 当count为24, 60, 100时在prompt中放入上一次推断出的meaning作为参考
previous_meaning_section = ""
previous_meaning_instruction = ""
if current_count in [24, 60, 100] and previous_meaning:
previous_meaning_section = f"""
**上一次推断的含义仅供参考**
{previous_meaning}
"""
previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果"
prompt1 = await global_prompt_manager.format_prompt( prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt", "jargon_inference_with_context_prompt",
content=content, content=content,
bot_name=global_config.bot.nickname, bot_name=global_config.bot.nickname,
raw_content_list=raw_content_text, raw_content_list=raw_content_text,
previous_meaning_section=previous_meaning_section,
previous_meaning_instruction=previous_meaning_instruction,
) )
response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3) response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3)
@ -430,45 +478,16 @@ class JargonMiner:
traceback.print_exc() traceback.print_exc()
def should_trigger(self) -> bool: async def run_once(self, messages: List[Any]) -> None:
# 冷却时间检查 """
if time.time() - self.last_learning_time < self.min_learning_interval: 运行一次黑话提取
return False
Args:
# 拉取最近消息数量是否足够 messages: 外部传入的消息列表必需
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( """
chat_id=self.chat_id,
timestamp_start=self.last_learning_time,
timestamp_end=time.time(),
)
return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning)
async def run_once(self) -> None:
# 使用异步锁防止并发执行 # 使用异步锁防止并发执行
async with self._extraction_lock: async with self._extraction_lock:
try: try:
# 在锁内检查,避免并发触发
if not self.should_trigger():
return
chat_stream = get_chat_manager().get_stream(self.chat_id)
if not chat_stream:
return
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_learning_time
extraction_end_time = time.time()
# 立即更新学习时间,防止并发触发
self.last_learning_time = extraction_end_time
# 拉取学习窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time,
limit=20,
)
if not messages: if not messages:
return return
@ -608,7 +627,7 @@ class JargonMiner:
# 查找匹配的记录 # 查找匹配的记录
matched_obj = None matched_obj = None
for obj in query: for obj in query:
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 开启all_global所有content匹配的记录都可以 # 开启all_global所有content匹配的记录都可以
matched_obj = obj matched_obj = obj
break break
@ -648,7 +667,7 @@ class JargonMiner:
obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False) obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False)
# 开启all_global时确保记录标记为is_global=True # 开启all_global时确保记录标记为is_global=True
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
obj.is_global = True obj.is_global = True
# 关闭all_global时保持原有is_global不变不修改 # 关闭all_global时保持原有is_global不变不修改
@ -664,7 +683,7 @@ class JargonMiner:
updated += 1 updated += 1
else: else:
# 没找到匹配记录,创建新记录 # 没找到匹配记录,创建新记录
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 开启all_global新记录默认为is_global=True # 开启all_global新记录默认为is_global=True
is_global_new = True is_global_new = True
else: else:
@ -718,9 +737,6 @@ class JargonMinerManager:
miner_manager = JargonMinerManager() miner_manager = JargonMinerManager()
async def extract_and_store_jargon(chat_id: str) -> None:
miner = miner_manager.get_miner(chat_id)
await miner.run_once()
def search_jargon( def search_jargon(
@ -770,7 +786,7 @@ def search_jargon(
query = query.where(search_condition) query = query.where(search_condition)
# 根据all_global配置决定查询逻辑 # 根据all_global配置决定查询逻辑
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 开启all_global所有记录都是全局的查询所有is_global=True的记录无视chat_id # 开启all_global所有记录都是全局的查询所有is_global=True的记录无视chat_id
query = query.where(Jargon.is_global) query = query.where(Jargon.is_global)
# 注意对于all_global=False的情况chat_id过滤在Python层面进行以便兼容新旧格式 # 注意对于all_global=False的情况chat_id过滤在Python层面进行以便兼容新旧格式
@ -787,7 +803,7 @@ def search_jargon(
results = [] results = []
for jargon in query: for jargon in query:
# 如果提供了chat_id且all_global=False需要检查chat_id列表是否包含目标chat_id # 如果提供了chat_id且all_global=False需要检查chat_id列表是否包含目标chat_id
if chat_id and not global_config.jargon.all_global: if chat_id and not global_config.expression.all_global_jargon:
chat_id_list = parse_chat_id_list(jargon.chat_id) chat_id_list = parse_chat_id_list(jargon.chat_id)
# 如果记录是is_global=True或者chat_id列表包含目标chat_id则包含 # 如果记录是is_global=True或者chat_id列表包含目标chat_id则包含
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):

View File

@ -1,5 +1,9 @@
import re
import difflib
import random
import json import json
from typing import List, Dict, Optional, Any from datetime import datetime
from typing import Optional, List, Dict, Any
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@ -9,7 +13,147 @@ from src.chat.utils.chat_message_builder import (
from src.chat.utils.utils import parse_platform_accounts from src.chat.utils.utils import parse_platform_accounts
logger = get_logger("jargon") logger = get_logger("learner_utils")
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID
content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间
count越高权重越高但最多为基础权重的5倍
如果表达已checked权重会再乘以3倍
"""
if not population:
return []
counts = []
checked_flags = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
# 获取checked状态
checked = item.get("checked", False)
checked_flags.append(bool(checked))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
base_weights = [1.0 for _ in counts]
else:
base_weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
base_weights.append(1.0 + normalized * 4.0) # 1~5
# 如果checked权重乘以3
weights = []
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
if checked:
weights.append(base_weight * 3.0)
else:
weights.append(base_weight)
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
@ -62,25 +206,37 @@ def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, incr
Returns: Returns:
List[List[Any]]: 更新后的chat_id列表 List[List[Any]]: 更新后的chat_id列表
""" """
# 查找是否已存在该chat_id item = _find_chat_id_item(chat_id_list, target_chat_id)
found = False if item is not None:
for item in chat_id_list: # 找到匹配的chat_id增加计数
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): if len(item) >= 2:
# 找到匹配的chat_id增加计数 item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment
if len(item) >= 2: else:
item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment item.append(increment)
else: else:
item.append(increment)
found = True
break
if not found:
# 未找到,添加新条目 # 未找到,添加新条目
chat_id_list.append([target_chat_id, increment]) chat_id_list.append([target_chat_id, increment])
return chat_id_list return chat_id_list
def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]:
"""
在chat_id列表中查找匹配的项辅助函数
Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id
Returns:
如果找到则返回匹配的项否则返回None
"""
for item in chat_id_list:
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return item
return None
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
""" """
检查chat_id列表中是否包含指定的chat_id 检查chat_id列表中是否包含指定的chat_id
@ -92,10 +248,7 @@ def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) ->
Returns: Returns:
bool: 如果包含则返回True bool: 如果包含则返回True
""" """
for item in chat_id_list: return _find_chat_id_item(chat_id_list, target_chat_id) is not None
if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id):
return True
return False
def contains_bot_self_name(content: str) -> bool: def contains_bot_self_name(content: str) -> bool:
@ -115,7 +268,7 @@ def contains_bot_self_name(content: str) -> bool:
candidates = [name for name in [nickname, *alias_names] if name] candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates if target) return any(name in target for name in candidates)
def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]:

View File

@ -0,0 +1,217 @@
import time
import asyncio
from typing import List, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive
from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.jargon_miner import miner_manager
logger = get_logger("bw_learner")
class MessageRecorder:
"""
统一的消息记录器负责管理时间窗口和消息提取并将消息分发给 expression_learner jargon_miner
"""
def __init__(self, chat_id: str) -> None:
self.chat_id = chat_id
self.chat_stream = get_chat_manager().get_stream(chat_id)
self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id
# 维护每个chat的上次提取时间
self.last_extraction_time: float = time.time()
# 提取锁,防止并发执行
self._extraction_lock = asyncio.Lock()
# 获取 expression 和 jargon 的配置参数
self._init_parameters()
# 获取 expression_learner 和 jargon_miner 实例
self.expression_learner = expression_learner_manager.get_expression_learner(chat_id)
self.jargon_miner = miner_manager.get_miner(chat_id)
def _init_parameters(self) -> None:
"""初始化提取参数"""
# 获取 expression 配置
_, self.enable_expression_learning, self.enable_jargon_learning = (
global_config.expression.get_expression_config_for_chat(self.chat_id)
)
self.min_messages_for_extraction = 30
self.min_extraction_interval = 60
logger.debug(
f"MessageRecorder 初始化: chat_id={self.chat_id}, "
f"min_messages={self.min_messages_for_extraction}, "
f"min_interval={self.min_extraction_interval}"
)
def should_trigger_extraction(self) -> bool:
"""
检查是否应该触发消息提取
Returns:
bool: 是否应该触发提取
"""
# 检查时间间隔
time_diff = time.time() - self.last_extraction_time
if time_diff < self.min_extraction_interval:
return False
# 检查消息数量
recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=self.last_extraction_time,
timestamp_end=time.time(),
)
if not recent_messages or len(recent_messages) < self.min_messages_for_extraction:
return False
return True
async def extract_and_distribute(self) -> None:
"""
提取消息并分发给 expression_learner jargon_miner
"""
# 使用异步锁防止并发执行
async with self._extraction_lock:
# 在锁内检查,避免并发触发
if not self.should_trigger_extraction():
return
# 检查 chat_stream 是否存在
if not self.chat_stream:
return
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_extraction_time
extraction_end_time = time.time()
# 立即更新提取时间,防止并发触发
self.last_extraction_time = extraction_end_time
try:
logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发")
# 拉取提取窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
timestamp_start=extraction_start_time,
timestamp_end=extraction_end_time,
)
if not messages:
logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取")
return
# 按时间排序,确保顺序一致
messages = sorted(messages, key=lambda msg: msg.time or 0)
logger.info(
f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息,"
f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}"
)
# 分别触发 expression_learner 和 jargon_miner 的处理
# 传递提取的消息,避免它们重复获取
# 触发 expression 学习(如果启用)
if self.enable_expression_learning:
asyncio.create_task(
self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages)
)
# 触发 jargon 提取(如果启用),传递消息
if self.enable_jargon_learning:
asyncio.create_task(
self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages)
)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}")
import traceback
traceback.print_exc()
# 即使失败也保持时间戳更新,避免频繁重试
async def _trigger_expression_learning(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
) -> None:
"""
触发 expression 学习使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息给 ExpressionLearner必需参数
learnt_style = await self.expression_learner.learn_and_store(messages=messages)
if learnt_style:
logger.info(f"聊天流 {self.chat_name} 表达学习完成")
else:
logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果")
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}")
import traceback
traceback.print_exc()
async def _trigger_jargon_extraction(
self,
timestamp_start: float,
timestamp_end: float,
messages: List[Any]
) -> None:
"""
触发 jargon 提取使用指定的消息列表
Args:
timestamp_start: 开始时间戳
timestamp_end: 结束时间戳
messages: 消息列表
"""
try:
# 传递消息给 JargonMiner避免它重复获取
await self.jargon_miner.run_once(messages=messages)
except Exception as e:
logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}")
import traceback
traceback.print_exc()
class MessageRecorderManager:
"""MessageRecorder 管理器"""
def __init__(self) -> None:
self._recorders: dict[str, MessageRecorder] = {}
def get_recorder(self, chat_id: str) -> MessageRecorder:
"""获取或创建指定 chat_id 的 MessageRecorder"""
if chat_id not in self._recorders:
self._recorders[chat_id] = MessageRecorder(chat_id)
return self._recorders[chat_id]
# 全局管理器实例
recorder_manager = MessageRecorderManager()
async def extract_and_distribute_messages(chat_id: str) -> None:
"""
统一的消息提取和分发入口函数
Args:
chat_id: 聊天流ID
"""
recorder = recorder_manager.get_recorder(chat_id)
await recorder.extract_and_distribute()

View File

@ -16,7 +16,8 @@ from src.chat.brain_chat.brain_planner import BrainPlanner
from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail from src.chat.heart_flow.hfc_utils import CycleDetail
from src.express.expression_learner import expression_learner_manager from src.bw_learner.expression_learner import expression_learner_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo from src.plugin_system.base.component_types import EventType, ActionInfo
from src.plugin_system.core import events_manager from src.plugin_system.core import events_manager
@ -252,7 +253,7 @@ class BrainChatting:
# ReflectTracker Check # ReflectTracker Check
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答 # 在每次回复前检查一次上下文,看是否有反思问题得到了解答
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
from src.express.reflect_tracker import reflect_tracker_manager from src.bw_learner.reflect_tracker import reflect_tracker_manager
tracker = reflect_tracker_manager.get_tracker(self.stream_id) tracker = reflect_tracker_manager.get_tracker(self.stream_id)
if tracker: if tracker:
@ -265,13 +266,15 @@ class BrainChatting:
# Expression Reflection Check # Expression Reflection Check
# 检查是否需要提问表达反思 # 检查是否需要提问表达反思
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
from src.express.expression_reflector import expression_reflector_manager from src.bw_learner.expression_reflector import expression_reflector_manager
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
asyncio.create_task(reflector.check_and_ask()) asyncio.create_task(reflector.check_and_ask())
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner
# 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息
asyncio.create_task(extract_and_distribute_messages(self.stream_id))
cycle_timers, thinking_id = self.start_cycle() cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考")

View File

@ -16,10 +16,10 @@ from src.chat.planner_actions.planner import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager from src.chat.planner_actions.action_manager import ActionManager
from src.chat.heart_flow.hfc_utils import CycleDetail from src.chat.heart_flow.hfc_utils import CycleDetail
from src.express.expression_learner import expression_learner_manager from src.bw_learner.expression_learner import expression_learner_manager
from src.chat.heart_flow.frequency_control import frequency_control_manager from src.chat.heart_flow.frequency_control import frequency_control_manager
from src.express.reflect_tracker import reflect_tracker_manager from src.bw_learner.reflect_tracker import reflect_tracker_manager
from src.express.expression_reflector import expression_reflector_manager from src.bw_learner.expression_reflector import expression_reflector_manager
from src.bw_learner.message_recorder import extract_and_distribute_messages from src.bw_learner.message_recorder import extract_and_distribute_messages
from src.person_info.person_info import Person from src.person_info.person_info import Person
from src.plugin_system.base.component_types import EventType, ActionInfo from src.plugin_system.base.component_types import EventType, ActionInfo

View File

@ -23,7 +23,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
replace_user_references, replace_user_references,
) )
from src.express.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
@ -35,7 +35,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
from src.jargon.jargon_explainer import explain_jargon_in_context from src.bw_learner.jargon_explainer import explain_jargon_in_context
init_lpmm_prompt() init_lpmm_prompt()
init_replyer_prompt() init_replyer_prompt()

View File

@ -23,7 +23,7 @@ from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
replace_user_references, replace_user_references,
) )
from src.express.expression_selector import expression_selector from src.bw_learner.expression_selector import expression_selector
from src.plugin_system.apis.message_api import translate_pid_to_description from src.plugin_system.apis.message_api import translate_pid_to_description
# from src.memory_system.memory_activator import MemoryActivator # from src.memory_system.memory_activator import MemoryActivator
@ -36,7 +36,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt
from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt
from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt
from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt
from src.jargon.jargon_explainer import explain_jargon_in_context from src.bw_learner.jargon_explainer import explain_jargon_in_context
init_lpmm_prompt() init_lpmm_prompt()
init_replyer_prompt() init_replyer_prompt()

View File

@ -33,7 +33,6 @@ from src.config.official_configs import (
VoiceConfig, VoiceConfig,
MemoryConfig, MemoryConfig,
DebugConfig, DebugConfig,
JargonConfig,
DreamConfig, DreamConfig,
) )
@ -355,7 +354,6 @@ class Config(ConfigBase):
memory: MemoryConfig memory: MemoryConfig
debug: DebugConfig debug: DebugConfig
voice: VoiceConfig voice: VoiceConfig
jargon: JargonConfig
dream: DreamConfig dream: DreamConfig

View File

@ -284,20 +284,20 @@ class ExpressionConfig(ConfigBase):
learning_list: list[list] = field(default_factory=lambda: []) learning_list: list[list] = field(default_factory=lambda: [])
""" """
表达学习配置列表支持按聊天流配置 表达学习配置列表支持按聊天流配置
格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...] 格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...]
示例: 示例:
[ [
["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0 ["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5 ["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习
["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 ["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
] ]
说明: 说明:
- 第一位: chat_stream_id空字符串表示全局配置 - 第一位: chat_stream_id空字符串表示全局配置
- 第二位: 是否使用学到的表达 ("enable"/"disable") - 第二位: 是否使用学到的表达 ("enable"/"disable")
- 第三位: 是否学习表达 ("enable"/"disable") - 第三位: 是否学习表达 ("enable"/"disable")
- 第四位: 学习强度浮点数影响学习频率最短学习时间间隔 = 300/学习强度 - 第四位: 是否启用jargon学习 ("enable"/"disable")
""" """
expression_groups: list[list[str]] = field(default_factory=list) expression_groups: list[list[str]] = field(default_factory=list)
@ -320,6 +320,9 @@ class ExpressionConfig(ConfigBase):
如果列表为空则所有聊天流都可以进行表达反思前提是 reflect = true 如果列表为空则所有聊天流都可以进行表达反思前提是 reflect = true
""" """
all_global_jargon: bool = False
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id。注意此功能关闭后已经记录的全局黑话不会改变需要手动删除"""
def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]:
""" """
解析流配置字符串并生成对应的 chat_id 解析流配置字符串并生成对应的 chat_id
@ -355,7 +358,7 @@ class ExpressionConfig(ConfigBase):
except (ValueError, IndexError): except (ValueError, IndexError):
return None return None
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]: def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]:
""" """
根据聊天流ID获取表达配置 根据聊天流ID获取表达配置
@ -363,35 +366,27 @@ class ExpressionConfig(ConfigBase):
chat_stream_id: 聊天流ID格式为哈希值 chat_stream_id: 聊天流ID格式为哈希值
Returns: Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔) tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)
""" """
if not self.learning_list: if not self.learning_list:
# 如果没有配置,使用默认值:启用表达,启用学习,学习强度1.0对应300秒间隔 # 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习
return True, True, 1.0 return True, True, True
# 优先检查聊天流特定的配置 # 优先检查聊天流特定的配置
if chat_stream_id: if chat_stream_id:
specific_expression_config = self._get_stream_specific_config(chat_stream_id) specific_expression_config = self._get_stream_specific_config(chat_stream_id)
if specific_expression_config is not None: if specific_expression_config is not None:
use_expression, enable_learning, learning_intensity = specific_expression_config return specific_expression_config
# 防止学习强度为0自动转换为0.0001
if learning_intensity == 0:
learning_intensity = 0.0000001
return use_expression, enable_learning, learning_intensity
# 检查全局配置(第一个元素为空字符串的配置) # 检查全局配置(第一个元素为空字符串的配置)
global_expression_config = self._get_global_config() global_expression_config = self._get_global_config()
if global_expression_config is not None: if global_expression_config is not None:
use_expression, enable_learning, learning_intensity = global_expression_config return global_expression_config
# 防止学习强度为0自动转换为0.0001
if learning_intensity == 0:
learning_intensity = 0.0000001
return use_expression, enable_learning, learning_intensity
# 如果都没有匹配,返回默认值:启用表达,启用学习,学习强度1.0对应300秒间隔 # 如果都没有匹配返回默认值启用表达启用学习启用jargon学习
return True, True, 1.0 return True, True, True
def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]: def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]:
""" """
获取特定聊天流的表达配置 获取特定聊天流的表达配置
@ -399,7 +394,7 @@ class ExpressionConfig(ConfigBase):
chat_stream_id: 聊天流ID哈希值 chat_stream_id: 聊天流ID哈希值
Returns: Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)如果没有配置则返回 None tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)如果没有配置则返回 None
""" """
for config_item in self.learning_list: for config_item in self.learning_list:
if not config_item or len(config_item) < 4: if not config_item or len(config_item) < 4:
@ -424,22 +419,19 @@ class ExpressionConfig(ConfigBase):
try: try:
use_expression: bool = config_item[1].lower() == "enable" use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable" enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity: float = float(config_item[3]) enable_jargon_learning: bool = config_item[3].lower() == "enable"
# 防止学习强度为0自动转换为0.0001 return use_expression, enable_learning, enable_jargon_learning # type: ignore
if learning_intensity == 0:
learning_intensity = 0.0000001
return use_expression, enable_learning, learning_intensity # type: ignore
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
return None return None
def _get_global_config(self) -> Optional[tuple[bool, bool, int]]: def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]:
""" """
获取全局表达配置 获取全局表达配置
Returns: Returns:
tuple: (是否使用表达, 是否学习表达, 学习间隔)如果没有配置则返回 None tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习)如果没有配置则返回 None
""" """
for config_item in self.learning_list: for config_item in self.learning_list:
if not config_item or len(config_item) < 4: if not config_item or len(config_item) < 4:
@ -450,11 +442,8 @@ class ExpressionConfig(ConfigBase):
try: try:
use_expression: bool = config_item[1].lower() == "enable" use_expression: bool = config_item[1].lower() == "enable"
enable_learning: bool = config_item[2].lower() == "enable" enable_learning: bool = config_item[2].lower() == "enable"
learning_intensity = float(config_item[3]) enable_jargon_learning: bool = config_item[3].lower() == "enable"
# 防止学习强度为0自动转换为0.0001 return use_expression, enable_learning, enable_jargon_learning # type: ignore
if learning_intensity == 0:
learning_intensity = 0.0000001
return use_expression, enable_learning, learning_intensity # type: ignore
except (ValueError, IndexError): except (ValueError, IndexError):
continue continue
@ -732,14 +721,6 @@ class LPMMKnowledgeConfig(ConfigBase):
"""嵌入向量维度,应该与模型的输出维度一致""" """嵌入向量维度,应该与模型的输出维度一致"""
@dataclass
class JargonConfig(ConfigBase):
"""Jargon配置类"""
all_global: bool = False
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id"""
@dataclass @dataclass
class DreamConfig(ConfigBase): class DreamConfig(ConfigBase):
"""Dream配置类""" """Dream配置类"""

View File

@ -4,7 +4,7 @@ from src.common.logger import get_logger
from src.common.database.database_model import Jargon from src.common.database.database_model import Jargon
from src.config.config import global_config from src.config.config import global_config
from src.chat.utils.utils import parse_keywords_string from src.chat.utils.utils import parse_keywords_string
from src.jargon.jargon_utils import parse_chat_id_list, chat_id_list_contains from src.bw_learner.learner_utils import parse_chat_id_list, chat_id_list_contains
logger = get_logger("dream_agent") logger = get_logger("dream_agent")
@ -24,7 +24,7 @@ def make_search_jargon(chat_id: str):
query = Jargon.select().where(Jargon.is_jargon) query = Jargon.select().where(Jargon.is_jargon)
# 根据 all_global 配置决定 chat_id 作用域 # 根据 all_global 配置决定 chat_id 作用域
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 开启全局黑话:只看 is_global=True 的记录,不区分 chat_id # 开启全局黑话:只看 is_global=True 的记录,不区分 chat_id
query = query.where(Jargon.is_global) query = query.where(Jargon.is_global)
else: else:
@ -63,7 +63,7 @@ def make_search_jargon(chat_id: str):
if any_matched: if any_matched:
filtered_keyword.append(r) filtered_keyword.append(r)
if global_config.jargon.all_global: if global_config.expression.all_global_jargon:
# 全局黑话模式:不再做 chat_id 过滤,直接使用关键词过滤结果 # 全局黑话模式:不再做 chat_id 过滤,直接使用关键词过滤结果
records = filtered_keyword records = filtered_keyword
else: else:
@ -80,7 +80,7 @@ def make_search_jargon(chat_id: str):
if not records: if not records:
scope_note = ( scope_note = (
"(当前为全局黑话模式,仅统计 is_global=True 的条目)" "(当前为全局黑话模式,仅统计 is_global=True 的条目)"
if global_config.jargon.all_global if global_config.expression.all_global_jargon
else "(当前为按 chat_id 作用域模式,仅统计全局黑话或与当前 chat_id 相关的条目)" else "(当前为按 chat_id 作用域模式,仅统计全局黑话或与当前 chat_id 相关的条目)"
) )
return f"未找到包含关键词'{keyword}'的 Jargon 记录{scope_note}" return f"未找到包含关键词'{keyword}'的 Jargon 记录{scope_note}"

View File

@ -1,145 +0,0 @@
import re
import difflib
import random
from datetime import datetime
from typing import Optional, List, Dict
def filter_message_content(content: Optional[str]) -> str:
"""
过滤消息内容移除回复@图片等格式
Args:
content: 原始消息内容
Returns:
str: 过滤后的内容
"""
if not content:
return ""
# 移除以[回复开头、]结尾的部分,包括后面的",说:"部分
content = re.sub(r"\[回复.*?\],说:\s*", "", content)
# 移除@<...>格式的内容
content = re.sub(r"@<[^>]*>", "", content)
# 移除[picid:...]格式的图片ID
content = re.sub(r"\[picid:[^\]]*\]", "", content)
# 移除[表情包:...]格式的内容
content = re.sub(r"\[表情包:[^\]]*\]", "", content)
return content.strip()
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度返回0-1之间的值
使用SequenceMatcher计算相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度值范围0-1
"""
return difflib.SequenceMatcher(None, text1, text2).ratio()
def format_create_date(timestamp: float) -> str:
"""
将时间戳格式化为可读的日期字符串
Args:
timestamp: 时间戳
Returns:
str: 格式化后的日期字符串
"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~5之间
count越高权重越高但最多为基础权重的5倍
如果表达已checked权重会再乘以3倍
"""
if not population:
return []
counts = []
checked_flags = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
# 获取checked状态
checked = item.get("checked", False)
checked_flags.append(bool(checked))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
base_weights = [1.0 for _ in counts]
else:
base_weights = []
for count_value in counts:
# 线性映射到[1,5]区间
normalized = (count_value - min_count) / (max_count - min_count)
base_weights.append(1.0 + normalized * 4.0) # 1~3
# 如果checked权重乘以3
weights = []
for base_weight, checked in zip(base_weights, checked_flags, strict=False):
if checked:
weights.append(base_weight * 3.0)
else:
weights.append(base_weight)
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
Args:
population: 总体数据列表
k: 需要抽取的数量
Returns:
List[Dict]: 抽取的数据列表
"""
if not population or k <= 0:
return []
if len(population) <= k:
return population.copy()
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected

View File

@ -1,5 +0,0 @@
from .jargon_miner import extract_and_store_jargon
__all__ = [
"extract_and_store_jargon",
]

View File

@ -11,7 +11,7 @@ from src.common.database.database_model import ThinkingBack
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.memory_utils import parse_questions_json from src.memory_system.memory_utils import parse_questions_json
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.jargon.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon from src.bw_learner.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon
logger = get_logger("memory_retrieval") logger = get_logger("memory_retrieval")
@ -972,6 +972,7 @@ async def _process_single_question(
context: str, context: str,
initial_info: str = "", initial_info: str = "",
initial_jargon_concepts: Optional[List[str]] = None, initial_jargon_concepts: Optional[List[str]] = None,
max_iterations: Optional[int] = None,
) -> Optional[str]: ) -> Optional[str]:
"""处理单个问题的查询 """处理单个问题的查询
@ -996,10 +997,14 @@ async def _process_single_question(
jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None
# 如果未指定max_iterations使用配置的默认值
if max_iterations is None:
max_iterations = global_config.memory.max_agent_iterations
found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question(
question=question, question=question,
chat_id=chat_id, chat_id=chat_id,
max_iterations=global_config.memory.max_agent_iterations, max_iterations=max_iterations,
timeout=global_config.memory.agent_timeout_seconds, timeout=global_config.memory.agent_timeout_seconds,
initial_info=question_initial_info, initial_info=question_initial_info,
initial_jargon_concepts=jargon_concepts_for_agent, initial_jargon_concepts=jargon_concepts_for_agent,
@ -1030,6 +1035,7 @@ async def build_memory_retrieval_prompt(
target: str, target: str,
chat_stream, chat_stream,
tool_executor, tool_executor,
think_level: int = 1,
) -> str: ) -> str:
"""构建记忆检索提示 """构建记忆检索提示
使用两段式查询第一步生成问题第二步使用ReAct Agent查询答案 使用两段式查询第一步生成问题第二步使用ReAct Agent查询答案
@ -1117,9 +1123,14 @@ async def build_memory_retrieval_prompt(
return "" return ""
# 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间) # 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间)
max_iterations = global_config.memory.max_agent_iterations base_max_iterations = global_config.memory.max_agent_iterations
# 根据think_level调整迭代次数think_level=1时不变think_level=0时减半
if think_level == 0:
max_iterations = max(1, base_max_iterations // 2) # 至少为1
else:
max_iterations = base_max_iterations
timeout_seconds = global_config.memory.agent_timeout_seconds timeout_seconds = global_config.memory.agent_timeout_seconds
logger.debug(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations},超时时间: {timeout_seconds}") logger.debug(f"问题数量: {len(questions)}think_level={think_level}设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations},超时时间: {timeout_seconds}")
# 并行处理所有问题,将概念检索结果作为初始信息传递 # 并行处理所有问题,将概念检索结果作为初始信息传递
question_tasks = [ question_tasks = [
@ -1129,6 +1140,7 @@ async def build_memory_retrieval_prompt(
context=message, context=message,
initial_info=initial_info, initial_info=initial_info,
initial_jargon_concepts=concepts if enable_jargon_detection else None, initial_jargon_concepts=concepts if enable_jargon_detection else None,
max_iterations=max_iterations,
) )
for question in questions for question in questions
] ]

View File

@ -30,7 +30,6 @@ from src.config.official_configs import (
MemoryConfig, MemoryConfig,
DebugConfig, DebugConfig,
VoiceConfig, VoiceConfig,
JargonConfig,
) )
from src.config.api_ada_configs import ( from src.config.api_ada_configs import (
ModelTaskConfig, ModelTaskConfig,
@ -129,7 +128,6 @@ async def get_config_section_schema(section_name: str):
"memory": MemoryConfig, "memory": MemoryConfig,
"debug": DebugConfig, "debug": DebugConfig,
"voice": VoiceConfig, "voice": VoiceConfig,
"jargon": JargonConfig,
"model_task_config": ModelTaskConfig, "model_task_config": ModelTaskConfig,
"api_provider": APIProvider, "api_provider": APIProvider,
"model_info": ModelInfo, "model_info": ModelInfo,

View File

@ -1,5 +1,5 @@
[inner] [inner]
version = "7.0.2" version = "7.1.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
# 如果你想要修改配置文件请递增version的值 # 如果你想要修改配置文件请递增version的值
@ -60,16 +60,14 @@ state_probability = 0.3
[expression] [expression]
# 表达学习配置 # 表达学习配置
learning_list = [ # 表达学习配置列表,支持按聊天流配置 learning_list = [ # 表达学习配置列表,支持按聊天流配置
["", "enable", "enable", "1.0"], # 全局配置:使用表达,启用学习,学习强度1.0 ["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习
["qq:1919810:group", "enable", "enable", "1.5"], # 特定群聊配置:使用表达,启用学习,学习强度1.5 ["qq:1919810:group", "enable", "enable", "enable"], # 特定群聊配置:使用表达,启用学习,启用jargon学习
["qq:114514:private", "enable", "disable", "0.5"], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 ["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习
# 格式说明: # 格式说明:
# 第一位: chat_stream_id空字符串表示全局配置 # 第一位: chat_stream_id空字符串表示全局配置
# 第二位: 是否使用学到的表达 ("enable"/"disable") # 第二位: 是否使用学到的表达 ("enable"/"disable")
# 第三位: 是否学习表达 ("enable"/"disable") # 第三位: 是否学习表达 ("enable"/"disable")
# 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) # 第四位: 是否启用jargon学习 ("enable"/"disable")
# 学习强度越高,学习越频繁;学习强度越低,学习越少
# 如果学习强度设置为0会自动转换为0.0001以避免除以零错误
] ]
expression_groups = [ expression_groups = [
@ -85,6 +83,8 @@ reflect = false # 是否启用表达反思Bot主动向管理员询问表达
reflect_operator_id = "" # 表达反思操作员ID格式platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group") reflect_operator_id = "" # 表达反思操作员ID格式platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group")
allow_reflect = [] # 允许进行表达反思的聊天流ID列表格式["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true allow_reflect = [] # 允许进行表达反思的聊天流ID列表格式["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true
all_global_jargon = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
[chat] # 麦麦的聊天设置 [chat] # 麦麦的聊天设置
talk_value = 1 # 聊天频率越小越沉默范围0-1如果设置为0会自动转换为0.0001以避免除以零错误 talk_value = 1 # 聊天频率越小越沉默范围0-1如果设置为0会自动转换为0.0001以避免除以零错误
@ -131,9 +131,6 @@ dream_time_ranges = [
] ]
# dream_time_ranges = [] # dream_time_ranges = []
[jargon]
all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除
[tool] [tool]
enable_tool = true # 是否启用工具 enable_tool = true # 是否启用工具