better:优化表达方式提取的token消耗

pull/1414/head
SengokuCola 2025-12-03 00:21:02 +08:00
parent 34129bafad
commit 7f66d5588d
6 changed files with 155 additions and 332 deletions

View File

@ -8,7 +8,7 @@ from typing import Any, Dict, Tuple, List
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database import db from src.common.database.database import db
from src.common.database.database_model import OnlineTime, LLMUsage, Messages from src.common.database.database_model import OnlineTime, LLMUsage, Messages, ActionRecords
from src.manager.async_task_manager import AsyncTask from src.manager.async_task_manager import AsyncTask
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
from src.config.config import global_config from src.config.config import global_config
@ -505,13 +505,6 @@ class StatisticOutputTask(AsyncTask):
for period_key, _ in collect_period for period_key, _ in collect_period
} }
# 获取bot的QQ账号
bot_qq_account = (
str(global_config.bot.qq_account)
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
else ""
)
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
message_time_ts = message.time # This is a float timestamp message_time_ts = message.time # This is a float timestamp
@ -537,7 +530,7 @@ class StatisticOutputTask(AsyncTask):
if not chat_id: # Should not happen if above logic is correct if not chat_id: # Should not happen if above logic is correct
continue continue
# Update name_mapping # Update name_mapping(仅用于展示聊天名称)
try: try:
if chat_id in self.name_mapping: if chat_id in self.name_mapping:
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]: if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
@ -549,19 +542,30 @@ class StatisticOutputTask(AsyncTask):
# 重置为正确的格式 # 重置为正确的格式
self.name_mapping[chat_id] = (chat_name, message_time_ts) self.name_mapping[chat_id] = (chat_name, message_time_ts)
# 检查是否是bot发送的消息回复
is_bot_reply = False
if bot_qq_account and message.user_id == bot_qq_account:
is_bot_reply = True
for idx, (_, period_start_dt) in enumerate(collect_period): for idx, (_, period_start_dt) in enumerate(collect_period):
if message_time_ts >= period_start_dt.timestamp(): if message_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]: for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_MSG_CNT] += 1 stats[period_key][TOTAL_MSG_CNT] += 1
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
if is_bot_reply:
stats[period_key][TOTAL_REPLY_CNT] += 1
break break
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
try:
action_query_start_timestamp = collect_period[-1][1].timestamp()
for action in ActionRecords.select().where(ActionRecords.time >= action_query_start_timestamp): # type: ignore
# 仅统计已完成的 reply 动作
if action.action_name != "reply" or not action.action_done:
continue
action_time_ts = action.time
for idx, (_, period_start_dt) in enumerate(collect_period):
if action_time_ts >= period_start_dt.timestamp():
for period_key, _ in collect_period[idx:]:
stats[period_key][TOTAL_REPLY_CNT] += 1
break
except Exception as e:
logger.warning(f"统计 reply 动作次数失败,将回复数视为 0错误信息{e}")
return stats return stats
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:

View File

@ -324,7 +324,6 @@ class Expression(BaseModel):
# new mode fields # new mode fields
context = TextField(null=True) context = TextField(null=True)
up_content = TextField(null=True)
content_list = TextField(null=True) content_list = TextField(null=True)
count = IntegerField(default=1) count = IntegerField(default=1)

View File

@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/ # 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.11.6" MMC_VERSION = "0.11.7-snapshot.1"
def get_key_comment(toml_table, key): def get_key_comment(toml_table, key):

View File

