Ruff Fix & format

pull/1389/head
墨梓柒 2025-11-29 14:38:42 +08:00
parent d7932595e8
commit 3935ce817e
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
31 changed files with 678 additions and 684 deletions

3
bot.py
View File

@ -78,6 +78,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
# 关闭 WebUI 服务器 # 关闭 WebUI 服务器
try: try:
from src.webui.webui_server import get_webui_server from src.webui.webui_server import get_webui_server
webui_server = get_webui_server() webui_server = get_webui_server()
if webui_server and webui_server._server: if webui_server and webui_server._server:
await webui_server.shutdown() await webui_server.shutdown()
@ -238,7 +239,7 @@ if __name__ == "__main__":
logger.warning("收到中断信号,正在优雅关闭...") logger.warning("收到中断信号,正在优雅关闭...")
# 取消主任务 # 取消主任务
if 'main_tasks' in locals() and main_tasks and not main_tasks.done(): if "main_tasks" in locals() and main_tasks and not main_tasks.done():
main_tasks.cancel() main_tasks.cancel()
try: try:
loop.run_until_complete(main_tasks) loop.run_until_complete(main_tasks)

View File

@ -254,6 +254,7 @@ class BrainChatting:
# 检查是否需要提问表达反思 # 检查是否需要提问表达反思
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
from src.express.expression_reflector import expression_reflector_manager from src.express.expression_reflector import expression_reflector_manager
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id) reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
asyncio.create_task(reflector.check_and_ask()) asyncio.create_task(reflector.check_and_ask())

View File

@ -410,7 +410,6 @@ class HeartFChatting:
reflect_tracker_manager.remove_tracker(self.stream_id) reflect_tracker_manager.remove_tracker(self.stream_id)
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.") logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
start_time = time.time() start_time = time.time()
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()): async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
asyncio.create_task(self.expression_learner.trigger_learning_for_chat()) asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
@ -427,7 +426,9 @@ class HeartFChatting:
# asyncio.create_task(self.chat_history_summarizer.process()) # asyncio.create_task(self.chat_history_summarizer.process())
cycle_timers, thinking_id = self.start_cycle() cycle_timers, thinking_id = self.start_cycle()
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})") logger.info(
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
)
# 第一步:动作检查 # 第一步:动作检查
available_actions: Dict[str, ActionInfo] = {} available_actions: Dict[str, ActionInfo] = {}

View File

@ -25,6 +25,7 @@ def get_webui_chat_broadcaster():
if _webui_chat_broadcaster is None: if _webui_chat_broadcaster is None:
try: try:
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM) _webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
except ImportError: except ImportError:
_webui_chat_broadcaster = (None, None) _webui_chat_broadcaster = (None, None)
@ -44,7 +45,8 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
import time import time
from src.config.config import global_config from src.config.config import global_config
await chat_manager.broadcast({ await chat_manager.broadcast(
{
"type": "bot_message", "type": "bot_message",
"content": message.processed_plain_text, "content": message.processed_plain_text,
"message_type": "text", "message_type": "text",
@ -53,8 +55,9 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
"name": global_config.bot.nickname, "name": global_config.bot.nickname,
"avatar": None, "avatar": None,
"is_bot": True, "is_bot": True,
},
} }
}) )
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库 # 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
# 无需手动保存 # 无需手动保存

View File

@ -181,8 +181,12 @@ class ActionPlanner:
found_ids = set(matches) found_ids = set(matches)
missing_ids = found_ids - available_ids missing_ids = found_ids - available_ids
if missing_ids: if missing_ids:
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...") logger.info(
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用其中{len(found_ids & available_ids)}个在上下文中") f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
)
logger.info(
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用其中{len(found_ids & available_ids)}个在上下文中"
)
def _replace(match: re.Match[str]) -> str: def _replace(match: re.Match[str]) -> str:
msg_id = match.group(0) msg_id = match.group(0)
@ -234,17 +238,11 @@ class ActionPlanner:
target_message = message_id_list[-1][1] target_message = message_id_list[-1][1]
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id使用最新消息作为target_message") logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id使用最新消息作为target_message")
if ( if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
action != "no_reply"
and target_message is not None
and self._is_message_from_self(target_message)
):
logger.info( logger.info(
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply" f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
) )
reasoning = ( reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
)
action = "no_reply" action = "no_reply"
target_message = None target_message = None
@ -295,10 +293,9 @@ class ActionPlanner:
def _is_message_from_self(self, message: "DatabaseMessages") -> bool: def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
"""判断消息是否由机器人自身发送""" """判断消息是否由机器人自身发送"""
try: try:
return ( return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
str(message.user_info.user_id) == str(global_config.bot.qq_account) message.user_info.platform or ""
and (message.user_info.platform or "") == (global_config.bot.platform or "") ) == (global_config.bot.platform or "")
)
except AttributeError: except AttributeError:
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段") logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
return False return False

View File

@ -132,9 +132,7 @@ class ImageManager:
deleted_images = Images.delete().where(Images.type == "emoji").execute() deleted_images = Images.delete().where(Images.type == "emoji").execute()
# 清理ImageDescriptions表中type为emoji的记录 # 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ( deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
)
total_deleted = deleted_images + deleted_descriptions total_deleted = deleted_images + deleted_descriptions
if total_deleted > 0: if total_deleted > 0:
@ -194,10 +192,14 @@ class ImageManager:
if cache_record: if cache_record:
# 优先使用情感标签,如果没有则使用详细描述 # 优先使用情感标签,如果没有则使用详细描述
if cache_record.emotion_tags: if cache_record.emotion_tags:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
return f"[表情包:{cache_record.emotion_tags}]" return f"[表情包:{cache_record.emotion_tags}]"
elif cache_record.description: elif cache_record.description:
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...") logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
return f"[表情包:{cache_record.description}]" return f"[表情包:{cache_record.description}]"
except Exception as e: except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}") logger.debug(f"查询EmojiDescriptionCache时出错: {e}")

View File

