Ruff Fix & format

pull/1389/head
墨梓柒 2025-11-29 14:38:42 +08:00
parent d7932595e8
commit 3935ce817e
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
31 changed files with 678 additions and 684 deletions

7
bot.py
View File

@ -78,6 +78,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
# 关闭 WebUI 服务器 # 关闭 WebUI 服务器
try: try:
from src.webui.webui_server import get_webui_server from src.webui.webui_server import get_webui_server
webui_server = get_webui_server() webui_server = get_webui_server()
if webui_server and webui_server._server: if webui_server and webui_server._server:
await webui_server.shutdown() await webui_server.shutdown()
@ -236,15 +237,15 @@ if __name__ == "__main__":
except KeyboardInterrupt: except KeyboardInterrupt:
logger.warning("收到中断信号,正在优雅关闭...") logger.warning("收到中断信号,正在优雅关闭...")
# 取消主任务 # 取消主任务
if 'main_tasks' in locals() and main_tasks and not main_tasks.done(): if "main_tasks" in locals() and main_tasks and not main_tasks.done():
main_tasks.cancel() main_tasks.cancel()
try: try:
loop.run_until_complete(main_tasks) loop.run_until_complete(main_tasks)
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# 执行优雅关闭 # 执行优雅关闭
if loop and not loop.is_closed(): if loop and not loop.is_closed():
try: try:

View File

@ -235,13 +235,13 @@ class BrainChatting:
if recent_messages_list is None: if recent_messages_list is None:
recent_messages_list = [] recent_messages_list = []
_reply_text = "" # 初始化reply_text变量避免UnboundLocalError _reply_text = "" # 初始化reply_text变量避免UnboundLocalError
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
# ReflectTracker Check # ReflectTracker Check
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答 # 在每次回复前检查一次上下文,看是否有反思问题得到了解答
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
from src.express.reflect_tracker import reflect_tracker_manager from src.express.reflect_tracker import reflect_tracker_manager
tracker = reflect_tracker_manager.get_tracker(self.stream_id) tracker = reflect_tracker_manager.get_tracker(self.stream_id)
if tracker: if tracker:
resolved = await tracker.trigger_tracker() resolved = await tracker.trigger_tracker()
@ -254,6 +254,7 @@ class BrainChatting:
# 检查是否需要提问表达反思 # 检查是否需要提问表达反思
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
from src.express.expression_reflector import expression_reflector_manager from src.express.expression_reflector import expression_reflector_manager
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
asyncio.create_task(reflector.check_and_ask()) asyncio.create_task(reflector.check_and_ask())

View File

@ -400,7 +400,7 @@ class HeartFChatting:
# ReflectTracker Check # ReflectTracker Check
# 在每次回复前检查一次上下文,看是否有反思问题得到了解答 # 在每次回复前检查一次上下文,看是否有反思问题得到了解答
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
await reflector.check_and_ask() await reflector.check_and_ask()
tracker = reflect_tracker_manager.get_tracker(self.stream_id) tracker = reflect_tracker_manager.get_tracker(self.stream_id)
@ -410,7 +410,6 @@ class HeartFChatting:
reflect_tracker_manager.remove_tracker(self.stream_id) reflect_tracker_manager.remove_tracker(self.stream_id)
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.") logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
start_time = time.time() start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
@ -427,7 +426,9 @@ class HeartFChatting:
# asyncio.create_task(self.chat_history_summarizer.process()) # asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle() cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})") logger.info(
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
)
# 第一步:动作检查 # 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {} available_actions: Dict[str, ActionInfo] = {}

View File

@ -25,6 +25,7 @@ def get_webui_chat_broadcaster():
if _webui_chat_broadcaster is None: if _webui_chat_broadcaster is None:
try: try:
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM) _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
except ImportError: except ImportError:
_webui_chat_broadcaster = (None, None) _webui_chat_broadcaster = (None, None)
@ -43,26 +44,28 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
# WebUI 聊天室消息,通过 WebSocket 广播 # WebUI 聊天室消息,通过 WebSocket 广播
import time import time
from src.config.config import global_config from src.config.config import global_config
await chat_manager.broadcast({ await chat_manager.broadcast(
"type": "bot_message", {
"content": message.processed_plain_text, "type": "bot_message",
"message_type": "text", "content": message.processed_plain_text,
"timestamp": time.time(), "message_type": "text",
"sender": { "timestamp": time.time(),
"name": global_config.bot.nickname, "sender": {
"avatar": None, "name": global_config.bot.nickname,
"is_bot": True, "avatar": None,
"is_bot": True,
},
} }
}) )
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库 # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
# 无需手动保存 # 无需手动保存
if show_log: if show_log:
logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室") logger.info(f"已将消息 '{message_preview}' 发往 WebUI 聊天室")
return True return True
# 直接调用API发送消息 # 直接调用API发送消息
await get_global_api().send_message(message) await get_global_api().send_message(message)
if show_log: if show_log:

View File

@ -181,8 +181,12 @@ class ActionPlanner:
found_ids = set(matches) found_ids = set(matches)
missing_ids = found_ids - available_ids missing_ids = found_ids - available_ids
if missing_ids: if missing_ids:
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...") logger.info(
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用其中{len(found_ids & available_ids)}个在上下文中") f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
)
logger.info(
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用其中{len(found_ids & available_ids)}个在上下文中"
)
def _replace(match: re.Match[str]) -> str: def _replace(match: re.Match[str]) -> str:
msg_id = match.group(0) msg_id = match.group(0)
@ -234,17 +238,11 @@ class ActionPlanner:
target_message = message_id_list[-1][1] target_message = message_id_list[-1][1]
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id使用最新消息作为target_message") logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id使用最新消息作为target_message")
if ( if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
action != "no_reply"
and target_message is not None
and self._is_message_from_self(target_message)
):
logger.info( logger.info(
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply" f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
) )
reasoning = ( reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
)
action = "no_reply" action = "no_reply"
target_message = None target_message = None
@ -295,10 +293,9 @@ class ActionPlanner:
def _is_message_from_self(self, message: "DatabaseMessages") -> bool: def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
"""判断消息是否由机器人自身发送""" """判断消息是否由机器人自身发送"""
try: try:
return ( return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
str(message.user_info.user_id) == str(global_config.bot.qq_account) message.user_info.platform or ""
and (message.user_info.platform or "") == (global_config.bot.platform or "") ) == (global_config.bot.platform or "")
)
except AttributeError: except AttributeError:
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段") logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
return False return False
@ -780,20 +777,20 @@ class ActionPlanner:
json_content_start = json_start_pos + 7 # ```json的长度 json_content_start = json_start_pos + 7 # ```json的长度
# 提取从```json之后到内容结尾的所有内容 # 提取从```json之后到内容结尾的所有内容
incomplete_json_str = content[json_content_start:].strip() incomplete_json_str = content[json_content_start:].strip()
# 提取JSON之前的内容作为推理文本 # 提取JSON之前的内容作为推理文本
if json_start_pos > 0: if json_start_pos > 0:
reasoning_content = content[:json_start_pos].strip() reasoning_content = content[:json_start_pos].strip()
reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE) reasoning_content = re.sub(r"^//\s*", "", reasoning_content, flags=re.MULTILINE)
reasoning_content = reasoning_content.strip() reasoning_content = reasoning_content.strip()
if incomplete_json_str: if incomplete_json_str:
try: try:
# 清理可能的注释和格式问题 # 清理可能的注释和格式问题
json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str) json_str = re.sub(r"//.*?\n", "\n", incomplete_json_str)
json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL) json_str = re.sub(r"/\*.*?\*/", "", json_str, flags=re.DOTALL)
json_str = json_str.strip() json_str = json_str.strip()
if json_str: if json_str:
# 尝试按行分割每行可能是一个JSON对象 # 尝试按行分割每行可能是一个JSON对象
lines = [line.strip() for line in json_str.split("\n") if line.strip()] lines = [line.strip() for line in json_str.split("\n") if line.strip()]
@ -808,7 +805,7 @@ class ActionPlanner:
json_objects.append(item) json_objects.append(item)
except json.JSONDecodeError: except json.JSONDecodeError:
pass pass
# 如果按行解析没有成功尝试将整个块作为一个JSON对象或数组 # 如果按行解析没有成功尝试将整个块作为一个JSON对象或数组
if not json_objects: if not json_objects:
try: try:

View File

@ -959,7 +959,7 @@ async def build_anonymous_messages(messages: List[DatabaseMessages], show_ids: b
header = f"[{i + 1}] {anon_name}" header = f"[{i + 1}] {anon_name}"
else: else:
header = f"{anon_name}" header = f"{anon_name}"
output_lines.append(header) output_lines.append(header)
stripped_line = content.strip() stripped_line = content.strip()
if stripped_line: if stripped_line:

View File

@ -130,12 +130,10 @@ class ImageManager:
try: try:
# 清理Images表中type为emoji的记录 # 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute() deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录 # 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ( deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
total_deleted = deleted_images + deleted_descriptions total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0: if total_deleted > 0:
logger.info( logger.info(
@ -194,10 +192,14 @@ class ImageManager:
if cache_record: if cache_record:
# 优先使用情感标签,如果没有则使用详细描述 # 优先使用情感标签,如果没有则使用详细描述
if cache_record.emotion_tags: if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
return f"[表情包:{cache_record.emotion_tags}]" return f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description: elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
return f"[表情包:{cache_record.description}]" return f"[表情包:{cache_record.description}]"
except Exception as e: except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}") logger.debug(f"查询EmojiDescriptionCache时出错: {e}")

View File

@ -226,19 +226,19 @@ class ExpressionLearner:
match_responses = [] match_responses = []
try: try:
response = response.strip() response = response.strip()
# 尝试提取JSON代码块如果存在 # 尝试提取JSON代码块如果存在
json_pattern = r"```json\s*(.*?)\s*```" json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL) matches = re.findall(json_pattern, response, re.DOTALL)
if matches: if matches:
response = matches[0].strip() response = matches[0].strip()
# 移除可能的markdown代码块标记如果没有找到```json但可能有``` # 移除可能的markdown代码块标记如果没有找到```json但可能有```
if not matches: if not matches:
response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE) response = re.sub(r"^```\s*", "", response, flags=re.MULTILINE)
response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE) response = re.sub(r"```\s*$", "", response, flags=re.MULTILINE)
response = response.strip() response = response.strip()
# 检查是否已经是标准JSON数组格式 # 检查是否已经是标准JSON数组格式
if response.startswith("[") and response.endswith("]"): if response.startswith("[") and response.endswith("]"):
match_responses = json.loads(response) match_responses = json.loads(response)

View File

