diff --git a/src/express/expression_learner.py b/src/bw_learner/expression_learner.py similarity index 74% rename from src/express/expression_learner.py rename to src/bw_learner/expression_learner.py index 58b2cd6d..38fcb370 100644 --- a/src/express/expression_learner.py +++ b/src/bw_learner/expression_learner.py @@ -3,19 +3,17 @@ import json import os import re import asyncio -from typing import List, Optional, Tuple -import traceback +from typing import List, Optional, Tuple, Any from src.common.logger import get_logger from src.common.database.database_model import Expression from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp_with_chat_inclusive, build_anonymous_messages, ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.express.express_utils import filter_message_content +from src.bw_learner.learner_utils import filter_message_content, is_bot_message from json_repair import repair_json @@ -26,15 +24,14 @@ logger = get_logger("expressor") def init_prompt() -> None: learn_style_prompt = """{chat_str} - -请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格。 -每一行消息前面的方括号中的数字(如 [1]、[2])是该行消息的唯一编号,请在输出中引用这些编号来标注“表达方式的来源行”。 +你的名字是{bot_name},现在请你请从上面这段群聊中用户的语言风格和说话方式 1. 只考虑文字,不要考虑表情包和图片 -2. 不要涉及具体的人名,但是可以涉及具体名词 -3. 思考有没有特殊的梗,一并总结成语言风格 -4. 例子仅供参考,请严格根据群聊内容总结!!! +2. 不要总结SELF的发言 +3. 不要涉及具体的人名,也不要涉及具体名词 +4. 思考有没有特殊的梗,一并总结成语言风格 +5. 例子仅供参考,请严格根据群聊内容总结!!! 注意:总结成如下格式的规律,总结的内容要详细,但具有概括性: -例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。 +例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。 请严格以 JSON 数组的形式输出结果,每个元素为一个对象,结构如下(注意字段名): [ @@ -45,10 +42,6 @@ def init_prompt() -> None: {{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "使用 这么强!", "source_id": "[消息编号]"}}, ] -请注意: -- 不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性 -- 请只针对最重要的若干条表达方式进行总结,避免输出太多重复或相似的条目 - 其中: - situation:表示“在什么情境下”的简短概括(不超过20个字) - style:表示对应的语言风格或常用表达(不超过20个字) @@ -69,170 +62,36 @@ class ExpressionLearner: self.summary_model: LLMRequest = LLMRequest( model_set=model_config.model_task_config.utils_small, request_type="expression.summary" ) - self.embedding_model: LLMRequest = LLMRequest( - model_set=model_config.model_task_config.embedding, request_type="expression.embedding" - ) self.chat_id = chat_id self.chat_stream = get_chat_manager().get_stream(chat_id) self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id - # 维护每个chat的上次学习时间 - self.last_learning_time: float = time.time() - # 学习锁,防止并发执行学习任务 self._learning_lock = asyncio.Lock() - # 学习参数 - _, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat( - self.chat_id - ) - # 防止除以零:如果学习强度为0或负数,使用最小值0.0001 - if self.learning_intensity <= 0: - logger.warning(f"学习强度为 {self.learning_intensity},已自动调整为 0.0001 以避免除以零错误") - self.learning_intensity = 0.0000001 - self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数 - self.min_learning_interval = 120 / self.learning_intensity - - def should_trigger_learning(self) -> bool: - """ - 检查是否应该触发学习 - - Args: - chat_id: 聊天流ID - - Returns: - bool: 是否应该触发学习 - """ - # 检查是否允许学习 - if not self.enable_learning: - return False - - # 检查时间间隔 - time_diff = time.time() - self.last_learning_time - if time_diff < self.min_learning_interval: - return False - - # 检查消息数量(只检查指定聊天流的消息) - recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_learning_time, - timestamp_end=time.time(), - ) - - if not recent_messages or len(recent_messages) < self.min_messages_for_learning: - return False - - return True - - async def trigger_learning_for_chat(self): - """ - 为指定聊天流触发学习 - - Args: - chat_id: 聊天流ID - - Returns: - bool: 是否成功触发学习 - """ - # 使用异步锁防止并发执行 - async with self._learning_lock: - # 在锁内检查,避免并发触发 - # 如果锁被持有,其他协程会等待,但等待期间条件可能已变化,所以需要再次检查 - if not self.should_trigger_learning(): - return - - # 保存学习开始前的时间戳,用于获取消息范围 - learning_start_timestamp = time.time() - previous_learning_time = self.last_learning_time - - # 立即更新学习时间,防止并发触发 - self.last_learning_time = learning_start_timestamp - - try: - logger.info(f"在聊天流 {self.chat_name} 学习表达方式") - # 学习语言风格,传递学习开始前的时间戳 - learnt_style = await self.learn_and_store(num=25, timestamp_start=previous_learning_time) - - if learnt_style: - logger.info(f"聊天流 {self.chat_name} 表达学习完成") - else: - logger.warning(f"聊天流 {self.chat_name} 表达学习未获得有效结果") - - except Exception as e: - logger.error(f"为聊天流 {self.chat_name} 触发学习失败: {e}") - traceback.print_exc() - # 即使失败也保持时间戳更新,避免频繁重试 - return - - async def learn_and_store(self, num: int = 10, timestamp_start: Optional[float] = None) -> List[Tuple[str, str, str]]: + async def learn_and_store( + self, + messages: List[Any], + ) -> List[Tuple[str, str, str]]: """ 学习并存储表达方式 Args: + messages: 外部传入的消息列表(必需) num: 学习数量 timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time """ - learnt_expressions = await self.learn_expression(num, timestamp_start=timestamp_start) - - if learnt_expressions is None: - logger.info("没有学习到表达风格") - return [] - - # 展示学到的表达方式 - learnt_expressions_str = "" - for ( - situation, - style, - _context, - ) in learnt_expressions: - learnt_expressions_str += f"{situation}->{style}\n" - logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") - - current_time = time.time() - - # 存储到数据库 Expression 表 - for ( - situation, - style, - context, - ) in learnt_expressions: - await self._upsert_expression_record( - situation=situation, - style=style, - context=context, - current_time=current_time, - ) - - return learnt_expressions - - async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str]]]: - """从指定聊天流学习表达方式 - - Args: - num: 学习数量 - timestamp_start: 学习开始的时间戳,如果为None则使用self.last_learning_time - """ - current_time = time.time() - - # 使用传入的时间戳,如果没有则使用self.last_learning_time - start_timestamp = timestamp_start if timestamp_start is not None else self.last_learning_time - - # 获取上次学习之后的消息 - random_msg = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=start_timestamp, - timestamp_end=current_time, - limit=num, - ) - # print(random_msg) - if not random_msg or random_msg == []: + if not messages: return None + + random_msg = messages # 学习用(开启行编号,便于溯源) random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True) prompt: str = await global_prompt_manager.format_prompt( "learn_style_prompt", + bot_name=global_config.bot.nickname, chat_str=random_msg_str, ) @@ -269,16 +128,50 @@ class ExpressionLearner: # 当前行的原始内容 current_msg = random_msg[line_index] + + # 过滤掉从bot自己发言中提取到的表达方式 + if is_bot_message(current_msg): + continue + context = filter_message_content(current_msg.processed_plain_text or "") if not context: continue filtered_expressions.append((situation, style, context)) + + + learnt_expressions = filtered_expressions - if not filtered_expressions: - return None + if learnt_expressions is None: + logger.info("没有学习到表达风格") + return [] - return filtered_expressions + # 展示学到的表达方式 + learnt_expressions_str = "" + for ( + situation, + style, + _context, + ) in learnt_expressions: + learnt_expressions_str += f"{situation}->{style}\n" + logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") + + current_time = time.time() + + # 存储到数据库 Expression 表 + for ( + situation, + style, + context, + ) in learnt_expressions: + await self._upsert_expression_record( + situation=situation, + style=style, + context=context, + current_time=current_time, + ) + + return learnt_expressions def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: """ @@ -356,9 +249,9 @@ class ExpressionLearner: if in_string: # 在字符串值内部,将中文引号替换为转义的英文引号 - if char == '"': # 中文左引号 + if char == '"': # 中文左引号 U+201C result.append('\\"') - elif char == '"': # 中文右引号 + elif char == '"': # 中文右引号 U+201D result.append('\\"') else: result.append(char) diff --git a/src/express/expression_reflector.py b/src/bw_learner/expression_reflector.py similarity index 100% rename from src/express/expression_reflector.py rename to src/bw_learner/expression_reflector.py diff --git a/src/express/expression_selector.py b/src/bw_learner/expression_selector.py similarity index 99% rename from src/express/expression_selector.py rename to src/bw_learner/expression_selector.py index 8d586d5d..4bbca2b8 100644 --- a/src/express/expression_selector.py +++ b/src/bw_learner/expression_selector.py @@ -10,7 +10,7 @@ from src.config.config import global_config, model_config from src.common.logger import get_logger from src.common.database.database_model import Expression from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.express.express_utils import weighted_sample +from src.bw_learner.learner_utils import weighted_sample logger = get_logger("expression_selector") diff --git a/src/jargon/jargon_explainer.py b/src/bw_learner/jargon_explainer.py similarity index 96% rename from src/jargon/jargon_explainer.py rename to src/bw_learner/jargon_explainer.py index 28122008..ac62fa5f 100644 --- a/src/jargon/jargon_explainer.py +++ b/src/bw_learner/jargon_explainer.py @@ -7,8 +7,8 @@ from src.common.database.database_model import Jargon from src.llm_models.utils_model import LLMRequest from src.config.config import model_config, global_config from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.jargon.jargon_miner import search_jargon -from src.jargon.jargon_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains +from src.bw_learner.jargon_miner import search_jargon +from src.bw_learner.learner_utils import is_bot_message, contains_bot_self_name, parse_chat_id_list, chat_id_list_contains logger = get_logger("jargon") @@ -82,7 +82,7 @@ class JargonExplainer: query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) # 根据all_global配置决定查询逻辑 - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 开启all_global:只查询is_global=True的记录 query = query.where(Jargon.is_global) else: @@ -107,7 +107,7 @@ class JargonExplainer: continue # 检查chat_id(如果all_global=False) - if not global_config.jargon.all_global: + if not global_config.expression.all_global_jargon: if jargon.is_global: # 全局黑话,包含 pass @@ -181,7 +181,7 @@ class JargonExplainer: content = entry["content"] # 根据是否开启全局黑话,决定查询方式 - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 开启全局黑话:查询所有is_global=True的记录 results = search_jargon( keyword=content, @@ -265,7 +265,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]: return [] query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != "")) - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: query = query.where(Jargon.is_global) query = query.order_by(Jargon.count.desc()) @@ -277,7 +277,7 @@ def match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]: if not content: continue - if not global_config.jargon.all_global and not jargon.is_global: + if not global_config.expression.all_global_jargon and not jargon.is_global: chat_id_list = parse_chat_id_list(jargon.chat_id) if not chat_id_list_contains(chat_id_list, chat_id): continue diff --git a/src/jargon/jargon_miner.py b/src/bw_learner/jargon_miner.py similarity index 91% rename from src/jargon/jargon_miner.py rename to src/bw_learner/jargon_miner.py index 7cda4e02..2e456122 100644 --- a/src/jargon/jargon_miner.py +++ b/src/bw_learner/jargon_miner.py @@ -1,6 +1,7 @@ import time import json import asyncio +import random from collections import OrderedDict from typing import List, Dict, Optional, Any from json_repair import repair_json @@ -16,7 +17,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp_with_chat_inclusive, ) from src.chat.utils.prompt_builder import Prompt, global_prompt_manager -from src.jargon.jargon_utils import ( +from src.bw_learner.learner_utils import ( is_bot_message, build_context_paragraph, contains_bot_self_name, @@ -29,6 +30,29 @@ from src.jargon.jargon_utils import ( logger = get_logger("jargon") +def _is_single_char_jargon(content: str) -> bool: + """ + 判断是否是单字黑话(单个汉字、英文或数字) + + Args: + content: 词条内容 + + Returns: + bool: 如果是单字黑话返回True,否则返回False + """ + if not content or len(content) != 1: + return False + + char = content[0] + # 判断是否是单个汉字、单个英文字母或单个数字 + return ( + '\u4e00' <= char <= '\u9fff' or # 汉字 + 'a' <= char <= 'z' or # 小写字母 + 'A' <= char <= 'Z' or # 大写字母 + '0' <= char <= '9' # 数字 + ) + + def _init_prompt() -> None: prompt_str = """ **聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID** @@ -36,11 +60,9 @@ def _init_prompt() -> None: 请从上面这段聊天内容中提取"可能是黑话"的候选项(黑话/俚语/网络缩写/口头禅)。 - 必须为对话中真实出现过的短词或短语 -- 必须是你无法理解含义的词语,没有明确含义的词语 -- 请不要选择有明确含义,或者含义清晰的词语 +- 必须是你无法理解含义的词语,没有明确含义的词语,请不要选择有明确含义,或者含义清晰的词语 - 排除:人名、@、表情包/图片中的内容、纯标点、常规功能词(如的、了、呢、啊等) - 每个词条长度建议 2-8 个字符(不强制),尽量短小 -- 合并重复项,去重 黑话必须为以下几种类型: - 由字母构成的,汉语拼音首字母的简写词,例如:nb、yyds、xswl @@ -67,12 +89,14 @@ def _init_inference_prompts() -> None: {content} **词条出现的上下文。其中的{bot_name}的发言内容是你自己的发言** {raw_content_list} +{previous_meaning_section} 请根据上下文,推断"{content}"这个词条的含义。 - 如果这是一个黑话、俚语或网络用语,请推断其含义 - 如果含义明确(常规词汇),也请说明 - {bot_name} 的发言内容可能包含错误,请不要参考其发言内容 - 如果上下文信息不足,无法推断含义,请设置 no_info 为 true +{previous_meaning_instruction} 以 JSON 格式输出: {{ @@ -166,10 +190,6 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool: class JargonMiner: def __init__(self, chat_id: str) -> None: self.chat_id = chat_id - self.last_learning_time: float = time.time() - # 频率控制,可按需调整 - self.min_messages_for_learning: int = 30 - self.min_learning_interval: float = 60 self.llm = LLMRequest( model_set=model_config.model_task_config.utils, @@ -200,6 +220,10 @@ class JargonMiner: if not key: return + # 单字黑话(单个汉字、英文或数字)不记录到缓存 + if _is_single_char_jargon(key): + return + if key in self.cache: self.cache.move_to_end(key) else: @@ -272,13 +296,37 @@ class JargonMiner: logger.warning(f"jargon {content} 没有raw_content,跳过推断") return + # 获取当前count和上一次的meaning + current_count = jargon_obj.count or 0 + previous_meaning = jargon_obj.meaning or "" + + # 当count为24, 60时,随机移除一半的raw_content项目 + if current_count in [24, 60] and len(raw_content_list) > 1: + # 计算要保留的数量(至少保留1个) + keep_count = max(1, len(raw_content_list) // 2) + raw_content_list = random.sample(raw_content_list, keep_count) + logger.info(f"jargon {content} count={current_count},随机移除后剩余 {len(raw_content_list)} 个raw_content项目") + # 步骤1: 基于raw_content和content推断 raw_content_text = "\n".join(raw_content_list) + + # 当count为24, 60, 100时,在prompt中放入上一次推断出的meaning作为参考 + previous_meaning_section = "" + previous_meaning_instruction = "" + if current_count in [24, 60, 100] and previous_meaning: + previous_meaning_section = f""" +**上一次推断的含义(仅供参考)** +{previous_meaning} +""" + previous_meaning_instruction = "- 请参考上一次推断的含义,结合新的上下文信息,给出更准确或更新的推断结果" + prompt1 = await global_prompt_manager.format_prompt( "jargon_inference_with_context_prompt", content=content, bot_name=global_config.bot.nickname, raw_content_list=raw_content_text, + previous_meaning_section=previous_meaning_section, + previous_meaning_instruction=previous_meaning_instruction, ) response1, _ = await self.llm_inference.generate_response_async(prompt1, temperature=0.3) @@ -430,45 +478,16 @@ class JargonMiner: traceback.print_exc() - def should_trigger(self) -> bool: - # 冷却时间检查 - if time.time() - self.last_learning_time < self.min_learning_interval: - return False - - # 拉取最近消息数量是否足够 - recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=self.last_learning_time, - timestamp_end=time.time(), - ) - return bool(recent_messages and len(recent_messages) >= self.min_messages_for_learning) - - async def run_once(self) -> None: + async def run_once(self, messages: List[Any]) -> None: + """ + 运行一次黑话提取 + + Args: + messages: 外部传入的消息列表(必需) + """ # 使用异步锁防止并发执行 async with self._extraction_lock: try: - # 在锁内检查,避免并发触发 - if not self.should_trigger(): - return - - chat_stream = get_chat_manager().get_stream(self.chat_id) - if not chat_stream: - return - - # 记录本次提取的时间窗口,避免重复提取 - extraction_start_time = self.last_learning_time - extraction_end_time = time.time() - - # 立即更新学习时间,防止并发触发 - self.last_learning_time = extraction_end_time - - # 拉取学习窗口内的消息 - messages = get_raw_msg_by_timestamp_with_chat_inclusive( - chat_id=self.chat_id, - timestamp_start=extraction_start_time, - timestamp_end=extraction_end_time, - limit=20, - ) if not messages: return @@ -608,7 +627,7 @@ class JargonMiner: # 查找匹配的记录 matched_obj = None for obj in query: - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 开启all_global:所有content匹配的记录都可以 matched_obj = obj break @@ -648,7 +667,7 @@ class JargonMiner: obj.chat_id = json.dumps(updated_chat_id_list, ensure_ascii=False) # 开启all_global时,确保记录标记为is_global=True - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: obj.is_global = True # 关闭all_global时,保持原有is_global不变(不修改) @@ -664,7 +683,7 @@ class JargonMiner: updated += 1 else: # 没找到匹配记录,创建新记录 - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 开启all_global:新记录默认为is_global=True is_global_new = True else: @@ -718,9 +737,6 @@ class JargonMinerManager: miner_manager = JargonMinerManager() -async def extract_and_store_jargon(chat_id: str) -> None: - miner = miner_manager.get_miner(chat_id) - await miner.run_once() def search_jargon( @@ -770,7 +786,7 @@ def search_jargon( query = query.where(search_condition) # 根据all_global配置决定查询逻辑 - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 开启all_global:所有记录都是全局的,查询所有is_global=True的记录(无视chat_id) query = query.where(Jargon.is_global) # 注意:对于all_global=False的情况,chat_id过滤在Python层面进行,以便兼容新旧格式 @@ -787,7 +803,7 @@ def search_jargon( results = [] for jargon in query: # 如果提供了chat_id且all_global=False,需要检查chat_id列表是否包含目标chat_id - if chat_id and not global_config.jargon.all_global: + if chat_id and not global_config.expression.all_global_jargon: chat_id_list = parse_chat_id_list(jargon.chat_id) # 如果记录是is_global=True,或者chat_id列表包含目标chat_id,则包含 if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): diff --git a/src/jargon/jargon_utils.py b/src/bw_learner/learner_utils.py similarity index 53% rename from src/jargon/jargon_utils.py rename to src/bw_learner/learner_utils.py index f42d807b..e67f8ea4 100644 --- a/src/jargon/jargon_utils.py +++ b/src/bw_learner/learner_utils.py @@ -1,5 +1,9 @@ +import re +import difflib +import random import json -from typing import List, Dict, Optional, Any +from datetime import datetime +from typing import Optional, List, Dict, Any from src.common.logger import get_logger from src.config.config import global_config @@ -9,7 +13,147 @@ from src.chat.utils.chat_message_builder import ( from src.chat.utils.utils import parse_platform_accounts -logger = get_logger("jargon") +logger = get_logger("learner_utils") + + +def filter_message_content(content: Optional[str]) -> str: + """ + 过滤消息内容,移除回复、@、图片等格式 + + Args: + content: 原始消息内容 + + Returns: + str: 过滤后的内容 + """ + if not content: + return "" + + # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 + content = re.sub(r"\[回复.*?\],说:\s*", "", content) + # 移除@<...>格式的内容 + content = re.sub(r"@<[^>]*>", "", content) + # 移除[picid:...]格式的图片ID + content = re.sub(r"\[picid:[^\]]*\]", "", content) + # 移除[表情包:...]格式的内容 + content = re.sub(r"\[表情包:[^\]]*\]", "", content) + + return content.strip() + + +def calculate_similarity(text1: str, text2: str) -> float: + """ + 计算两个文本的相似度,返回0-1之间的值 + 使用SequenceMatcher计算相似度 + + Args: + text1: 第一个文本 + text2: 第二个文本 + + Returns: + float: 相似度值,范围0-1 + """ + return difflib.SequenceMatcher(None, text1, text2).ratio() + + +def format_create_date(timestamp: float) -> str: + """ + 将时间戳格式化为可读的日期字符串 + + Args: + timestamp: 时间戳 + + Returns: + str: 格式化后的日期字符串 + """ + try: + return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") + except (ValueError, OSError): + return "未知时间" + + +def _compute_weights(population: List[Dict]) -> List[float]: + """ + 根据表达的count计算权重,范围限定在1~5之间。 + count越高,权重越高,但最多为基础权重的5倍。 + 如果表达已checked,权重会再乘以3倍。 + """ + if not population: + return [] + + counts = [] + checked_flags = [] + for item in population: + count = item.get("count", 1) + try: + count_value = float(count) + except (TypeError, ValueError): + count_value = 1.0 + counts.append(max(count_value, 0.0)) + # 获取checked状态 + checked = item.get("checked", False) + checked_flags.append(bool(checked)) + + min_count = min(counts) + max_count = max(counts) + + if max_count == min_count: + base_weights = [1.0 for _ in counts] + else: + base_weights = [] + for count_value in counts: + # 线性映射到[1,5]区间 + normalized = (count_value - min_count) / (max_count - min_count) + base_weights.append(1.0 + normalized * 4.0) # 1~5 + + # 如果checked,权重乘以3 + weights = [] + for base_weight, checked in zip(base_weights, checked_flags, strict=False): + if checked: + weights.append(base_weight * 3.0) + else: + weights.append(base_weight) + return weights + + +def weighted_sample(population: List[Dict], k: int) -> List[Dict]: + """ + 随机抽样函数 + + Args: + population: 总体数据列表 + k: 需要抽取的数量 + + Returns: + List[Dict]: 抽取的数据列表 + """ + if not population or k <= 0: + return [] + + if len(population) <= k: + return population.copy() + + selected: List[Dict] = [] + population_copy = population.copy() + + for _ in range(min(k, len(population_copy))): + weights = _compute_weights(population_copy) + total_weight = sum(weights) + if total_weight <= 0: + # 回退到均匀随机 + idx = random.randint(0, len(population_copy) - 1) + selected.append(population_copy.pop(idx)) + continue + + threshold = random.uniform(0, total_weight) + cumulative = 0.0 + for idx, weight in enumerate(weights): + cumulative += weight + if threshold <= cumulative: + selected.append(population_copy.pop(idx)) + break + + return selected def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: @@ -62,25 +206,37 @@ def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, incr Returns: List[List[Any]]: 更新后的chat_id列表 """ - # 查找是否已存在该chat_id - found = False - for item in chat_id_list: - if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): - # 找到匹配的chat_id,增加计数 - if len(item) >= 2: - item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment - else: - item.append(increment) - found = True - break - - if not found: + item = _find_chat_id_item(chat_id_list, target_chat_id) + if item is not None: + # 找到匹配的chat_id,增加计数 + if len(item) >= 2: + item[1] = (item[1] if isinstance(item[1], (int, float)) else 0) + increment + else: + item.append(increment) + else: # 未找到,添加新条目 chat_id_list.append([target_chat_id, increment]) return chat_id_list +def _find_chat_id_item(chat_id_list: List[List[Any]], target_chat_id: str) -> Optional[List[Any]]: + """ + 在chat_id列表中查找匹配的项(辅助函数) + + Args: + chat_id_list: chat_id列表,格式为 [[chat_id, count], ...] + target_chat_id: 要查找的chat_id + + Returns: + 如果找到则返回匹配的项,否则返回None + """ + for item in chat_id_list: + if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): + return item + return None + + def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: """ 检查chat_id列表中是否包含指定的chat_id @@ -92,10 +248,7 @@ def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> Returns: bool: 如果包含则返回True """ - for item in chat_id_list: - if isinstance(item, list) and len(item) >= 1 and str(item[0]) == str(target_chat_id): - return True - return False + return _find_chat_id_item(chat_id_list, target_chat_id) is not None def contains_bot_self_name(content: str) -> bool: @@ -115,7 +268,7 @@ def contains_bot_self_name(content: str) -> bool: candidates = [name for name in [nickname, *alias_names] if name] - return any(name in target for name in candidates if target) + return any(name in target for name in candidates) def build_context_paragraph(messages: List[Any], center_index: int) -> Optional[str]: diff --git a/src/bw_learner/message_recorder.py b/src/bw_learner/message_recorder.py new file mode 100644 index 00000000..ec310184 --- /dev/null +++ b/src/bw_learner/message_recorder.py @@ -0,0 +1,217 @@ +import time +import asyncio +from typing import List, Any +from src.common.logger import get_logger +from src.config.config import global_config +from src.chat.message_receive.chat_stream import get_chat_manager +from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat_inclusive +from src.bw_learner.expression_learner import expression_learner_manager +from src.bw_learner.jargon_miner import miner_manager + +logger = get_logger("bw_learner") + + +class MessageRecorder: + """ + 统一的消息记录器,负责管理时间窗口和消息提取,并将消息分发给 expression_learner 和 jargon_miner + """ + + def __init__(self, chat_id: str) -> None: + self.chat_id = chat_id + self.chat_stream = get_chat_manager().get_stream(chat_id) + self.chat_name = get_chat_manager().get_stream_name(chat_id) or chat_id + + # 维护每个chat的上次提取时间 + self.last_extraction_time: float = time.time() + + # 提取锁,防止并发执行 + self._extraction_lock = asyncio.Lock() + + # 获取 expression 和 jargon 的配置参数 + self._init_parameters() + + # 获取 expression_learner 和 jargon_miner 实例 + self.expression_learner = expression_learner_manager.get_expression_learner(chat_id) + self.jargon_miner = miner_manager.get_miner(chat_id) + + def _init_parameters(self) -> None: + """初始化提取参数""" + # 获取 expression 配置 + _, self.enable_expression_learning, self.enable_jargon_learning = ( + global_config.expression.get_expression_config_for_chat(self.chat_id) + ) + self.min_messages_for_extraction = 30 + self.min_extraction_interval = 60 + + logger.debug( + f"MessageRecorder 初始化: chat_id={self.chat_id}, " + f"min_messages={self.min_messages_for_extraction}, " + f"min_interval={self.min_extraction_interval}" + ) + + def should_trigger_extraction(self) -> bool: + """ + 检查是否应该触发消息提取 + + Returns: + bool: 是否应该触发提取 + """ + # 检查时间间隔 + time_diff = time.time() - self.last_extraction_time + if time_diff < self.min_extraction_interval: + return False + + # 检查消息数量 + recent_messages = get_raw_msg_by_timestamp_with_chat_inclusive( + chat_id=self.chat_id, + timestamp_start=self.last_extraction_time, + timestamp_end=time.time(), + ) + + if not recent_messages or len(recent_messages) < self.min_messages_for_extraction: + return False + + return True + + async def extract_and_distribute(self) -> None: + """ + 提取消息并分发给 expression_learner 和 jargon_miner + """ + # 使用异步锁防止并发执行 + async with self._extraction_lock: + # 在锁内检查,避免并发触发 + if not self.should_trigger_extraction(): + return + + # 检查 chat_stream 是否存在 + if not self.chat_stream: + return + + # 记录本次提取的时间窗口,避免重复提取 + extraction_start_time = self.last_extraction_time + extraction_end_time = time.time() + + # 立即更新提取时间,防止并发触发 + self.last_extraction_time = extraction_end_time + + try: + logger.info(f"在聊天流 {self.chat_name} 开始统一消息提取和分发") + + # 拉取提取窗口内的消息 + messages = get_raw_msg_by_timestamp_with_chat_inclusive( + chat_id=self.chat_id, + timestamp_start=extraction_start_time, + timestamp_end=extraction_end_time, + ) + + if not messages: + logger.debug(f"聊天流 {self.chat_name} 没有新消息,跳过提取") + return + + # 按时间排序,确保顺序一致 + messages = sorted(messages, key=lambda msg: msg.time or 0) + + logger.info( + f"聊天流 {self.chat_name} 提取到 {len(messages)} 条消息," + f"时间窗口: {extraction_start_time:.2f} - {extraction_end_time:.2f}" + ) + + + # 分别触发 expression_learner 和 jargon_miner 的处理 + # 传递提取的消息,避免它们重复获取 + # 触发 expression 学习(如果启用) + if self.enable_expression_learning: + asyncio.create_task( + self._trigger_expression_learning(extraction_start_time, extraction_end_time, messages) + ) + + # 触发 jargon 提取(如果启用),传递消息 + if self.enable_jargon_learning: + asyncio.create_task( + self._trigger_jargon_extraction(extraction_start_time, extraction_end_time, messages) + ) + + except Exception as e: + logger.error(f"为聊天流 {self.chat_name} 提取和分发消息失败: {e}") + import traceback + traceback.print_exc() + # 即使失败也保持时间戳更新,避免频繁重试 + + async def _trigger_expression_learning( + self, + timestamp_start: float, + timestamp_end: float, + messages: List[Any] + ) -> None: + """ + 触发 expression 学习,使用指定的消息列表 + + Args: + timestamp_start: 开始时间戳 + timestamp_end: 结束时间戳 + messages: 消息列表 + """ + try: + # 传递消息给 ExpressionLearner(必需参数) + learnt_style = await self.expression_learner.learn_and_store(messages=messages) + + if learnt_style: + logger.info(f"聊天流 {self.chat_name} 表达学习完成") + else: + logger.debug(f"聊天流 {self.chat_name} 表达学习未获得有效结果") + except Exception as e: + logger.error(f"为聊天流 {self.chat_name} 触发表达学习失败: {e}") + import traceback + traceback.print_exc() + + async def _trigger_jargon_extraction( + self, + timestamp_start: float, + timestamp_end: float, + messages: List[Any] + ) -> None: + """ + 触发 jargon 提取,使用指定的消息列表 + + Args: + timestamp_start: 开始时间戳 + timestamp_end: 结束时间戳 + messages: 消息列表 + """ + try: + # 传递消息给 JargonMiner,避免它重复获取 + await self.jargon_miner.run_once(messages=messages) + + except Exception as e: + logger.error(f"为聊天流 {self.chat_name} 触发黑话提取失败: {e}") + import traceback + traceback.print_exc() + + +class MessageRecorderManager: + """MessageRecorder 管理器""" + + def __init__(self) -> None: + self._recorders: dict[str, MessageRecorder] = {} + + def get_recorder(self, chat_id: str) -> MessageRecorder: + """获取或创建指定 chat_id 的 MessageRecorder""" + if chat_id not in self._recorders: + self._recorders[chat_id] = MessageRecorder(chat_id) + return self._recorders[chat_id] + + +# 全局管理器实例 +recorder_manager = MessageRecorderManager() + + +async def extract_and_distribute_messages(chat_id: str) -> None: + """ + 统一的消息提取和分发入口函数 + + Args: + chat_id: 聊天流ID + """ + recorder = recorder_manager.get_recorder(chat_id) + await recorder.extract_and_distribute() + diff --git a/src/express/reflect_tracker.py b/src/bw_learner/reflect_tracker.py similarity index 100% rename from src/express/reflect_tracker.py rename to src/bw_learner/reflect_tracker.py diff --git a/src/chat/brain_chat/brain_chat.py b/src/chat/brain_chat/brain_chat.py index 9945b496..4691f89a 100644 --- a/src/chat/brain_chat/brain_chat.py +++ b/src/chat/brain_chat/brain_chat.py @@ -16,7 +16,8 @@ from src.chat.brain_chat.brain_planner import BrainPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail -from src.express.expression_learner import expression_learner_manager +from src.bw_learner.expression_learner import expression_learner_manager +from src.bw_learner.message_recorder import extract_and_distribute_messages from src.person_info.person_info import Person from src.plugin_system.base.component_types import EventType, ActionInfo from src.plugin_system.core import events_manager @@ -252,7 +253,7 @@ class BrainChatting: # ReflectTracker Check # 在每次回复前检查一次上下文,看是否有反思问题得到了解答 # ------------------------------------------------------------------------- - from src.express.reflect_tracker import reflect_tracker_manager + from src.bw_learner.reflect_tracker import reflect_tracker_manager tracker = reflect_tracker_manager.get_tracker(self.stream_id) if tracker: @@ -265,13 +266,15 @@ class BrainChatting: # Expression Reflection Check # 检查是否需要提问表达反思 # ------------------------------------------------------------------------- - from src.express.expression_reflector import expression_reflector_manager + from src.bw_learner.expression_reflector import expression_reflector_manager reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) asyncio.create_task(reflector.check_and_ask()) async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): - asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) + # 通过 MessageRecorder 统一提取消息并分发给 expression_learner 和 jargon_miner + # 在 replyer 执行时触发,统一管理时间窗口,避免重复获取消息 + asyncio.create_task(extract_and_distribute_messages(self.stream_id)) cycle_timers, thinking_id = self.start_cycle() logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考") diff --git a/src/chat/heart_flow/heartFC_chat.py b/src/chat/heart_flow/heartFC_chat.py index 763c1663..7ad757de 100644 --- a/src/chat/heart_flow/heartFC_chat.py +++ b/src/chat/heart_flow/heartFC_chat.py @@ -16,10 +16,10 @@ from src.chat.planner_actions.planner import ActionPlanner from src.chat.planner_actions.action_modifier import ActionModifier from src.chat.planner_actions.action_manager import ActionManager from src.chat.heart_flow.hfc_utils import CycleDetail -from src.express.expression_learner import expression_learner_manager +from src.bw_learner.expression_learner import expression_learner_manager from src.chat.heart_flow.frequency_control import frequency_control_manager -from src.express.reflect_tracker import reflect_tracker_manager -from src.express.expression_reflector import expression_reflector_manager +from src.bw_learner.reflect_tracker import reflect_tracker_manager +from src.bw_learner.expression_reflector import expression_reflector_manager from src.bw_learner.message_recorder import extract_and_distribute_messages from src.person_info.person_info import Person from src.plugin_system.base.component_types import EventType, ActionInfo diff --git a/src/chat/replyer/group_generator.py b/src/chat/replyer/group_generator.py index 9ba979a9..054b1ea9 100644 --- a/src/chat/replyer/group_generator.py +++ b/src/chat/replyer/group_generator.py @@ -23,7 +23,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, replace_user_references, ) -from src.express.expression_selector import expression_selector +from src.bw_learner.expression_selector import expression_selector from src.plugin_system.apis.message_api import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator @@ -35,7 +35,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt -from src.jargon.jargon_explainer import explain_jargon_in_context +from src.bw_learner.jargon_explainer import explain_jargon_in_context init_lpmm_prompt() init_replyer_prompt() diff --git a/src/chat/replyer/private_generator.py b/src/chat/replyer/private_generator.py index e241759b..c0fef546 100644 --- a/src/chat/replyer/private_generator.py +++ b/src/chat/replyer/private_generator.py @@ -23,7 +23,7 @@ from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, replace_user_references, ) -from src.express.expression_selector import expression_selector +from src.bw_learner.expression_selector import expression_selector from src.plugin_system.apis.message_api import translate_pid_to_description # from src.memory_system.memory_activator import MemoryActivator @@ -36,7 +36,7 @@ from src.chat.replyer.prompt.lpmm_prompt import init_lpmm_prompt from src.chat.replyer.prompt.replyer_prompt import init_replyer_prompt from src.chat.replyer.prompt.rewrite_prompt import init_rewrite_prompt from src.memory_system.memory_retrieval import init_memory_retrieval_prompt, build_memory_retrieval_prompt -from src.jargon.jargon_explainer import explain_jargon_in_context +from src.bw_learner.jargon_explainer import explain_jargon_in_context init_lpmm_prompt() init_replyer_prompt() diff --git a/src/config/config.py b/src/config/config.py index d4ecade8..44125df4 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -33,7 +33,6 @@ from src.config.official_configs import ( VoiceConfig, MemoryConfig, DebugConfig, - JargonConfig, DreamConfig, ) @@ -355,7 +354,6 @@ class Config(ConfigBase): memory: MemoryConfig debug: DebugConfig voice: VoiceConfig - jargon: JargonConfig dream: DreamConfig diff --git a/src/config/official_configs.py b/src/config/official_configs.py index d93b7429..e9037c02 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -284,20 +284,20 @@ class ExpressionConfig(ConfigBase): learning_list: list[list] = field(default_factory=lambda: []) """ 表达学习配置列表,支持按聊天流配置 - 格式: [["chat_stream_id", "use_expression", "enable_learning", learning_intensity], ...] + 格式: [["chat_stream_id", "use_expression", "enable_learning", "enable_jargon_learning"], ...] 示例: [ - ["", "enable", "enable", 1.0], # 全局配置:使用表达,启用学习,学习强度1.0 - ["qq:1919810:private", "enable", "enable", 1.5], # 特定私聊配置:使用表达,启用学习,学习强度1.5 - ["qq:114514:private", "enable", "disable", 0.5], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 + ["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习 + ["qq:1919810:private", "enable", "enable", "enable"], # 特定私聊配置:使用表达,启用学习,启用jargon学习 + ["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习 ] 说明: - 第一位: chat_stream_id,空字符串表示全局配置 - 第二位: 是否使用学到的表达 ("enable"/"disable") - 第三位: 是否学习表达 ("enable"/"disable") - - 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) + - 第四位: 是否启用jargon学习 ("enable"/"disable") """ expression_groups: list[list[str]] = field(default_factory=list) @@ -320,6 +320,9 @@ class ExpressionConfig(ConfigBase): 如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true) """ + all_global_jargon: bool = False + """是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id。注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除""" + def _parse_stream_config_to_chat_id(self, stream_config_str: str) -> Optional[str]: """ 解析流配置字符串并生成对应的 chat_id @@ -355,7 +358,7 @@ class ExpressionConfig(ConfigBase): except (ValueError, IndexError): return None - def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, int]: + def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, bool]: """ 根据聊天流ID获取表达配置 @@ -363,35 +366,27 @@ class ExpressionConfig(ConfigBase): chat_stream_id: 聊天流ID,格式为哈希值 Returns: - tuple: (是否使用表达, 是否学习表达, 学习间隔) + tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习) """ if not self.learning_list: - # 如果没有配置,使用默认值:启用表达,启用学习,学习强度1.0(对应300秒间隔) - return True, True, 1.0 + # 如果没有配置,使用默认值:启用表达,启用学习,启用jargon学习 + return True, True, True # 优先检查聊天流特定的配置 if chat_stream_id: specific_expression_config = self._get_stream_specific_config(chat_stream_id) if specific_expression_config is not None: - use_expression, enable_learning, learning_intensity = specific_expression_config - # 防止学习强度为0,自动转换为0.0001 - if learning_intensity == 0: - learning_intensity = 0.0000001 - return use_expression, enable_learning, learning_intensity + return specific_expression_config # 检查全局配置(第一个元素为空字符串的配置) global_expression_config = self._get_global_config() if global_expression_config is not None: - use_expression, enable_learning, learning_intensity = global_expression_config - # 防止学习强度为0,自动转换为0.0001 - if learning_intensity == 0: - learning_intensity = 0.0000001 - return use_expression, enable_learning, learning_intensity + return global_expression_config - # 如果都没有匹配,返回默认值:启用表达,启用学习,学习强度1.0(对应300秒间隔) - return True, True, 1.0 + # 如果都没有匹配,返回默认值:启用表达,启用学习,启用jargon学习 + return True, True, True - def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, int]]: + def _get_stream_specific_config(self, chat_stream_id: str) -> Optional[tuple[bool, bool, bool]]: """ 获取特定聊天流的表达配置 @@ -399,7 +394,7 @@ class ExpressionConfig(ConfigBase): chat_stream_id: 聊天流ID(哈希值) Returns: - tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None + tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None """ for config_item in self.learning_list: if not config_item or len(config_item) < 4: @@ -424,22 +419,19 @@ class ExpressionConfig(ConfigBase): try: use_expression: bool = config_item[1].lower() == "enable" enable_learning: bool = config_item[2].lower() == "enable" - learning_intensity: float = float(config_item[3]) - # 防止学习强度为0,自动转换为0.0001 - if learning_intensity == 0: - learning_intensity = 0.0000001 - return use_expression, enable_learning, learning_intensity # type: ignore + enable_jargon_learning: bool = config_item[3].lower() == "enable" + return use_expression, enable_learning, enable_jargon_learning # type: ignore except (ValueError, IndexError): continue return None - def _get_global_config(self) -> Optional[tuple[bool, bool, int]]: + def _get_global_config(self) -> Optional[tuple[bool, bool, bool]]: """ 获取全局表达配置 Returns: - tuple: (是否使用表达, 是否学习表达, 学习间隔),如果没有配置则返回 None + tuple: (是否使用表达, 是否学习表达, 是否启用jargon学习),如果没有配置则返回 None """ for config_item in self.learning_list: if not config_item or len(config_item) < 4: @@ -450,11 +442,8 @@ class ExpressionConfig(ConfigBase): try: use_expression: bool = config_item[1].lower() == "enable" enable_learning: bool = config_item[2].lower() == "enable" - learning_intensity = float(config_item[3]) - # 防止学习强度为0,自动转换为0.0001 - if learning_intensity == 0: - learning_intensity = 0.0000001 - return use_expression, enable_learning, learning_intensity # type: ignore + enable_jargon_learning: bool = config_item[3].lower() == "enable" + return use_expression, enable_learning, enable_jargon_learning # type: ignore except (ValueError, IndexError): continue @@ -732,14 +721,6 @@ class LPMMKnowledgeConfig(ConfigBase): """嵌入向量维度,应该与模型的输出维度一致""" -@dataclass -class JargonConfig(ConfigBase): - """Jargon配置类""" - - all_global: bool = False - """是否将所有新增的jargon项目默认为全局(is_global=True),chat_id记录第一次存储时的id""" - - @dataclass class DreamConfig(ConfigBase): """Dream配置类""" diff --git a/src/dream/tools/search_jargon_tool.py b/src/dream/tools/search_jargon_tool.py index d20f170b..0429a6a7 100644 --- a/src/dream/tools/search_jargon_tool.py +++ b/src/dream/tools/search_jargon_tool.py @@ -4,7 +4,7 @@ from src.common.logger import get_logger from src.common.database.database_model import Jargon from src.config.config import global_config from src.chat.utils.utils import parse_keywords_string -from src.jargon.jargon_utils import parse_chat_id_list, chat_id_list_contains +from src.bw_learner.learner_utils import parse_chat_id_list, chat_id_list_contains logger = get_logger("dream_agent") @@ -24,7 +24,7 @@ def make_search_jargon(chat_id: str): query = Jargon.select().where(Jargon.is_jargon) # 根据 all_global 配置决定 chat_id 作用域 - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 开启全局黑话:只看 is_global=True 的记录,不区分 chat_id query = query.where(Jargon.is_global) else: @@ -63,7 +63,7 @@ def make_search_jargon(chat_id: str): if any_matched: filtered_keyword.append(r) - if global_config.jargon.all_global: + if global_config.expression.all_global_jargon: # 全局黑话模式:不再做 chat_id 过滤,直接使用关键词过滤结果 records = filtered_keyword else: @@ -80,7 +80,7 @@ def make_search_jargon(chat_id: str): if not records: scope_note = ( "(当前为全局黑话模式,仅统计 is_global=True 的条目)" - if global_config.jargon.all_global + if global_config.expression.all_global_jargon else "(当前为按 chat_id 作用域模式,仅统计全局黑话或与当前 chat_id 相关的条目)" ) return f"未找到包含关键词'{keyword}'的 Jargon 记录{scope_note}" diff --git a/src/express/express_utils.py b/src/express/express_utils.py deleted file mode 100644 index b0702e30..00000000 --- a/src/express/express_utils.py +++ /dev/null @@ -1,145 +0,0 @@ -import re -import difflib -import random -from datetime import datetime -from typing import Optional, List, Dict - - -def filter_message_content(content: Optional[str]) -> str: - """ - 过滤消息内容,移除回复、@、图片等格式 - - Args: - content: 原始消息内容 - - Returns: - str: 过滤后的内容 - """ - if not content: - return "" - - # 移除以[回复开头、]结尾的部分,包括后面的",说:"部分 - content = re.sub(r"\[回复.*?\],说:\s*", "", content) - # 移除@<...>格式的内容 - content = re.sub(r"@<[^>]*>", "", content) - # 移除[picid:...]格式的图片ID - content = re.sub(r"\[picid:[^\]]*\]", "", content) - # 移除[表情包:...]格式的内容 - content = re.sub(r"\[表情包:[^\]]*\]", "", content) - - return content.strip() - - -def calculate_similarity(text1: str, text2: str) -> float: - """ - 计算两个文本的相似度,返回0-1之间的值 - 使用SequenceMatcher计算相似度 - - Args: - text1: 第一个文本 - text2: 第二个文本 - - Returns: - float: 相似度值,范围0-1 - """ - return difflib.SequenceMatcher(None, text1, text2).ratio() - - -def format_create_date(timestamp: float) -> str: - """ - 将时间戳格式化为可读的日期字符串 - - Args: - timestamp: 时间戳 - - Returns: - str: 格式化后的日期字符串 - """ - try: - return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - except (ValueError, OSError): - return "未知时间" - - -def _compute_weights(population: List[Dict]) -> List[float]: - """ - 根据表达的count计算权重,范围限定在1~5之间。 - count越高,权重越高,但最多为基础权重的5倍。 - 如果表达已checked,权重会再乘以3倍。 - """ - if not population: - return [] - - counts = [] - checked_flags = [] - for item in population: - count = item.get("count", 1) - try: - count_value = float(count) - except (TypeError, ValueError): - count_value = 1.0 - counts.append(max(count_value, 0.0)) - # 获取checked状态 - checked = item.get("checked", False) - checked_flags.append(bool(checked)) - - min_count = min(counts) - max_count = max(counts) - - if max_count == min_count: - base_weights = [1.0 for _ in counts] - else: - base_weights = [] - for count_value in counts: - # 线性映射到[1,5]区间 - normalized = (count_value - min_count) / (max_count - min_count) - base_weights.append(1.0 + normalized * 4.0) # 1~3 - - # 如果checked,权重乘以3 - weights = [] - for base_weight, checked in zip(base_weights, checked_flags, strict=False): - if checked: - weights.append(base_weight * 3.0) - else: - weights.append(base_weight) - return weights - - -def weighted_sample(population: List[Dict], k: int) -> List[Dict]: - """ - 随机抽样函数 - - Args: - population: 总体数据列表 - k: 需要抽取的数量 - - Returns: - List[Dict]: 抽取的数据列表 - """ - if not population or k <= 0: - return [] - - if len(population) <= k: - return population.copy() - - selected: List[Dict] = [] - population_copy = population.copy() - - for _ in range(min(k, len(population_copy))): - weights = _compute_weights(population_copy) - total_weight = sum(weights) - if total_weight <= 0: - # 回退到均匀随机 - idx = random.randint(0, len(population_copy) - 1) - selected.append(population_copy.pop(idx)) - continue - - threshold = random.uniform(0, total_weight) - cumulative = 0.0 - for idx, weight in enumerate(weights): - cumulative += weight - if threshold <= cumulative: - selected.append(population_copy.pop(idx)) - break - - return selected diff --git a/src/jargon/__init__.py b/src/jargon/__init__.py deleted file mode 100644 index 37b61644..00000000 --- a/src/jargon/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .jargon_miner import extract_and_store_jargon - -__all__ = [ - "extract_and_store_jargon", -] diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index aa20ce0f..2cd4dfde 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -11,7 +11,7 @@ from src.common.database.database_model import ThinkingBack from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.memory_system.memory_utils import parse_questions_json from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message -from src.jargon.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon +from src.bw_learner.jargon_explainer import match_jargon_from_text, retrieve_concepts_with_jargon logger = get_logger("memory_retrieval") @@ -972,6 +972,7 @@ async def _process_single_question( context: str, initial_info: str = "", initial_jargon_concepts: Optional[List[str]] = None, + max_iterations: Optional[int] = None, ) -> Optional[str]: """处理单个问题的查询 @@ -996,10 +997,14 @@ async def _process_single_question( jargon_concepts_for_agent = initial_jargon_concepts if global_config.memory.enable_jargon_detection else None + # 如果未指定max_iterations,使用配置的默认值 + if max_iterations is None: + max_iterations = global_config.memory.max_agent_iterations + found_answer, answer, thinking_steps, is_timeout = await _react_agent_solve_question( question=question, chat_id=chat_id, - max_iterations=global_config.memory.max_agent_iterations, + max_iterations=max_iterations, timeout=global_config.memory.agent_timeout_seconds, initial_info=question_initial_info, initial_jargon_concepts=jargon_concepts_for_agent, @@ -1030,6 +1035,7 @@ async def build_memory_retrieval_prompt( target: str, chat_stream, tool_executor, + think_level: int = 1, ) -> str: """构建记忆检索提示 使用两段式查询:第一步生成问题,第二步使用ReAct Agent查询答案 @@ -1117,9 +1123,14 @@ async def build_memory_retrieval_prompt( return "" # 第二步:并行处理所有问题(使用配置的最大迭代次数和超时时间) - max_iterations = global_config.memory.max_agent_iterations + base_max_iterations = global_config.memory.max_agent_iterations + # 根据think_level调整迭代次数:think_level=1时不变,think_level=0时减半 + if think_level == 0: + max_iterations = max(1, base_max_iterations // 2) # 至少为1 + else: + max_iterations = base_max_iterations timeout_seconds = global_config.memory.agent_timeout_seconds - logger.debug(f"问题数量: {len(questions)},设置最大迭代次数: {max_iterations},超时时间: {timeout_seconds}秒") + logger.debug(f"问题数量: {len(questions)},think_level={think_level},设置最大迭代次数: {max_iterations}(基础值: {base_max_iterations}),超时时间: {timeout_seconds}秒") # 并行处理所有问题,将概念检索结果作为初始信息传递 question_tasks = [ @@ -1129,6 +1140,7 @@ async def build_memory_retrieval_prompt( context=message, initial_info=initial_info, initial_jargon_concepts=concepts if enable_jargon_detection else None, + max_iterations=max_iterations, ) for question in questions ] diff --git a/src/webui/config_routes.py b/src/webui/config_routes.py index 4cfc20ca..25127fe5 100644 --- a/src/webui/config_routes.py +++ b/src/webui/config_routes.py @@ -30,7 +30,6 @@ from src.config.official_configs import ( MemoryConfig, DebugConfig, VoiceConfig, - JargonConfig, ) from src.config.api_ada_configs import ( ModelTaskConfig, @@ -129,7 +128,6 @@ async def get_config_section_schema(section_name: str): "memory": MemoryConfig, "debug": DebugConfig, "voice": VoiceConfig, - "jargon": JargonConfig, "model_task_config": ModelTaskConfig, "api_provider": APIProvider, "model_info": ModelInfo, diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 2b6a63e6..16e4c235 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.0.2" +version = "7.1.0" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- # 如果你想要修改配置文件,请递增version的值 @@ -60,16 +60,14 @@ state_probability = 0.3 [expression] # 表达学习配置 learning_list = [ # 表达学习配置列表,支持按聊天流配置 - ["", "enable", "enable", "1.0"], # 全局配置:使用表达,启用学习,学习强度1.0 - ["qq:1919810:group", "enable", "enable", "1.5"], # 特定群聊配置:使用表达,启用学习,学习强度1.5 - ["qq:114514:private", "enable", "disable", "0.5"], # 特定私聊配置:使用表达,禁用学习,学习强度0.5 + ["", "enable", "enable", "enable"], # 全局配置:使用表达,启用学习,启用jargon学习 + ["qq:1919810:group", "enable", "enable", "enable"], # 特定群聊配置:使用表达,启用学习,启用jargon学习 + ["qq:114514:private", "enable", "disable", "disable"], # 特定私聊配置:使用表达,禁用学习,禁用jargon学习 # 格式说明: # 第一位: chat_stream_id,空字符串表示全局配置 # 第二位: 是否使用学到的表达 ("enable"/"disable") # 第三位: 是否学习表达 ("enable"/"disable") - # 第四位: 学习强度(浮点数),影响学习频率,最短学习时间间隔 = 300/学习强度(秒) - # 学习强度越高,学习越频繁;学习强度越低,学习越少 - # 如果学习强度设置为0会自动转换为0.0001以避免除以零错误 + # 第四位: 是否启用jargon学习 ("enable"/"disable") ] expression_groups = [ @@ -85,6 +83,8 @@ reflect = false # 是否启用表达反思(Bot主动向管理员询问表达 reflect_operator_id = "" # 表达反思操作员ID,格式:platform:id:type (例如 "qq:123456:private" 或 "qq:654321:group") allow_reflect = [] # 允许进行表达反思的聊天流ID列表,格式:["qq:123456:private", "qq:654321:group", ...],只有在此列表中的聊天流才会提出问题并跟踪。如果列表为空,则所有聊天流都可以进行表达反思(前提是 reflect = true) +all_global_jargon = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除 + [chat] # 麦麦的聊天设置 talk_value = 1 # 聊天频率,越小越沉默,范围0-1,如果设置为0会自动转换为0.0001以避免除以零错误 @@ -131,9 +131,6 @@ dream_time_ranges = [ ] # dream_time_ranges = [] -[jargon] -all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除 - [tool] enable_tool = true # 是否启用工具