@ -62,11 +62,15 @@ class ExpressionReflector:
max_interval = 15 * 60 # 10分钟 max_interval = 15 * 60 # 10分钟
interval = random.uniform(min_interval, max_interval) interval = random.uniform(min_interval, max_interval)
logger.info(f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask/60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval/60:.2f}分钟)") logger.info(
f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)"
)
if time_since_last_ask < interval: if time_since_last_ask < interval:
remaining_time = interval - time_since_last_ask remaining_time = interval - time_since_last_ask
logger.info(f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time/60:.2f}分钟),跳过") logger.info(
f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过"
)
return False return False
# 检查是否已经有针对该 Operator 的 Tracker 在运行 # 检查是否已经有针对该 Operator 的 Tracker 在运行
@ -78,10 +82,9 @@ class ExpressionReflector:
# 获取未检查的表达 # 获取未检查的表达
try: try:
logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达") logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达")
expressions = (Expression expressions = (
.select() Expression.select().where((Expression.checked == False) & (Expression.rejected == False)).limit(50)
.where((Expression.checked == False) & (Expression.rejected == False)) )
.limit(50))
expr_list = list(expressions) expr_list = list(expressions)
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达") logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
@ -91,7 +94,9 @@ class ExpressionReflector:
return False return False
target_expr: Expression = random.choice(expr_list) target_expr: Expression = random.choice(expr_list)
logger.info(f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}") logger.info(
f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}"
)
# 生成询问文本 # 生成询问文本
ask_text = _generate_ask_text(target_expr) ask_text = _generate_ask_text(target_expr)
@ -112,11 +117,13 @@ class ExpressionReflector:
except Exception as e: except Exception as e:
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
except Exception as e: except Exception as e:
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}") logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return False return False
@ -141,6 +148,7 @@ expression_reflector_manager = ExpressionReflectorManager()
async def _check_tracker_exists(operator_config: str) -> bool: async def _check_tracker_exists(operator_config: str) -> bool:
"""检查指定 Operator 是否已有活跃的 Tracker""" """检查指定 Operator 是否已有活跃的 Tracker"""
from src.express.reflect_tracker import reflect_tracker_manager from src.express.reflect_tracker import reflect_tracker_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = None chat_stream = None
@ -240,12 +248,5 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
reflect_tracker_manager.add_tracker(stream_id, tracker) reflect_tracker_manager.add_tracker(stream_id, tracker)
# 发送消息 # 发送消息
await send_api.text_to_stream( await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True)
text=text,
stream_id=stream_id,
typing=True
)
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}") logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
logger = get_logger("reflect_tracker") logger = get_logger("reflect_tracker")
class ReflectTracker: class ReflectTracker:
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float): def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
self.chat_stream = chat_stream self.chat_stream = chat_stream
@ -28,9 +29,7 @@ class ReflectTracker:
self.max_duration = 15 * 60 # 15 minutes self.max_duration = 15 * 60 # 15 minutes
# LLM for judging response # LLM for judging response
self.judge_model = LLMRequest( self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
model_set=model_config.model_task_config.utils, request_type="reflect.tracker"
)
self._init_prompts() self._init_prompts()
@ -109,7 +108,7 @@ class ReflectTracker:
"reflect_judge_prompt", "reflect_judge_prompt",
situation=self.expression.situation, situation=self.expression.situation,
style=self.expression.style, style=self.expression.style,
context_block=context_block context_block=context_block,
) )
logger.info(f"ReflectTracker LLM Prompt: {prompt}") logger.info(f"ReflectTracker LLM Prompt: {prompt}")
@ -162,9 +161,13 @@ class ReflectTracker:
self.expression.save() self.expression.save()
if has_update: if has_update:
logger.info(f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}") logger.info(
f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}"
)
else: else:
logger.info(f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1.") logger.info(
f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1."
)
return True return True
elif judgment == "Ignore": elif judgment == "Ignore":
@ -177,6 +180,7 @@ class ReflectTracker:
return False return False
# Global manager for trackers # Global manager for trackers
class ReflectTrackerManager: class ReflectTrackerManager:
def __init__(self): def __init__(self):
@ -192,5 +196,5 @@ class ReflectTrackerManager:
if chat_id in self.trackers: if chat_id in self.trackers:
del self.trackers[chat_id] del self.trackers[chat_id]
reflect_tracker_manager = ReflectTrackerManager()
reflect_tracker_manager = ReflectTrackerManager()

View File

@ -315,7 +315,9 @@ class ChatHistorySummarizer:
before_count = len(self.current_batch.messages) before_count = len(self.current_batch.messages)
self.current_batch.messages.extend(new_messages) self.current_batch.messages.extend(new_messages)
self.current_batch.end_time = current_time self.current_batch.end_time = current_time
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息") logger.info(
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
)
# 更新批次后持久化 # 更新批次后持久化
self._persist_topic_cache() self._persist_topic_cache()
else: else:
@ -361,9 +363,7 @@ class ChatHistorySummarizer:
else: else:
time_str = f"{time_since_last_check / 3600:.1f}小时" time_str = f"{time_since_last_check / 3600:.1f}小时"
logger.info( logger.info(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
)
# 检查“话题检查”触发条件 # 检查“话题检查”触发条件
should_check = False should_check = False
@ -426,7 +426,9 @@ class ChatHistorySummarizer:
return return
# 2. 构造编号后的消息字符串和参与者信息 # 2. 构造编号后的消息字符串和参与者信息
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages) numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
self._build_numbered_messages_for_llm(messages)
)
# 3. 调用 LLM 识别话题,并得到 topic -> indices # 3. 调用 LLM 识别话题,并得到 topic -> indices
existing_topics = list(self.topic_cache.keys()) existing_topics = list(self.topic_cache.keys())
@ -588,9 +590,7 @@ class ChatHistorySummarizer:
if not numbered_lines: if not numbered_lines:
return False, {} return False, {}
history_topics_block = ( history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
)
messages_block = "\n".join(numbered_lines) messages_block = "\n".join(numbered_lines)
prompt = await global_prompt_manager.format_prompt( prompt = await global_prompt_manager.format_prompt(
@ -607,6 +607,7 @@ class ChatHistorySummarizer:
) )
import re import re
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}") logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}") logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
@ -895,4 +896,3 @@ class ChatHistorySummarizer:
init_prompt() init_prompt()

View File