@ -13,21 +13,21 @@ logger = get_logger("expression_reflector")
class ExpressionReflector: class ExpressionReflector:
"""表达反思器,管理单个聊天流的表达反思提问""" """表达反思器,管理单个聊天流的表达反思提问"""
def __init__(self, chat_id: str): def __init__(self, chat_id: str):
self.chat_id = chat_id self.chat_id = chat_id
self.last_ask_time: float = 0.0 self.last_ask_time: float = 0.0
async def check_and_ask(self) -> bool: async def check_and_ask(self) -> bool:
""" """
检查是否需要提问表达反思如果需要则提问 检查是否需要提问表达反思如果需要则提问
Returns: Returns:
bool: 是否执行了提问 bool: 是否执行了提问
""" """
try: try:
logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})") logger.debug(f"[Expression Reflection] 开始检查是否需要提问 (stream_id: {self.chat_id})")
if not global_config.expression.reflect: if not global_config.expression.reflect:
logger.debug(f"[Expression Reflection] 表达反思功能未启用,跳过") logger.debug(f"[Expression Reflection] 表达反思功能未启用,跳过")
return False return False
@ -48,7 +48,7 @@ class ExpressionReflector:
allow_reflect_chat_ids.append(parsed_chat_id) allow_reflect_chat_ids.append(parsed_chat_id)
else: else:
logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}") logger.warning(f"[Expression Reflection] 无法解析 allow_reflect 配置项: {stream_config}")
if self.chat_id not in allow_reflect_chat_ids: if self.chat_id not in allow_reflect_chat_ids:
logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过") logger.info(f"[Expression Reflection] 当前聊天流 {self.chat_id} 不在允许列表中,跳过")
return False return False
@ -56,17 +56,21 @@ class ExpressionReflector:
# 检查上一次提问时间 # 检查上一次提问时间
current_time = time.time() current_time = time.time()
time_since_last_ask = current_time - self.last_ask_time time_since_last_ask = current_time - self.last_ask_time
# 5-10分钟间隔随机选择 # 5-10分钟间隔随机选择
min_interval = 10 * 60 # 5分钟 min_interval = 10 * 60 # 5分钟
max_interval = 15 * 60 # 10分钟 max_interval = 15 * 60 # 10分钟
interval = random.uniform(min_interval, max_interval) interval = random.uniform(min_interval, max_interval)
logger.info(f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask/60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval/60:.2f}分钟)") logger.info(
f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)"
)
if time_since_last_ask < interval: if time_since_last_ask < interval:
remaining_time = interval - time_since_last_ask remaining_time = interval - time_since_last_ask
logger.info(f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time/60:.2f}分钟),跳过") logger.info(
f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过"
)
return False return False
# 检查是否已经有针对该 Operator 的 Tracker 在运行 # 检查是否已经有针对该 Operator 的 Tracker 在运行
@ -78,21 +82,22 @@ class ExpressionReflector:
# 获取未检查的表达 # 获取未检查的表达
try: try:
logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达") logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = (Expression expressions = (
.select() Expression.select().where((Expression.checked == False) & (Expression.rejected == False)).limit(50)
.where((Expression.checked == False) & (Expression.rejected == False)) )
.limit(50))
expr_list = list(expressions) expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
if not expr_list: if not expr_list:
logger.info(f"[Expression Reflection] 没有可用的表达,跳过") logger.info(f"[Expression Reflection] 没有可用的表达,跳过")
return False return False
target_expr: Expression = random.choice(expr_list) target_expr: Expression = random.choice(expr_list)
logger.info(f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}") logger.info(
f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}"
)
# 生成询问文本 # 生成询问文本
ask_text = _generate_ask_text(target_expr) ask_text = _generate_ask_text(target_expr)
if not ask_text: if not ask_text:
@ -102,31 +107,33 @@ class ExpressionReflector:
logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问") logger.info(f"[Expression Reflection] 准备向 Operator {operator_config} 发送提问")
# 发送给 Operator # 发送给 Operator
await _send_to_operator(operator_config, ask_text, target_expr) await _send_to_operator(operator_config, ask_text, target_expr)
# 更新上一次提问时间 # 更新上一次提问时间
self.last_ask_time = current_time self.last_ask_time = current_time
logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}") logger.info(f"[Expression Reflection] 提问成功,已更新上次提问时间为 {current_time:.2f}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
except Exception as e: except Exception as e:
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
class ExpressionReflectorManager: class ExpressionReflectorManager:
"""表达反思管理器,管理多个聊天流的表达反思实例""" """表达反思管理器,管理多个聊天流的表达反思实例"""
def __init__(self): def __init__(self):
self.reflectors: Dict[str, ExpressionReflector] = {} self.reflectors: Dict[str, ExpressionReflector] = {}
def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector: def get_or_create_reflector(self, chat_id: str) -> ExpressionReflector:
"""获取或创建指定聊天流的表达反思实例""" """获取或创建指定聊天流的表达反思实例"""
if chat_id not in self.reflectors: if chat_id not in self.reflectors:
@ -141,6 +148,7 @@ expression_reflector_manager = ExpressionReflectorManager()
async def _check_tracker_exists(operator_config: str) -> bool: async def _check_tracker_exists(operator_config: str) -> bool:
"""检查指定 Operator 是否已有活跃的 Tracker""" """检查指定 Operator 是否已有活跃的 Tracker"""
from src.express.reflect_tracker import reflect_tracker_manager from src.express.reflect_tracker import reflect_tracker_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = None chat_stream = None
@ -150,12 +158,12 @@ async def _check_tracker_exists(operator_config: str) -> bool:
platform = parts[0] platform = parts[0]
id_str = parts[1] id_str = parts[1]
stream_type = parts[2] stream_type = parts[2]
user_info = None user_info = None
group_info = None group_info = None
from maim_message import UserInfo, GroupInfo from maim_message import UserInfo, GroupInfo
if stream_type == "group": if stream_type == "group":
group_info = GroupInfo(group_id=id_str, platform=platform) group_info = GroupInfo(group_id=id_str, platform=platform)
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform) user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
@ -203,12 +211,12 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
platform = parts[0] platform = parts[0]
id_str = parts[1] id_str = parts[1]
stream_type = parts[2] stream_type = parts[2]
user_info = None user_info = None
group_info = None group_info = None
from maim_message import UserInfo, GroupInfo from maim_message import UserInfo, GroupInfo
if stream_type == "group": if stream_type == "group":
group_info = GroupInfo(group_id=id_str, platform=platform) group_info = GroupInfo(group_id=id_str, platform=platform)
user_info = UserInfo(user_id="system", user_nickname="System", platform=platform) user_info = UserInfo(user_id="system", user_nickname="System", platform=platform)
@ -232,20 +240,13 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
return return
stream_id = chat_stream.stream_id stream_id = chat_stream.stream_id
# 注册 Tracker # 注册 Tracker
from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager from src.express.reflect_tracker import ReflectTracker, reflect_tracker_manager
tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time()) tracker = ReflectTracker(chat_stream=chat_stream, expression=expr, created_time=time.time())
reflect_tracker_manager.add_tracker(stream_id, tracker) reflect_tracker_manager.add_tracker(stream_id, tracker)
# 发送消息 # 发送消息
await send_api.text_to_stream( await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True)
text=text,
stream_id=stream_id,
typing=True
)
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}") logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")

View File

@ -17,21 +17,20 @@ if TYPE_CHECKING:
logger = get_logger("reflect_tracker") logger = get_logger("reflect_tracker")
class ReflectTracker: class ReflectTracker:
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float): def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
self.chat_stream = chat_stream self.chat_stream = chat_stream
self.expression = expression self.expression = expression
self.created_time = created_time self.created_time = created_time
# self.message_count = 0 # Replaced by checking message list length # self.message_count = 0 # Replaced by checking message list length
self.last_check_msg_count = 0 self.last_check_msg_count = 0
self.max_message_count = 30 self.max_message_count = 30
self.max_duration = 15 * 60 # 15 minutes self.max_duration = 15 * 60 # 15 minutes
# LLM for judging response # LLM for judging response
self.judge_model = LLMRequest( self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
model_set=model_config.model_task_config.utils, request_type="reflect.tracker"
)
self._init_prompts() self._init_prompts()
def _init_prompts(self): def _init_prompts(self):
@ -72,16 +71,16 @@ class ReflectTracker:
if time.time() - self.created_time > self.max_duration: if time.time() - self.created_time > self.max_duration:
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).") logger.info(f"ReflectTracker for expr {self.expression.id} timed out (duration).")
return True return True
# Fetch messages since creation # Fetch messages since creation
msg_list = get_raw_msg_by_timestamp_with_chat( msg_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_stream.stream_id, chat_id=self.chat_stream.stream_id,
timestamp_start=self.created_time, timestamp_start=self.created_time,
timestamp_end=time.time(), timestamp_end=time.time(),
) )
current_msg_count = len(msg_list) current_msg_count = len(msg_list)
# Check message limit # Check message limit
if current_msg_count > self.max_message_count: if current_msg_count > self.max_message_count:
logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).") logger.info(f"ReflectTracker for expr {self.expression.id} timed out (message count).")
@ -90,9 +89,9 @@ class ReflectTracker:
# If no new messages since last check, skip # If no new messages since last check, skip
if current_msg_count <= self.last_check_msg_count: if current_msg_count <= self.last_check_msg_count:
return False return False
self.last_check_msg_count = current_msg_count self.last_check_msg_count = current_msg_count
# Build context block # Build context block
# Use simple readable format # Use simple readable format
context_block = build_readable_messages( context_block = build_readable_messages(
@ -109,78 +108,83 @@ class ReflectTracker:
"reflect_judge_prompt", "reflect_judge_prompt",
situation=self.expression.situation, situation=self.expression.situation,
style=self.expression.style, style=self.expression.style,
context_block=context_block context_block=context_block,
) )
logger.info(f"ReflectTracker LLM Prompt: {prompt}") logger.info(f"ReflectTracker LLM Prompt: {prompt}")
response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1) response, _ = await self.judge_model.generate_response_async(prompt, temperature=0.1)
logger.info(f"ReflectTracker LLM Response: {response}") logger.info(f"ReflectTracker LLM Response: {response}")
# Parse JSON # Parse JSON
import json import json
import re import re
from json_repair import repair_json from json_repair import repair_json
json_pattern = r"```json\s*(.*?)\s*```" json_pattern = r"```json\s*(.*?)\s*```"
matches = re.findall(json_pattern, response, re.DOTALL) matches = re.findall(json_pattern, response, re.DOTALL)
if not matches: if not matches:
# Try to parse raw response if no code block # Try to parse raw response if no code block
matches = [response] matches = [response]
json_obj = json.loads(repair_json(matches[0])) json_obj = json.loads(repair_json(matches[0]))
judgment = json_obj.get("judgment") judgment = json_obj.get("judgment")
if judgment == "Approve": if judgment == "Approve":
self.expression.checked = True self.expression.checked = True
self.expression.rejected = False self.expression.rejected = False
self.expression.save() self.expression.save()
logger.info(f"Expression {self.expression.id} approved by operator.") logger.info(f"Expression {self.expression.id} approved by operator.")
return True return True
elif judgment == "Reject": elif judgment == "Reject":
self.expression.checked = True self.expression.checked = True
corrected_situation = json_obj.get("corrected_situation") corrected_situation = json_obj.get("corrected_situation")
corrected_style = json_obj.get("corrected_style") corrected_style = json_obj.get("corrected_style")
# 检查是否有更新 # 检查是否有更新
has_update = bool(corrected_situation or corrected_style) has_update = bool(corrected_situation or corrected_style)
if corrected_situation: if corrected_situation:
self.expression.situation = corrected_situation self.expression.situation = corrected_situation
if corrected_style: if corrected_style:
self.expression.style = corrected_style self.expression.style = corrected_style
# 如果拒绝但未更新,标记为 rejected=1 # 如果拒绝但未更新,标记为 rejected=1
if not has_update: if not has_update:
self.expression.rejected = True self.expression.rejected = True
else: else:
self.expression.rejected = False self.expression.rejected = False
self.expression.save() self.expression.save()
if has_update: if has_update:
logger.info(f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}") logger.info(
f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}"
)
else: else:
logger.info(f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1.") logger.info(
f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1."
)
return True return True
elif judgment == "Ignore": elif judgment == "Ignore":
logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.") logger.info(f"ReflectTracker for expr {self.expression.id} judged as Ignore.")
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error in ReflectTracker check: {e}") logger.error(f"Error in ReflectTracker check: {e}")
return False return False
return False return False
# Global manager for trackers # Global manager for trackers
class ReflectTrackerManager: class ReflectTrackerManager:
def __init__(self): def __init__(self):
self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker self.trackers: Dict[str, ReflectTracker] = {} # chat_id -> tracker
def add_tracker(self, chat_id: str, tracker: ReflectTracker): def add_tracker(self, chat_id: str, tracker: ReflectTracker):
self.trackers[chat_id] = tracker self.trackers[chat_id] = tracker
@ -192,5 +196,5 @@ class ReflectTrackerManager:
if chat_id in self.trackers: if chat_id in self.trackers:
del self.trackers[chat_id] del self.trackers[chat_id]
reflect_tracker_manager = ReflectTrackerManager()
reflect_tracker_manager = ReflectTrackerManager()

View File