@ -12,11 +12,10 @@ from src.config.config import model_config, global_config
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp_with_chat_inclusive, get_raw_msg_by_timestamp_with_chat_inclusive,
build_anonymous_messages, build_anonymous_messages,
build_bare_messages,
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.express.express_utils import filter_message_content, calculate_similarity from src.express.express_utils import filter_message_content
from json_repair import repair_json from json_repair import repair_json
@ -26,10 +25,10 @@ logger = get_logger("expressor")
def init_prompt() -> None: def init_prompt() -> None:
learn_style_prompt = """ learn_style_prompt = """{chat_str}
{chat_str}
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格 请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
每一行消息前面的方括号中的数字 [1][2]是该行消息的唯一编号请在输出中引用这些编号来标注表达方式的来源行
1. 只考虑文字不要考虑表情包和图片 1. 只考虑文字不要考虑表情包和图片
2. 不要涉及具体的人名但是可以涉及具体名词 2. 不要涉及具体的人名但是可以涉及具体名词
3. 思考有没有特殊的梗一并总结成语言风格 3. 思考有没有特殊的梗一并总结成语言风格
@ -37,41 +36,29 @@ def init_prompt() -> None:
注意总结成如下格式的规律总结的内容要详细但具有概括性 注意总结成如下格式的规律总结的内容要详细但具有概括性
例如"AAAAA"可以"BBBBB", AAAAA代表某个具体的场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字 例如"AAAAA"可以"BBBBB", AAAAA代表某个具体的场景不超过20个字BBBBB代表对应的语言风格特定句式或表达方式不超过20个字
例如 请严格以 JSON 数组的形式输出结果每个元素为一个对象结构如下注意字段名
"对某件事表示十分惊叹"使用"我嘞个xxxx" [
"表示讽刺的赞同,不讲道理"使用"对对对" {{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}},
"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂" {{"situation": "CCCC", "style": "DDDD", "source_id": "7"}}
"当涉及游戏相关时,夸赞,略带戏谑意味"使用"这么强!" {{"situation": "对某件事表示十分惊叹", "style": "我嘞个xxxx", "source_id": "[消息编号]"}},
{{"situation": "表示讽刺的赞同,不讲道理", "style": "对对对", "source_id": "[消息编号]"}},
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "这么强!", "source_id": "[消息编号]"}},
]
请注意不要总结你自己SELF的发言尽量保证总结内容的逻辑性 请注意
现在请你概括 - 不要总结你自己SELF的发言尽量保证总结内容的逻辑性
- 请只针对最重要的若干条表达方式进行总结避免输出太多重复或相似的条目
其中
- situation表示在什么情境下的简短概括不超过20个字
- style表示对应的语言风格或常用表达不超过20个字
- source_id该表达方式对应的来源行编号即上方聊天记录中方括号里的数字例如 [3]请只输出数字本身不要包含方括号
现在请你输出 JSON
""" """
Prompt(learn_style_prompt, "learn_style_prompt") Prompt(learn_style_prompt, "learn_style_prompt")
match_expression_context_prompt = """
**聊天内容**
{chat_str}
**从聊天内容总结的表达方式pairs**
{expression_pairs}
请你为上面的每一条表达方式找到该表达方式的原文句子并输出匹配结果expression_pair不能有重复每个expression_pair仅输出一个最合适的context
如果找不到原句就不输出该句的匹配结果
以json格式输出
格式如下
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
{{
"expression_pair": "表达方式pair的序号数字",
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
}}
...
现在请你输出匹配结果
"""
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
class ExpressionLearner: class ExpressionLearner:
@ -193,7 +180,6 @@ class ExpressionLearner:
situation, situation,
style, style,
_context, _context,
_up_content,
) in learnt_expressions: ) in learnt_expressions:
learnt_expressions_str += f"{situation}->{style}\n" learnt_expressions_str += f"{situation}->{style}\n"
logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}") logger.info(f"{self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
@ -205,193 +191,17 @@ class ExpressionLearner:
situation, situation,
style, style,
context, context,
up_content,
) in learnt_expressions: ) in learnt_expressions:
await self._upsert_expression_record( await self._upsert_expression_record(
situation=situation, situation=situation,
style=style, style=style,
context=context, context=context,
up_content=up_content,
current_time=current_time, current_time=current_time,
) )
return learnt_expressions return learnt_expressions
async def match_expression_context( async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str]]]:
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
) -> List[Tuple[str, str, str]]:
# 为expression_pairs逐个条目赋予编号并构建成字符串
numbered_pairs = []
for i, (situation, style) in enumerate(expression_pairs, 1):
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
expression_pairs_str = "\n".join(numbered_pairs)
prompt = "match_expression_context_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt,
expression_pairs=expression_pairs_str,
chat_str=random_msg_match_str,
)
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
# print(f"match_expression_context_prompt: {prompt}")
# print(f"{response}")
# 解析JSON响应
match_responses = []
try:
response = response.strip()
# 尝试提取JSON代码块如果存在
json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL)
if matches:
response = matches[0].strip()
# 移除可能的markdown代码块标记如果没有找到```json但可能有```
if not matches:
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
response = response.strip()
# 检查是否已经是标准JSON数组格式
if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response)
else:
# 尝试直接解析多个JSON对象
try:
# 如果是多个JSON对象用逗号分隔包装成数组
if response.startswith("{") and not response.startswith("["):
response = "[" + response + "]"
match_responses = json.loads(response)
else:
# 使用repair_json处理响应
repaired_content = repair_json(response)
# 确保repaired_content是列表格式
if isinstance(repaired_content, str):
try:
parsed_data = json.loads(repaired_content)
if isinstance(parsed_data, dict):
# 如果是字典,包装成列表
match_responses = [parsed_data]
elif isinstance(parsed_data, list):
match_responses = parsed_data
else:
match_responses = []
except json.JSONDecodeError:
match_responses = []
elif isinstance(repaired_content, dict):
# 如果是字典,包装成列表
match_responses = [repaired_content]
elif isinstance(repaired_content, list):
match_responses = repaired_content
else:
match_responses = []
except json.JSONDecodeError:
# 如果还是失败尝试repair_json
repaired_content = repair_json(response)
if isinstance(repaired_content, str):
parsed_data = json.loads(repaired_content)
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
else:
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
except (json.JSONDecodeError, Exception) as e:
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
return []
# 确保 match_responses 是一个列表
if not isinstance(match_responses, list):
if isinstance(match_responses, dict):
match_responses = [match_responses]
else:
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
return []
# 清理和规范化 match_responses 中的元素
normalized_responses = []
for item in match_responses:
if isinstance(item, dict):
# 已经是字典,直接添加
normalized_responses.append(item)
elif isinstance(item, str):
# 如果是字符串,尝试解析为 JSON
try:
parsed = json.loads(item)
if isinstance(parsed, dict):
normalized_responses.append(parsed)
elif isinstance(parsed, list):
# 如果是列表,递归处理
for sub_item in parsed:
if isinstance(sub_item, dict):
normalized_responses.append(sub_item)
else:
logger.debug(f"跳过非字典类型的子元素: {type(sub_item)}, 内容: {sub_item}")
else:
logger.debug(f"跳过无法转换为字典的字符串元素: {item}")
except (json.JSONDecodeError, TypeError):
logger.debug(f"跳过无法解析为JSON的字符串元素: {item}")
elif isinstance(item, list):
# 如果是列表,展开并处理其中的字典
for sub_item in item:
if isinstance(sub_item, dict):
normalized_responses.append(sub_item)
elif isinstance(sub_item, str):
# 尝试解析字符串
try:
parsed = json.loads(sub_item)
if isinstance(parsed, dict):
normalized_responses.append(parsed)
else:
logger.debug(f"跳过非字典类型的解析结果: {type(parsed)}, 内容: {parsed}")
except (json.JSONDecodeError, TypeError):
logger.debug(f"跳过无法解析为JSON的字符串子元素: {sub_item}")
else:
logger.debug(f"跳过非字典类型的列表元素: {type(sub_item)}, 内容: {sub_item}")
else:
logger.debug(f"跳过无法处理的元素类型: {type(item)}, 内容: {item}")
match_responses = normalized_responses
matched_expressions = []
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
logger.debug(f"规范化后的 match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
logger.debug(f"规范化后的 match_responses 内容: {match_responses}")
for match_response in match_responses:
try:
# 检查 match_response 的类型(此时应该都是字典)
if not isinstance(match_response, dict):
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
continue
# 获取表达方式序号
if "expression_pair" not in match_response:
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
continue
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
# 检查索引是否有效且未被使用过
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
situation, style = expression_pairs[pair_index]
context = match_response.get("context", "")
matched_expressions.append((situation, style, context))
used_pair_indices.add(pair_index) # 标记该索引已使用
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
elif pair_index in used_pair_indices:
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
except (ValueError, KeyError, IndexError, TypeError) as e:
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
continue
return matched_expressions
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str, str]]]:
"""从指定聊天流学习表达方式 """从指定聊天流学习表达方式
Args: Args:
@ -414,10 +224,8 @@ class ExpressionLearner:
if not random_msg or random_msg == []: if not random_msg or random_msg == []:
return None return None
# 学习用 # 学习用(开启行编号,便于溯源)
random_msg_str: str = await build_anonymous_messages(random_msg) random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
# 溯源用
random_msg_match_str: str = await build_bare_messages(random_msg)
prompt: str = await global_prompt_manager.format_prompt( prompt: str = await global_prompt_manager.format_prompt(
"learn_style_prompt", "learn_style_prompt",
@ -432,83 +240,107 @@ class ExpressionLearner:
except Exception as e: except Exception as e:
logger.error(f"学习表达方式失败,模型生成出错: {e}") logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
# 解析 LLM 返回的表达方式列表(包含来源行编号)
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions) expressions = self._filter_self_reference_styles(expressions)
if not expressions: if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)") logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return None return None
# logger.debug(f"学习{type_str}的response: {response}") # logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源 # 直接根据 source_id 在 random_msg 中溯源,获取 context
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context( filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
expressions, random_msg_match_str
)
# 为每条消息构建精简文本列表,保留到原消息索引的映射
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
# 将 matched_expressions 结合上一句 up_content若不存在上一句则跳过
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
for situation, style, context in matched_expressions:
# 在 bare_lines 中找到第一处相似度达到85%的行
pos = None
for i, (_, c) in enumerate(bare_lines):
similarity = calculate_similarity(c, context)
if similarity >= 0.85: # 85%相似度阈值
pos = i
break
if pos is None or pos == 0: for situation, style, source_id in expressions:
# 没有匹配到目标句或没有上一句,跳过该表达 source_id_str = (source_id or "").strip()
if not source_id_str.isdigit():
# 无效的来源行编号,跳过
continue continue
# 检查目标句是否为空 line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
target_content = bare_lines[pos][1] if line_index < 0 or line_index >= len(random_msg):
if not target_content: # 超出范围,跳过
# 目标句为空,跳过该表达
continue continue
prev_original_idx = bare_lines[pos - 1][0] # 当前行的原始内容
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "") current_msg = random_msg[line_index]
if not up_content: context = filter_message_content(current_msg.processed_plain_text or "")
# 上一句为空,跳过该表达 if not context:
continue continue
filtered_with_up.append((situation, style, context, up_content))
if not filtered_with_up: filtered_expressions.append((situation, style, context))
if not filtered_expressions:
return None return None
return filtered_with_up return filtered_expressions
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]: def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
""" """
解析LLM返回的表达风格总结每一行提取"""使用"之间的内容存储为(situation, style)元组 解析 LLM 返回的表达风格总结 JSON提取 (situation, style, source_id) 元组列表
期望的 JSON 结构
[
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"},
...
]
""" """
if not response:
return []
raw = response.strip()
# 尝试提取 ```json 代码块
json_block_pattern = r"```json\s*(.*?)\s*```"
match = re.search(json_block_pattern, raw, re.DOTALL)
if match:
raw = match.group(1).strip()
else:
# 去掉可能存在的通用 ``` 包裹
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
raw = raw.strip()
parsed = None
expressions: List[Tuple[str, str, str]] = [] expressions: List[Tuple[str, str, str]] = []
for line in response.splitlines():
line = line.strip() try:
if not line: # 优先尝试直接解析
if raw.startswith("[") and raw.endswith("]"):
parsed = json.loads(raw)
else:
repaired = repair_json(raw)
if isinstance(repaired, str):
parsed = json.loads(repaired)
else:
parsed = repaired
except Exception:
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
return []
if isinstance(parsed, dict):
parsed_list = [parsed]
elif isinstance(parsed, list):
parsed_list = parsed
else:
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
return []
for item in parsed_list:
if not isinstance(item, dict):
continue continue
# 查找"当"和下一个引号 situation = str(item.get("situation", "")).strip()
idx_when = line.find('"') style = str(item.get("style", "")).strip()
if idx_when == -1: source_id = str(item.get("source_id", "")).strip()
if not situation or not style or not source_id:
# 三个字段必须同时存在
continue continue
idx_quote1 = idx_when + 1 expressions.append((situation, style, source_id))
idx_quote2 = line.find('"', idx_quote1 + 1)
if idx_quote2 == -1:
continue
situation = line[idx_quote1 + 1 : idx_quote2]
# 查找"使用"
idx_use = line.find('使用"', idx_quote2)
if idx_use == -1:
continue
idx_quote3 = idx_use + 2
idx_quote4 = line.find('"', idx_quote3 + 1)
if idx_quote4 == -1:
continue
style = line[idx_quote3 + 1 : idx_quote4]
expressions.append((situation, style))
return expressions return expressions
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]: def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
""" """
过滤掉style与机器人名称/昵称重复的表达 过滤掉style与机器人名称/昵称重复的表达
""" """
@ -525,12 +357,12 @@ class ExpressionLearner:
banned_casefold = {name.casefold() for name in banned_names if name} banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str]] = [] filtered: List[Tuple[str, str, str]] = []
removed_count = 0 removed_count = 0
for situation, style in expressions: for situation, style, source_id in expressions:
normalized_style = (style or "").strip() normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold: if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style)) filtered.append((situation, style, source_id))
else: else:
removed_count += 1 removed_count += 1
@ -544,7 +376,6 @@ class ExpressionLearner:
situation: str, situation: str,
style: str, style: str,
context: str, context: str,
up_content: str,
current_time: float, current_time: float,
) -> None: ) -> None:
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first() expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
@ -554,7 +385,6 @@ class ExpressionLearner:
expr_obj=expr_obj, expr_obj=expr_obj,
situation=situation, situation=situation,
context=context, context=context,
up_content=up_content,
current_time=current_time, current_time=current_time,
) )
return return
@ -563,7 +393,6 @@ class ExpressionLearner:
situation=situation, situation=situation,
style=style, style=style,
context=context, context=context,
up_content=up_content,
current_time=current_time, current_time=current_time,
) )
@ -572,7 +401,6 @@ class ExpressionLearner:
situation: str, situation: str,
style: str, style: str,
context: str, context: str,
up_content: str,
current_time: float, current_time: float,
) -> None: ) -> None:
content_list = [situation] content_list = [situation]
@ -587,7 +415,6 @@ class ExpressionLearner:
chat_id=self.chat_id, chat_id=self.chat_id,
create_date=current_time, create_date=current_time,
context=context, context=context,
up_content=up_content,
) )
async def _update_existing_expression( async def _update_existing_expression(
@ -595,7 +422,6 @@ class ExpressionLearner:
expr_obj: Expression, expr_obj: Expression,
situation: str, situation: str,
context: str, context: str,
up_content: str,
current_time: float, current_time: float,
) -> None: ) -> None:
content_list = self._parse_content_list(expr_obj.content_list) content_list = self._parse_content_list(expr_obj.content_list)
@ -605,7 +431,6 @@ class ExpressionLearner:
expr_obj.count = (expr_obj.count or 0) + 1 expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.last_active_time = current_time expr_obj.last_active_time = current_time
expr_obj.context = context expr_obj.context = context
expr_obj.up_content = up_content
new_situation = await self._compose_situation_text( new_situation = await self._compose_situation_text(
content_list=content_list, content_list=content_list,
@ -651,27 +476,6 @@ class ExpressionLearner:
logger.error(f"概括表达情境失败: {e}") logger.error(f"概括表达情境失败: {e}")
return None return None
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表保留到原消息索引的映射
Args:
messages: 消息列表
Returns:
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
"""
bare_lines: List[Tuple[int, str]] = []
for idx, msg in enumerate(messages):
content = msg.processed_plain_text or ""
content = filter_message_content(content)
# 即使content为空也要记录防止错位
bare_lines.append((idx, content))
return bare_lines
init_prompt() init_prompt()