@ -44,9 +44,7 @@ class JargonExplainer:
request_type="jargon.explain", request_type="jargon.explain",
) )
def match_jargon_from_messages( def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
self, messages: List[Any]
) -> List[Dict[str, str]]:
""" """
通过直接匹配数据库中的jargon字符串来提取黑话 通过直接匹配数据库中的jargon字符串来提取黑话
@ -68,7 +66,9 @@ class JargonExplainer:
if is_bot_message(msg): if is_bot_message(msg):
continue continue
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip() msg_text = (
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
).strip()
if msg_text: if msg_text:
message_texts.append(msg_text) message_texts.append(msg_text)
@ -79,9 +79,7 @@ class JargonExplainer:
combined_text = " ".join(message_texts) combined_text = " ".join(message_texts)
# 查询所有有meaning的jargon记录 # 查询所有有meaning的jargon记录
query = Jargon.select().where( query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
# 根据all_global配置决定查询逻辑 # 根据all_global配置决定查询逻辑
if global_config.jargon.all_global: if global_config.jargon.all_global:
@ -123,12 +121,12 @@ class JargonExplainer:
pattern = re.escape(content) pattern = re.escape(content)
# 使用单词边界或中文字符边界来匹配,避免部分匹配 # 使用单词边界或中文字符边界来匹配,避免部分匹配
# 对于中文使用Unicode字符类对于英文使用单词边界 # 对于中文使用Unicode字符类对于英文使用单词边界
if re.search(r'[\u4e00-\u9fff]', content): if re.search(r"[\u4e00-\u9fff]", content):
# 包含中文,使用更宽松的匹配 # 包含中文,使用更宽松的匹配
search_pattern = pattern search_pattern = pattern
else: else:
# 纯英文/数字,使用单词边界 # 纯英文/数字,使用单词边界
search_pattern = r'\b' + pattern + r'\b' search_pattern = r"\b" + pattern + r"\b"
if re.search(search_pattern, combined_text, re.IGNORECASE): if re.search(search_pattern, combined_text, re.IGNORECASE):
# 找到匹配,记录(去重) # 找到匹配,记录(去重)
@ -147,9 +145,7 @@ class JargonExplainer:
return list(matched_jargon.values()) return list(matched_jargon.values())
async def explain_jargon( async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
self, messages: List[Any], chat_context: str
) -> Optional[str]:
""" """
解释上下文中的黑话 解释上下文中的黑话
@ -239,9 +235,7 @@ class JargonExplainer:
return summary return summary
async def explain_jargon_in_context( async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
chat_id: str, messages: List[Any], chat_context: str
) -> Optional[str]:
""" """
解释上下文中的黑话便捷函数 解释上下文中的黑话便捷函数
@ -255,4 +249,3 @@ async def explain_jargon_in_context(
""" """
explainer = JargonExplainer(chat_id) explainer = JargonExplainer(chat_id)
return await explainer.explain_jargon(messages, chat_context) return await explainer.explain_jargon(messages, chat_context)

View File

@ -22,15 +22,13 @@ from src.jargon.jargon_utils import (
contains_bot_self_name, contains_bot_self_name,
parse_chat_id_list, parse_chat_id_list,
chat_id_list_contains, chat_id_list_contains,
update_chat_id_list update_chat_id_list,
) )
logger = get_logger("jargon") logger = get_logger("jargon")
def _init_prompt() -> None: def _init_prompt() -> None:
prompt_str = """ prompt_str = """
**聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID** **聊天内容其中的{bot_name}的发言内容是你自己的发言[msg_id] 是消息ID**
@ -126,7 +124,6 @@ _init_prompt()
_init_inference_prompts() _init_inference_prompts()
def _should_infer_meaning(jargon_obj: Jargon) -> bool: def _should_infer_meaning(jargon_obj: Jargon) -> bool:
""" """
判断是否需要进行含义推断 判断是否需要进行含义推断
@ -211,7 +208,9 @@ class JargonMiner:
processed_pairs = set() processed_pairs = set()
for idx, msg in enumerate(messages): for idx, msg in enumerate(messages):
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip() msg_text = (
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
).strip()
if not msg_text or is_bot_message(msg): if not msg_text or is_bot_message(msg):
continue continue
@ -588,7 +587,6 @@ class JargonMiner:
content = entry["content"] content = entry["content"]
raw_content_list = entry["raw_content"] # 已经是列表 raw_content_list = entry["raw_content"] # 已经是列表
try: try:
# 查询所有content匹配的记录 # 查询所有content匹配的记录
query = Jargon.select().where(Jargon.content == content) query = Jargon.select().where(Jargon.content == content)

View File

@ -13,6 +13,7 @@ from src.chat.utils.utils import parse_platform_accounts
logger = get_logger("jargon") logger = get_logger("jargon")
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]: def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
""" """
解析chat_id字段兼容旧格式字符串和新格式JSON列表 解析chat_id字段兼容旧格式字符串和新格式JSON列表
@ -168,10 +169,7 @@ def is_bot_message(msg: Any) -> bool:
.strip() .strip()
.lower() .lower()
) )
user_id = ( user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "")
.strip()
)
if not platform or not user_id: if not platform or not user_id:
return False return False

View File

@ -338,7 +338,9 @@ class LLMRequest:
if e.__cause__: if e.__cause__:
original_error_type = type(e.__cause__).__name__ original_error_type = type(e.__cause__).__name__
original_error_msg = str(e.__cause__) original_error_msg = str(e.__cause__)
original_error_info = f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}" original_error_info = (
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
)
retry_remain -= 1 retry_remain -= 1
if retry_remain <= 0: if retry_remain <= 0:

View File

@ -296,7 +296,6 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
if not content: if not content:
continue continue
if not global_config.jargon.all_global and not jargon.is_global: if not global_config.jargon.all_global and not jargon.is_global:
chat_id_list = parse_chat_id_list(jargon.chat_id) chat_id_list = parse_chat_id_list(jargon.chat_id)
if not chat_id_list_contains(chat_id_list, chat_id): if not chat_id_list_contains(chat_id_list, chat_id):
@ -586,9 +585,7 @@ async def _react_agent_solve_question(
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}}) step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
step["observations"] = ["从LLM输出内容中检测到found_answer"] step["observations"] = ["从LLM输出内容中检测到found_answer"]
thinking_steps.append(step) thinking_steps.append(step)
logger.info( logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}")
f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}"
)
return True, found_answer_content, thinking_steps, False return True, found_answer_content, thinking_steps, False
if not_enough_info_reason: if not_enough_info_reason:
@ -1016,9 +1013,7 @@ async def build_memory_retrieval_prompt(
if question_results: if question_results:
retrieved_memory = "\n\n".join(question_results) retrieved_memory = "\n\n".join(question_results)
logger.info( logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆")
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆"
)
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n" return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else: else:
logger.debug("所有问题均未找到答案") logger.debug("所有问题均未找到答案")

View File

@ -54,7 +54,9 @@ async def search_chat_history(
if record.participants: if record.participants:
try: try:
participants_data = ( participants_data = (
json.loads(record.participants) if isinstance(record.participants, str) else record.participants json.loads(record.participants)
if isinstance(record.participants, str)
else record.participants
) )
if isinstance(participants_data, list): if isinstance(participants_data, list):
participants_list = [str(p).lower() for p in participants_data] participants_list = [str(p).lower() for p in participants_data]
@ -156,9 +158,7 @@ async def search_chat_history(
# 添加关键词 # 添加关键词
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data: if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data]) keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}") result_parts.append(f"关键词:{keywords_str}")
@ -208,9 +208,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
return "未提供有效的记忆ID" return "未提供有效的记忆ID"
# 查询记录 # 查询记录
query = ChatHistory.select().where( query = ChatHistory.select().where((ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list)))
(ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list))
)
records = list(query.order_by(ChatHistory.start_time.desc())) records = list(query.order_by(ChatHistory.start_time.desc()))
if not records: if not records:
@ -256,9 +254,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
# 添加关键词 # 添加关键词
if record.keywords: if record.keywords:
try: try:
keywords_data = ( keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list) and keywords_data: if isinstance(keywords_data, list) and keywords_data:
keywords_str = "".join([str(k) for k in keywords_data]) keywords_str = "".join([str(k) for k in keywords_data])
result_parts.append(f"关键词:{keywords_str}") result_parts.append(f"关键词:{keywords_str}")