@ -315,7 +315,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息") logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化 # 更新批次后持久化
self._persist_topic_cache() self._persist_topic_cache()
else: else:
@ -361,9 +363,7 @@ class ChatHistorySummarizer:
else: else:
time_str = f"{time_since_last_check / 3600:.1f}小时" time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.info( logger.info(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
# 检查“话题检查”触发条件 # 检查“话题检查”触发条件
should_check = False should_check = False
@ -413,7 +413,7 @@ class ChatHistorySummarizer:
# 说明 bot 没有参与这段对话,不应该记录 # 说明 bot 没有参与这段对话,不应该记录
bot_user_id = str(global_config.bot.qq_account) bot_user_id = str(global_config.bot.qq_account)
has_bot_message = False has_bot_message = False
for msg in messages: for msg in messages:
if msg.user_info.user_id == bot_user_id: if msg.user_info.user_id == bot_user_id:
has_bot_message = True has_bot_message = True
@ -426,7 +426,9 @@ class ChatHistorySummarizer:
return return
# 2. 构造编号后的消息字符串和参与者信息 # 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices # 3. 调用 LLM 识别话题,并得到 topic -> indices
existing_topics = list(self.topic_cache.keys()) existing_topics = list(self.topic_cache.keys())
@ -588,9 +590,7 @@ class ChatHistorySummarizer:
if not numbered_lines: if not numbered_lines:
return False, {} return False, {}
history_topics_block = ( history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
messages_block = "\n".join(numbered_lines) messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
@ -607,6 +607,7 @@ class ChatHistorySummarizer:
) )
import re import re
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
@ -895,4 +896,3 @@ class ChatHistorySummarizer:
init_prompt() init_prompt()

View File

@ -44,9 +44,7 @@ class JargonExplainer:
request_type="jargon.explain", request_type="jargon.explain",
) )
def match_jargon_from_messages( def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
self, messages: List[Any]
) -> List[Dict[str, str]]:
""" """
通过直接匹配数据库中的jargon字符串来提取黑话 通过直接匹配数据库中的jargon字符串来提取黑话
@ -57,7 +55,7 @@ class JargonExplainer:
List[Dict[str, str]]: 提取到的黑话列表每个元素包含content List[Dict[str, str]]: 提取到的黑话列表每个元素包含content
""" """
start_time = time.time() start_time = time.time()
if not messages: if not messages:
return [] return []
@ -67,8 +65,10 @@ class JargonExplainer:
# 跳过机器人自己的消息 # 跳过机器人自己的消息
if is_bot_message(msg): if is_bot_message(msg):
continue continue
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip() msg_text = (
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
).strip()
if msg_text: if msg_text:
message_texts.append(msg_text) message_texts.append(msg_text)
@ -79,9 +79,7 @@ class JargonExplainer:
combined_text = " ".join(message_texts) combined_text = " ".join(message_texts)
# 查询所有有meaning的jargon记录 # 查询所有有meaning的jargon记录
query = Jargon.select().where( query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
# 根据all_global配置决定查询逻辑 # 根据all_global配置决定查询逻辑
if global_config.jargon.all_global: if global_config.jargon.all_global:
@ -98,7 +96,7 @@ class JargonExplainer:
# 执行查询并匹配 # 执行查询并匹配
matched_jargon: Dict[str, Dict[str, str]] = {} matched_jargon: Dict[str, Dict[str, str]] = {}
query_time = time.time() query_time = time.time()
for jargon in query: for jargon in query:
content = jargon.content or "" content = jargon.content or ""
if not content or not content.strip(): if not content or not content.strip():
@ -123,13 +121,13 @@ class JargonExplainer:
pattern = re.escape(content) pattern = re.escape(content)
# 使用单词边界或中文字符边界来匹配,避免部分匹配 # 使用单词边界或中文字符边界来匹配,避免部分匹配
# 对于中文使用Unicode字符类对于英文使用单词边界 # 对于中文使用Unicode字符类对于英文使用单词边界
if re.search(r'[\u4e00-\u9fff]', content): if re.search(r"[\u4e00-\u9fff]", content):
# 包含中文,使用更宽松的匹配 # 包含中文,使用更宽松的匹配
search_pattern = pattern search_pattern = pattern
else: else:
# 纯英文/数字,使用单词边界 # 纯英文/数字,使用单词边界
search_pattern = r'\b' + pattern + r'\b' search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, combined_text, re.IGNORECASE): if re.search(search_pattern, combined_text, re.IGNORECASE):
# 找到匹配,记录(去重) # 找到匹配,记录(去重)
if content not in matched_jargon: if content not in matched_jargon:
@ -139,7 +137,7 @@ class JargonExplainer:
total_time = match_time - start_time total_time = match_time - start_time
query_duration = query_time - start_time query_duration = query_time - start_time
match_duration = match_time - query_time match_duration = match_time - query_time
logger.info( logger.info(
f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, " f"黑话匹配完成: 查询耗时 {query_duration:.3f}s, 匹配耗时 {match_duration:.3f}s, "
f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话" f"总耗时 {total_time:.3f}s, 匹配到 {len(matched_jargon)} 个黑话"
@ -147,9 +145,7 @@ class JargonExplainer:
return list(matched_jargon.values()) return list(matched_jargon.values())
async def explain_jargon( async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
self, messages: List[Any], chat_context: str
) -> Optional[str]:
""" """
解释上下文中的黑话 解释上下文中的黑话
@ -183,7 +179,7 @@ class JargonExplainer:
jargon_explanations: List[str] = [] jargon_explanations: List[str] = []
for entry in jargon_list: for entry in jargon_list:
content = entry["content"] content = entry["content"]
# 根据是否开启全局黑话,决定查询方式 # 根据是否开启全局黑话,决定查询方式
if global_config.jargon.all_global: if global_config.jargon.all_global:
# 开启全局黑话查询所有is_global=True的记录 # 开启全局黑话查询所有is_global=True的记录
@ -239,9 +235,7 @@ class JargonExplainer:
return summary return summary
async def explain_jargon_in_context( async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
chat_id: str, messages: List[Any], chat_context: str
) -> Optional[str]:
""" """
解释上下文中的黑话便捷函数 解释上下文中的黑话便捷函数
@ -255,4 +249,3 @@ async def explain_jargon_in_context(
""" """
explainer = JargonExplainer(chat_id) explainer = JargonExplainer(chat_id)
return await explainer.explain_jargon(messages, chat_context) return await explainer.explain_jargon(messages, chat_context)

View File

@ -17,20 +17,18 @@ from src.chat.utils.chat_message_builder import (
) )
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.jargon.jargon_utils import ( from src.jargon.jargon_utils import (
is_bot_message, is_bot_message,
build_context_paragraph, build_context_paragraph,
contains_bot_self_name, contains_bot_self_name,
parse_chat_id_list, parse_chat_id_list,
chat_id_list_contains, chat_id_list_contains,
update_chat_id_list update_chat_id_list,
) )
logger = get_logger("jargon") logger = get_logger("jargon")
def _init_prompt() -> None: def _init_prompt() -> None:
prompt_str = """ prompt_str = """
**聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID** **聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID**
@ -126,7 +124,6 @@ _init_prompt()
_init_inference_prompts() _init_inference_prompts()
def _should_infer_meaning(jargon_obj: Jargon) -> bool: def _should_infer_meaning(jargon_obj: Jargon) -> bool:
""" """
判断是否需要进行含义推断 判断是否需要进行含义推断
@ -211,7 +208,9 @@ class JargonMiner:
processed_pairs = set() processed_pairs = set()
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip() msg_text = (
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
).strip()
if not msg_text or is_bot_message(msg): if not msg_text or is_bot_message(msg):
continue continue
@ -270,7 +269,7 @@ class JargonMiner:
prompt1 = await global_prompt_manager.format_prompt( prompt1 = await global_prompt_manager.format_prompt(
"jargon_inference_with_context_prompt", "jargon_inference_with_context_prompt",
content=content, content=content,
bot_name = global_config.bot.nickname, bot_name=global_config.bot.nickname,
raw_content_list=raw_content_text, raw_content_list=raw_content_text,
) )
@ -588,7 +587,6 @@ class JargonMiner:
content = entry["content"] content = entry["content"]
raw_content_list = entry["raw_content"] # 已经是列表 raw_content_list = entry["raw_content"] # 已经是列表
try: try:
# 查询所有content匹配的记录 # 查询所有content匹配的记录
query = Jargon.select().where(Jargon.content == content) query = Jargon.select().where(Jargon.content == content)
@ -782,15 +780,15 @@ def search_jargon(
# 如果记录是is_global=True或者chat_id列表包含目标chat_id则包含 # 如果记录是is_global=True或者chat_id列表包含目标chat_id则包含
if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id): if not jargon.is_global and not chat_id_list_contains(chat_id_list, chat_id):
continue continue
# 只返回有meaning的记录 # 只返回有meaning的记录
if not jargon.meaning or jargon.meaning.strip() == "": if not jargon.meaning or jargon.meaning.strip() == "":
continue continue
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""}) results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
# 达到限制数量后停止 # 达到限制数量后停止
if len(results) >= limit: if len(results) >= limit:
break break
return results return results

View File

@ -13,19 +13,20 @@ from src.chat.utils.utils import parse_platform_accounts
logger = get_logger("jargon") logger = get_logger("jargon")
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
""" """
解析chat_id字段兼容旧格式字符串和新格式JSON列表 解析chat_id字段兼容旧格式字符串和新格式JSON列表
Args: Args:
chat_id_value: 可能是字符串旧格式或JSON字符串新格式 chat_id_value: 可能是字符串旧格式或JSON字符串新格式
Returns: Returns:
List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表 List[List[Any]]: 格式为 [[chat_id, count], ...] 的列表
""" """
if not chat_id_value: if not chat_id_value:
return [] return []
# 如果是字符串尝试解析为JSON # 如果是字符串尝试解析为JSON
if isinstance(chat_id_value, str): if isinstance(chat_id_value, str):
# 尝试解析JSON # 尝试解析JSON
@ -54,12 +55,12 @@ def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]: def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, increment: int = 1) -> List[List[Any]]:
""" """
更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目 更新chat_id列表如果target_chat_id已存在则增加计数否则添加新条目
Args: Args:
chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...] chat_id_list: 当前的chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要更新或添加的chat_id target_chat_id: 要更新或添加的chat_id
increment: 增加的计数默认为1 increment: 增加的计数默认为1
Returns: Returns:
List[List[Any]]: 更新后的chat_id列表 List[List[Any]]: 更新后的chat_id列表
""" """
@ -74,22 +75,22 @@ def update_chat_id_list(chat_id_list: List[List[Any]], target_chat_id: str, incr
item.append(increment) item.append(increment)
found = True found = True
break break
if not found: if not found:
# 未找到,添加新条目 # 未找到,添加新条目
chat_id_list.append([target_chat_id, increment]) chat_id_list.append([target_chat_id, increment])
return chat_id_list return chat_id_list
def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool: def chat_id_list_contains(chat_id_list: List[List[Any]], target_chat_id: str) -> bool:
""" """
检查chat_id列表中是否包含指定的chat_id 检查chat_id列表中是否包含指定的chat_id
Args: Args:
chat_id_list: chat_id列表格式为 [[chat_id, count], ...] chat_id_list: chat_id列表格式为 [[chat_id, count], ...]
target_chat_id: 要查找的chat_id target_chat_id: 要查找的chat_id
Returns: Returns:
bool: 如果包含则返回True bool: 如果包含则返回True
""" """
@ -168,10 +169,7 @@ def is_bot_message(msg: Any) -> bool:
.strip() .strip()
.lower() .lower()
) )
user_id = ( user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "")
.strip()
)
if not platform or not user_id: if not platform or not user_id:
return False return False
@ -196,4 +194,4 @@ def is_bot_message(msg: Any) -> bool:
bot_accounts[plat] = account bot_accounts[plat] = account
bot_account = bot_accounts.get(platform) bot_account = bot_accounts.get(platform)
return bool(bot_account and user_id == bot_account) return bool(bot_account and user_id == bot_account)

View File

@ -338,8 +338,10 @@ class LLMRequest:
if e.__cause__: if e.__cause__:
original_error_type = type(e.__cause__).__name__ original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__) original_error_msg = str(e.__cause__)
original_error_info = f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" original_error_info = (
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
retry_remain -= 1 retry_remain -= 1
if retry_remain <= 0: if retry_remain <= 0:
logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}") logger.error(f"模型 '{model_info.name}' 在网络错误重试用尽后仍然失败。{original_error_info}")

View File

@ -56,7 +56,7 @@ class MainSystem:
from src.webui.webui_server import get_webui_server from src.webui.webui_server import get_webui_server
self.webui_server = get_webui_server() self.webui_server = get_webui_server()
if webui_mode == "development": if webui_mode == "development":
logger.info("📝 WebUI 开发模式已启用") logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001") logger.info("🌐 后端 API 将运行在 http://0.0.0.0:8001")
@ -66,7 +66,7 @@ class MainSystem:
logger.info("✅ WebUI 生产模式已启用") logger.info("✅ WebUI 生产模式已启用")
logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001") logger.info(f"🌐 WebUI 将运行在 http://0.0.0.0:8001")
logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build") logger.info("💡 请确保已构建前端: cd MaiBot-Dashboard && bun run build")
except Exception as e: except Exception as e:
logger.error(f"❌ 初始化 WebUI 服务器失败: {e}") logger.error(f"❌ 初始化 WebUI 服务器失败: {e}")

View File

@ -296,7 +296,6 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
if not content: if not content:
continue continue
if not global_config.jargon.all_global and not jargon.is_global: if not global_config.jargon.all_global and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id) chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id): if not chat_id_list_contains(chat_id_list, chat_id):
@ -586,9 +585,7 @@ async def _react_agent_solve_question(
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}}) step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
step["observations"] = ["从LLM输出内容中检测到found_answer"] step["observations"] = ["从LLM输出内容中检测到found_answer"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info( logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}")
f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}"
)
return True, found_answer_content, thinking_steps, False return True, found_answer_content, thinking_steps, False
if not_enough_info_reason: if not_enough_info_reason:
@ -1016,9 +1013,7 @@ async def build_memory_retrieval_prompt(
if question_results: if question_results:
retrieved_memory = "\n\n".join(question_results) retrieved_memory = "\n\n".join(question_results)
logger.info( logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆")
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else: else:
logger.debug("所有问题均未找到答案") logger.debug("所有问题均未找到答案")

View File

@ -54,7 +54,9 @@ async def search_chat_history(
if record.participants: if record.participants:
try: try:
participants_data = ( participants_data = (
json.loads(record.participants) if isinstance(record.participants, str) else record.participants json.loads(record.participants)
if isinstance(record.participants, str)
else record.participants
) )
if isinstance(participants_data, list): if isinstance(participants_data, list):
participants_list = [str(p).lower() for p in participants_data] participants_list = [str(p).lower() for p in participants_data]
@ -156,9 +158,7 @@ async def search_chat_history(
# 添加关键词 # 添加关键词
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data: if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data]) keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}") result_parts.append(f"关键词:{keywords_str}")
@ -208,9 +208,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
return "未提供有效的记忆ID" return "未提供有效的记忆ID"
# 查询记录 # 查询记录
query = ChatHistory.select().where( query = ChatHistory.select().where((ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list)))
(ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list))
)
records = list(query.order_by(ChatHistory.start_time.desc())) records = list(query.order_by(ChatHistory.start_time.desc()))
if not records: if not records:
@ -256,9 +254,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
# 添加关键词 # 添加关键词
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data: if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data]) keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}") result_parts.append(f"关键词:{keywords_str}")

View File

