diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py
index 9b5497e9..fa865d48 100644
--- a/src/chat/utils/statistic.py
+++ b/src/chat/utils/statistic.py
@@ -227,6 +227,8 @@ class StatisticOutputTask(AsyncTask):
"",
self._format_model_classified_stat(stats["last_hour"]),
"",
+ self._format_module_classified_stat(stats["last_hour"]),
+ "",
self._format_chat_stat(stats["last_hour"]),
self.SEP_LINE,
"",
@@ -737,11 +739,13 @@ class StatisticOutputTask(AsyncTask):
"""
if stats[TOTAL_REQ_CNT] <= 0:
return ""
- data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f}"
+ data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
+ total_replies = stats.get(TOTAL_REPLY_CNT, 0)
+
output = [
"按模型分类统计:",
- " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒)",
+ " 模型名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
]
for model_name, count in sorted(stats[REQ_CNT_BY_MODEL].items()):
name = f"{model_name[:29]}..." if len(model_name) > 32 else model_name
@@ -751,11 +755,19 @@ class StatisticOutputTask(AsyncTask):
cost = stats[COST_BY_MODEL][model_name]
avg_time_cost = stats[AVG_TIME_COST_BY_MODEL][model_name]
std_time_cost = stats[STD_TIME_COST_BY_MODEL][model_name]
+
+ # 计算每次回复平均值
+ avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
+ avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
+
# 格式化大数字
formatted_count = _format_large_number(count)
formatted_in_tokens = _format_large_number(in_tokens)
formatted_out_tokens = _format_large_number(out_tokens)
formatted_tokens = _format_large_number(tokens)
+ formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
+ formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
+
output.append(
data_fmt.format(
name,
@@ -766,6 +778,62 @@ class StatisticOutputTask(AsyncTask):
cost,
avg_time_cost,
std_time_cost,
+ formatted_avg_count,
+ formatted_avg_tokens,
+ )
+ )
+
+ output.append("")
+ return "\n".join(output)
+
+ @staticmethod
+ def _format_module_classified_stat(stats: Dict[str, Any]) -> str:
+ """
+ 格式化按模块分类的统计数据
+ """
+ if stats[TOTAL_REQ_CNT] <= 0:
+ return ""
+ data_fmt = "{:<32} {:>10} {:>12} {:>12} {:>12} {:>9.2f}¥ {:>10.1f} {:>10.1f} {:>12} {:>12}"
+
+ total_replies = stats.get(TOTAL_REPLY_CNT, 0)
+
+ output = [
+ "按模块分类统计:",
+ " 模块名称 调用次数 输入Token 输出Token Token总量 累计花费 平均耗时(秒) 标准差(秒) 每次回复平均调用次数 每次回复平均Token数",
+ ]
+ for module_name, count in sorted(stats[REQ_CNT_BY_MODULE].items()):
+ name = f"{module_name[:29]}..." if len(module_name) > 32 else module_name
+ in_tokens = stats[IN_TOK_BY_MODULE][module_name]
+ out_tokens = stats[OUT_TOK_BY_MODULE][module_name]
+ tokens = stats[TOTAL_TOK_BY_MODULE][module_name]
+ cost = stats[COST_BY_MODULE][module_name]
+ avg_time_cost = stats[AVG_TIME_COST_BY_MODULE][module_name]
+ std_time_cost = stats[STD_TIME_COST_BY_MODULE][module_name]
+
+ # 计算每次回复平均值
+ avg_count_per_reply = count / total_replies if total_replies > 0 else 0.0
+ avg_tokens_per_reply = tokens / total_replies if total_replies > 0 else 0.0
+
+ # 格式化大数字
+ formatted_count = _format_large_number(count)
+ formatted_in_tokens = _format_large_number(in_tokens)
+ formatted_out_tokens = _format_large_number(out_tokens)
+ formatted_tokens = _format_large_number(tokens)
+ formatted_avg_count = _format_large_number(avg_count_per_reply) if total_replies > 0 else "N/A"
+ formatted_avg_tokens = _format_large_number(avg_tokens_per_reply) if total_replies > 0 else "N/A"
+
+ output.append(
+ data_fmt.format(
+ name,
+ formatted_count,
+ formatted_in_tokens,
+ formatted_out_tokens,
+ formatted_tokens,
+ cost,
+ avg_time_cost,
+ std_time_cost,
+ formatted_avg_count,
+ formatted_avg_tokens,
)
)
@@ -849,6 +917,7 @@ class StatisticOutputTask(AsyncTask):
# format总在线时间
# 按模型分类统计
+ total_replies = stat_data.get(TOTAL_REPLY_CNT, 0)
model_rows = "\n".join(
[
f"
"
@@ -860,11 +929,13 @@ class StatisticOutputTask(AsyncTask):
f"| {stat_data[COST_BY_MODEL][model_name]:.2f} ¥ | "
f"{stat_data[AVG_TIME_COST_BY_MODEL][model_name]:.1f} 秒 | "
f"{stat_data[STD_TIME_COST_BY_MODEL][model_name]:.1f} 秒 | "
+ f"{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'} | "
+ f"{_format_large_number(stat_data[TOTAL_TOK_BY_MODEL][model_name] / total_replies, html=True) if total_replies > 0 else 'N/A'} | "
f"
"
for model_name, count in sorted(stat_data[REQ_CNT_BY_MODEL].items())
]
if stat_data[REQ_CNT_BY_MODEL]
- else ["| 暂无数据 |
"]
+ else ["| 暂无数据 |
"]
)
# 按请求类型分类统计
type_rows = "\n".join(
@@ -878,11 +949,13 @@ class StatisticOutputTask(AsyncTask):
f"{stat_data[COST_BY_TYPE][req_type]:.2f} ¥ | "
f"{stat_data[AVG_TIME_COST_BY_TYPE][req_type]:.1f} 秒 | "
f"{stat_data[STD_TIME_COST_BY_TYPE][req_type]:.1f} 秒 | "
+ f"{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'} | "
+ f"{_format_large_number(stat_data[TOTAL_TOK_BY_TYPE][req_type] / total_replies, html=True) if total_replies > 0 else 'N/A'} | "
f""
for req_type, count in sorted(stat_data[REQ_CNT_BY_TYPE].items())
]
if stat_data[REQ_CNT_BY_TYPE]
- else ["| 暂无数据 |
"]
+ else ["| 暂无数据 |
"]
)
# 按模块分类统计
module_rows = "\n".join(
@@ -896,11 +969,13 @@ class StatisticOutputTask(AsyncTask):
f"{stat_data[COST_BY_MODULE][module_name]:.2f} ¥ | "
f"{stat_data[AVG_TIME_COST_BY_MODULE][module_name]:.1f} 秒 | "
f"{stat_data[STD_TIME_COST_BY_MODULE][module_name]:.1f} 秒 | "
+ f"{_format_large_number(count / total_replies, html=True) if total_replies > 0 else 'N/A'} | "
+ f"{_format_large_number(stat_data[TOTAL_TOK_BY_MODULE][module_name] / total_replies, html=True) if total_replies > 0 else 'N/A'} | "
f""
for module_name, count in sorted(stat_data[REQ_CNT_BY_MODULE].items())
]
if stat_data[REQ_CNT_BY_MODULE]
- else ["| 暂无数据 |
"]
+ else ["| 暂无数据 |
"]
)
# 聊天消息统计
@@ -975,7 +1050,7 @@ class StatisticOutputTask(AsyncTask):
按模型分类统计
- | 模型名称 | 调用次数 | 输入Token | 输出Token | Token总量 | 累计花费 | 平均耗时(秒) | 标准差(秒) |
+ | 模型名称 | 调用次数 | 输入Token | 输出Token | Token总量 | 累计花费 | 平均耗时(秒) | 标准差(秒) | 每次回复平均调用次数 | 每次回复平均Token数 |
{model_rows}
@@ -986,7 +1061,7 @@ class StatisticOutputTask(AsyncTask):
- | 模块名称 | 调用次数 | 输入Token | 输出Token | Token总量 | 累计花费 | 平均耗时(秒) | 标准差(秒) |
+ | 模块名称 | 调用次数 | 输入Token | 输出Token | Token总量 | 累计花费 | 平均耗时(秒) | 标准差(秒) | 每次回复平均调用次数 | 每次回复平均Token数 |
{module_rows}
@@ -998,7 +1073,7 @@ class StatisticOutputTask(AsyncTask):
- | 请求类型 | 调用次数 | 输入Token | 输出Token | Token总量 | 累计花费 | 平均耗时(秒) | 标准差(秒) |
+ | 请求类型 | 调用次数 | 输入Token | 输出Token | Token总量 | 累计花费 | 平均耗时(秒) | 标准差(秒) | 每次回复平均调用次数 | 每次回复平均Token数 |
{type_rows}
diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py
index a244fecb..56a3ae0f 100644
--- a/src/chat/utils/utils_image.py
+++ b/src/chat/utils/utils_image.py
@@ -130,10 +130,12 @@ class ImageManager:
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
-
+
# 清理ImageDescriptions表中type为emoji的记录
- deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
-
+ deleted_descriptions = (
+ ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
+ )
+
total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0:
logger.info(
@@ -162,6 +164,47 @@ class ImageManager:
tag_str = ",".join(emotion_list)
return f"[表情包:{tag_str}]"
+ async def _save_emoji_file_if_needed(self, image_base64: str, image_hash: str, image_format: str) -> None:
+ """如果启用了steal_emoji且表情包未注册,保存文件到data/emoji目录
+
+ Args:
+ image_base64: 图片的base64编码
+ image_hash: 图片的MD5哈希值
+ image_format: 图片格式
+ """
+ if not global_config.emoji.steal_emoji:
+ return
+
+ try:
+ from src.chat.emoji_system.emoji_manager import EMOJI_DIR
+ from src.chat.emoji_system.emoji_manager import get_emoji_manager
+
+ # 确保目录存在
+ os.makedirs(EMOJI_DIR, exist_ok=True)
+
+ # 检查是否已存在该表情包(通过哈希值)
+ emoji_manager = get_emoji_manager()
+ existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
+ if existing_emoji:
+ logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
+ return
+
+ # 生成文件名:使用哈希值前8位 + 格式
+ filename = f"{image_hash[:8]}.{image_format}"
+ file_path = os.path.join(EMOJI_DIR, filename)
+
+ # 检查文件是否已存在(可能之前保存过但未注册)
+ if not os.path.exists(file_path):
+ # 保存文件
+ if base64_to_image(image_base64, file_path):
+ logger.info(f"[自动保存] 表情包已保存到 {file_path} (Hash: {image_hash[:8]}...)")
+ else:
+ logger.warning(f"[自动保存] 保存表情包文件失败: {file_path}")
+ else:
+ logger.debug(f"[自动保存] 表情包文件已存在,跳过: {file_path}")
+ except Exception as save_error:
+ logger.warning(f"[自动保存] 保存表情包文件时出错: {save_error}")
+
async def get_emoji_description(self, image_base64: str) -> str:
"""获取表情包描述,优先使用EmojiDescriptionCache表中的缓存数据"""
try:
@@ -191,16 +234,18 @@ class ImageManager:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
if cache_record:
# 优先使用情感标签,如果没有则使用详细描述
+ result_text = ""
if cache_record.emotion_tags:
- logger.info(
- f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
- )
- return f"[表情包:{cache_record.emotion_tags}]"
+ logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
+ result_text = f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description:
- logger.info(
- f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
- )
- return f"[表情包:{cache_record.description}]"
+ logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
+ result_text = f"[表情包:{cache_record.description}]"
+
+ # 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
+ if result_text:
+ await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
+ return result_text
except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
@@ -292,6 +337,9 @@ class ImageManager:
except Exception as e:
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
+ # 如果启用了steal_emoji,自动保存表情包文件到data/emoji目录
+ await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
+
return f"[表情包:{final_emotion}]"
except Exception as e:
diff --git a/src/config/config.py b/src/config/config.py
index 3f5db816..9342cbb4 100644
--- a/src/config/config.py
+++ b/src/config/config.py
@@ -57,7 +57,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/
-MMC_VERSION = "0.11.6-snapshot.1"
+MMC_VERSION = "0.11.6"
def get_key_comment(toml_table, key):
diff --git a/src/hippo_memorizer/chat_history_summarizer.py b/src/hippo_memorizer/chat_history_summarizer.py
index 840f349d..357f46dd 100644
--- a/src/hippo_memorizer/chat_history_summarizer.py
+++ b/src/hippo_memorizer/chat_history_summarizer.py
@@ -8,8 +8,9 @@ import json
import time
import re
from pathlib import Path
-from typing import Dict, List, Optional, Set
+from typing import Any, Dict, List, Optional, Set
from dataclasses import dataclass, field
+from json_repair import repair_json
from src.common.logger import get_logger
from src.common.data_models.database_data_model import DatabaseMessages
@@ -315,9 +316,7 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time
- logger.info(
- f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
- )
+ logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
# 更新批次后持久化
self._persist_topic_cache()
else:
@@ -363,18 +362,20 @@ class ChatHistorySummarizer:
else:
time_str = f"{time_since_last_check / 3600:.1f}小时"
- logger.info(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
+ logger.info(
+ f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
+ )
# 检查“话题检查”触发条件
should_check = False
# 条件1: 消息数量 >= 100,触发一次检查
- if message_count >= 50:
+ if message_count >= 80:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 消息数量达到 {message_count} 条(阈值: 100条)")
# 条件2: 距离上一次检查 > 3600 秒(1小时),触发一次检查
- elif time_since_last_check > 1200:
+ elif time_since_last_check > 2400:
should_check = True
logger.info(f"{self.log_prefix} 触发检查条件: 距上次检查 {time_str}(阈值: 1小时)")
@@ -413,7 +414,7 @@ class ChatHistorySummarizer:
# 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False
-
+
for msg in messages:
if msg.user_info.user_id == bot_user_id:
has_bot_message = True
@@ -426,9 +427,7 @@ class ChatHistorySummarizer:
return
# 2. 构造编号后的消息字符串和参与者信息
- numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
- self._build_numbered_messages_for_llm(messages)
- )
+ numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
# 3. 调用 LLM 识别话题,并得到 topic -> indices
existing_topics = list(self.topic_cache.keys())
@@ -485,11 +484,11 @@ class ChatHistorySummarizer:
topics_to_finalize: List[str] = []
for topic, item in self.topic_cache.items():
if item.no_update_checks >= 3:
- logger.info(f"{self.log_prefix} 话题[{topic}] 连续 5 次检查无新增内容,触发打包存储")
+ logger.info(f"{self.log_prefix} 话题[{topic}] 连续 3 次检查无新增内容,触发打包存储")
topics_to_finalize.append(topic)
continue
- if len(item.messages) > 8:
- logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 30,触发打包存储")
+ if len(item.messages) > 5:
+ logger.info(f"{self.log_prefix} 话题[{topic}] 消息条数超过 4,触发打包存储")
topics_to_finalize.append(topic)
for topic in topics_to_finalize:
@@ -590,7 +589,9 @@ class ChatHistorySummarizer:
if not numbered_lines:
return False, {}
- history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
+ history_topics_block = (
+ "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
+ )
messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt(
@@ -606,19 +607,42 @@ class ChatHistorySummarizer:
max_tokens=800,
)
- import re
-
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
- json_str = response.strip()
- # 移除可能的 markdown 代码块标记
- json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
- json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
- json_str = json_str.strip()
+ # 尝试从响应中提取JSON代码块
+ json_str = None
+ json_pattern = r"```json\s*(.*?)\s*```"
+ matches = re.findall(json_pattern, response, re.DOTALL)
+
+ if matches:
+ # 找到JSON代码块,使用第一个匹配
+ json_str = matches[0].strip()
+ else:
+ # 如果没有找到代码块,尝试查找JSON数组的开始和结束位置
+ # 查找第一个 [ 和最后一个 ]
+ start_idx = response.find('[')
+ end_idx = response.rfind(']')
+ if start_idx != -1 and end_idx != -1 and end_idx > start_idx:
+ json_str = response[start_idx:end_idx + 1].strip()
+ else:
+ # 如果还是找不到,尝试直接使用整个响应(移除可能的markdown标记)
+ json_str = response.strip()
+ json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
+ json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
+ json_str = json_str.strip()
- # 尝试直接解析为 JSON 数组
- result = json.loads(json_str)
+ # 使用json_repair修复可能的JSON错误
+ if json_str:
+ try:
+ repaired_json = repair_json(json_str)
+ result = json.loads(repaired_json) if isinstance(repaired_json, str) else repaired_json
+ except Exception as repair_error:
+ # 如果repair失败,尝试直接解析
+ logger.warning(f"{self.log_prefix} JSON修复失败,尝试直接解析: {repair_error}")
+ result = json.loads(json_str)
+ else:
+ raise ValueError("无法从响应中提取JSON内容")
if not isinstance(result, list):
logger.error(f"{self.log_prefix} 话题识别返回的 JSON 不是列表: {result}")
@@ -723,41 +747,30 @@ class ChatHistorySummarizer:
)
# 解析JSON响应
- import re
-
- # 移除可能的markdown代码块标记
json_str = response.strip()
json_str = re.sub(r"^```json\s*", "", json_str, flags=re.MULTILINE)
json_str = re.sub(r"^```\s*", "", json_str, flags=re.MULTILINE)
json_str = json_str.strip()
- # 尝试找到JSON对象的开始和结束位置
- # 查找第一个 { 和最后一个匹配的 }
+ # 查找JSON对象的开始与结束
start_idx = json_str.find("{")
if start_idx == -1:
raise ValueError("未找到JSON对象开始标记")
- # 从后往前查找最后一个 }
end_idx = json_str.rfind("}")
if end_idx == -1 or end_idx <= start_idx:
- raise ValueError("未找到JSON对象结束标记")
+ logger.warning(f"{self.log_prefix} JSON缺少结束标记,尝试自动修复")
+ extracted_json = json_str[start_idx:]
+ else:
+ extracted_json = json_str[start_idx : end_idx + 1]
- # 提取JSON字符串
- json_str = json_str[start_idx : end_idx + 1]
-
- # 尝试解析JSON
- try:
- result = json.loads(json_str)
- except json.JSONDecodeError:
- # 如果解析失败,尝试修复字符串值中的中文引号
- # 简单方法:将字符串值中的中文引号替换为转义的英文引号
- # 使用状态机方法:遍历字符串,在字符串值内部替换中文引号
- fixed_chars = []
+ def _parse_with_quote_fix(payload: str) -> Dict[str, Any]:
+ fixed_chars: List[str] = []
in_string = False
escape_next = False
i = 0
- while i < len(json_str):
- char = json_str[i]
+ while i < len(payload):
+ char = payload[i]
if escape_next:
fixed_chars.append(char)
escape_next = False
@@ -767,16 +780,28 @@ class ChatHistorySummarizer:
elif char == '"' and not escape_next:
fixed_chars.append(char)
in_string = not in_string
- elif in_string and (char == '"' or char == '"'):
+ elif in_string and char in {"“", "”"}:
# 在字符串值内部,将中文引号替换为转义的英文引号
fixed_chars.append('\\"')
else:
fixed_chars.append(char)
i += 1
- json_str = "".join(fixed_chars)
- # 再次尝试解析
- result = json.loads(json_str)
+ repaired = "".join(fixed_chars)
+ return json.loads(repaired)
+
+ try:
+ result = json.loads(extracted_json)
+ except json.JSONDecodeError:
+ try:
+ repaired_json = repair_json(extracted_json)
+ if isinstance(repaired_json, str):
+ result = json.loads(repaired_json)
+ else:
+ result = repaired_json
+ except Exception as repair_error:
+ logger.warning(f"{self.log_prefix} repair_json 失败,使用引号修复: {repair_error}")
+ result = _parse_with_quote_fix(extracted_json)
keywords = result.get("keywords", [])
summary = result.get("summary", "无概括")
@@ -896,3 +921,4 @@ class ChatHistorySummarizer:
init_prompt()
+
diff --git a/src/plugin_system/base/plugin_base.py b/src/plugin_system/base/plugin_base.py
index 30d33ec2..f21c2a40 100644
--- a/src/plugin_system/base/plugin_base.py
+++ b/src/plugin_system/base/plugin_base.py
@@ -212,6 +212,22 @@ class PluginBase(ABC):
return value
+ def _format_toml_value(self, value: Any) -> str:
+ """将Python值格式化为合法的TOML字符串"""
+ if isinstance(value, str):
+ return json.dumps(value, ensure_ascii=False)
+ if isinstance(value, bool):
+ return str(value).lower()
+ if isinstance(value, (int, float)):
+ return str(value)
+ if isinstance(value, list):
+ inner = ", ".join(self._format_toml_value(item) for item in value)
+ return f"[{inner}]"
+ if isinstance(value, dict):
+ items = [f"{k} = {self._format_toml_value(v)}" for k, v in value.items()]
+ return "{ " + ", ".join(items) + " }"
+ return json.dumps(value, ensure_ascii=False)
+
def _generate_and_save_default_config(self, config_file_path: str):
"""根据插件的Schema生成并保存默认配置文件"""
if not self.config_schema:
@@ -251,12 +267,7 @@ class PluginBase(ABC):
# 添加字段值
value = field.default
- if isinstance(value, str):
- toml_str += f'{field_name} = "{value}"\n'
- elif isinstance(value, bool):
- toml_str += f"{field_name} = {str(value).lower()}\n"
- else:
- toml_str += f"{field_name} = {value}\n"
+ toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
toml_str += "\n"
toml_str += "\n"
@@ -429,19 +440,7 @@ class PluginBase(ABC):
# 添加字段值(使用迁移后的值)
value = section_data.get(field_name, field.default)
- if isinstance(value, str):
- toml_str += f'{field_name} = "{value}"\n'
- elif isinstance(value, bool):
- toml_str += f"{field_name} = {str(value).lower()}\n"
- elif isinstance(value, list):
- # 格式化列表
- if all(isinstance(item, str) for item in value):
- formatted_list = "[" + ", ".join(f'"{item}"' for item in value) + "]"
- else:
- formatted_list = str(value)
- toml_str += f"{field_name} = {formatted_list}\n"
- else:
- toml_str += f"{field_name} = {value}\n"
+ toml_str += f"{field_name} = {self._format_toml_value(value)}\n"
toml_str += "\n"
toml_str += "\n"