View File

@ -150,6 +150,7 @@ class ConfigSection:
order=1 order=1
) )
""" """
title: str # 显示标题 title: str # 显示标题
description: Optional[str] = None # 详细描述 description: Optional[str] = None # 详细描述
icon: Optional[str] = None # 图标名称 icon: Optional[str] = None # 图标名称
@ -182,6 +183,7 @@ class ConfigTab:
sections=["plugin", "api"] sections=["plugin", "api"]
) )
""" """
id: str # 标签页 ID id: str # 标签页 ID
title: str # 显示标题 title: str # 显示标题
sections: List[str] = field(default_factory=list) # 包含的 section 名称列表 sections: List[str] = field(default_factory=list) # 包含的 section 名称列表
@ -222,6 +224,7 @@ class ConfigLayout:
] ]
) )
""" """
type: str = "auto" # 布局类型: auto, tabs, pages type: str = "auto" # 布局类型: auto, tabs, pages
tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表 tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表
@ -234,11 +237,7 @@ class ConfigLayout:
def section_meta( def section_meta(
title: str, title: str, description: Optional[str] = None, icon: Optional[str] = None, collapsed: bool = False, order: int = 0
description: Optional[str] = None,
icon: Optional[str] = None,
collapsed: bool = False,
order: int = 0
) -> Union[str, ConfigSection]: ) -> Union[str, ConfigSection]:
""" """
便捷函数创建 section 元数据 便捷函数创建 section 元数据
@ -261,10 +260,4 @@ def section_meta(
"debug": section_meta("调试设置", collapsed=True, order=99), "debug": section_meta("调试设置", collapsed=True, order=99),
} }
""" """
return ConfigSection( return ConfigSection(title=title, description=description, icon=icon, collapsed=collapsed, order=order)
title=title,
description=description,
icon=icon,
collapsed=collapsed,
order=order
)

View File