@ -12,12 +12,12 @@ from dataclasses import dataclass, field
class ConfigField: class ConfigField:
""" """
配置字段定义 配置字段定义
用于定义插件配置项的元数据支持类型验证UI 渲染等功能 用于定义插件配置项的元数据支持类型验证UI 渲染等功能
基础示例: 基础示例:
ConfigField(type=str, default="", description="API密钥") ConfigField(type=str, default="", description="API密钥")
完整示例: 完整示例:
ConfigField( ConfigField(
type=str, type=str,
@ -73,9 +73,9 @@ class ConfigField:
def get_ui_type(self) -> str: def get_ui_type(self) -> str:
""" """
获取 UI 控件类型 获取 UI 控件类型
如果指定了 input_type 则直接返回否则根据 type choices 自动推断 如果指定了 input_type 则直接返回否则根据 type choices 自动推断
Returns: Returns:
控件类型字符串 控件类型字符串
""" """
@ -103,7 +103,7 @@ class ConfigField:
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
转换为可序列化的字典用于 API 传输 转换为可序列化的字典用于 API 传输
Returns: Returns:
包含所有配置信息的字典 包含所有配置信息的字典
""" """
@ -139,9 +139,9 @@ class ConfigField:
class ConfigSection: class ConfigSection:
""" """
配置节定义 配置节定义
用于描述配置文件中一个 section 的元数据 用于描述配置文件中一个 section 的元数据
示例: 示例:
ConfigSection( ConfigSection(
title="API配置", title="API配置",
@ -150,6 +150,7 @@ class ConfigSection:
order=1 order=1
) )
""" """
title: str # 显示标题 title: str # 显示标题
description: Optional[str] = None # 详细描述 description: Optional[str] = None # 详细描述
icon: Optional[str] = None # 图标名称 icon: Optional[str] = None # 图标名称
@ -171,9 +172,9 @@ class ConfigSection:
class ConfigTab: class ConfigTab:
""" """
配置标签页定义 配置标签页定义
用于将多个 section 组织到一个标签页中 用于将多个 section 组织到一个标签页中
示例: 示例:
ConfigTab( ConfigTab(
id="general", id="general",
@ -182,6 +183,7 @@ class ConfigTab:
sections=["plugin", "api"] sections=["plugin", "api"]
) )
""" """
id: str # 标签页 ID id: str # 标签页 ID
title: str # 显示标题 title: str # 显示标题
sections: List[str] = field(default_factory=list) # 包含的 section 名称列表 sections: List[str] = field(default_factory=list) # 包含的 section 名称列表
@ -201,18 +203,18 @@ class ConfigTab:
} }
@dataclass @dataclass
class ConfigLayout: class ConfigLayout:
""" """
配置页面布局定义 配置页面布局定义
用于定义插件配置页面的整体布局结构 用于定义插件配置页面的整体布局结构
布局类型: 布局类型:
- "auto": 自动布局sections 作为折叠面板显示 - "auto": 自动布局sections 作为折叠面板显示
- "tabs": 标签页布局 - "tabs": 标签页布局
- "pages": 分页布局左侧导航 + 右侧内容 - "pages": 分页布局左侧导航 + 右侧内容
简单示例标签页布局: 简单示例标签页布局:
ConfigLayout( ConfigLayout(
type="tabs", type="tabs",
@ -222,9 +224,10 @@ class ConfigLayout:
] ]
) )
""" """
type: str = "auto" # 布局类型: auto, tabs, pages type: str = "auto" # 布局类型: auto, tabs, pages
tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表 tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为可序列化的字典""" """转换为可序列化的字典"""
return { return {
@ -234,37 +237,27 @@ class ConfigLayout:
def section_meta( def section_meta(
title: str, title: str, description: Optional[str] = None, icon: Optional[str] = None, collapsed: bool = False, order: int = 0
description: Optional[str] = None,
icon: Optional[str] = None,
collapsed: bool = False,
order: int = 0
) -> Union[str, ConfigSection]: ) -> Union[str, ConfigSection]:
""" """
便捷函数创建 section 元数据 便捷函数创建 section 元数据
可以在 config_section_descriptions 中使用提供比纯字符串更丰富的信息 可以在 config_section_descriptions 中使用提供比纯字符串更丰富的信息
Args: Args:
title: 显示标题 title: 显示标题
description: 详细描述 description: 详细描述
icon: 图标名称 icon: 图标名称
collapsed: 默认是否折叠 collapsed: 默认是否折叠
order: 排序权重 order: 排序权重
Returns: Returns:
ConfigSection 实例 ConfigSection 实例
示例: 示例:
config_section_descriptions = { config_section_descriptions = {
"api": section_meta("API配置", icon="cloud", order=1), "api": section_meta("API配置", icon="cloud", order=1),
"debug": section_meta("调试设置", collapsed=True, order=99), "debug": section_meta("调试设置", collapsed=True, order=99),
} }
""" """
return ConfigSection( return ConfigSection(title=title, description=description, icon=icon, collapsed=collapsed, order=order)
title=title,
description=description,
icon=icon,
collapsed=collapsed,
order=order
)

View File

@ -574,14 +574,14 @@ class PluginBase(ABC):
def get_webui_config_schema(self) -> Dict[str, Any]: def get_webui_config_schema(self) -> Dict[str, Any]:
""" """
获取 WebUI 配置 Schema 获取 WebUI 配置 Schema
返回完整的配置 schema包含 返回完整的配置 schema包含
- 插件基本信息 - 插件基本信息
- 所有 section 及其字段定义 - 所有 section 及其字段定义
- 布局配置 - 布局配置
用于 WebUI 动态生成配置表单 用于 WebUI 动态生成配置表单
Returns: Returns:
Dict: 完整的配置 schema Dict: 完整的配置 schema
""" """
@ -596,12 +596,12 @@ class PluginBase(ABC):
"sections": {}, "sections": {},
"layout": None, "layout": None,
} }
# 处理 sections # 处理 sections
for section_name, fields in self.config_schema.items(): for section_name, fields in self.config_schema.items():
if not isinstance(fields, dict): if not isinstance(fields, dict):
continue continue
section_data = { section_data = {
"name": section_name, "name": section_name,
"title": section_name, "title": section_name,
@ -611,7 +611,7 @@ class PluginBase(ABC):
"order": 0, "order": 0,
"fields": {}, "fields": {},
} }
# 获取 section 元数据 # 获取 section 元数据
section_meta = self.config_section_descriptions.get(section_name) section_meta = self.config_section_descriptions.get(section_name)
if section_meta: if section_meta:
@ -625,16 +625,16 @@ class PluginBase(ABC):
section_data["order"] = section_meta.order section_data["order"] = section_meta.order
elif isinstance(section_meta, dict): elif isinstance(section_meta, dict):
section_data.update(section_meta) section_data.update(section_meta)
# 处理字段 # 处理字段
for field_name, field_def in fields.items(): for field_name, field_def in fields.items():
if isinstance(field_def, ConfigField): if isinstance(field_def, ConfigField):
field_data = field_def.to_dict() field_data = field_def.to_dict()
field_data["name"] = field_name field_data["name"] = field_name
section_data["fields"][field_name] = field_data section_data["fields"][field_name] = field_data
schema["sections"][section_name] = section_data schema["sections"][section_name] = section_data
# 处理布局 # 处理布局
if self.config_layout: if self.config_layout:
schema["layout"] = self.config_layout.to_dict() schema["layout"] = self.config_layout.to_dict()
@ -644,15 +644,15 @@ class PluginBase(ABC):
"type": "auto", "type": "auto",
"tabs": [], "tabs": [],
} }
return schema return schema
def get_current_config_values(self) -> Dict[str, Any]: def get_current_config_values(self) -> Dict[str, Any]:
""" """
获取当前配置值 获取当前配置值
返回插件当前的配置值已从配置文件加载 返回插件当前的配置值已从配置文件加载
Returns: Returns:
Dict: 当前配置值 Dict: 当前配置值
""" """

View File

@ -25,6 +25,7 @@ WEBUI_USER_ID_PREFIX = "webui_user_"
class ChatHistoryMessage(BaseModel): class ChatHistoryMessage(BaseModel):
"""聊天历史消息""" """聊天历史消息"""
id: str id: str
type: str # 'user' | 'bot' | 'system' type: str # 'user' | 'bot' | 'system'
content: str content: str
@ -36,17 +37,17 @@ class ChatHistoryMessage(BaseModel):
class ChatHistoryManager: class ChatHistoryManager:
"""聊天历史管理器 - 使用 SQLite 数据库存储""" """聊天历史管理器 - 使用 SQLite 数据库存储"""
def __init__(self, max_messages: int = 200): def __init__(self, max_messages: int = 200):
self.max_messages = max_messages self.max_messages = max_messages
def _message_to_dict(self, msg: Messages) -> Dict[str, Any]: def _message_to_dict(self, msg: Messages) -> Dict[str, Any]:
"""将数据库消息转换为前端格式""" """将数据库消息转换为前端格式"""
# 判断是否是机器人消息 # 判断是否是机器人消息
# WebUI 用户的 user_id 以 "webui_" 开头,其他都是机器人消息 # WebUI 用户的 user_id 以 "webui_" 开头,其他都是机器人消息
user_id = msg.user_id or "" user_id = msg.user_id or ""
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX) is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
return { return {
"id": msg.message_id, "id": msg.message_id,
"type": "bot" if is_bot else "user", "type": "bot" if is_bot else "user",
@ -56,7 +57,7 @@ class ChatHistoryManager:
"sender_id": "bot" if is_bot else user_id, "sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot, "is_bot": is_bot,
} }
def get_history(self, limit: int = 50) -> List[Dict[str, Any]]: def get_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""从数据库获取最近的历史记录""" """从数据库获取最近的历史记录"""
try: try:
@ -67,25 +68,21 @@ class ChatHistoryManager:
.order_by(Messages.time.desc()) .order_by(Messages.time.desc())
.limit(limit) .limit(limit)
) )
# 转换为列表并反转(使最旧的消息在前) # 转换为列表并反转(使最旧的消息在前)
result = [self._message_to_dict(msg) for msg in messages] result = [self._message_to_dict(msg) for msg in messages]
result.reverse() result.reverse()
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录") logger.debug(f"从数据库加载了 {len(result)} 条聊天记录")
return result return result
except Exception as e: except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}") logger.error(f"从数据库加载聊天记录失败: {e}")
return [] return []
def clear_history(self) -> int: def clear_history(self) -> int:
"""清空 WebUI 聊天历史记录""" """清空 WebUI 聊天历史记录"""
try: try:
deleted = ( deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute()
Messages.delete()
.where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID)
.execute()
)
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录") logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
return deleted return deleted
except Exception as e: except Exception as e:
@ -100,31 +97,31 @@ chat_history = ChatHistoryManager()
# 存储 WebSocket 连接 # 存储 WebSocket 连接
class ChatConnectionManager: class ChatConnectionManager:
"""聊天连接管理器""" """聊天连接管理器"""
def __init__(self): def __init__(self):
self.active_connections: Dict[str, WebSocket] = {} self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射 self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
async def connect(self, websocket: WebSocket, session_id: str, user_id: str): async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
await websocket.accept() await websocket.accept()
self.active_connections[session_id] = websocket self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}") logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str): def disconnect(self, session_id: str, user_id: str):
if session_id in self.active_connections: if session_id in self.active_connections:
del self.active_connections[session_id] del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id: if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id] del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}") logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict): async def send_message(self, session_id: str, message: dict):
if session_id in self.active_connections: if session_id in self.active_connections:
try: try:
await self.active_connections[session_id].send_json(message) await self.active_connections[session_id].send_json(message)
except Exception as e: except Exception as e:
logger.error(f"发送消息失败: {e}") logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict): async def broadcast(self, message: dict):
"""广播消息给所有连接""" """广播消息给所有连接"""
for session_id in list(self.active_connections.keys()): for session_id in list(self.active_connections.keys()):
@ -135,16 +132,12 @@ chat_manager = ChatConnectionManager()
def create_message_data( def create_message_data(
content: str, content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据""" """创建符合麦麦消息格式的消息数据"""
if message_id is None: if message_id is None:
message_id = str(uuid.uuid4()) message_id = str(uuid.uuid4())
return { return {
"message_info": { "message_info": {
"platform": WEBUI_CHAT_PLATFORM, "platform": WEBUI_CHAT_PLATFORM,
@ -163,7 +156,7 @@ def create_message_data(
}, },
"additional_config": { "additional_config": {
"at_bot": is_at_bot, "at_bot": is_at_bot,
} },
}, },
"message_segment": { "message_segment": {
"type": "seglist", "type": "seglist",
@ -175,8 +168,8 @@ def create_message_data(
{ {
"type": "mention_bot", "type": "mention_bot",
"data": "1.0", "data": "1.0",
} },
] ],
}, },
"raw_message": content, "raw_message": content,
"processed_plain_text": content, "processed_plain_text": content,
@ -186,10 +179,10 @@ def create_message_data(
@router.get("/history") @router.get("/history")
async def get_chat_history( async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200), limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None) # 保留参数兼容性,但不用于过滤 user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
): ):
"""获取聊天历史记录 """获取聊天历史记录
所有 WebUI 用户共享同一个聊天室因此返回所有历史记录 所有 WebUI 用户共享同一个聊天室因此返回所有历史记录
""" """
history = chat_history.get_history(limit) history = chat_history.get_history(limit)
@ -217,76 +210,87 @@ async def websocket_chat(
user_name: Optional[str] = Query(default="WebUI用户"), user_name: Optional[str] = Query(default="WebUI用户"),
): ):
"""WebSocket 聊天端点 """WebSocket 聊天端点
Args: Args:
user_id: 用户唯一标识由前端生成并持久化 user_id: 用户唯一标识由前端生成并持久化
user_name: 用户显示昵称可修改 user_name: 用户显示昵称可修改
""" """
# 生成会话 ID每次连接都是新的 # 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
# 如果没有提供 user_id生成一个新的 # 如果没有提供 user_id生成一个新的
if not user_id: if not user_id:
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}" user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
elif not user_id.startswith(WEBUI_USER_ID_PREFIX): elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
# 确保 user_id 有正确的前缀 # 确保 user_id 有正确的前缀
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}" user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
await chat_manager.connect(websocket, session_id, user_id) await chat_manager.connect(websocket, session_id, user_id)
try: try:
# 发送会话信息(包含用户 ID前端需要保存 # 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, { await chat_manager.send_message(
"type": "session_info", session_id,
"session_id": session_id, {
"user_id": user_id, "type": "session_info",
"user_name": user_name, "session_id": session_id,
"bot_name": global_config.bot.nickname, "user_id": user_id,
}) "user_name": user_name,
"bot_name": global_config.bot.nickname,
},
)
# 发送历史记录 # 发送历史记录
history = chat_history.get_history(50) history = chat_history.get_history(50)
if history: if history:
await chat_manager.send_message(session_id, { await chat_manager.send_message(
"type": "history", session_id,
"messages": history, {
}) "type": "history",
"messages": history,
},
)
# 发送欢迎消息(不保存到历史) # 发送欢迎消息(不保存到历史)
await chat_manager.send_message(session_id, { await chat_manager.send_message(
"type": "system", session_id,
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!", {
"timestamp": time.time(), "type": "system",
}) "content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
"timestamp": time.time(),
},
)
while True: while True:
data = await websocket.receive_json() data = await websocket.receive_json()
if data.get("type") == "message": if data.get("type") == "message":
content = data.get("content", "").strip() content = data.get("content", "").strip()
if not content: if not content:
continue continue
# 用户可以更新昵称 # 用户可以更新昵称
current_user_name = data.get("user_name", user_name) current_user_name = data.get("user_name", user_name)
message_id = str(uuid.uuid4()) message_id = str(uuid.uuid4())
timestamp = time.time() timestamp = time.time()
# 广播用户消息给所有连接(包括发送者) # 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库 # 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast({ await chat_manager.broadcast(
"type": "user_message", {
"content": content, "type": "user_message",
"message_id": message_id, "content": content,
"timestamp": timestamp, "message_id": message_id,
"sender": { "timestamp": timestamp,
"name": current_user_name, "sender": {
"user_id": user_id, "name": current_user_name,
"is_bot": False, "user_id": user_id,
"is_bot": False,
},
} }
}) )
# 创建麦麦消息格式 # 创建麦麦消息格式
message_data = create_message_data( message_data = create_message_data(
content=content, content=content,
@ -295,46 +299,59 @@ async def websocket_chat(
message_id=message_id, message_id=message_id,
is_at_bot=True, is_at_bot=True,
) )
try: try:
# 显示正在输入状态 # 显示正在输入状态
await chat_manager.broadcast({ await chat_manager.broadcast(
"type": "typing", {
"is_typing": True, "type": "typing",
}) "is_typing": True,
}
)
# 调用麦麦的消息处理 # 调用麦麦的消息处理
await chat_bot.message_process(message_data) await chat_bot.message_process(message_data)
except Exception as e: except Exception as e:
logger.error(f"处理消息时出错: {e}") logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(session_id, { await chat_manager.send_message(
"type": "error", session_id,
"content": f"处理消息时出错: {str(e)}", {
"timestamp": time.time(), "type": "error",
}) "content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
},
)
finally: finally:
await chat_manager.broadcast({ await chat_manager.broadcast(
"type": "typing", {
"is_typing": False, "type": "typing",
}) "is_typing": False,
}
)
elif data.get("type") == "ping": elif data.get("type") == "ping":
await chat_manager.send_message(session_id, { await chat_manager.send_message(
"type": "pong", session_id,
"timestamp": time.time(), {
}) "type": "pong",
"timestamp": time.time(),
},
)
elif data.get("type") == "update_nickname": elif data.get("type") == "update_nickname":
# 允许用户更新昵称 # 允许用户更新昵称
if new_name := data.get("user_name", "").strip(): if new_name := data.get("user_name", "").strip():
current_user_name = new_name current_user_name = new_name
await chat_manager.send_message(session_id, { await chat_manager.send_message(
"type": "nickname_updated", session_id,
"user_name": current_user_name, {
"timestamp": time.time(), "type": "nickname_updated",
}) "user_name": current_user_name,
"timestamp": time.time(),
},
)
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}") logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
except Exception as e: except Exception as e:
@ -356,7 +373,7 @@ async def get_chat_info():
def get_webui_chat_broadcaster() -> tuple: def get_webui_chat_broadcaster() -> tuple:
"""获取 WebUI 聊天广播器,供外部模块使用 """获取 WebUI 聊天广播器,供外部模块使用
Returns: Returns:
(chat_manager, WEBUI_CHAT_PLATFORM) 元组 (chat_manager, WEBUI_CHAT_PLATFORM) 元组
""" """

View File

@ -5,7 +5,7 @@
import os import os
import tomlkit import tomlkit
from fastapi import APIRouter, HTTPException, Body from fastapi import APIRouter, HTTPException, Body
from typing import Any from typing import Any, Annotated
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
@ -41,6 +41,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui") logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"]) router = APIRouter(prefix="/config", tags=["config"])
@ -90,7 +96,7 @@ async def get_bot_config_schema():
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取配置架构失败: {e}") logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
@router.get("/schema/model") @router.get("/schema/model")
@ -101,7 +107,7 @@ async def get_model_config_schema():
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取模型配置架构失败: {e}") logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
# ===== 子配置架构获取接口 ===== # ===== 子配置架构获取接口 =====
@ -174,7 +180,7 @@ async def get_config_section_schema(section_name: str):
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取配置节架构失败: {e}") logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
# ===== 配置读取接口 ===== # ===== 配置读取接口 =====
@ -196,7 +202,7 @@ async def get_bot_config():
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取配置文件失败: {e}") logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.get("/model") @router.get("/model")
@ -215,21 +221,21 @@ async def get_model_config():
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取配置文件失败: {e}") logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
# ===== 配置更新接口 ===== # ===== 配置更新接口 =====
@router.post("/bot") @router.post("/bot")
async def update_bot_config(config_data: dict[str, Any] = Body(...)): async def update_bot_config(config_data: ConfigBody):
"""更新麦麦主程序配置""" """更新麦麦主程序配置"""
try: try:
# 验证配置数据 # 验证配置数据
try: try:
Config.from_dict(config_data) Config.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件 # 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml") config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -242,18 +248,18 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存配置文件失败: {e}") logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model") @router.post("/model")
async def update_model_config(config_data: dict[str, Any] = Body(...)): async def update_model_config(config_data: ConfigBody):
"""更新模型配置""" """更新模型配置"""
try: try:
# 验证配置数据 # 验证配置数据
try: try:
APIAdapterConfig.from_dict(config_data) APIAdapterConfig.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件 # 保存配置文件
config_path = os.path.join(CONFIG_DIR, "model_config.toml") config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@ -266,14 +272,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存配置文件失败: {e}") logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
# ===== 配置节更新接口 ===== # ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}") @router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)): async def update_bot_config_section(section_name: str, section_data: SectionBody):
"""更新麦麦主程序配置的指定节(保留注释和格式)""" """更新麦麦主程序配置的指定节(保留注释和格式)"""
try: try:
# 读取现有配置 # 读取现有配置
@ -304,7 +310,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
try: try:
Config.from_dict(config_data) Config.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置tomlkit.dump 会保留注释) # 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
@ -316,7 +322,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
raise raise
except Exception as e: except Exception as e:
logger.error(f"更新配置节失败: {e}") logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 原始 TOML 文件操作接口 ===== # ===== 原始 TOML 文件操作接口 =====
@ -338,24 +344,24 @@ async def get_bot_config_raw():
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取配置文件失败: {e}") logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.post("/bot/raw") @router.post("/bot/raw")
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)): async def update_bot_config_raw(raw_content: RawContentBody):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)""" """更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try: try:
# 验证 TOML 格式 # 验证 TOML 格式
try: try:
config_data = tomlkit.loads(raw_content) config_data = tomlkit.loads(raw_content)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 验证配置数据结构 # 验证配置数据结构
try: try:
Config.from_dict(config_data) Config.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件 # 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml") config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -368,11 +374,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存配置文件失败: {e}") logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model/section/{section_name}") @router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: Any = Body(...)): async def update_model_config_section(section_name: str, section_data: SectionBody):
"""更新模型配置的指定节(保留注释和格式)""" """更新模型配置的指定节(保留注释和格式)"""
try: try:
# 读取现有配置 # 读取现有配置
@ -403,7 +409,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
try: try:
APIAdapterConfig.from_dict(config_data) APIAdapterConfig.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置tomlkit.dump 会保留注释) # 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
@ -415,7 +421,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
raise raise
except Exception as e: except Exception as e:
logger.error(f"更新配置节失败: {e}") logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 适配器配置管理接口 ===== # ===== 适配器配置管理接口 =====
@ -425,11 +431,11 @@ def _normalize_adapter_path(path: str) -> str:
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)""" """将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
if not path: if not path:
return path return path
# 如果已经是绝对路径,直接返回 # 如果已经是绝对路径,直接返回
if os.path.isabs(path): if os.path.isabs(path):
return path return path
# 相对路径,转换为相对于项目根目录的绝对路径 # 相对路径,转换为相对于项目根目录的绝对路径
return os.path.normpath(os.path.join(PROJECT_ROOT, path)) return os.path.normpath(os.path.join(PROJECT_ROOT, path))
@ -438,17 +444,17 @@ def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径""" """尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path): if not path or not os.path.isabs(path):
return path return path
try: try:
# 尝试获取相对路径 # 尝试获取相对路径
rel_path = os.path.relpath(path, PROJECT_ROOT) rel_path = os.path.relpath(path, PROJECT_ROOT)
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径 # 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
if not rel_path.startswith('..'): if not rel_path.startswith(".."):
return rel_path return rel_path
except (ValueError, TypeError): except (ValueError, TypeError):
# 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError # 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError
pass pass
# 无法转换为相对路径,返回绝对路径 # 无法转换为相对路径,返回绝对路径
return path return path
@ -463,6 +469,7 @@ async def get_adapter_config_path():
return {"success": True, "path": None} return {"success": True, "path": None}
import json import json
with open(webui_data_path, "r", encoding="utf-8") as f: with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f) webui_data = json.load(f)
@ -472,10 +479,11 @@ async def get_adapter_config_path():
# 将路径规范化为绝对路径 # 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path) abs_path = _normalize_adapter_path(adapter_config_path)
# 检查文件是否存在并返回最后修改时间 # 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path): if os.path.exists(abs_path):
import datetime import datetime
mtime = os.path.getmtime(abs_path) mtime = os.path.getmtime(abs_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat() last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
# 返回相对路径(如果可能) # 返回相对路径(如果可能)
@ -487,11 +495,11 @@ async def get_adapter_config_path():
except Exception as e: except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}") logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
@router.post("/adapter-config/path") @router.post("/adapter-config/path")
async def save_adapter_config_path(data: dict[str, str] = Body(...)): async def save_adapter_config_path(data: PathBody):
"""保存适配器配置文件路径偏好""" """保存适配器配置文件路径偏好"""
try: try:
path = data.get("path") path = data.get("path")
@ -511,10 +519,10 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
# 将路径规范化为绝对路径 # 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path) abs_path = _normalize_adapter_path(path)
# 尝试转换为相对路径保存(如果文件在项目目录内) # 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path) save_path = _to_relative_path(abs_path)
# 更新路径 # 更新路径
webui_data["adapter_config_path"] = save_path webui_data["adapter_config_path"] = save_path
@ -530,7 +538,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}") logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
@router.get("/adapter-config") @router.get("/adapter-config")
@ -542,7 +550,7 @@ async def get_adapter_config(path: str):
# 将路径规范化为绝对路径 # 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path) abs_path = _normalize_adapter_path(path)
# 检查文件是否存在 # 检查文件是否存在
if not os.path.exists(abs_path): if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}") raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
@ -562,11 +570,11 @@ async def get_adapter_config(path: str):
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取适配器配置失败: {e}") logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
@router.post("/adapter-config") @router.post("/adapter-config")
async def save_adapter_config(data: dict[str, str] = Body(...)): async def save_adapter_config(data: PathBody):
"""保存适配器配置到指定路径""" """保存适配器配置到指定路径"""
try: try:
path = data.get("path") path = data.get("path")
@ -579,17 +587,16 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
# 将路径规范化为绝对路径 # 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path) abs_path = _normalize_adapter_path(path)
# 检查文件扩展名 # 检查文件扩展名
if not abs_path.endswith(".toml"): if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件") raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式 # 验证 TOML 格式
try: try:
import tomlkit
tomlkit.loads(content) tomlkit.loads(content)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 确保目录存在 # 确保目录存在
dir_path = os.path.dirname(abs_path) dir_path = os.path.dirname(abs_path)
@ -607,5 +614,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存适配器配置失败: {e}") logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e

View File

@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
if next_line.startswith('"""') or next_line.startswith("'''"): if next_line.startswith('"""') or next_line.startswith("'''"):
# 单行文档字符串 # 单行文档字符串
if next_line.count('"""') == 2 or next_line.count("'''") == 2: if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip()) description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else: else:
# 多行文档字符串 # 多行文档字符串
quote = '"""' if next_line.startswith('"""') else "'''" quote = '"""' if next_line.startswith('"""') else "'''"
@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
next_line = lines[i + 1].strip() next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"): if next_line.startswith('"""') or next_line.startswith("'''"):
if next_line.count('"""') == 2 or next_line.count("'''") == 2: if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip()) description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else: else:
quote = '"""' if next_line.startswith('"""') else "'''" quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip()) description_lines.append(next_line.strip(quote).strip())
@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
return FieldType.ARRAY, None, items return FieldType.ARRAY, None, items
# 处理基本类型 # 处理基本类型
if field_type is bool or field_type == bool: if field_type is bool:
return FieldType.BOOLEAN, None, None return FieldType.BOOLEAN, None, None
elif field_type is int or field_type == int: elif field_type is int:
return FieldType.INTEGER, None, None return FieldType.INTEGER, None, None
elif field_type is float or field_type == float: elif field_type is float:
return FieldType.NUMBER, None, None return FieldType.NUMBER, None, None
elif field_type is str or field_type == str: elif field_type is str:
return FieldType.STRING, None, None return FieldType.STRING, None, None
elif field_type is dict or origin is dict: elif field_type is dict or origin is dict:
return FieldType.OBJECT, None, None return FieldType.OBJECT, None, None

