mirror of https://github.com/Mai-with-u/MaiBot.git
Merge branch 'Mai-with-u:dev' into dev
commit
d23cc4a2e3
|
|
@ -227,6 +227,8 @@ class StatisticOutputTask(AsyncTask):
|
|||
"",
|
||||
self._format_model_classified_stat(stats["last_hour"]),
|
||||
"",
|
||||
self._format_module_classified_stat(stats["last_hour"]),
|
||||
"",
|
||||
self._format_chat_stat(stats["last_hour"]),
|
||||
self.SEP_LINE,
|
||||
"",
|
||||
|
|
@ -737,11 +739,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
"""
|
||||
if stats[TOTAL_REQ_CNT] <= 0:
|
||||
return ""
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f}"
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
output = [
|
||||
"按模型分类统计:",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒)",
|
||||
" 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
]
|
||||
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
|
||||
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
|
||||
|
|
@ -751,11 +755,19 @@ class StatisticOutputTask(AsyncTask):
|
|||
cost = stats[COST_BY_MODEL][model_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
formatted_out_tokens = _format_large_number(out_tokens)
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
|
|
@ -766,6 +778,62 @@ class StatisticOutputTask(AsyncTask):
|
|||
cost,
|
||||
avg_time_cost,
|
||||
std_time_cost,
|
||||
formatted_avg_count,
|
||||
formatted_avg_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
output.append("")
|
||||
return "\n".join(output)
|
||||
|
||||
@staticmethod
|
||||
def _format_module_classified_stat(stats: Dict[str, Any]) -> str:
|
||||
"""
|
||||
格式化按模块分类的统计数据
|
||||
"""
|
||||
if stats[TOTAL_REQ_CNT] <= 0:
|
||||
return ""
|
||||
data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
|
||||
|
||||
total_replies = stats.get(TOTAL_REPLY_CNT, 0)
|
||||
|
||||
output = [
|
||||
"按模块分类统计:",
|
||||
" 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
|
||||
]
|
||||
for module_name, count in sorted(stats[REQ_CNT_BY_MODULE].items()):
|
||||
name = f"{module_name[:29]}..." if len(module_name) > 32 else module_name
|
||||
in_tokens = stats[IN_TOK_BY_MODULE][module_name]
|
||||
out_tokens = stats[OUT_TOK_BY_MODULE][module_name]
|
||||
tokens = stats[TOTAL_TOK_BY_MODULE][module_name]
|
||||
cost = stats[COST_BY_MODULE][module_name]
|
||||
avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
|
||||
std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
|
||||
|
||||
# 计算每次回复平均值
|
||||
avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
|
||||
avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
|
||||
|
||||
# 格式化大数字
|
||||
formatted_count = _format_large_number(count)
|
||||
formatted_in_tokens = _format_large_number(in_tokens)
|
||||
formatted_out_tokens = _format_large_number(out_tokens)
|
||||
formatted_tokens = _format_large_number(tokens)
|
||||
formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
|
||||
formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
|
||||
|
||||
output.append(
|
||||
data_fmt.format(
|
||||
name,
|
||||
formatted_count,
|
||||
formatted_in_tokens,
|
||||
formatted_out_tokens,
|
||||
formatted_tokens,
|
||||
cost,
|
||||
avg_time_cost,
|
||||
std_time_cost,
|
||||
formatted_avg_count,
|
||||
formatted_avg_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -849,6 +917,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
# format总在线时间
|
||||
|
||||
# 按模型分类统计
|
||||
total_replies = stat_data.get(TOTAL_REPLY_CNT, 0)
|
||||
model_rows = "\n".join(
|
||||
[
|
||||
f"<tr>"
|
||||
|
|
@ -860,11 +929,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
f"<td>{stat_data[COST_BY_MODEL][model_name]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODEL]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按请求类型分类统计
|
||||
type_rows = "\n".join(
|
||||
|
|
@ -878,11 +949,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
f"<td>{stat_data[COST_BY_TYPE][req_type]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_TYPE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
# 按模块分类统计
|
||||
module_rows = "\n".join(
|
||||
|
|
@ -896,11 +969,13 @@ class StatisticOutputTask(AsyncTask):
|
|||
f"<td>{stat_data[COST_BY_MODULE][module_name]:.2f} ¥</td>"
|
||||
f"<td>{stat_data[AVG_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"<td>{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒</td>"
|
||||
f"<td>{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"<td>{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / total_replies, html=True) if total_replies > 0 else 'N/A'}</td>"
|
||||
f"</tr>"
|
||||
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
|
||||
]
|
||||
if stat_data[REQ_CNT_BY_MODULE]
|
||||
else ["<tr><td colspan='8' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
else ["<tr><td colspan='10' style='text-align: center; color: #999;'>暂无数据</td></tr>"]
|
||||
)
|
||||
|
||||
# 聊天消息统计
|
||||
|
|
@ -975,7 +1050,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<h2>按模型分类统计</h2>
|
||||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr></thead>
|
||||
<thead><tr><th>模型名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr></thead>
|
||||
<tbody>
|
||||
{model_rows}
|
||||
</tbody>
|
||||
|
|
@ -986,7 +1061,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
<tr><th>模块名称</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{module_rows}
|
||||
|
|
@ -998,7 +1073,7 @@ class StatisticOutputTask(AsyncTask):
|
|||
<div class=\"table-wrap\">
|
||||
<table>
|
||||
<thead>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th></tr>
|
||||
<tr><th>请求类型</th><th>调用次数</th><th>输入Token</th><th>输出Token</th><th>Token总量</th><th>累计花费</th><th>平均耗时(秒)</th><th>标准差(秒)</th><th>每次回复平均调用次数</th><th>每次回复平均Token数</th></tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{type_rows}
|
||||
|
|
|
|||
|
|
@ -130,10 +130,12 @@ class ImageManager:
|
|||
try:
|
||||
# 清理Images表中type为emoji的记录
|
||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||
|
||||
|
||||
# 清理ImageDescriptions表中type为emoji的记录
|
||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
|
||||
deleted_descriptions = (
|
||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||
)
|
||||
|
||||
total_deleted = deleted_images + deleted_descriptions
|
||||
if total_deleted > 0:
|
||||
logger.info(
|
||||
|
|
@ -162,6 +164,47 @@ class ImageManager:
|
|||
tag_str = ",".join(emotion_list)
|
||||
return f"[表情包:{tag_str}]"
|
||||
|
||||
async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
|
||||
"""如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录
|
||||
|
||||
Args:
|
||||
image_base64: 图片的base64编码
|
||||
image_hash: 图片的MD5哈希值
|
||||
image_format: 图片格式
|
||||
"""
|
||||
if not global_config.emoji.steal_emoji:
|
||||
return
|
||||
|
||||
try:
|
||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||
|
||||
# 检查是否已存在该表情包(通过哈希值)
|
||||
emoji_manager = get_emoji_manager()
|
||||
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
||||
if existing_emoji:
|
||||
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
|
||||
return
|
||||
|
||||
# 生成文件名:使用哈希值前8位 + 格式
|
||||
filename = f"{image_hash[:8]}.{image_format}"
|
||||
file_path = os.path.join(EMOJI_DIR, filename)
|
||||
|
||||
# 检查文件是否已存在(可能之前保存过但未注册)
|
||||
if not os.path.exists(file_path):
|
||||
# 保存文件
|
||||
if base64_to_image(image_base64, file_path):
|
||||
logger.info(f"[自动保存] 表情包已保存到 {file_path} (Hash: {image_hash[:8]}...)")
|
||||
else:
|
||||
logger.warning(f"[自动保存] 保存表情包文件失败: {file_path}")
|
||||
else:
|
||||
logger.debug(f"[自动保存] 表情包文件已存在,跳过: {file_path}")
|
||||
except Exception as save_error:
|
||||
logger.warning(f"[自动保存] 保存表情包文件时出错: {save_error}")
|
||||
|
||||
async def get_emoji_description(self, image_base64: str) -> str:
|
||||
"""获取表情包描述,优先使用EmojiDescriptionCache表中的缓存数据"""
|
||||
try:
|
||||
|
|
@ -191,16 +234,18 @@ class ImageManager:
|
|||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
||||
if cache_record:
|
||||
# 优先使用情感标签,如果没有则使用详细描述
|
||||
result_text = ""
|
||||
if cache_record.emotion_tags:
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||
)
|
||||
return f"[表情包:{cache_record.emotion_tags}]"
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
||||
elif cache_record.description:
|
||||
logger.info(
|
||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||
)
|
||||
return f"[表情包:{cache_record.description}]"
|
||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
||||
result_text = f"[表情包:{cache_record.description}]"
|
||||
|
||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
||||
if result_text:
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
return result_text
|
||||
except Exception as e:
|
||||
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
||||
|
||||
|
|
@ -292,6 +337,9 @@ class ImageManager:
|
|||
except Exception as e:
|
||||
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
|
||||
|
||||
# 如果启用了steal_emoji,自动保存表情包文件到data/emoji目录
|
||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||
|
||||
return f"[表情包:{final_emotion}]"
|
||||
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
|
|||
|
||||
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
|
||||
MMC_VERSION = "0.11.6-snapshot.1"
|
||||
MMC_VERSION = "0.11.6"
|
||||
|
||||
|
||||
def get_key_comment(toml_table, key):
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import json
|
|||
import time
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
|
@ -315,9 +316,7 @@ class ChatHistorySummarizer:
|
|||
before_count = len(self.current_batch.messages)
|
||||
self.current_batch.messages.extend(new_messages)
|
||||
self.current_batch.end_time = current_time
|
||||
logger.info(
|
||||
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
||||
)
|
||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
||||
# 更新批次后持久化
|
||||
self._persist_topic_cache()
|
||||
else:
|
||||
|
|
@ -363,18 +362,20 @@ class ChatHistorySummarizer:
|
|||
else:
|
||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||
|
||||
logger.info(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
||||
)
|
||||
|
||||
# 检查“话题检查”触发条件
|
||||
should_check = False
|
||||
|
||||
# 条件1: 消息数量 >= 100,触发一次检查
|
||||
if message_count >= 50:
|
||||
if message_count >= 80:
|
||||
should_check = True
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: 100条)")
|
||||
|
||||
# 条件2: 距离上一次检查 > 3600 秒(1小时),触发一次检查
|
||||
elif time_since_last_check > 1200:
|
||||
elif time_since_last_check > 2400:
|
||||
should_check = True
|
||||
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 1小时)")
|
||||
|
||||
|
|
@ -413,7 +414,7 @@ class ChatHistorySummarizer:
|
|||
# 说明 bot 没有参与这段对话,不应该记录
|
||||
bot_user_id = str(global_config.bot.qq_account)
|
||||
has_bot_message = False
|
||||
|
||||
|
||||
for msg in messages:
|
||||
if msg.user_info.user_id == bot_user_id:
|
||||
has_bot_message = True
|
||||
|
|
@ -426,9 +427,7 @@ class ChatHistorySummarizer:
|
|||
return
|
||||
|
||||
# 2. 构造编号后的消息字符串和参与者信息
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
|
||||
self._build_numbered_messages_for_llm(messages)
|
||||
)
|
||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
||||
|
||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices
|
||||
existing_topics = list(self.topic_cache.keys())
|
||||
|
|
@ -485,11 +484,11 @@ class ChatHistorySummarizer:
|
|||
topics_to_finalize: List[str] = []
|
||||
for topic, item in self.topic_cache.items():
|
||||
if item.no_update_checks >= 3:
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 连续 5 次检查无新增内容,触发打包存储")
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 连续 3 次检查无新增内容,触发打包存储")
|
||||
topics_to_finalize.append(topic)
|
||||
continue
|
||||
if len(item.messages) > 8:
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 30,触发打包存储")
|
||||
if len(item.messages) > 5:
|
||||
logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 4,触发打包存储")
|
||||
topics_to_finalize.append(topic)
|
||||
|
||||
for topic in topics_to_finalize:
|
||||
|
|
@ -590,7 +589,9 @@ class ChatHistorySummarizer:
|
|||
if not numbered_lines:
|
||||
return False, {}
|
||||
|
||||
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
history_topics_block = (
|
||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||
)
|
||||
messages_block = "\n".join(numbered_lines)
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
|
|
@ -606,19 +607,42 @@ class ChatHistorySummarizer:
|
|||
max_tokens=800,
|
||||
)
|
||||
|
||||
import re
|
||||
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
|
||||
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
|
||||
|
||||
json_str = response.strip()
|
||||
# 移除可能的 markdown 代码块标记
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
# 尝试从响应中提取JSON代码块
|
||||
json_str = None
|
||||
json_pattern = r"```json\s*(.*?)\s*```"
|
||||
matches = re.findall(json_pattern, response, re.DOTALL)
|
||||
|
||||
if matches:
|
||||
# 找到JSON代码块,使用第一个匹配
|
||||
json_str = matches[0].strip()
|
||||
else:
|
||||
# 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
|
||||
# 查找第一个 [ 和最后一个 ]
|
||||
start_idx = response.find('[')
|
||||
end_idx = response.rfind(']')
|
||||
if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
|
||||
json_str = response[start_idx:end_idx + 1].strip()
|
||||
else:
|
||||
# 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
|
||||
json_str = response.strip()
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
|
||||
# 尝试直接解析为 JSON 数组
|
||||
result = json.loads(json_str)
|
||||
# 使用json_repair修复可能的JSON错误
|
||||
if json_str:
|
||||
try:
|
||||
repaired_json = repair_json(json_str)
|
||||
result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json
|
||||
except Exception as repair_error:
|
||||
# 如果repair失败,尝试直接解析
|
||||
logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}")
|
||||
result = json.loads(json_str)
|
||||
else:
|
||||
raise ValueError("无法从响应中提取JSON内容")
|
||||
|
||||
if not isinstance(result, list):
|
||||
logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}")
|
||||
|
|
@ -723,41 +747,30 @@ class ChatHistorySummarizer:
|
|||
)
|
||||
|
||||
# 解析JSON响应
|
||||
import re
|
||||
|
||||
# 移除可能的markdown代码块标记
|
||||
json_str = response.strip()
|
||||
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
|
||||
json_str = json_str.strip()
|
||||
|
||||
# 尝试找到JSON对象的开始和结束位置
|
||||
# 查找第一个 { 和最后一个匹配的 }
|
||||
# 查找JSON对象的开始与结束
|
||||
start_idx = json_str.find("{")
|
||||
if start_idx == -1:
|
||||
raise ValueError("未找到JSON对象开始标记")
|
||||
|
||||
# 从后往前查找最后一个 }
|
||||
end_idx = json_str.rfind("}")
|
||||
if end_idx == -1 or end_idx <= start_idx:
|
||||
raise ValueError("未找到JSON对象结束标记")
|
||||
logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复")
|
||||
extracted_json = json_str[start_idx:]
|
||||
else:
|
||||
extracted_json = json_str[start_idx : end_idx + 1]
|
||||
|
||||
# 提取JSON字符串
|
||||
json_str = json_str[start_idx : end_idx + 1]
|
||||
|
||||
# 尝试解析JSON
|
||||
try:
|
||||
result = json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
# 如果解析失败,尝试修复字符串值中的中文引号
|
||||
# 简单方法:将字符串值中的中文引号替换为转义的英文引号
|
||||
# 使用状态机方法:遍历字符串,在字符串值内部替换中文引号
|
||||
fixed_chars = []
|
||||
def _parse_with_quote_fix(payload: str) -> Dict[str, Any]:
|
||||
fixed_chars: List[str] = []
|
||||
in_string = False
|
||||
escape_next = False
|
||||
i = 0
|
||||
while i < len(json_str):
|
||||
char = json_str[i]
|
||||
while i < len(payload):
|
||||
char = payload[i]
|
||||
if escape_next:
|
||||
fixed_chars.append(char)
|
||||
escape_next = False
|
||||
|
|
@ -767,16 +780,28 @@ class ChatHistorySummarizer:
|
|||
elif char == '"' and not escape_next:
|
||||
fixed_chars.append(char)
|
||||
in_string = not in_string
|
||||
elif in_string and (char == '"' or char == '"'):
|
||||
elif in_string and char in {"“", "”"}:
|
||||
# 在字符串值内部,将中文引号替换为转义的英文引号
|
||||
fixed_chars.append('\\"')
|
||||
else:
|
||||
fixed_chars.append(char)
|
||||
i += 1
|
||||
|
||||
json_str = "".join(fixed_chars)
|
||||
# 再次尝试解析
|
||||
result = json.loads(json_str)
|
||||
repaired = "".join(fixed_chars)
|
||||
return json.loads(repaired)
|
||||
|
||||
try:
|
||||
result = json.loads(extracted_json)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
repaired_json = repair_json(extracted_json)
|
||||
if isinstance(repaired_json, str):
|
||||
result = json.loads(repaired_json)
|
||||
else:
|
||||
result = repaired_json
|
||||
except Exception as repair_error:
|
||||
logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}")
|
||||
result = _parse_with_quote_fix(extracted_json)
|
||||
|
||||
keywords = result.get("keywords", [])
|
||||
summary = result.get("summary", "无概括")
|
||||
|
|
@ -896,3 +921,4 @@ class ChatHistorySummarizer:
|
|||
|
||||
|
||||
init_prompt()
|
||||
|
||||
|
|
|
|||
|
|
@ -212,6 +212,22 @@ class PluginBase(ABC):
|
|||
|
||||
return value
|
||||
|
||||
def _format_toml_value(self, value: Any) -> str:
|
||||
"""将Python值格式化为合法的TOML字符串"""
|
||||
if isinstance(value, str):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
if isinstance(value, bool):
|
||||
return str(value).lower()
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
if isinstance(value, list):
|
||||
inner = ", ".join(self._format_toml_value(item) for item in value)
|
||||
return f"[{inner}]"
|
||||
if isinstance(value, dict):
|
||||
items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()]
|
||||
return "{ " + ", ".join(items) + " }"
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
|
||||
def _generate_and_save_default_config(self, config_file_path: str):
|
||||
"""根据插件的Schema生成并保存默认配置文件"""
|
||||
if not self.config_schema:
|
||||
|
|
@ -251,12 +267,7 @@ class PluginBase(ABC):
|
|||
|
||||
# 添加字段值
|
||||
value = field.default
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
|
|
@ -429,19 +440,7 @@ class PluginBase(ABC):
|
|||
|
||||
# 添加字段值(使用迁移后的值)
|
||||
value = section_data.get(field_name, field.default)
|
||||
if isinstance(value, str):
|
||||
toml_str += f'{field_name} = "{value}"\n'
|
||||
elif isinstance(value, bool):
|
||||
toml_str += f"{field_name} = {str(value).lower()}\n"
|
||||
elif isinstance(value, list):
|
||||
# 格式化列表
|
||||
if all(isinstance(item, str) for item in value):
|
||||
formatted_list = "[" + ", ".join(f'"{item}"' for item in value) + "]"
|
||||
else:
|
||||
formatted_list = str(value)
|
||||
toml_str += f"{field_name} = {formatted_list}\n"
|
||||
else:
|
||||
toml_str += f"{field_name} = {value}\n"
|
||||
toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
|
||||
|
||||
toml_str += "\n"
|
||||
toml_str += "\n"
|
||||
|
|
|
|||
Loading…
Reference in New Issue