@ -25,6 +25,7 @@ WEBUI_USER_ID_PREFIX = "webui_user_"
class ChatHistoryMessage(BaseModel): class ChatHistoryMessage(BaseModel):
"""聊天历史消息""" """聊天历史消息"""
id: str id: str
type: str # 'user' | 'bot' | 'system' type: str # 'user' | 'bot' | 'system'
content: str content: str
@ -81,11 +82,7 @@ class ChatHistoryManager:
def clear_history(self) -> int: def clear_history(self) -> int:
"""清空 WebUI 聊天历史记录""" """清空 WebUI 聊天历史记录"""
try: try:
deleted = ( deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute()
Messages.delete()
.where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID)
.execute()
)
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录") logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
return deleted return deleted
except Exception as e: except Exception as e:
@ -135,11 +132,7 @@ chat_manager = ChatConnectionManager()
def create_message_data( def create_message_data(
content: str, content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据""" """创建符合麦麦消息格式的消息数据"""
if message_id is None: if message_id is None:
@ -163,7 +156,7 @@ def create_message_data(
}, },
"additional_config": { "additional_config": {
"at_bot": is_at_bot, "at_bot": is_at_bot,
} },
}, },
"message_segment": { "message_segment": {
"type": "seglist", "type": "seglist",
@ -175,8 +168,8 @@ def create_message_data(
{ {
"type": "mention_bot", "type": "mention_bot",
"data": "1.0", "data": "1.0",
} },
] ],
}, },
"raw_message": content, "raw_message": content,
"processed_plain_text": content, "processed_plain_text": content,
@ -186,7 +179,7 @@ def create_message_data(
@router.get("/history") @router.get("/history")
async def get_chat_history( async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200), limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None) # 保留参数兼容性,但不用于过滤 user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
): ):
"""获取聊天历史记录 """获取聊天历史记录
@ -236,28 +229,37 @@ async def websocket_chat(
try: try:
# 发送会话信息(包含用户 ID前端需要保存 # 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, { await chat_manager.send_message(
session_id,
{
"type": "session_info", "type": "session_info",
"session_id": session_id, "session_id": session_id,
"user_id": user_id, "user_id": user_id,
"user_name": user_name, "user_name": user_name,
"bot_name": global_config.bot.nickname, "bot_name": global_config.bot.nickname,
}) },
)
# 发送历史记录 # 发送历史记录
history = chat_history.get_history(50) history = chat_history.get_history(50)
if history: if history:
await chat_manager.send_message(session_id, { await chat_manager.send_message(
session_id,
{
"type": "history", "type": "history",
"messages": history, "messages": history,
}) },
)
# 发送欢迎消息(不保存到历史) # 发送欢迎消息(不保存到历史)
await chat_manager.send_message(session_id, { await chat_manager.send_message(
session_id,
{
"type": "system", "type": "system",
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!", "content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
"timestamp": time.time(), "timestamp": time.time(),
}) },
)
while True: while True:
data = await websocket.receive_json() data = await websocket.receive_json()
@ -275,7 +277,8 @@ async def websocket_chat(
# 广播用户消息给所有连接(包括发送者) # 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库 # 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast({ await chat_manager.broadcast(
{
"type": "user_message", "type": "user_message",
"content": content, "content": content,
"message_id": message_id, "message_id": message_id,
@ -284,8 +287,9 @@ async def websocket_chat(
"name": current_user_name, "name": current_user_name,
"user_id": user_id, "user_id": user_id,
"is_bot": False, "is_bot": False,
},
} }
}) )
# 创建麦麦消息格式 # 创建麦麦消息格式
message_data = create_message_data( message_data = create_message_data(
@ -298,42 +302,55 @@ async def websocket_chat(
try: try:
# 显示正在输入状态 # 显示正在输入状态
await chat_manager.broadcast({ await chat_manager.broadcast(
{
"type": "typing", "type": "typing",
"is_typing": True, "is_typing": True,
}) }
)
# 调用麦麦的消息处理 # 调用麦麦的消息处理
await chat_bot.message_process(message_data) await chat_bot.message_process(message_data)
except Exception as e: except Exception as e:
logger.error(f"处理消息时出错: {e}") logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(session_id, { await chat_manager.send_message(
session_id,
{
"type": "error", "type": "error",
"content": f"处理消息时出错: {str(e)}", "content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(), "timestamp": time.time(),
}) },
)
finally: finally:
await chat_manager.broadcast({ await chat_manager.broadcast(
{
"type": "typing", "type": "typing",
"is_typing": False, "is_typing": False,
}) }
)
elif data.get("type") == "ping": elif data.get("type") == "ping":
await chat_manager.send_message(session_id, { await chat_manager.send_message(
session_id,
{
"type": "pong", "type": "pong",
"timestamp": time.time(), "timestamp": time.time(),
}) },
)
elif data.get("type") == "update_nickname": elif data.get("type") == "update_nickname":
# 允许用户更新昵称 # 允许用户更新昵称
if new_name := data.get("user_name", "").strip(): if new_name := data.get("user_name", "").strip():
current_user_name = new_name current_user_name = new_name
await chat_manager.send_message(session_id, { await chat_manager.send_message(
session_id,
{
"type": "nickname_updated", "type": "nickname_updated",
"user_name": current_user_name, "user_name": current_user_name,
"timestamp": time.time(), "timestamp": time.time(),
}) },
)
except WebSocketDisconnect: except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}") logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")

View File

@ -5,7 +5,7 @@
import os import os
import tomlkit import tomlkit
from fastapi import APIRouter, HTTPException, Body from fastapi import APIRouter, HTTPException, Body
from typing import Any from typing import Any, Annotated
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
@ -41,6 +41,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui") logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"]) router = APIRouter(prefix="/config", tags=["config"])
@ -90,7 +96,7 @@ async def get_bot_config_schema():
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取配置架构失败: {e}") logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
@router.get("/schema/model") @router.get("/schema/model")
@ -101,7 +107,7 @@ async def get_model_config_schema():
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取模型配置架构失败: {e}") logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
# ===== 子配置架构获取接口 ===== # ===== 子配置架构获取接口 =====
@ -174,7 +180,7 @@ async def get_config_section_schema(section_name: str):
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
except Exception as e: except Exception as e:
logger.error(f"获取配置节架构失败: {e}") logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
# ===== 配置读取接口 ===== # ===== 配置读取接口 =====
@ -196,7 +202,7 @@ async def get_bot_config():
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取配置文件失败: {e}") logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.get("/model") @router.get("/model")
@ -215,21 +221,21 @@ async def get_model_config():
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取配置文件失败: {e}") logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
# ===== 配置更新接口 ===== # ===== 配置更新接口 =====
@router.post("/bot") @router.post("/bot")
async def update_bot_config(config_data: dict[str, Any] = Body(...)): async def update_bot_config(config_data: ConfigBody):
"""更新麦麦主程序配置""" """更新麦麦主程序配置"""
try: try:
# 验证配置数据 # 验证配置数据
try: try:
Config.from_dict(config_data) Config.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件 # 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml") config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -242,18 +248,18 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存配置文件失败: {e}") logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model") @router.post("/model")
async def update_model_config(config_data: dict[str, Any] = Body(...)): async def update_model_config(config_data: ConfigBody):
"""更新模型配置""" """更新模型配置"""
try: try:
# 验证配置数据 # 验证配置数据
try: try:
APIAdapterConfig.from_dict(config_data) APIAdapterConfig.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件 # 保存配置文件
config_path = os.path.join(CONFIG_DIR, "model_config.toml") config_path = os.path.join(CONFIG_DIR, "model_config.toml")
@ -266,14 +272,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存配置文件失败: {e}") logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
# ===== 配置节更新接口 ===== # ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}") @router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)): async def update_bot_config_section(section_name: str, section_data: SectionBody):
"""更新麦麦主程序配置的指定节(保留注释和格式)""" """更新麦麦主程序配置的指定节(保留注释和格式)"""
try: try:
# 读取现有配置 # 读取现有配置
@ -304,7 +310,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
try: try:
Config.from_dict(config_data) Config.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置tomlkit.dump 会保留注释) # 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
@ -316,7 +322,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
raise raise
except Exception as e: except Exception as e:
logger.error(f"更新配置节失败: {e}") logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 原始 TOML 文件操作接口 ===== # ===== 原始 TOML 文件操作接口 =====
@ -338,24 +344,24 @@ async def get_bot_config_raw():
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取配置文件失败: {e}") logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.post("/bot/raw") @router.post("/bot/raw")
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)): async def update_bot_config_raw(raw_content: RawContentBody):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)""" """更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try: try:
# 验证 TOML 格式 # 验证 TOML 格式
try: try:
config_data = tomlkit.loads(raw_content) config_data = tomlkit.loads(raw_content)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 验证配置数据结构 # 验证配置数据结构
try: try:
Config.from_dict(config_data) Config.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件 # 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml") config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
@ -368,11 +374,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存配置文件失败: {e}") logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model/section/{section_name}") @router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: Any = Body(...)): async def update_model_config_section(section_name: str, section_data: SectionBody):
"""更新模型配置的指定节(保留注释和格式)""" """更新模型配置的指定节(保留注释和格式)"""
try: try:
# 读取现有配置 # 读取现有配置
@ -403,7 +409,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
try: try:
APIAdapterConfig.from_dict(config_data) APIAdapterConfig.from_dict(config_data)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置tomlkit.dump 会保留注释) # 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f: with open(config_path, "w", encoding="utf-8") as f:
@ -415,7 +421,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
raise raise
except Exception as e: except Exception as e:
logger.error(f"更新配置节失败: {e}") logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 适配器配置管理接口 ===== # ===== 适配器配置管理接口 =====
@ -443,7 +449,7 @@ def _to_relative_path(path: str) -> str:
# 尝试获取相对路径 # 尝试获取相对路径
rel_path = os.path.relpath(path, PROJECT_ROOT) rel_path = os.path.relpath(path, PROJECT_ROOT)
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径 # 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
if not rel_path.startswith('..'): if not rel_path.startswith(".."):
return rel_path return rel_path
except (ValueError, TypeError): except (ValueError, TypeError):
# 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError # 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError
@ -463,6 +469,7 @@ async def get_adapter_config_path():
return {"success": True, "path": None} return {"success": True, "path": None}
import json import json
with open(webui_data_path, "r", encoding="utf-8") as f: with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f) webui_data = json.load(f)
@ -476,6 +483,7 @@ async def get_adapter_config_path():
# 检查文件是否存在并返回最后修改时间 # 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path): if os.path.exists(abs_path):
import datetime import datetime
mtime = os.path.getmtime(abs_path) mtime = os.path.getmtime(abs_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat() last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
# 返回相对路径(如果可能) # 返回相对路径(如果可能)
@ -487,11 +495,11 @@ async def get_adapter_config_path():
except Exception as e: except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}") logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
@router.post("/adapter-config/path") @router.post("/adapter-config/path")
async def save_adapter_config_path(data: dict[str, str] = Body(...)): async def save_adapter_config_path(data: PathBody):
"""保存适配器配置文件路径偏好""" """保存适配器配置文件路径偏好"""
try: try:
path = data.get("path") path = data.get("path")
@ -530,7 +538,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}") logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
@router.get("/adapter-config") @router.get("/adapter-config")
@ -562,11 +570,11 @@ async def get_adapter_config(path: str):
raise raise
except Exception as e: except Exception as e:
logger.error(f"读取适配器配置失败: {e}") logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
@router.post("/adapter-config") @router.post("/adapter-config")
async def save_adapter_config(data: dict[str, str] = Body(...)): async def save_adapter_config(data: PathBody):
"""保存适配器配置到指定路径""" """保存适配器配置到指定路径"""
try: try:
path = data.get("path") path = data.get("path")
@ -586,10 +594,9 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
# 验证 TOML 格式 # 验证 TOML 格式
try: try:
import tomlkit
tomlkit.loads(content) tomlkit.loads(content)
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 确保目录存在 # 确保目录存在
dir_path = os.path.dirname(abs_path) dir_path = os.path.dirname(abs_path)
@ -607,5 +614,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
raise raise
except Exception as e: except Exception as e:
logger.error(f"保存适配器配置失败: {e}") logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e

View File

@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
if next_line.startswith('"""') or next_line.startswith("'''"): if next_line.startswith('"""') or next_line.startswith("'''"):
# 单行文档字符串 # 单行文档字符串
if next_line.count('"""') == 2 or next_line.count("'''") == 2: if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip()) description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else: else:
# 多行文档字符串 # 多行文档字符串
quote = '"""' if next_line.startswith('"""') else "'''" quote = '"""' if next_line.startswith('"""') else "'''"
@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
next_line = lines[i + 1].strip() next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"): if next_line.startswith('"""') or next_line.startswith("'''"):
if next_line.count('"""') == 2 or next_line.count("'''") == 2: if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip()) description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else: else:
quote = '"""' if next_line.startswith('"""') else "'''" quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip()) description_lines.append(next_line.strip(quote).strip())
@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
return FieldType.ARRAY, None, items return FieldType.ARRAY, None, items
# 处理基本类型 # 处理基本类型
if field_type is bool or field_type == bool: if field_type is bool:
return FieldType.BOOLEAN, None, None return FieldType.BOOLEAN, None, None
elif field_type is int or field_type == int: elif field_type is int:
return FieldType.INTEGER, None, None return FieldType.INTEGER, None, None
elif field_type is float or field_type == float: elif field_type is float:
return FieldType.NUMBER, None, None return FieldType.NUMBER, None, None
elif field_type is str or field_type == str: elif field_type is str:
return FieldType.STRING, None, None return FieldType.STRING, None, None
elif field_type is dict or origin is dict: elif field_type is dict or origin is dict:
return FieldType.OBJECT, None, None return FieldType.OBJECT, None, None