View File

@ -3,20 +3,25 @@
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List, Annotated
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Emoji from src.common.database.database_model import Emoji
from .token_manager import get_token_manager from .token_manager import get_token_manager
import json
import time import time
import os import os
import hashlib import hashlib
import base64
from PIL import Image from PIL import Image
import io import io
logger = get_logger("webui.emoji") logger = get_logger("webui.emoji")
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
DescriptionForm = Annotated[str, Form(description="表情包描述")]
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
# 创建路由器 # 创建路由器
router = APIRouter(prefix="/emoji", tags=["Emoji"]) router = APIRouter(prefix="/emoji", tags=["Emoji"])
@ -592,10 +597,10 @@ class EmojiUploadResponse(BaseModel):
@router.post("/upload", response_model=EmojiUploadResponse) @router.post("/upload", response_model=EmojiUploadResponse)
async def upload_emoji( async def upload_emoji(
file: UploadFile = File(..., description="表情包图片文件"), file: EmojiFile,
description: str = Form("", description="表情包描述"), description: DescriptionForm = "",
emotion: str = Form("", description="情感标签,多个用逗号分隔"), emotion: EmotionForm = "",
is_registered: bool = Form(True, description="是否直接注册"), is_registered: IsRegisteredForm = True,
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
): ):
""" """
@ -713,9 +718,9 @@ async def upload_emoji(
@router.post("/batch/upload") @router.post("/batch/upload")
async def batch_upload_emoji( async def batch_upload_emoji(
files: List[UploadFile] = File(..., description="多个表情包图片文件"), files: EmojiFiles,
emotion: str = Form("", description="情感标签,多个用逗号分隔"), emotion: EmotionForm = "",
is_registered: bool = Form(True, description="是否直接注册"), is_registered: IsRegisteredForm = True,
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
): ):
""" """
@ -749,11 +754,13 @@ async def batch_upload_emoji(
# 验证文件类型 # 验证文件类型
if file.content_type not in allowed_types: if file.content_type not in allowed_types:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
"filename": file.filename, {
"success": False, "filename": file.filename,
"error": f"不支持的文件类型: {file.content_type}", "success": False,
}) "error": f"不支持的文件类型: {file.content_type}",
}
)
continue continue
# 读取文件内容 # 读取文件内容
@ -761,11 +768,13 @@ async def batch_upload_emoji(
if not file_content: if not file_content:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
"filename": file.filename, {
"success": False, "filename": file.filename,
"error": "文件内容为空", "success": False,
}) "error": "文件内容为空",
}
)
continue continue
# 验证图片 # 验证图片
@ -774,11 +783,13 @@ async def batch_upload_emoji(
img_format = img.format.lower() if img.format else "png" img_format = img.format.lower() if img.format else "png"
except Exception as e: except Exception as e:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
"filename": file.filename, {
"success": False, "filename": file.filename,
"error": f"无效的图片: {str(e)}", "success": False,
}) "error": f"无效的图片: {str(e)}",
}
)
continue continue
# 计算哈希 # 计算哈希
@ -787,11 +798,13 @@ async def batch_upload_emoji(
# 检查重复 # 检查重复
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash): if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
"filename": file.filename, {
"success": False, "filename": file.filename,
"error": "已存在相同的表情包", "success": False,
}) "error": "已存在相同的表情包",
}
)
continue continue
# 生成文件名并保存 # 生成文件名并保存
@ -829,19 +842,23 @@ async def batch_upload_emoji(
) )
results["uploaded"] += 1 results["uploaded"] += 1
results["details"].append({ results["details"].append(
"filename": file.filename, {
"success": True, "filename": file.filename,
"id": emoji.id, "success": True,
}) "id": emoji.id,
}
)
except Exception as e: except Exception as e:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
"filename": file.filename, {
"success": False, "filename": file.filename,
"error": str(e), "success": False,
}) "error": str(e),
}
)
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']}" results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']}"
return results return results
@ -850,4 +867,4 @@ async def batch_upload_emoji(
raise raise
except Exception as e: except Exception as e:
logger.exception(f"批量上传表情包失败: {e}") logger.exception(f"批量上传表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e raise HTTPException(status_code=500, detail=f"批量上传失败: {str(e)}") from e

View File

@ -602,9 +602,9 @@ class GitMirrorService:
# 执行 git clone在线程池中运行以避免阻塞 # 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def run_git_clone(): def run_git_clone(clone_cmd=cmd):
return subprocess.run( return subprocess.run(
cmd, clone_cmd,
capture_output=True, capture_output=True,
text=True, text=True,
timeout=300, # 5分钟超时 timeout=300, # 5分钟超时

View File

@ -1,4 +1,5 @@
"""知识库图谱可视化 API 路由""" """知识库图谱可视化 API 路由"""
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
class KnowledgeNode(BaseModel): class KnowledgeNode(BaseModel):
"""知识节点""" """知识节点"""
id: str id: str
type: str # 'entity' or 'paragraph' type: str # 'entity' or 'paragraph'
content: str content: str
@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
class KnowledgeEdge(BaseModel): class KnowledgeEdge(BaseModel):
"""知识边""" """知识边"""
source: str source: str
target: str target: str
weight: float weight: float
@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
"""知识图谱""" """知识图谱"""
nodes: List[KnowledgeNode] nodes: List[KnowledgeNode]
edges: List[KnowledgeEdge] edges: List[KnowledgeEdge]
class KnowledgeStats(BaseModel): class KnowledgeStats(BaseModel):
"""知识库统计信息""" """知识库统计信息"""
total_nodes: int total_nodes: int
total_edges: int total_edges: int
entity_nodes: int entity_nodes: int
@ -45,7 +50,7 @@ def _load_kg_manager():
"""延迟加载 KGManager""" """延迟加载 KGManager"""
try: try:
from src.chat.knowledge.kg_manager import KGManager from src.chat.knowledge.kg_manager import KGManager
kg_manager = KGManager() kg_manager = KGManager()
kg_manager.load_from_file() kg_manager.load_from_file()
return kg_manager return kg_manager
@ -58,31 +63,26 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
"""将 DiGraph 转换为 JSON 格式""" """将 DiGraph 转换为 JSON 格式"""
if kg_manager is None or kg_manager.graph is None: if kg_manager is None or kg_manager.graph is None:
return KnowledgeGraph(nodes=[], edges=[]) return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph graph = kg_manager.graph
nodes = [] nodes = []
edges = [] edges = []
# 转换节点 # 转换节点
node_list = graph.get_node_list() node_list = graph.get_node_list()
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph" # 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data['content'] if 'content' in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
create_time = node_data['create_time'] if 'create_time' in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode( nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
id=node_id,
type=node_type,
content=content,
create_time=create_time
))
except Exception as e: except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}") logger.warning(f"跳过节点 {node_id}: {e}")
continue continue
# 转换边 # 转换边
edge_list = graph.get_edge_list() edge_list = graph.get_edge_list()
for edge_tuple in edge_list: for edge_tuple in edge_list:
@ -91,37 +91,35 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
source, target = edge_tuple[0], edge_tuple[1] source, target = edge_tuple[0], edge_tuple[1]
# 通过 graph[source, target] 获取边的属性数据 # 通过 graph[source, target] 获取边的属性数据
edge_data = graph[source, target] edge_data = graph[source, target]
# edge_data 支持 [] 操作符但不支持 .get() # edge_data 支持 [] 操作符但不支持 .get()
weight = edge_data['weight'] if 'weight' in edge_data else 1.0 weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data['create_time'] if 'create_time' in edge_data else None create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data['update_time'] if 'update_time' in edge_data else None update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(KnowledgeEdge( edges.append(
source=source, KnowledgeEdge(
target=target, source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
weight=weight, )
create_time=create_time, )
update_time=update_time
))
except Exception as e: except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}") logger.warning(f"跳过边 {edge_tuple}: {e}")
continue continue
return KnowledgeGraph(nodes=nodes, edges=edges) return KnowledgeGraph(nodes=nodes, edges=edges)
@router.get("/graph", response_model=KnowledgeGraph) @router.get("/graph", response_model=KnowledgeGraph)
async def get_knowledge_graph( async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"), limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph") node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
): ):
"""获取知识图谱(限制节点数量) """获取知识图谱(限制节点数量)
Args: Args:
limit: 返回的最大节点数,默认 100,最大 10000 limit: 返回的最大节点数,默认 100,最大 10000
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落) node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
Returns: Returns:
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱 KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
""" """
@ -130,46 +128,43 @@ async def get_knowledge_graph(
if kg_manager is None: if kg_manager is None:
logger.warning("KGManager 未初始化,返回空图谱") logger.warning("KGManager 未初始化,返回空图谱")
return KnowledgeGraph(nodes=[], edges=[]) return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph graph = kg_manager.graph
all_node_list = graph.get_node_list() all_node_list = graph.get_node_list()
# 按类型过滤节点 # 按类型过滤节点
if node_type == "entity": if node_type == "entity":
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent'] all_node_list = [
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
]
elif node_type == "paragraph": elif node_type == "paragraph":
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg'] all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
# 限制节点数量 # 限制节点数量
total_nodes = len(all_node_list) total_nodes = len(all_node_list)
if len(all_node_list) > limit: if len(all_node_list) > limit:
node_list = all_node_list[:limit] node_list = all_node_list[:limit]
else: else:
node_list = all_node_list node_list = all_node_list
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})") logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
# 转换节点 # 转换节点
nodes = [] nodes = []
node_ids = set() node_ids = set()
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data['content'] if 'content' in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
create_time = node_data['create_time'] if 'create_time' in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode( nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
id=node_id,
type=node_type_val,
content=content,
create_time=create_time
))
node_ids.add(node_id) node_ids.add(node_id)
except Exception as e: except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}") logger.warning(f"跳过节点 {node_id}: {e}")
continue continue
# 只获取涉及当前节点集的边(保证图的完整性) # 只获取涉及当前节点集的边(保证图的完整性)
edges = [] edges = []
edge_list = graph.get_edge_list() edge_list = graph.get_edge_list()
@ -179,27 +174,25 @@ async def get_knowledge_graph(
# 只包含两端都在当前节点集中的边 # 只包含两端都在当前节点集中的边
if source not in node_ids or target not in node_ids: if source not in node_ids or target not in node_ids:
continue continue
edge_data = graph[source, target] edge_data = graph[source, target]
weight = edge_data['weight'] if 'weight' in edge_data else 1.0 weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data['create_time'] if 'create_time' in edge_data else None create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data['update_time'] if 'update_time' in edge_data else None update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(KnowledgeEdge( edges.append(
source=source, KnowledgeEdge(
target=target, source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
weight=weight, )
create_time=create_time, )
update_time=update_time
))
except Exception as e: except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}") logger.warning(f"跳过边 {edge_tuple}: {e}")
continue continue
graph_data = KnowledgeGraph(nodes=nodes, edges=edges) graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边") logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
return graph_data return graph_data
except Exception as e: except Exception as e:
logger.error(f"获取知识图谱失败: {e}", exc_info=True) logger.error(f"获取知识图谱失败: {e}", exc_info=True)
return KnowledgeGraph(nodes=[], edges=[]) return KnowledgeGraph(nodes=[], edges=[])
@ -208,71 +201,59 @@ async def get_knowledge_graph(
@router.get("/stats", response_model=KnowledgeStats) @router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats(): async def get_knowledge_stats():
"""获取知识库统计信息 """获取知识库统计信息
Returns: Returns:
KnowledgeStats: 统计信息 KnowledgeStats: 统计信息
""" """
try: try:
kg_manager = _load_kg_manager() kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None: if kg_manager is None or kg_manager.graph is None:
return KnowledgeStats( return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
total_nodes=0,
total_edges=0,
entity_nodes=0,
paragraph_nodes=0,
avg_connections=0.0
)
graph = kg_manager.graph graph = kg_manager.graph
node_list = graph.get_node_list() node_list = graph.get_node_list()
edge_list = graph.get_edge_list() edge_list = graph.get_edge_list()
total_nodes = len(node_list) total_nodes = len(node_list)
total_edges = len(edge_list) total_edges = len(edge_list)
# 统计节点类型 # 统计节点类型
entity_nodes = 0 entity_nodes = 0
paragraph_nodes = 0 paragraph_nodes = 0
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
node_type = node_data['type'] if 'type' in node_data else 'ent' node_type = node_data["type"] if "type" in node_data else "ent"
if node_type == 'ent': if node_type == "ent":
entity_nodes += 1 entity_nodes += 1
elif node_type == 'pg': elif node_type == "pg":
paragraph_nodes += 1 paragraph_nodes += 1
except Exception: except Exception:
continue continue
# 计算平均连接数 # 计算平均连接数
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0 avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
return KnowledgeStats( return KnowledgeStats(
total_nodes=total_nodes, total_nodes=total_nodes,
total_edges=total_edges, total_edges=total_edges,
entity_nodes=entity_nodes, entity_nodes=entity_nodes,
paragraph_nodes=paragraph_nodes, paragraph_nodes=paragraph_nodes,
avg_connections=round(avg_connections, 2) avg_connections=round(avg_connections, 2),
) )
except Exception as e: except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True) logger.error(f"获取统计信息失败: {e}", exc_info=True)
return KnowledgeStats( return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
total_nodes=0,
total_edges=0,
entity_nodes=0,
paragraph_nodes=0,
avg_connections=0.0
)
@router.get("/search", response_model=List[KnowledgeNode]) @router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1)): async def search_knowledge_node(query: str = Query(..., min_length=1)):
"""搜索知识节点 """搜索知识节点
Args: Args:
query: 搜索关键词 query: 搜索关键词
Returns: Returns:
List[KnowledgeNode]: 匹配的节点列表 List[KnowledgeNode]: 匹配的节点列表
""" """
@ -280,33 +261,28 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
kg_manager = _load_kg_manager() kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None: if kg_manager is None or kg_manager.graph is None:
return [] return []
graph = kg_manager.graph graph = kg_manager.graph
node_list = graph.get_node_list() node_list = graph.get_node_list()
results = [] results = []
query_lower = query.lower() query_lower = query.lower()
# 在节点内容中搜索 # 在节点内容中搜索
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
content = node_data['content'] if 'content' in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
if query_lower in content.lower() or query_lower in node_id.lower(): if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data['create_time'] if 'create_time' in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
results.append(KnowledgeNode( results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
id=node_id,
type=node_type,
content=content,
create_time=create_time
))
except Exception: except Exception:
continue continue
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点") logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
return results[:50] # 限制返回数量 return results[:50] # 限制返回数量
except Exception as e: except Exception as e:
logger.error(f"搜索节点失败: {e}", exc_info=True) logger.error(f"搜索节点失败: {e}", exc_info=True)
return [] return []