View File

@ -429,15 +429,36 @@ class ChatHistorySummarizer:
# 2. 构造编号后的消息字符串和参与者信息 # 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
# 3. 调用 LLM 识别话题,并得到 topic -> indices # 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
existing_topics = list(self.topic_cache.keys()) existing_topics = list(self.topic_cache.keys())
success, topic_to_indices = await self._analyze_topics_with_llm( max_retries = 3
numbered_lines=numbered_lines, attempt = 0
existing_topics=existing_topics, success = False
) topic_to_indices: Dict[str, List[int]] = {}
while attempt < max_retries:
attempt += 1
success, topic_to_indices = await self._analyze_topics_with_llm(
numbered_lines=numbered_lines,
existing_topics=existing_topics,
)
if success and topic_to_indices:
if attempt > 1:
logger.info(
f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}"
)
break
logger.warning(
f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败"
+ ("" if attempt >= max_retries else ",准备重试")
)
if not success or not topic_to_indices: if not success or not topic_to_indices:
logger.warning(f"{self.log_prefix} 话题识别失败或无有效话题,本次检查忽略") logger.error(
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
)
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状 # 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks保持原状
return return

View File

@ -1,7 +1,7 @@
"""表达方式管理 API 路由""" """表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel from pydantic import BaseModel, NonNegativeFloat
from typing import Optional, List, Dict from typing import Optional, List, Dict
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams from src.common.database.database_model import Expression, ChatStreams
@ -21,7 +21,6 @@ class ExpressionResponse(BaseModel):
situation: str situation: str
style: str style: str
context: Optional[str] context: Optional[str]
up_content: Optional[str]
last_active_time: float last_active_time: float
chat_id: str chat_id: str
create_date: Optional[float] create_date: Optional[float]
@ -49,8 +48,7 @@ class ExpressionCreateRequest(BaseModel):
situation: str situation: str
style: str style: str
context: Optional[str] = None context: Optional[str] = NonNegativeFloat
up_content: Optional[str] = None
chat_id: str chat_id: str
@ -60,7 +58,6 @@ class ExpressionUpdateRequest(BaseModel):
situation: Optional[str] = None situation: Optional[str] = None
style: Optional[str] = None style: Optional[str] = None
context: Optional[str] = None context: Optional[str] = None
up_content: Optional[str] = None
chat_id: Optional[str] = None chat_id: Optional[str] = None
@ -102,7 +99,6 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
situation=expression.situation, situation=expression.situation,
style=expression.style, style=expression.style,
context=expression.context, context=expression.context,
up_content=expression.up_content,
last_active_time=expression.last_active_time, last_active_time=expression.last_active_time,
chat_id=expression.chat_id, chat_id=expression.chat_id,
create_date=expression.create_date, create_date=expression.create_date,
@ -310,7 +306,6 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
situation=request.situation, situation=request.situation,
style=request.style, style=request.style,
context=request.context, context=request.context,
up_content=request.up_content,
chat_id=request.chat_id, chat_id=request.chat_id,
last_active_time=current_time, last_active_time=current_time,
create_date=current_time, create_date=current_time,