View File

@ -3,20 +3,25 @@
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
from fastapi.responses import FileResponse from fastapi.responses import FileResponse
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List, Annotated
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Emoji from src.common.database.database_model import Emoji
from .token_manager import get_token_manager from .token_manager import get_token_manager
import json
import time import time
import os import os
import hashlib import hashlib
import base64
from PIL import Image from PIL import Image
import io import io
logger = get_logger("webui.emoji") logger = get_logger("webui.emoji")
# 模块级别的类型别名(解决 B008 ruff 错误)
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
DescriptionForm = Annotated[str, Form(description="表情包描述")]
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
# 创建路由器 # 创建路由器
router = APIRouter(prefix="/emoji", tags=["Emoji"]) router = APIRouter(prefix="/emoji", tags=["Emoji"])
@ -592,10 +597,10 @@ class EmojiUploadResponse(BaseModel):
@router.post("/upload", response_model=EmojiUploadResponse) @router.post("/upload", response_model=EmojiUploadResponse)
async def upload_emoji( async def upload_emoji(
file: UploadFile = File(..., description="表情包图片文件"), file: EmojiFile,
description: str = Form("", description="表情包描述"), description: DescriptionForm = "",
emotion: str = Form("", description="情感标签,多个用逗号分隔"), emotion: EmotionForm = "",
is_registered: bool = Form(True, description="是否直接注册"), is_registered: IsRegisteredForm = True,
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
): ):
""" """
@ -713,9 +718,9 @@ async def upload_emoji(
@router.post("/batch/upload") @router.post("/batch/upload")
async def batch_upload_emoji( async def batch_upload_emoji(
files: List[UploadFile] = File(..., description="多个表情包图片文件"), files: EmojiFiles,
emotion: str = Form("", description="情感标签,多个用逗号分隔"), emotion: EmotionForm = "",
is_registered: bool = Form(True, description="是否直接注册"), is_registered: IsRegisteredForm = True,
authorization: Optional[str] = Header(None), authorization: Optional[str] = Header(None),
): ):
""" """
@ -749,11 +754,13 @@ async def batch_upload_emoji(
# 验证文件类型 # 验证文件类型
if file.content_type not in allowed_types: if file.content_type not in allowed_types:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
{
"filename": file.filename, "filename": file.filename,
"success": False, "success": False,
"error": f"不支持的文件类型: {file.content_type}", "error": f"不支持的文件类型: {file.content_type}",
}) }
)
continue continue
# 读取文件内容 # 读取文件内容
@ -761,11 +768,13 @@ async def batch_upload_emoji(
if not file_content: if not file_content:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
{
"filename": file.filename, "filename": file.filename,
"success": False, "success": False,
"error": "文件内容为空", "error": "文件内容为空",
}) }
)
continue continue
# 验证图片 # 验证图片
@ -774,11 +783,13 @@ async def batch_upload_emoji(
img_format = img.format.lower() if img.format else "png" img_format = img.format.lower() if img.format else "png"
except Exception as e: except Exception as e:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
{
"filename": file.filename, "filename": file.filename,
"success": False, "success": False,
"error": f"无效的图片: {str(e)}", "error": f"无效的图片: {str(e)}",
}) }
)
continue continue
# 计算哈希 # 计算哈希
@ -787,11 +798,13 @@ async def batch_upload_emoji(
# 检查重复 # 检查重复
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash): if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
{
"filename": file.filename, "filename": file.filename,
"success": False, "success": False,
"error": "已存在相同的表情包", "error": "已存在相同的表情包",
}) }
)
continue continue
# 生成文件名并保存 # 生成文件名并保存
@ -829,19 +842,23 @@ async def batch_upload_emoji(
) )
results["uploaded"] += 1 results["uploaded"] += 1
results["details"].append({ results["details"].append(
{
"filename": file.filename, "filename": file.filename,
"success": True, "success": True,
"id": emoji.id, "id": emoji.id,
}) }
)
except Exception as e: except Exception as e:
results["failed"] += 1 results["failed"] += 1
results["details"].append({ results["details"].append(
{
"filename": file.filename, "filename": file.filename,
"success": False, "success": False,
"error": str(e), "error": str(e),
}) }
)
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']}" results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']}"
return results return results

View File