View File

@ -43,25 +43,27 @@ def _normalize_url(url: str) -> str:
def _parse_openai_response(data: dict) -> list[dict]: def _parse_openai_response(data: dict) -> list[dict]:
""" """
解析 OpenAI 格式的模型列表响应 解析 OpenAI 格式的模型列表响应
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] } 格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
""" """
models = [] models = []
if "data" in data and isinstance(data["data"], list): if "data" in data and isinstance(data["data"], list):
for model in data["data"]: for model in data["data"]:
if isinstance(model, dict) and "id" in model: if isinstance(model, dict) and "id" in model:
models.append({ models.append(
"id": model["id"], {
"name": model.get("name") or model["id"], "id": model["id"],
"owned_by": model.get("owned_by", ""), "name": model.get("name") or model["id"],
}) "owned_by": model.get("owned_by", ""),
}
)
return models return models
def _parse_gemini_response(data: dict) -> list[dict]: def _parse_gemini_response(data: dict) -> list[dict]:
""" """
解析 Gemini 格式的模型列表响应 解析 Gemini 格式的模型列表响应
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] } 格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
""" """
models = [] models = []
@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
model_id = model["name"] model_id = model["name"]
if model_id.startswith("models/"): if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀 model_id = model_id[7:] # 去掉 "models/" 前缀
models.append({ models.append(
"id": model_id, {
"name": model.get("displayName") or model_id, "id": model_id,
"owned_by": "google", "name": model.get("displayName") or model_id,
}) "owned_by": "google",
}
)
return models return models
@ -89,55 +93,54 @@ async def _fetch_models_from_provider(
) -> list[dict]: ) -> list[dict]:
""" """
从提供商 API 获取模型列表 从提供商 API 获取模型列表
Args: Args:
base_url: 提供商的基础 URL base_url: 提供商的基础 URL
api_key: API 密钥 api_key: API 密钥
endpoint: 获取模型列表的端点 endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini') parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini') client_type: 客户端类型 ('openai' | 'gemini')
Returns: Returns:
模型列表 模型列表
""" """
url = f"{_normalize_url(base_url)}{endpoint}" url = f"{_normalize_url(base_url)}{endpoint}"
# 根据客户端类型设置请求头 # 根据客户端类型设置请求头
headers = {} headers = {}
params = {} params = {}
if client_type == "gemini": if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key # Gemini 使用 URL 参数传递 API Key
params["key"] = api_key params["key"] = api_key
else: else:
# OpenAI 兼容格式使用 Authorization 头 # OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}" headers["Authorization"] = f"Bearer {api_key}"
try: try:
async with httpx.AsyncClient(timeout=30.0) as client: async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers, params=params) response = await client.get(url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
except httpx.TimeoutException: except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
# 注意:使用 502 Bad Gateway 而不是原始的 401/403 # 注意:使用 502 Bad Gateway 而不是原始的 401/403
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理 # 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
if e.response.status_code == 401: if e.response.status_code == 401:
raise HTTPException(status_code=502, detail="API Key 无效或已过期") raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
elif e.response.status_code == 403: elif e.response.status_code == 403:
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
elif e.response.status_code == 404: elif e.response.status_code == 404:
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
else: else:
raise HTTPException( raise HTTPException(
status_code=502, status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}" ) from e
)
except Exception as e: except Exception as e:
logger.error(f"获取模型列表失败: {e}") logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
# 根据解析器类型解析响应 # 根据解析器类型解析响应
if parser == "openai": if parser == "openai":
return _parse_openai_response(data) return _parse_openai_response(data)
@ -150,26 +153,26 @@ async def _fetch_models_from_provider(
def _get_provider_config(provider_name: str) -> Optional[dict]: def _get_provider_config(provider_name: str) -> Optional[dict]:
""" """
model_config.toml 获取指定提供商的配置 model_config.toml 获取指定提供商的配置
Args: Args:
provider_name: 提供商名称 provider_name: 提供商名称
Returns: Returns:
提供商配置如果未找到则返回 None 提供商配置如果未找到则返回 None
""" """
config_path = os.path.join(CONFIG_DIR, "model_config.toml") config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path): if not os.path.exists(config_path):
return None return None
try: try:
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f) config_data = tomlkit.load(f)
providers = config_data.get("api_providers", []) providers = config_data.get("api_providers", [])
for provider in providers: for provider in providers:
if provider.get("name") == provider_name: if provider.get("name") == provider_name:
return dict(provider) return dict(provider)
return None return None
except Exception as e: except Exception as e:
logger.error(f"读取提供商配置失败: {e}") logger.error(f"读取提供商配置失败: {e}")
@ -184,23 +187,23 @@ async def get_provider_models(
): ):
""" """
获取指定提供商的可用模型列表 获取指定提供商的可用模型列表
通过提供商名称查找配置然后请求对应的模型列表端点 通过提供商名称查找配置然后请求对应的模型列表端点
""" """
# 获取提供商配置 # 获取提供商配置
provider_config = _get_provider_config(provider_name) provider_config = _get_provider_config(provider_name)
if not provider_config: if not provider_config:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider_config.get("base_url") base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key") api_key = provider_config.get("api_key")
client_type = provider_config.get("client_type", "openai") client_type = provider_config.get("client_type", "openai")
if not base_url: if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
if not api_key: if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key") raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
# 获取模型列表 # 获取模型列表
models = await _fetch_models_from_provider( models = await _fetch_models_from_provider(
base_url=base_url, base_url=base_url,
@ -209,7 +212,7 @@ async def get_provider_models(
parser=parser, parser=parser,
client_type=client_type, client_type=client_type,
) )
return { return {
"success": True, "success": True,
"models": models, "models": models,
@ -236,7 +239,7 @@ async def get_models_by_url(
parser=parser, parser=parser,
client_type=client_type, client_type=client_type,
) )
return { return {
"success": True, "success": True,
"models": models, "models": models,
@ -251,11 +254,11 @@ async def test_provider_connection(
): ):
""" """
测试提供商连接状态 测试提供商连接状态
分两步测试 分两步测试
1. 网络连通性测试 base_url 发送请求检查是否能连接 1. 网络连通性测试 base_url 发送请求检查是否能连接
2. API Key 验证可选如果提供了 api_key尝试获取模型列表验证 Key 是否有效 2. API Key 验证可选如果提供了 api_key尝试获取模型列表验证 Key 是否有效
返回 返回
- network_ok: 网络是否连通 - network_ok: 网络是否连通
- api_key_valid: API Key 是否有效仅在提供 api_key 时返回 - api_key_valid: API Key 是否有效仅在提供 api_key 时返回
@ -263,11 +266,11 @@ async def test_provider_connection(
- error: 错误信息如果有 - error: 错误信息如果有
""" """
import time import time
base_url = _normalize_url(base_url) base_url = _normalize_url(base_url)
if not base_url: if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空") raise HTTPException(status_code=400, detail="base_url 不能为空")
result = { result = {
"network_ok": False, "network_ok": False,
"api_key_valid": None, "api_key_valid": None,
@ -275,7 +278,7 @@ async def test_provider_connection(
"error": None, "error": None,
"http_status": None, "http_status": None,
} }
# 第一步:测试网络连通性 # 第一步:测试网络连通性
try: try:
start_time = time.time() start_time = time.time()
@ -283,11 +286,11 @@ async def test_provider_connection(
# 尝试 GET 请求 base_url不需要 API Key # 尝试 GET 请求 base_url不需要 API Key
response = await client.get(base_url) response = await client.get(base_url)
latency = (time.time() - start_time) * 1000 latency = (time.time() - start_time) * 1000
result["network_ok"] = True result["network_ok"] = True
result["latency_ms"] = round(latency, 2) result["latency_ms"] = round(latency, 2)
result["http_status"] = response.status_code result["http_status"] = response.status_code
except httpx.ConnectError as e: except httpx.ConnectError as e:
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})" result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
return result return result
@ -300,7 +303,7 @@ async def test_provider_connection(
except Exception as e: except Exception as e:
result["error"] = f"未知错误:{str(e)}" result["error"] = f"未知错误:{str(e)}"
return result return result
# 第二步:如果提供了 API Key验证其有效性 # 第二步:如果提供了 API Key验证其有效性
if api_key: if api_key:
try: try:
@ -313,7 +316,7 @@ async def test_provider_connection(
# 尝试获取模型列表 # 尝试获取模型列表
models_url = f"{base_url}/models" models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers) response = await client.get(models_url, headers=headers)
if response.status_code == 200: if response.status_code == 200:
result["api_key_valid"] = True result["api_key_valid"] = True
elif response.status_code in (401, 403): elif response.status_code in (401, 403):
@ -322,12 +325,12 @@ async def test_provider_connection(
else: else:
# 其他状态码,可能是端点不支持,但 Key 可能是有效的 # 其他状态码,可能是端点不支持,但 Key 可能是有效的
result["api_key_valid"] = None result["api_key_valid"] = None
except Exception as e: except Exception as e:
# API Key 验证失败不影响网络连通性结果 # API Key 验证失败不影响网络连通性结果
logger.warning(f"API Key 验证失败: {e}") logger.warning(f"API Key 验证失败: {e}")
result["api_key_valid"] = None result["api_key_valid"] = None
return result return result
@ -342,10 +345,10 @@ async def test_provider_connection_by_name(
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml") model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(model_config_path): if not os.path.exists(model_config_path):
raise HTTPException(status_code=404, detail="配置文件不存在") raise HTTPException(status_code=404, detail="配置文件不存在")
with open(model_config_path, "r", encoding="utf-8") as f: with open(model_config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f) config = tomlkit.load(f)
# 查找提供商 # 查找提供商
providers = config.get("api_providers", []) providers = config.get("api_providers", [])
provider = None provider = None
@ -353,15 +356,15 @@ async def test_provider_connection_by_name(
if p.get("name") == provider_name: if p.get("name") == provider_name:
provider = p provider = p
break break
if not provider: if not provider:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}") raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider.get("base_url", "") base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "") api_key = provider.get("api_key", "")
if not base_url: if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url") raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口 # 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None) return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)

View File

@ -31,8 +31,9 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
""" """
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符) # 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
import re import re
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀 # 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
base_version = re.split(r'[-.](?:snapshot|dev|alpha|beta|rc)', version_str, flags=re.IGNORECASE)[0] base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
parts = base_version.split(".") parts = base_version.split(".")
if len(parts) < 3: if len(parts) < 3:
@ -613,7 +614,7 @@ async def install_plugin(request: InstallPluginRequest, authorization: Optional[
for field in required_fields: for field in required_fields:
if field not in manifest: if field not in manifest:
raise ValueError(f"缺少必需字段: {field}") raise ValueError(f"缺少必需字段: {field}")
# 将插件 ID 写入 manifest用于后续准确识别 # 将插件 ID 写入 manifest用于后续准确识别
# 这样即使文件夹名称改变,也能通过 manifest 准确识别插件 # 这样即使文件夹名称改变,也能通过 manifest 准确识别插件
manifest["id"] = request.plugin_id manifest["id"] = request.plugin_id
@ -705,7 +706,7 @@ async def uninstall_plugin(
plugin_path = plugins_dir / folder_name plugin_path = plugins_dir / folder_name
# 旧格式:点 # 旧格式:点
old_format_path = plugins_dir / request.plugin_id old_format_path = plugins_dir / request.plugin_id
# 优先使用新格式,如果不存在则尝试旧格式 # 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists(): if not plugin_path.exists():
if old_format_path.exists(): if old_format_path.exists():
@ -839,7 +840,7 @@ async def update_plugin(request: UpdatePluginRequest, authorization: Optional[st
plugin_path = plugins_dir / folder_name plugin_path = plugins_dir / folder_name
# 旧格式:点 # 旧格式:点
old_format_path = plugins_dir / request.plugin_id old_format_path = plugins_dir / request.plugin_id
# 优先使用新格式,如果不存在则尝试旧格式 # 优先使用新格式,如果不存在则尝试旧格式
if not plugin_path.exists(): if not plugin_path.exists():
if old_format_path.exists(): if old_format_path.exists():
@ -1092,21 +1093,21 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
# 尝试从 author.name 和 repository_url 构建标准 ID # 尝试从 author.name 和 repository_url 构建标准 ID
author_name = None author_name = None
repo_name = None repo_name = None
# 获取作者名 # 获取作者名
if "author" in manifest: if "author" in manifest:
if isinstance(manifest["author"], dict) and "name" in manifest["author"]: if isinstance(manifest["author"], dict) and "name" in manifest["author"]:
author_name = manifest["author"]["name"] author_name = manifest["author"]["name"]
elif isinstance(manifest["author"], str): elif isinstance(manifest["author"], str):
author_name = manifest["author"] author_name = manifest["author"]
# 从 repository_url 获取仓库名 # 从 repository_url 获取仓库名
if "repository_url" in manifest: if "repository_url" in manifest:
repo_url = manifest["repository_url"].rstrip("/") repo_url = manifest["repository_url"].rstrip("/")
if repo_url.endswith(".git"): if repo_url.endswith(".git"):
repo_url = repo_url[:-4] repo_url = repo_url[:-4]
repo_name = repo_url.split("/")[-1] repo_name = repo_url.split("/")[-1]
# 构建 ID # 构建 ID
if author_name and repo_name: if author_name and repo_name:
# 标准格式: Author.RepoName # 标准格式: Author.RepoName
@ -1122,7 +1123,7 @@ async def get_installed_plugins(authorization: Optional[str] = Header(None)) ->
else: else:
# 直接使用文件夹名 # 直接使用文件夹名
plugin_id = folder_name plugin_id = folder_name
# 将推断的 ID 写入 manifest方便下次识别 # 将推断的 ID 写入 manifest方便下次识别
logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}") logger.info(f"为插件 {folder_name} 自动生成 ID: {plugin_id}")
manifest["id"] = plugin_id manifest["id"] = plugin_id
@ -1167,12 +1168,10 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema") @router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema( async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件配置 Schema 获取插件配置 Schema
返回插件的完整配置 schema包含所有 section字段定义和布局信息 返回插件的完整配置 schema包含所有 section字段定义和布局信息
用于前端动态生成配置表单 用于前端动态生成配置表单
""" """
@ -1187,10 +1186,10 @@ async def get_plugin_config_schema(
try: try:
# 尝试从已加载的插件中获取 # 尝试从已加载的插件中获取
from src.plugin_system.core.plugin_manager import plugin_manager from src.plugin_system.core.plugin_manager import plugin_manager
# 查找插件实例 # 查找插件实例
plugin_instance = None plugin_instance = None
# 遍历所有已加载的插件 # 遍历所有已加载的插件
for loaded_plugin_name in plugin_manager.list_loaded_plugins(): for loaded_plugin_name in plugin_manager.list_loaded_plugins():
instance = plugin_manager.get_plugin_instance(loaded_plugin_name) instance = plugin_manager.get_plugin_instance(loaded_plugin_name)
@ -1204,17 +1203,17 @@ async def get_plugin_config_schema(
if manifest_id == plugin_id: if manifest_id == plugin_id:
plugin_instance = instance plugin_instance = instance
break break
if plugin_instance and hasattr(plugin_instance, 'get_webui_config_schema'): if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
# 从插件实例获取 schema # 从插件实例获取 schema
schema = plugin_instance.get_webui_config_schema() schema = plugin_instance.get_webui_config_schema()
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
# 如果插件未加载,尝试从文件系统读取 # 如果插件未加载,尝试从文件系统读取
# 查找插件目录 # 查找插件目录
plugins_dir = Path("plugins") plugins_dir = Path("plugins")
plugin_path = None plugin_path = None
for p in plugins_dir.iterdir(): for p in plugins_dir.iterdir():
if p.is_dir(): if p.is_dir():
manifest_path = p / "_manifest.json" manifest_path = p / "_manifest.json"
@ -1227,18 +1226,19 @@ async def get_plugin_config_schema(
break break
except Exception: except Exception:
continue continue
if not plugin_path: if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
# 读取配置文件获取当前配置 # 读取配置文件获取当前配置
config_path = plugin_path / "config.toml" config_path = plugin_path / "config.toml"
current_config = {} current_config = {}
if config_path.exists(): if config_path.exists():
import tomlkit import tomlkit
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
current_config = tomlkit.load(f) current_config = tomlkit.load(f)
# 构建基础 schema无法获取完整的 ConfigField 信息) # 构建基础 schema无法获取完整的 ConfigField 信息)
schema = { schema = {
"plugin_id": plugin_id, "plugin_id": plugin_id,
@ -1252,7 +1252,7 @@ async def get_plugin_config_schema(
"layout": {"type": "auto", "tabs": []}, "layout": {"type": "auto", "tabs": []},
"_note": "插件未加载,仅返回当前配置结构", "_note": "插件未加载,仅返回当前配置结构",
} }
# 从当前配置推断 schema # 从当前配置推断 schema
for section_name, section_data in current_config.items(): for section_name, section_data in current_config.items():
if isinstance(section_data, dict): if isinstance(section_data, dict):
@ -1277,7 +1277,7 @@ async def get_plugin_config_schema(
ui_type = "list" ui_type = "list"
elif isinstance(field_value, dict): elif isinstance(field_value, dict):
ui_type = "json" ui_type = "json"
schema["sections"][section_name]["fields"][field_name] = { schema["sections"][section_name]["fields"][field_name] = {
"name": field_name, "name": field_name,
"type": field_type, "type": field_type,
@ -1290,7 +1290,7 @@ async def get_plugin_config_schema(
"disabled": False, "disabled": False,
"order": 0, "order": 0,
} }
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except HTTPException: except HTTPException:
@ -1301,12 +1301,10 @@ async def get_plugin_config_schema(
@router.get("/config/{plugin_id}") @router.get("/config/{plugin_id}")
async def get_plugin_config( async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件当前配置值 获取插件当前配置值
返回插件的当前配置值 返回插件的当前配置值
""" """
# Token 验证 # Token 验证
@ -1321,7 +1319,7 @@ async def get_plugin_config(
# 查找插件目录 # 查找插件目录
plugins_dir = Path("plugins") plugins_dir = Path("plugins")
plugin_path = None plugin_path = None
for p in plugins_dir.iterdir(): for p in plugins_dir.iterdir():
if p.is_dir(): if p.is_dir():
manifest_path = p / "_manifest.json" manifest_path = p / "_manifest.json"
@ -1334,19 +1332,20 @@ async def get_plugin_config(
break break
except Exception: except Exception:
continue continue
if not plugin_path: if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
# 读取配置文件 # 读取配置文件
config_path = plugin_path / "config.toml" config_path = plugin_path / "config.toml"
if not config_path.exists(): if not config_path.exists():
return {"success": True, "config": {}, "message": "配置文件不存在"} return {"success": True, "config": {}, "message": "配置文件不存在"}
import tomlkit import tomlkit
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f) config = tomlkit.load(f)
return {"success": True, "config": dict(config)} return {"success": True, "config": dict(config)}
except HTTPException: except HTTPException:
@ -1358,13 +1357,11 @@ async def get_plugin_config(
@router.put("/config/{plugin_id}") @router.put("/config/{plugin_id}")
async def update_plugin_config( async def update_plugin_config(
plugin_id: str, plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None)
request: UpdatePluginConfigRequest,
authorization: Optional[str] = Header(None)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
更新插件配置 更新插件配置
保存新的配置值到插件的配置文件 保存新的配置值到插件的配置文件
""" """
# Token 验证 # Token 验证
@ -1379,7 +1376,7 @@ async def update_plugin_config(
# 查找插件目录 # 查找插件目录
plugins_dir = Path("plugins") plugins_dir = Path("plugins")
plugin_path = None plugin_path = None
for p in plugins_dir.iterdir(): for p in plugins_dir.iterdir():
if p.is_dir(): if p.is_dir():
manifest_path = p / "_manifest.json" manifest_path = p / "_manifest.json"
@ -1392,23 +1389,25 @@ async def update_plugin_config(
break break
except Exception: except Exception:
continue continue
if not plugin_path: if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml" config_path = plugin_path / "config.toml"
# 备份旧配置 # 备份旧配置
import shutil import shutil
import datetime import datetime
if config_path.exists(): if config_path.exists():
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = plugin_path / backup_name backup_path = plugin_path / backup_name
shutil.copy(config_path, backup_path) shutil.copy(config_path, backup_path)
logger.info(f"已备份配置文件: {backup_path}") logger.info(f"已备份配置文件: {backup_path}")
# 写入新配置(使用 tomlkit 保留注释) # 写入新配置(使用 tomlkit 保留注释)
import tomlkit import tomlkit
# 先读取原配置以保留注释和格式 # 先读取原配置以保留注释和格式
existing_doc = tomlkit.document() existing_doc = tomlkit.document()
if config_path.exists(): if config_path.exists():
@ -1419,14 +1418,10 @@ async def update_plugin_config(
existing_doc[key] = value existing_doc[key] = value
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(existing_doc, f) tomlkit.dump(existing_doc, f)
logger.info(f"已更新插件配置: {plugin_id}") logger.info(f"已更新插件配置: {plugin_id}")
return { return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
"success": True,
"message": "配置已保存",
"note": "配置更改将在插件重新加载后生效"
}
except HTTPException: except HTTPException:
raise raise
@ -1436,12 +1431,10 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset") @router.post("/config/{plugin_id}/reset")
async def reset_plugin_config( async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
重置插件配置为默认值 重置插件配置为默认值
删除当前配置文件下次加载插件时将使用默认配置 删除当前配置文件下次加载插件时将使用默认配置
""" """
# Token 验证 # Token 验证
@ -1456,7 +1449,7 @@ async def reset_plugin_config(
# 查找插件目录 # 查找插件目录
plugins_dir = Path("plugins") plugins_dir = Path("plugins")
plugin_path = None plugin_path = None
for p in plugins_dir.iterdir(): for p in plugins_dir.iterdir():
if p.is_dir(): if p.is_dir():
manifest_path = p / "_manifest.json" manifest_path = p / "_manifest.json"
@ -1469,29 +1462,26 @@ async def reset_plugin_config(
break break
except Exception: except Exception:
continue continue
if not plugin_path: if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml" config_path = plugin_path / "config.toml"
if not config_path.exists(): if not config_path.exists():
return {"success": True, "message": "配置文件不存在,无需重置"} return {"success": True, "message": "配置文件不存在,无需重置"}
# 备份并删除 # 备份并删除
import shutil import shutil
import datetime import datetime
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = plugin_path / backup_name backup_path = plugin_path / backup_name
shutil.move(config_path, backup_path) shutil.move(config_path, backup_path)
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}") logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
return { return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
"success": True,
"message": "配置已重置,下次加载插件时将使用默认配置",
"backup": str(backup_path)
}
except HTTPException: except HTTPException:
raise raise
@ -1501,12 +1491,10 @@ async def reset_plugin_config(
@router.post("/config/{plugin_id}/toggle") @router.post("/config/{plugin_id}/toggle")
async def toggle_plugin( async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
切换插件启用状态 切换插件启用状态
切换插件配置中的 enabled 字段 切换插件配置中的 enabled 字段
""" """
# Token 验证 # Token 验证
@ -1521,7 +1509,7 @@ async def toggle_plugin(
# 查找插件目录 # 查找插件目录
plugins_dir = Path("plugins") plugins_dir = Path("plugins")
plugin_path = None plugin_path = None
for p in plugins_dir.iterdir(): for p in plugins_dir.iterdir():
if p.is_dir(): if p.is_dir():
manifest_path = p / "_manifest.json" manifest_path = p / "_manifest.json"
@ -1534,40 +1522,40 @@ async def toggle_plugin(
break break
except Exception: except Exception:
continue continue
if not plugin_path: if not plugin_path:
raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}") raise HTTPException(status_code=404, detail=f"未找到插件: {plugin_id}")
config_path = plugin_path / "config.toml" config_path = plugin_path / "config.toml"
import tomlkit import tomlkit
# 读取当前配置(保留注释和格式) # 读取当前配置(保留注释和格式)
config = tomlkit.document() config = tomlkit.document()
if config_path.exists(): if config_path.exists():
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f) config = tomlkit.load(f)
# 切换 enabled 状态 # 切换 enabled 状态
if "plugin" not in config: if "plugin" not in config:
config["plugin"] = tomlkit.table() config["plugin"] = tomlkit.table()
current_enabled = config["plugin"].get("enabled", True) current_enabled = config["plugin"].get("enabled", True)
new_enabled = not current_enabled new_enabled = not current_enabled
config["plugin"]["enabled"] = new_enabled config["plugin"]["enabled"] = new_enabled
# 写入配置(保留注释) # 写入配置(保留注释)
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config, f) tomlkit.dump(config, f)
status = "启用" if new_enabled else "禁用" status = "启用" if new_enabled else "禁用"
logger.info(f"{status}插件: {plugin_id}") logger.info(f"{status}插件: {plugin_id}")
return { return {
"success": True, "success": True,
"enabled": new_enabled, "enabled": new_enabled,
"message": f"插件已{status}", "message": f"插件已{status}",
"note": "状态更改将在下次加载插件时生效" "note": "状态更改将在下次加载插件时生效",
} }
except HTTPException: except HTTPException:

View File

@ -43,7 +43,7 @@ async def restart_maibot():
注意此操作会使麦麦暂时离线 注意此操作会使麦麦暂时离线
""" """
import asyncio import asyncio
try: try:
# 记录重启操作 # 记录重启操作
print(f"[{datetime.now()}] WebUI 触发重启操作") print(f"[{datetime.now()}] WebUI 触发重启操作")
@ -54,7 +54,7 @@ async def restart_maibot():
python = sys.executable python = sys.executable
args = [python] + sys.argv args = [python] + sys.argv
os.execv(python, args) os.execv(python, args)
# 创建后台任务执行重启 # 创建后台任务执行重启
asyncio.create_task(delayed_restart()) asyncio.create_task(delayed_restart())

View File

@ -20,10 +20,10 @@ class WebUIServer:
self.port = port self.port = port
self.app = FastAPI(title="MaiBot WebUI") self.app = FastAPI(title="MaiBot WebUI")
self._server = None self._server = None
# 显示 Access Token # 显示 Access Token
self._show_access_token() self._show_access_token()
# 重要:先注册 API 路由,再设置静态文件 # 重要:先注册 API 路由,再设置静态文件
self._register_api_routes() self._register_api_routes()
self._setup_static_files() self._setup_static_files()
@ -32,7 +32,7 @@ class WebUIServer:
"""显示 WebUI Access Token""" """显示 WebUI Access Token"""
try: try:
from src.webui.token_manager import get_token_manager from src.webui.token_manager import get_token_manager
token_manager = get_token_manager() token_manager = get_token_manager()
current_token = token_manager.get_token() current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}") logger.info(f"🔑 WebUI Access Token: {current_token}")
@ -69,7 +69,7 @@ class WebUIServer:
# 如果是根路径,直接返回 index.html # 如果是根路径,直接返回 index.html
if not full_path or full_path == "/": if not full_path or full_path == "/":
return FileResponse(static_path / "index.html", media_type="text/html") return FileResponse(static_path / "index.html", media_type="text/html")
# 检查是否是静态文件 # 检查是否是静态文件
file_path = static_path / full_path file_path = static_path / full_path
if file_path.is_file() and file_path.exists(): if file_path.is_file() and file_path.exists():
@ -88,13 +88,15 @@ class WebUIServer:
# 导入所有 WebUI 路由 # 导入所有 WebUI 路由
from src.webui.routes import router as webui_router from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router from src.webui.logs_ws import router as logs_router
logger.info("开始导入 knowledge_routes...") logger.info("开始导入 knowledge_routes...")
from src.webui.knowledge_routes import router as knowledge_router from src.webui.knowledge_routes import router as knowledge_router
logger.info("knowledge_routes 导入成功") logger.info("knowledge_routes 导入成功")
# 导入本地聊天室路由 # 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router from src.webui.chat_routes import router as chat_router
logger.info("chat_routes 导入成功") logger.info("chat_routes 导入成功")
# 注册路由 # 注册路由

View File

@ -8,23 +8,23 @@ if edges:
e = edges[0] e = edges[0]
print(f"Edge tuple: {e}") print(f"Edge tuple: {e}")
print(f"Edge tuple type: {type(e)}") print(f"Edge tuple type: {type(e)}")
edge_data = kg.graph[e[0], e[1]] edge_data = kg.graph[e[0], e[1]]
print(f"\nEdge data type: {type(edge_data)}") print(f"\nEdge data type: {type(edge_data)}")
print(f"Edge data: {edge_data}") print(f"Edge data: {edge_data}")
print(f"Has 'get' method: {hasattr(edge_data, 'get')}") print(f"Has 'get' method: {hasattr(edge_data, 'get')}")
print(f"Is dict: {isinstance(edge_data, dict)}") print(f"Is dict: {isinstance(edge_data, dict)}")
# 尝试不同的访问方式 # 尝试不同的访问方式
try: try:
print(f"\nUsing []: {edge_data['weight']}") print(f"\nUsing []: {edge_data['weight']}")
except Exception as e: except Exception as e:
print(f"Using [] failed: {e}") print(f"Using [] failed: {e}")
try: try:
print(f"Using .get(): {edge_data.get('weight')}") print(f"Using .get(): {edge_data.get('weight')}")
except Exception as e: except Exception as e:
print(f"Using .get() failed: {e}") print(f"Using .get() failed: {e}")
# 查看所有属性 # 查看所有属性
print(f"\nDir: {[x for x in dir(edge_data) if not x.startswith('_')]}") print(f"\nDir: {[x for x in dir(edge_data) if not x.startswith('_')]}")