mirror of https://github.com/Mai-with-u/MaiBot.git
better:优化表达方式提取的token消耗
parent
34129bafad
commit
7f66d5588d
|
|
@ -8,7 +8,7 @@ from typing import Any, Dict, Tuple, List
|
|||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database import db
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages
|
||||
from src.common.database.database_model import OnlineTime, LLMUsage, Messages, ActionRecords
|
||||
from src.manager.async_task_manager import AsyncTask
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.config.config import global_config
|
||||
|
|
@ -505,13 +505,6 @@ class StatisticOutputTask(AsyncTask):
|
|||
for period_key, _ in collect_period
|
||||
}
|
||||
|
||||
# 获取bot的QQ账号
|
||||
bot_qq_account = (
|
||||
str(global_config.bot.qq_account)
|
||||
if hasattr(global_config, "bot") and hasattr(global_config.bot, "qq_account")
|
||||
else ""
|
||||
)
|
||||
|
||||
query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp)
|
||||
for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore
|
||||
message_time_ts = message.time # This is a float timestamp
|
||||
|
|
@ -537,7 +530,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
if not chat_id: # Should not happen if above logic is correct
|
||||
continue
|
||||
|
||||
# Update name_mapping
|
||||
# Update name_mapping(仅用于展示聊天名称)
|
||||
try:
|
||||
if chat_id in self.name_mapping:
|
||||
if chat_name != self.name_mapping[chat_id][0] and message_time_ts > self.name_mapping[chat_id][1]:
|
||||
|
|
@ -549,19 +542,30 @@ class StatisticOutputTask(AsyncTask):
|
|||
# 重置为正确的格式
|
||||
self.name_mapping[chat_id] = (chat_name, message_time_ts)
|
||||
|
||||
# 检查是否是bot发送的消息(回复)
|
||||
is_bot_reply = False
|
||||
if bot_qq_account and message.user_id == bot_qq_account:
|
||||
is_bot_reply = True
|
||||
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if message_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_MSG_CNT] += 1
|
||||
stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1
|
||||
if is_bot_reply:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
|
||||
# 使用 ActionRecords 中的 reply 动作次数作为回复数基准
|
||||
try:
|
||||
action_query_start_timestamp = collect_period[-1][1].timestamp()
|
||||
for action in ActionRecords.select().where(ActionRecords.time >= action_query_start_timestamp): # type: ignore
|
||||
# 仅统计已完成的 reply 动作
|
||||
if action.action_name != "reply" or not action.action_done:
|
||||
continue
|
||||
|
||||
action_time_ts = action.time
|
||||
for idx, (_, period_start_dt) in enumerate(collect_period):
|
||||
if action_time_ts >= period_start_dt.timestamp():
|
||||
for period_key, _ in collect_period[idx:]:
|
||||
stats[period_key][TOTAL_REPLY_CNT] += 1
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"统计 reply 动作次数失败,将回复数视为 0,错误信息:{e}")
|
||||
|
||||
return stats
|
||||
|
||||
def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]:
|
||||
|
|
|
|||
|
|
@ -324,7 +324,6 @@ class Expression(BaseModel):
|
|||
|
||||
# new mode fields
|
||||
context = TextField(null=True)
|
||||
up_content = TextField(null=True)
|
||||
|
||||
content_list = TextField(null=True)
|
||||
count = IntegerField(default=1)
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.6"
|
||||
MMC_VERSION = "0.11.7-snapshot.1"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
|
|
|||
|
|
@ -12,11 +12,10 @@ from src.config.config import model_config, global_config
|
|||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp_with_chat_inclusive,
|
||||
build_anonymous_messages,
|
||||
build_bare_messages,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.express.express_utils import filter_message_content, calculate_similarity
|
||||
from src.express.express_utils import filter_message_content
|
||||
from json_repair import repair_json
|
||||
|
||||
|
||||
|
|
@ -26,10 +25,10 @@ logger = get_logger("expressor")
|
|||
|
||||
|
||||
def init_prompt() -> None:
|
||||
learn_style_prompt = """
|
||||
{chat_str}
|
||||
learn_style_prompt = """{chat_str}
|
||||
|
||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格
|
||||
请从上面这段群聊中概括除了人名为"SELF"之外的人的语言风格。
|
||||
每一行消息前面的方括号中的数字(如 [1]、[2])是该行消息的唯一编号,请在输出中引用这些编号来标注“表达方式的来源行”。
|
||||
1. 只考虑文字,不要考虑表情包和图片
|
||||
2. 不要涉及具体的人名,但是可以涉及具体名词
|
||||
3. 思考有没有特殊的梗,一并总结成语言风格
|
||||
|
|
@ -37,41 +36,29 @@ def init_prompt() -> None:
|
|||
注意:总结成如下格式的规律,总结的内容要详细,但具有概括性:
|
||||
例如:当"AAAAA"时,可以"BBBBB", AAAAA代表某个具体的场景,不超过20个字。BBBBB代表对应的语言风格,特定句式或表达方式,不超过20个字。
|
||||
|
||||
例如:
|
||||
当"对某件事表示十分惊叹"时,使用"我嘞个xxxx"
|
||||
当"表示讽刺的赞同,不讲道理"时,使用"对对对"
|
||||
当"想说明某个具体的事实观点,但懒得明说,使用"懂的都懂"
|
||||
当"当涉及游戏相关时,夸赞,略带戏谑意味"时,使用"这么强!"
|
||||
请严格以 JSON 数组的形式输出结果,每个元素为一个对象,结构如下(注意字段名):
|
||||
[
|
||||
{{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"}},
|
||||
{{"situation": "CCCC", "style": "DDDD", "source_id": "7"}}
|
||||
{{"situation": "对某件事表示十分惊叹", "style": "我嘞个xxxx", "source_id": "[消息编号]"}},
|
||||
{{"situation": "表示讽刺的赞同,不讲道理", "style": "对对对", "source_id": "[消息编号]"}},
|
||||
{{"situation": "当涉及游戏相关时,夸赞,略带戏谑意味", "style": "这么强!", "source_id": "[消息编号]"}},
|
||||
]
|
||||
|
||||
请注意:不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||
现在请你概括
|
||||
请注意:
|
||||
- 不要总结你自己(SELF)的发言,尽量保证总结内容的逻辑性
|
||||
- 请只针对最重要的若干条表达方式进行总结,避免输出太多重复或相似的条目
|
||||
|
||||
其中:
|
||||
- situation:表示“在什么情境下”的简短概括(不超过20个字)
|
||||
- style:表示对应的语言风格或常用表达(不超过20个字)
|
||||
- source_id:该表达方式对应的“来源行编号”,即上方聊天记录中方括号里的数字(例如 [3]),请只输出数字本身,不要包含方括号
|
||||
|
||||
现在请你输出 JSON:
|
||||
"""
|
||||
Prompt(learn_style_prompt, "learn_style_prompt")
|
||||
|
||||
match_expression_context_prompt = """
|
||||
**聊天内容**
|
||||
{chat_str}
|
||||
|
||||
**从聊天内容总结的表达方式pairs**
|
||||
{expression_pairs}
|
||||
|
||||
请你为上面的每一条表达方式,找到该表达方式的原文句子,并输出匹配结果,expression_pair不能有重复,每个expression_pair仅输出一个最合适的context。
|
||||
如果找不到原句,就不输出该句的匹配结果。
|
||||
以json格式输出:
|
||||
格式如下:
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
{{
|
||||
"expression_pair": "表达方式pair的序号(数字)",
|
||||
"context": "与表达方式对应的原文句子的原始内容,不要修改原文句子的内容",
|
||||
}},
|
||||
...
|
||||
|
||||
现在请你输出匹配结果:
|
||||
"""
|
||||
Prompt(match_expression_context_prompt, "match_expression_context_prompt")
|
||||
|
||||
|
||||
class ExpressionLearner:
|
||||
|
|
@ -193,7 +180,6 @@ class ExpressionLearner:
|
|||
situation,
|
||||
style,
|
||||
_context,
|
||||
_up_content,
|
||||
) in learnt_expressions:
|
||||
learnt_expressions_str += f"{situation}->{style}\n"
|
||||
logger.info(f"在 {self.chat_name} 学习到表达风格:\n{learnt_expressions_str}")
|
||||
|
|
@ -205,193 +191,17 @@ class ExpressionLearner:
|
|||
situation,
|
||||
style,
|
||||
context,
|
||||
up_content,
|
||||
) in learnt_expressions:
|
||||
await self._upsert_expression_record(
|
||||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
return learnt_expressions
|
||||
|
||||
async def match_expression_context(
|
||||
self, expression_pairs: List[Tuple[str, str]], random_msg_match_str: str
|
||||
) -> List[Tuple[str, str, str]]:
|
||||
# 为expression_pairs逐个条目赋予编号,并构建成字符串
|
||||
numbered_pairs = []
|
||||
for i, (situation, style) in enumerate(expression_pairs, 1):
|
||||
numbered_pairs.append(f'{i}. 当"{situation}"时,使用"{style}"')
|
||||
|
||||
expression_pairs_str = "\n".join(numbered_pairs)
|
||||
|
||||
prompt = "match_expression_context_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt,
|
||||
expression_pairs=expression_pairs_str,
|
||||
chat_str=random_msg_match_str,
|
||||
)
|
||||
|
||||
response, _ = await self.express_learn_model.generate_response_async(prompt, temperature=0.3)
|
||||
|
||||
# print(f"match_expression_context_prompt: {prompt}")
|
||||
# print(f"{response}")
|
||||
|
||||
# 解析JSON响应
|
||||
match_responses = []
|
||||
try:
|
||||
response = response.strip()
|
||||
|
||||
# 尝试提取JSON代码块(如果存在)
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
if matches:
|
||||
response = matches[0].strip()
|
||||
|
||||
# 移除可能的markdown代码块标记(如果没有找到```json,但可能有```)
|
||||
if not matches:
|
||||
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
|
||||
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
|
||||
response = response.strip()
|
||||
|
||||
# 检查是否已经是标准JSON数组格式
|
||||
if response.startswith("[") and response.endswith("]"):
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 尝试直接解析多个JSON对象
|
||||
try:
|
||||
# 如果是多个JSON对象用逗号分隔,包装成数组
|
||||
if response.startswith("{") and not response.startswith("["):
|
||||
response = "[" + response + "]"
|
||||
match_responses = json.loads(response)
|
||||
else:
|
||||
# 使用repair_json处理响应
|
||||
repaired_content = repair_json(response)
|
||||
|
||||
# 确保repaired_content是列表格式
|
||||
if isinstance(repaired_content, str):
|
||||
try:
|
||||
parsed_data = json.loads(repaired_content)
|
||||
if isinstance(parsed_data, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [parsed_data]
|
||||
elif isinstance(parsed_data, list):
|
||||
match_responses = parsed_data
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
match_responses = []
|
||||
elif isinstance(repaired_content, dict):
|
||||
# 如果是字典,包装成列表
|
||||
match_responses = [repaired_content]
|
||||
elif isinstance(repaired_content, list):
|
||||
match_responses = repaired_content
|
||||
else:
|
||||
match_responses = []
|
||||
except json.JSONDecodeError:
|
||||
# 如果还是失败,尝试repair_json
|
||||
repaired_content = repair_json(response)
|
||||
if isinstance(repaired_content, str):
|
||||
parsed_data = json.loads(repaired_content)
|
||||
match_responses = parsed_data if isinstance(parsed_data, list) else [parsed_data]
|
||||
else:
|
||||
match_responses = repaired_content if isinstance(repaired_content, list) else [repaired_content]
|
||||
|
||||
except (json.JSONDecodeError, Exception) as e:
|
||||
logger.error(f"解析匹配响应JSON失败: {e}, 响应内容: \n{response}")
|
||||
return []
|
||||
|
||||
# 确保 match_responses 是一个列表
|
||||
if not isinstance(match_responses, list):
|
||||
if isinstance(match_responses, dict):
|
||||
match_responses = [match_responses]
|
||||
else:
|
||||
logger.error(f"match_responses 不是列表或字典类型: {type(match_responses)}, 内容: {match_responses}")
|
||||
return []
|
||||
|
||||
# 清理和规范化 match_responses 中的元素
|
||||
normalized_responses = []
|
||||
for item in match_responses:
|
||||
if isinstance(item, dict):
|
||||
# 已经是字典,直接添加
|
||||
normalized_responses.append(item)
|
||||
elif isinstance(item, str):
|
||||
# 如果是字符串,尝试解析为 JSON
|
||||
try:
|
||||
parsed = json.loads(item)
|
||||
if isinstance(parsed, dict):
|
||||
normalized_responses.append(parsed)
|
||||
elif isinstance(parsed, list):
|
||||
# 如果是列表,递归处理
|
||||
for sub_item in parsed:
|
||||
if isinstance(sub_item, dict):
|
||||
normalized_responses.append(sub_item)
|
||||
else:
|
||||
logger.debug(f"跳过非字典类型的子元素: {type(sub_item)}, 内容: {sub_item}")
|
||||
else:
|
||||
logger.debug(f"跳过无法转换为字典的字符串元素: {item}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.debug(f"跳过无法解析为JSON的字符串元素: {item}")
|
||||
elif isinstance(item, list):
|
||||
# 如果是列表,展开并处理其中的字典
|
||||
for sub_item in item:
|
||||
if isinstance(sub_item, dict):
|
||||
normalized_responses.append(sub_item)
|
||||
elif isinstance(sub_item, str):
|
||||
# 尝试解析字符串
|
||||
try:
|
||||
parsed = json.loads(sub_item)
|
||||
if isinstance(parsed, dict):
|
||||
normalized_responses.append(parsed)
|
||||
else:
|
||||
logger.debug(f"跳过非字典类型的解析结果: {type(parsed)}, 内容: {parsed}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.debug(f"跳过无法解析为JSON的字符串子元素: {sub_item}")
|
||||
else:
|
||||
logger.debug(f"跳过非字典类型的列表元素: {type(sub_item)}, 内容: {sub_item}")
|
||||
else:
|
||||
logger.debug(f"跳过无法处理的元素类型: {type(item)}, 内容: {item}")
|
||||
|
||||
match_responses = normalized_responses
|
||||
|
||||
matched_expressions = []
|
||||
used_pair_indices = set() # 用于跟踪已经使用的expression_pair索引
|
||||
|
||||
logger.debug(f"规范化后的 match_responses 类型: {type(match_responses)}, 长度: {len(match_responses)}")
|
||||
logger.debug(f"规范化后的 match_responses 内容: {match_responses}")
|
||||
|
||||
for match_response in match_responses:
|
||||
try:
|
||||
# 检查 match_response 的类型(此时应该都是字典)
|
||||
if not isinstance(match_response, dict):
|
||||
logger.error(f"match_response 不是字典类型: {type(match_response)}, 内容: {match_response}")
|
||||
continue
|
||||
|
||||
# 获取表达方式序号
|
||||
if "expression_pair" not in match_response:
|
||||
logger.error(f"match_response 缺少 'expression_pair' 字段: {match_response}")
|
||||
continue
|
||||
|
||||
pair_index = int(match_response["expression_pair"]) - 1 # 转换为0-based索引
|
||||
|
||||
# 检查索引是否有效且未被使用过
|
||||
if 0 <= pair_index < len(expression_pairs) and pair_index not in used_pair_indices:
|
||||
situation, style = expression_pairs[pair_index]
|
||||
context = match_response.get("context", "")
|
||||
matched_expressions.append((situation, style, context))
|
||||
used_pair_indices.add(pair_index) # 标记该索引已使用
|
||||
logger.debug(f"成功匹配表达方式 {pair_index + 1}: {situation} -> {style}")
|
||||
elif pair_index in used_pair_indices:
|
||||
logger.debug(f"跳过重复的表达方式 {pair_index + 1}")
|
||||
except (ValueError, KeyError, IndexError, TypeError) as e:
|
||||
logger.error(f"解析匹配条目失败: {e}, 条目: {match_response}")
|
||||
continue
|
||||
|
||||
return matched_expressions
|
||||
|
||||
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str, str]]]:
|
||||
async def learn_expression(self, num: int = 10, timestamp_start: Optional[float] = None) -> Optional[List[Tuple[str, str, str]]]:
|
||||
"""从指定聊天流学习表达方式
|
||||
|
||||
Args:
|
||||
|
|
@ -414,10 +224,8 @@ class ExpressionLearner:
|
|||
if not random_msg or random_msg == []:
|
||||
return None
|
||||
|
||||
# 学习用
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg)
|
||||
# 溯源用
|
||||
random_msg_match_str: str = await build_bare_messages(random_msg)
|
||||
# 学习用(开启行编号,便于溯源)
|
||||
random_msg_str: str = await build_anonymous_messages(random_msg, show_ids=True)
|
||||
|
||||
prompt: str = await global_prompt_manager.format_prompt(
|
||||
"learn_style_prompt",
|
||||
|
|
@ -432,83 +240,107 @@ class ExpressionLearner:
|
|||
except Exception as e:
|
||||
logger.error(f"学习表达方式失败,模型生成出错: {e}")
|
||||
return None
|
||||
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
|
||||
|
||||
# 解析 LLM 返回的表达方式列表(包含来源行编号)
|
||||
expressions: List[Tuple[str, str, str]] = self.parse_expression_response(response)
|
||||
expressions = self._filter_self_reference_styles(expressions)
|
||||
if not expressions:
|
||||
logger.info("过滤后没有可用的表达方式(style 与机器人名称重复)")
|
||||
return None
|
||||
# logger.debug(f"学习{type_str}的response: {response}")
|
||||
|
||||
# 对表达方式溯源
|
||||
matched_expressions: List[Tuple[str, str, str]] = await self.match_expression_context(
|
||||
expressions, random_msg_match_str
|
||||
)
|
||||
# 为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
bare_lines: List[Tuple[int, str]] = self._build_bare_lines(random_msg)
|
||||
# 将 matched_expressions 结合上一句 up_content(若不存在上一句则跳过)
|
||||
filtered_with_up: List[Tuple[str, str, str, str]] = [] # (situation, style, context, up_content)
|
||||
for situation, style, context in matched_expressions:
|
||||
# 在 bare_lines 中找到第一处相似度达到85%的行
|
||||
pos = None
|
||||
for i, (_, c) in enumerate(bare_lines):
|
||||
similarity = calculate_similarity(c, context)
|
||||
if similarity >= 0.85: # 85%相似度阈值
|
||||
pos = i
|
||||
break
|
||||
# 直接根据 source_id 在 random_msg 中溯源,获取 context
|
||||
filtered_expressions: List[Tuple[str, str, str]] = [] # (situation, style, context)
|
||||
|
||||
if pos is None or pos == 0:
|
||||
# 没有匹配到目标句或没有上一句,跳过该表达
|
||||
for situation, style, source_id in expressions:
|
||||
source_id_str = (source_id or "").strip()
|
||||
if not source_id_str.isdigit():
|
||||
# 无效的来源行编号,跳过
|
||||
continue
|
||||
|
||||
# 检查目标句是否为空
|
||||
target_content = bare_lines[pos][1]
|
||||
if not target_content:
|
||||
# 目标句为空,跳过该表达
|
||||
line_index = int(source_id_str) - 1 # build_anonymous_messages 的编号从 1 开始
|
||||
if line_index < 0 or line_index >= len(random_msg):
|
||||
# 超出范围,跳过
|
||||
continue
|
||||
|
||||
prev_original_idx = bare_lines[pos - 1][0]
|
||||
up_content = filter_message_content(random_msg[prev_original_idx].processed_plain_text or "")
|
||||
if not up_content:
|
||||
# 上一句为空,跳过该表达
|
||||
# 当前行的原始内容
|
||||
current_msg = random_msg[line_index]
|
||||
context = filter_message_content(current_msg.processed_plain_text or "")
|
||||
if not context:
|
||||
continue
|
||||
filtered_with_up.append((situation, style, context, up_content))
|
||||
|
||||
if not filtered_with_up:
|
||||
filtered_expressions.append((situation, style, context))
|
||||
|
||||
if not filtered_expressions:
|
||||
return None
|
||||
|
||||
return filtered_with_up
|
||||
return filtered_expressions
|
||||
|
||||
def parse_expression_response(self, response: str) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
解析LLM返回的表达风格总结,每一行提取"当"和"使用"之间的内容,存储为(situation, style)元组
|
||||
解析 LLM 返回的表达风格总结 JSON,提取 (situation, style, source_id) 元组列表。
|
||||
|
||||
期望的 JSON 结构:
|
||||
[
|
||||
{"situation": "AAAAA", "style": "BBBBB", "source_id": "3"},
|
||||
...
|
||||
]
|
||||
"""
|
||||
if not response:
|
||||
return []
|
||||
|
||||
raw = response.strip()
|
||||
|
||||
# 尝试提取 ```json 代码块
|
||||
json_block_pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(json_block_pattern, raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
else:
|
||||
# 去掉可能存在的通用 ``` 包裹
|
||||
raw = re.sub(r"^```\s*", "", raw, flags=re.MULTILINE)
|
||||
raw = re.sub(r"```\s*$", "", raw, flags=re.MULTILINE)
|
||||
raw = raw.strip()
|
||||
|
||||
parsed = None
|
||||
expressions: List[Tuple[str, str, str]] = []
|
||||
for line in response.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
|
||||
try:
|
||||
# 优先尝试直接解析
|
||||
if raw.startswith("[") and raw.endswith("]"):
|
||||
parsed = json.loads(raw)
|
||||
else:
|
||||
repaired = repair_json(raw)
|
||||
if isinstance(repaired, str):
|
||||
parsed = json.loads(repaired)
|
||||
else:
|
||||
parsed = repaired
|
||||
except Exception:
|
||||
logger.error(f"解析表达风格 JSON 失败,原始响应:{response}")
|
||||
return []
|
||||
|
||||
if isinstance(parsed, dict):
|
||||
parsed_list = [parsed]
|
||||
elif isinstance(parsed, list):
|
||||
parsed_list = parsed
|
||||
else:
|
||||
logger.error(f"表达风格解析结果类型异常: {type(parsed)}, 内容: {parsed}")
|
||||
return []
|
||||
|
||||
for item in parsed_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
# 查找"当"和下一个引号
|
||||
idx_when = line.find('当"')
|
||||
if idx_when == -1:
|
||||
situation = str(item.get("situation", "")).strip()
|
||||
style = str(item.get("style", "")).strip()
|
||||
source_id = str(item.get("source_id", "")).strip()
|
||||
if not situation or not style or not source_id:
|
||||
# 三个字段必须同时存在
|
||||
continue
|
||||
idx_quote1 = idx_when + 1
|
||||
idx_quote2 = line.find('"', idx_quote1 + 1)
|
||||
if idx_quote2 == -1:
|
||||
continue
|
||||
situation = line[idx_quote1 + 1 : idx_quote2]
|
||||
# 查找"使用"
|
||||
idx_use = line.find('使用"', idx_quote2)
|
||||
if idx_use == -1:
|
||||
continue
|
||||
idx_quote3 = idx_use + 2
|
||||
idx_quote4 = line.find('"', idx_quote3 + 1)
|
||||
if idx_quote4 == -1:
|
||||
continue
|
||||
style = line[idx_quote3 + 1 : idx_quote4]
|
||||
expressions.append((situation, style))
|
||||
expressions.append((situation, style, source_id))
|
||||
|
||||
return expressions
|
||||
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
||||
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str, str]]) -> List[Tuple[str, str, str]]:
|
||||
"""
|
||||
过滤掉style与机器人名称/昵称重复的表达
|
||||
"""
|
||||
|
|
@ -525,12 +357,12 @@ class ExpressionLearner:
|
|||
|
||||
banned_casefold = {name.casefold() for name in banned_names if name}
|
||||
|
||||
filtered: List[Tuple[str, str]] = []
|
||||
filtered: List[Tuple[str, str, str]] = []
|
||||
removed_count = 0
|
||||
for situation, style in expressions:
|
||||
for situation, style, source_id in expressions:
|
||||
normalized_style = (style or "").strip()
|
||||
if normalized_style and normalized_style.casefold() not in banned_casefold:
|
||||
filtered.append((situation, style))
|
||||
filtered.append((situation, style, source_id))
|
||||
else:
|
||||
removed_count += 1
|
||||
|
||||
|
|
@ -544,7 +376,6 @@ class ExpressionLearner:
|
|||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
|
||||
|
|
@ -554,7 +385,6 @@ class ExpressionLearner:
|
|||
expr_obj=expr_obj,
|
||||
situation=situation,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
return
|
||||
|
|
@ -563,7 +393,6 @@ class ExpressionLearner:
|
|||
situation=situation,
|
||||
style=style,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
current_time=current_time,
|
||||
)
|
||||
|
||||
|
|
@ -572,7 +401,6 @@ class ExpressionLearner:
|
|||
situation: str,
|
||||
style: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = [situation]
|
||||
|
|
@ -587,7 +415,6 @@ class ExpressionLearner:
|
|||
chat_id=self.chat_id,
|
||||
create_date=current_time,
|
||||
context=context,
|
||||
up_content=up_content,
|
||||
)
|
||||
|
||||
async def _update_existing_expression(
|
||||
|
|
@ -595,7 +422,6 @@ class ExpressionLearner:
|
|||
expr_obj: Expression,
|
||||
situation: str,
|
||||
context: str,
|
||||
up_content: str,
|
||||
current_time: float,
|
||||
) -> None:
|
||||
content_list = self._parse_content_list(expr_obj.content_list)
|
||||
|
|
@ -605,7 +431,6 @@ class ExpressionLearner:
|
|||
expr_obj.count = (expr_obj.count or 0) + 1
|
||||
expr_obj.last_active_time = current_time
|
||||
expr_obj.context = context
|
||||
expr_obj.up_content = up_content
|
||||
|
||||
new_situation = await self._compose_situation_text(
|
||||
content_list=content_list,
|
||||
|
|
@ -651,27 +476,6 @@ class ExpressionLearner:
|
|||
logger.error(f"概括表达情境失败: {e}")
|
||||
return None
|
||||
|
||||
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
为每条消息构建精简文本列表,保留到原消息索引的映射
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
|
||||
Returns:
|
||||
List[Tuple[int, str]]: (original_index, bare_content) 元组列表
|
||||
"""
|
||||
bare_lines: List[Tuple[int, str]] = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
content = msg.processed_plain_text or ""
|
||||
content = filter_message_content(content)
|
||||
# 即使content为空也要记录,防止错位
|
||||
bare_lines.append((idx, content))
|
||||
|
||||
return bare_lines
|
||||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -429,15 +429,36 @@ class ChatHistorySummarizer:
|
|||
# 2. 构造编号后的消息字符串和参与者信息
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
||||
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices(失败时最多重试 3 次)
|
||||
existing_topics = list(self.topic_cache.keys())
|
||||
success, topic_to_indices = await self._analyze_topics_with_llm(
|
||||
numbered_lines=numbered_lines,
|
||||
existing_topics=existing_topics,
|
||||
)
|
||||
max_retries = 3
|
||||
attempt = 0
|
||||
success = False
|
||||
topic_to_indices: Dict[str, List[int]] = {}
|
||||
|
||||
while attempt < max_retries:
|
||||
attempt += 1
|
||||
success, topic_to_indices = await self._analyze_topics_with_llm(
|
||||
numbered_lines=numbered_lines,
|
||||
existing_topics=existing_topics,
|
||||
)
|
||||
|
||||
if success and topic_to_indices:
|
||||
if attempt > 1:
|
||||
logger.info(
|
||||
f"{self.log_prefix} 话题识别在第 {attempt} 次重试后成功 | 话题数: {len(topic_to_indices)}"
|
||||
)
|
||||
break
|
||||
|
||||
logger.warning(
|
||||
f"{self.log_prefix} 话题识别失败或无有效话题,第 {attempt} 次尝试失败"
|
||||
+ ("" if attempt >= max_retries else ",准备重试")
|
||||
)
|
||||
|
||||
if not success or not topic_to_indices:
|
||||
logger.warning(f"{self.log_prefix} 话题识别失败或无有效话题,本次检查忽略")
|
||||
logger.error(
|
||||
f"{self.log_prefix} 话题识别连续 {max_retries} 次失败或始终无有效话题,本次检查放弃"
|
||||
)
|
||||
# 即使识别失败,也认为是一次“检查”,但不更新 no_update_checks(保持原状)
|
||||
return
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, NonNegativeFloat
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
|
|
@ -21,7 +21,6 @@ class ExpressionResponse(BaseModel):
|
|||
situation: str
|
||||
style: str
|
||||
context: Optional[str]
|
||||
up_content: Optional[str]
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
create_date: Optional[float]
|
||||
|
|
@ -49,8 +48,7 @@ class ExpressionCreateRequest(BaseModel):
|
|||
|
||||
situation: str
|
||||
style: str
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
context: Optional[str] = NonNegativeFloat
|
||||
chat_id: str
|
||||
|
||||
|
||||
|
|
@ -60,7 +58,6 @@ class ExpressionUpdateRequest(BaseModel):
|
|||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
context: Optional[str] = None
|
||||
up_content: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
|
||||
|
||||
|
|
@ -102,7 +99,6 @@ def expression_to_response(expression: Expression) -> ExpressionResponse:
|
|||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
context=expression.context,
|
||||
up_content=expression.up_content,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
|
|
@ -310,7 +306,6 @@ async def create_expression(request: ExpressionCreateRequest, maibot_session: Op
|
|||
situation=request.situation,
|
||||
style=request.style,
|
||||
context=request.context,
|
||||
up_content=request.up_content,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
|
|
|
|||
Loading…
Reference in New Issue