@ -602,9 +602,9 @@ class GitMirrorService:
# 执行 git clone在线程池中运行以避免阻塞 # 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
def run_git_clone(): def run_git_clone(clone_cmd=cmd):
return subprocess.run( return subprocess.run(
cmd, clone_cmd,
capture_output=True, capture_output=True,
text=True, text=True,
timeout=300, # 5分钟超时 timeout=300, # 5分钟超时

View File

@ -1,4 +1,5 @@
"""知识库图谱可视化 API 路由""" """知识库图谱可视化 API 路由"""
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Query from fastapi import APIRouter, Query
from pydantic import BaseModel from pydantic import BaseModel
@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
class KnowledgeNode(BaseModel): class KnowledgeNode(BaseModel):
"""知识节点""" """知识节点"""
id: str id: str
type: str # 'entity' or 'paragraph' type: str # 'entity' or 'paragraph'
content: str content: str
@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
class KnowledgeEdge(BaseModel): class KnowledgeEdge(BaseModel):
"""知识边""" """知识边"""
source: str source: str
target: str target: str
weight: float weight: float
@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
class KnowledgeGraph(BaseModel): class KnowledgeGraph(BaseModel):
"""知识图谱""" """知识图谱"""
nodes: List[KnowledgeNode] nodes: List[KnowledgeNode]
edges: List[KnowledgeEdge] edges: List[KnowledgeEdge]
class KnowledgeStats(BaseModel): class KnowledgeStats(BaseModel):
"""知识库统计信息""" """知识库统计信息"""
total_nodes: int total_nodes: int
total_edges: int total_edges: int
entity_nodes: int entity_nodes: int
@ -69,16 +74,11 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph" # 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data['content'] if 'content' in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
create_time = node_data['create_time'] if 'create_time' in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode( nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
id=node_id,
type=node_type,
content=content,
create_time=create_time
))
except Exception as e: except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}") logger.warning(f"跳过节点 {node_id}: {e}")
continue continue
@ -93,17 +93,15 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
edge_data = graph[source, target] edge_data = graph[source, target]
# edge_data 支持 [] 操作符但不支持 .get() # edge_data 支持 [] 操作符但不支持 .get()
weight = edge_data['weight'] if 'weight' in edge_data else 1.0 weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data['create_time'] if 'create_time' in edge_data else None create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data['update_time'] if 'update_time' in edge_data else None update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(KnowledgeEdge( edges.append(
source=source, KnowledgeEdge(
target=target, source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
weight=weight, )
create_time=create_time, )
update_time=update_time
))
except Exception as e: except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}") logger.warning(f"跳过边 {edge_tuple}: {e}")
continue continue
@ -114,7 +112,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
@router.get("/graph", response_model=KnowledgeGraph) @router.get("/graph", response_model=KnowledgeGraph)
async def get_knowledge_graph( async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"), limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph") node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
): ):
"""获取知识图谱(限制节点数量) """获取知识图谱(限制节点数量)
@ -136,9 +134,11 @@ async def get_knowledge_graph(
# 按类型过滤节点 # 按类型过滤节点
if node_type == "entity": if node_type == "entity":
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent'] all_node_list = [
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
]
elif node_type == "paragraph": elif node_type == "paragraph":
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg'] all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
# 限制节点数量 # 限制节点数量
total_nodes = len(all_node_list) total_nodes = len(all_node_list)
@ -155,16 +155,11 @@ async def get_knowledge_graph(
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data['content'] if 'content' in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
create_time = node_data['create_time'] if 'create_time' in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode( nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
id=node_id,
type=node_type_val,
content=content,
create_time=create_time
))
node_ids.add(node_id) node_ids.add(node_id)
except Exception as e: except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}") logger.warning(f"跳过节点 {node_id}: {e}")
@ -181,17 +176,15 @@ async def get_knowledge_graph(
continue continue
edge_data = graph[source, target] edge_data = graph[source, target]
weight = edge_data['weight'] if 'weight' in edge_data else 1.0 weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data['create_time'] if 'create_time' in edge_data else None create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data['update_time'] if 'update_time' in edge_data else None update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(KnowledgeEdge( edges.append(
source=source, KnowledgeEdge(
target=target, source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
weight=weight, )
create_time=create_time, )
update_time=update_time
))
except Exception as e: except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}") logger.warning(f"跳过边 {edge_tuple}: {e}")
continue continue
@ -215,13 +208,7 @@ async def get_knowledge_stats():
try: try:
kg_manager = _load_kg_manager() kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None: if kg_manager is None or kg_manager.graph is None:
return KnowledgeStats( return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
total_nodes=0,
total_edges=0,
entity_nodes=0,
paragraph_nodes=0,
avg_connections=0.0
)
graph = kg_manager.graph graph = kg_manager.graph
node_list = graph.get_node_list() node_list = graph.get_node_list()
@ -236,10 +223,10 @@ async def get_knowledge_stats():
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
node_type = node_data['type'] if 'type' in node_data else 'ent' node_type = node_data["type"] if "type" in node_data else "ent"
if node_type == 'ent': if node_type == "ent":
entity_nodes += 1 entity_nodes += 1
elif node_type == 'pg': elif node_type == "pg":
paragraph_nodes += 1 paragraph_nodes += 1
except Exception: except Exception:
continue continue
@ -252,18 +239,12 @@ async def get_knowledge_stats():
total_edges=total_edges, total_edges=total_edges,
entity_nodes=entity_nodes, entity_nodes=entity_nodes,
paragraph_nodes=paragraph_nodes, paragraph_nodes=paragraph_nodes,
avg_connections=round(avg_connections, 2) avg_connections=round(avg_connections, 2),
) )
except Exception as e: except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True) logger.error(f"获取统计信息失败: {e}", exc_info=True)
return KnowledgeStats( return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
total_nodes=0,
total_edges=0,
entity_nodes=0,
paragraph_nodes=0,
avg_connections=0.0
)
@router.get("/search", response_model=List[KnowledgeNode]) @router.get("/search", response_model=List[KnowledgeNode])
@ -290,17 +271,12 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
for node_id in node_list: for node_id in node_list:
try: try:
node_data = graph[node_id] node_data = graph[node_id]
content = node_data['content'] if 'content' in node_data else node_id content = node_data["content"] if "content" in node_data else node_id
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph" node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
if query_lower in content.lower() or query_lower in node_id.lower(): if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data['create_time'] if 'create_time' in node_data else None create_time = node_data["create_time"] if "create_time" in node_data else None
results.append(KnowledgeNode( results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
id=node_id,
type=node_type,
content=content,
create_time=create_time
))
except Exception: except Exception:
continue continue

View File

@ -50,11 +50,13 @@ def _parse_openai_response(data: dict) -> list[dict]:
if "data" in data and isinstance(data["data"], list): if "data" in data and isinstance(data["data"], list):
for model in data["data"]: for model in data["data"]:
if isinstance(model, dict) and "id" in model: if isinstance(model, dict) and "id" in model:
models.append({ models.append(
{
"id": model["id"], "id": model["id"],
"name": model.get("name") or model["id"], "name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""), "owned_by": model.get("owned_by", ""),
}) }
)
return models return models
@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
model_id = model["name"] model_id = model["name"]
if model_id.startswith("models/"): if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀 model_id = model_id[7:] # 去掉 "models/" 前缀
models.append({ models.append(
{
"id": model_id, "id": model_id,
"name": model.get("displayName") or model_id, "name": model.get("displayName") or model_id,
"owned_by": "google", "owned_by": "google",
}) }
)
return models return models
@ -118,25 +122,24 @@ async def _fetch_models_from_provider(
response = await client.get(url, headers=headers, params=params) response = await client.get(url, headers=headers, params=params)
response.raise_for_status() response.raise_for_status()
data = response.json() data = response.json()
except httpx.TimeoutException: except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
except httpx.HTTPStatusError as e: except httpx.HTTPStatusError as e:
# 注意:使用 502 Bad Gateway 而不是原始的 401/403 # 注意:使用 502 Bad Gateway 而不是原始的 401/403
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理 # 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
if e.response.status_code == 401: if e.response.status_code == 401:
raise HTTPException(status_code=502, detail="API Key 无效或已过期") raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
elif e.response.status_code == 403: elif e.response.status_code == 403:
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
elif e.response.status_code == 404: elif e.response.status_code == 404:
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
else: else:
raise HTTPException( raise HTTPException(
status_code=502, status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}" ) from e
)
except Exception as e: except Exception as e:
logger.error(f"获取模型列表失败: {e}") logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
# 根据解析器类型解析响应 # 根据解析器类型解析响应
if parser == "openai": if parser == "openai":

View File

@ -31,8 +31,9 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
""" """
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符) # 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
import re import re
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀 # 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
base_version = re.split(r'[-.](?:snapshot|dev|alpha|beta|rc)', version_str, flags=re.IGNORECASE)[0] base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
parts = base_version.split(".") parts = base_version.split(".")
if len(parts) < 3: if len(parts) < 3:
@ -1167,9 +1168,7 @@ class UpdatePluginConfigRequest(BaseModel):
@router.get("/config/{plugin_id}/schema") @router.get("/config/{plugin_id}/schema")
async def get_plugin_config_schema( async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件配置 Schema 获取插件配置 Schema
@ -1205,7 +1204,7 @@ async def get_plugin_config_schema(
plugin_instance = instance plugin_instance = instance
break break
if plugin_instance and hasattr(plugin_instance, 'get_webui_config_schema'): if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
# 从插件实例获取 schema # 从插件实例获取 schema
schema = plugin_instance.get_webui_config_schema() schema = plugin_instance.get_webui_config_schema()
return {"success": True, "schema": schema} return {"success": True, "schema": schema}
@ -1236,6 +1235,7 @@ async def get_plugin_config_schema(
current_config = {} current_config = {}
if config_path.exists(): if config_path.exists():
import tomlkit import tomlkit
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
current_config = tomlkit.load(f) current_config = tomlkit.load(f)
@ -1301,9 +1301,7 @@ async def get_plugin_config_schema(
@router.get("/config/{plugin_id}") @router.get("/config/{plugin_id}")
async def get_plugin_config( async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
获取插件当前配置值 获取插件当前配置值
@ -1344,6 +1342,7 @@ async def get_plugin_config(
return {"success": True, "config": {}, "message": "配置文件不存在"} return {"success": True, "config": {}, "message": "配置文件不存在"}
import tomlkit import tomlkit
with open(config_path, "r", encoding="utf-8") as f: with open(config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f) config = tomlkit.load(f)
@ -1358,9 +1357,7 @@ async def get_plugin_config(
@router.put("/config/{plugin_id}") @router.put("/config/{plugin_id}")
async def update_plugin_config( async def update_plugin_config(
plugin_id: str, plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None)
request: UpdatePluginConfigRequest,
authorization: Optional[str] = Header(None)
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
更新插件配置 更新插件配置
@ -1401,6 +1398,7 @@ async def update_plugin_config(
# 备份旧配置 # 备份旧配置
import shutil import shutil
import datetime import datetime
if config_path.exists(): if config_path.exists():
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = plugin_path / backup_name backup_path = plugin_path / backup_name
@ -1409,6 +1407,7 @@ async def update_plugin_config(
# 写入新配置(使用 tomlkit 保留注释) # 写入新配置(使用 tomlkit 保留注释)
import tomlkit import tomlkit
# 先读取原配置以保留注释和格式 # 先读取原配置以保留注释和格式
existing_doc = tomlkit.document() existing_doc = tomlkit.document()
if config_path.exists(): if config_path.exists():
@ -1422,11 +1421,7 @@ async def update_plugin_config(
logger.info(f"已更新插件配置: {plugin_id}") logger.info(f"已更新插件配置: {plugin_id}")
return { return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
"success": True,
"message": "配置已保存",
"note": "配置更改将在插件重新加载后生效"
}
except HTTPException: except HTTPException:
raise raise
@ -1436,9 +1431,7 @@ async def update_plugin_config(
@router.post("/config/{plugin_id}/reset") @router.post("/config/{plugin_id}/reset")
async def reset_plugin_config( async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
重置插件配置为默认值 重置插件配置为默认值
@ -1481,17 +1474,14 @@ async def reset_plugin_config(
# 备份并删除 # 备份并删除
import shutil import shutil
import datetime import datetime
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}" backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
backup_path = plugin_path / backup_name backup_path = plugin_path / backup_name
shutil.move(config_path, backup_path) shutil.move(config_path, backup_path)
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}") logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
return { return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
"success": True,
"message": "配置已重置,下次加载插件时将使用默认配置",
"backup": str(backup_path)
}
except HTTPException: except HTTPException:
raise raise
@ -1501,9 +1491,7 @@ async def reset_plugin_config(
@router.post("/config/{plugin_id}/toggle") @router.post("/config/{plugin_id}/toggle")
async def toggle_plugin( async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
plugin_id: str, authorization: Optional[str] = Header(None)
) -> Dict[str, Any]:
""" """
切换插件启用状态 切换插件启用状态
@ -1567,7 +1555,7 @@ async def toggle_plugin(
"success": True, "success": True,
"enabled": new_enabled, "enabled": new_enabled,
"message": f"插件已{status}", "message": f"插件已{status}",
"note": "状态更改将在下次加载插件时生效" "note": "状态更改将在下次加载插件时生效",
} }
except HTTPException: except HTTPException:

View File

@ -91,10 +91,12 @@ class WebUIServer:
logger.info("开始导入 knowledge_routes...") logger.info("开始导入 knowledge_routes...")
from src.webui.knowledge_routes import router as knowledge_router from src.webui.knowledge_routes import router as knowledge_router
logger.info("knowledge_routes 导入成功") logger.info("knowledge_routes 导入成功")
# 导入本地聊天室路由 # 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router from src.webui.chat_routes import router as chat_router
logger.info("chat_routes 导入成功") logger.info("chat_routes 导入成功")
# 注册路由 # 注册路由