mirror of https://github.com/Mai-with-u/MaiBot.git
Ruff Fix & format
parent
d7932595e8
commit
3935ce817e
3
bot.py
3
bot.py
|
|
@ -78,6 +78,7 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
|
||||||
# 关闭 WebUI 服务器
|
# 关闭 WebUI 服务器
|
||||||
try:
|
try:
|
||||||
from src.webui.webui_server import get_webui_server
|
from src.webui.webui_server import get_webui_server
|
||||||
|
|
||||||
webui_server = get_webui_server()
|
webui_server = get_webui_server()
|
||||||
if webui_server and webui_server._server:
|
if webui_server and webui_server._server:
|
||||||
await webui_server.shutdown()
|
await webui_server.shutdown()
|
||||||
|
|
@ -238,7 +239,7 @@ if __name__ == "__main__":
|
||||||
logger.warning("收到中断信号,正在优雅关闭...")
|
logger.warning("收到中断信号,正在优雅关闭...")
|
||||||
|
|
||||||
# 取消主任务
|
# 取消主任务
|
||||||
if 'main_tasks' in locals() and main_tasks and not main_tasks.done():
|
if "main_tasks" in locals() and main_tasks and not main_tasks.done():
|
||||||
main_tasks.cancel()
|
main_tasks.cancel()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(main_tasks)
|
loop.run_until_complete(main_tasks)
|
||||||
|
|
|
||||||
|
|
@ -254,6 +254,7 @@ class BrainChatting:
|
||||||
# 检查是否需要提问表达反思
|
# 检查是否需要提问表达反思
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
from src.express.expression_reflector import expression_reflector_manager
|
from src.express.expression_reflector import expression_reflector_manager
|
||||||
|
|
||||||
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
reflector = expression_reflector_manager.get_or_create_reflector(self.stream_id)
|
||||||
asyncio.create_task(reflector.check_and_ask())
|
asyncio.create_task(reflector.check_and_ask())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -410,7 +410,6 @@ class HeartFChatting:
|
||||||
reflect_tracker_manager.remove_tracker(self.stream_id)
|
reflect_tracker_manager.remove_tracker(self.stream_id)
|
||||||
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
|
logger.info(f"{self.log_prefix} ReflectTracker resolved and removed.")
|
||||||
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
async with global_prompt_manager.async_message_scope(self.chat_stream.context.get_template_name()):
|
||||||
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
asyncio.create_task(self.expression_learner.trigger_learning_for_chat())
|
||||||
|
|
@ -427,7 +426,9 @@ class HeartFChatting:
|
||||||
# asyncio.create_task(self.chat_history_summarizer.process())
|
# asyncio.create_task(self.chat_history_summarizer.process())
|
||||||
|
|
||||||
cycle_timers, thinking_id = self.start_cycle()
|
cycle_timers, thinking_id = self.start_cycle()
|
||||||
logger.info(f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 开始第{self._cycle_counter}次思考(频率: {global_config.chat.get_talk_value(self.stream_id)})"
|
||||||
|
)
|
||||||
|
|
||||||
# 第一步:动作检查
|
# 第一步:动作检查
|
||||||
available_actions: Dict[str, ActionInfo] = {}
|
available_actions: Dict[str, ActionInfo] = {}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ def get_webui_chat_broadcaster():
|
||||||
if _webui_chat_broadcaster is None:
|
if _webui_chat_broadcaster is None:
|
||||||
try:
|
try:
|
||||||
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
from src.webui.chat_routes import chat_manager, WEBUI_CHAT_PLATFORM
|
||||||
|
|
||||||
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
_webui_chat_broadcaster = (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_webui_chat_broadcaster = (None, None)
|
_webui_chat_broadcaster = (None, None)
|
||||||
|
|
@ -44,7 +45,8 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
import time
|
import time
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
await chat_manager.broadcast({
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
"type": "bot_message",
|
"type": "bot_message",
|
||||||
"content": message.processed_plain_text,
|
"content": message.processed_plain_text,
|
||||||
"message_type": "text",
|
"message_type": "text",
|
||||||
|
|
@ -53,8 +55,9 @@ async def _send_message(message: MessageSending, show_log=True) -> bool:
|
||||||
"name": global_config.bot.nickname,
|
"name": global_config.bot.nickname,
|
||||||
"avatar": None,
|
"avatar": None,
|
||||||
"is_bot": True,
|
"is_bot": True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
# 注意:机器人消息会由 MessageStorage.store_message 自动保存到数据库
|
||||||
# 无需手动保存
|
# 无需手动保存
|
||||||
|
|
|
||||||
|
|
@ -181,8 +181,12 @@ class ActionPlanner:
|
||||||
found_ids = set(matches)
|
found_ids = set(matches)
|
||||||
missing_ids = found_ids - available_ids
|
missing_ids = found_ids - available_ids
|
||||||
if missing_ids:
|
if missing_ids:
|
||||||
logger.info(f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}...")
|
logger.info(
|
||||||
logger.info(f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中")
|
f"{self.log_prefix}planner理由中引用的消息ID不在当前上下文中: {missing_ids}, 可用ID: {list(available_ids)[:10]}..."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"{self.log_prefix}planner理由替换: 找到{len(matches)}个消息ID引用,其中{len(found_ids & available_ids)}个在上下文中"
|
||||||
|
)
|
||||||
|
|
||||||
def _replace(match: re.Match[str]) -> str:
|
def _replace(match: re.Match[str]) -> str:
|
||||||
msg_id = match.group(0)
|
msg_id = match.group(0)
|
||||||
|
|
@ -234,17 +238,11 @@ class ActionPlanner:
|
||||||
target_message = message_id_list[-1][1]
|
target_message = message_id_list[-1][1]
|
||||||
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
logger.debug(f"{self.log_prefix}动作'{action}'缺少target_message_id,使用最新消息作为target_message")
|
||||||
|
|
||||||
if (
|
if action != "no_reply" and target_message is not None and self._is_message_from_self(target_message):
|
||||||
action != "no_reply"
|
|
||||||
and target_message is not None
|
|
||||||
and self._is_message_from_self(target_message)
|
|
||||||
):
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
f"{self.log_prefix}Planner选择了自己的消息 {target_message_id or target_message.message_id} 作为目标,强制使用 no_reply"
|
||||||
)
|
)
|
||||||
reasoning = (
|
reasoning = f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
||||||
f"目标消息 {target_message_id or target_message.message_id} 来自机器人自身,违反不回复自身消息规则。原始理由: {reasoning}"
|
|
||||||
)
|
|
||||||
action = "no_reply"
|
action = "no_reply"
|
||||||
target_message = None
|
target_message = None
|
||||||
|
|
||||||
|
|
@ -295,10 +293,9 @@ class ActionPlanner:
|
||||||
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
|
def _is_message_from_self(self, message: "DatabaseMessages") -> bool:
|
||||||
"""判断消息是否由机器人自身发送"""
|
"""判断消息是否由机器人自身发送"""
|
||||||
try:
|
try:
|
||||||
return (
|
return str(message.user_info.user_id) == str(global_config.bot.qq_account) and (
|
||||||
str(message.user_info.user_id) == str(global_config.bot.qq_account)
|
message.user_info.platform or ""
|
||||||
and (message.user_info.platform or "") == (global_config.bot.platform or "")
|
) == (global_config.bot.platform or "")
|
||||||
)
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
logger.warning(f"{self.log_prefix}检测消息发送者失败,缺少必要字段")
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -132,9 +132,7 @@ class ImageManager:
|
||||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
||||||
|
|
||||||
# 清理ImageDescriptions表中type为emoji的记录
|
# 清理ImageDescriptions表中type为emoji的记录
|
||||||
deleted_descriptions = (
|
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
||||||
ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
total_deleted = deleted_images + deleted_descriptions
|
total_deleted = deleted_images + deleted_descriptions
|
||||||
if total_deleted > 0:
|
if total_deleted > 0:
|
||||||
|
|
@ -194,10 +192,14 @@ class ImageManager:
|
||||||
if cache_record:
|
if cache_record:
|
||||||
# 优先使用情感标签,如果没有则使用详细描述
|
# 优先使用情感标签,如果没有则使用详细描述
|
||||||
if cache_record.emotion_tags:
|
if cache_record.emotion_tags:
|
||||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}...")
|
logger.info(
|
||||||
|
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
||||||
|
)
|
||||||
return f"[表情包:{cache_record.emotion_tags}]"
|
return f"[表情包:{cache_record.emotion_tags}]"
|
||||||
elif cache_record.description:
|
elif cache_record.description:
|
||||||
logger.info(f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}...")
|
logger.info(
|
||||||
|
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
||||||
|
)
|
||||||
return f"[表情包:{cache_record.description}]"
|
return f"[表情包:{cache_record.description}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
||||||
|
|
|
||||||
|
|
@ -62,11 +62,15 @@ class ExpressionReflector:
|
||||||
max_interval = 15 * 60 # 10分钟
|
max_interval = 15 * 60 # 10分钟
|
||||||
interval = random.uniform(min_interval, max_interval)
|
interval = random.uniform(min_interval, max_interval)
|
||||||
|
|
||||||
logger.info(f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask/60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval/60:.2f}分钟)")
|
logger.info(
|
||||||
|
f"[Expression Reflection] 上次提问时间: {self.last_ask_time:.2f}, 当前时间: {current_time:.2f}, 已过时间: {time_since_last_ask:.2f}秒 ({time_since_last_ask / 60:.2f}分钟), 需要间隔: {interval:.2f}秒 ({interval / 60:.2f}分钟)"
|
||||||
|
)
|
||||||
|
|
||||||
if time_since_last_ask < interval:
|
if time_since_last_ask < interval:
|
||||||
remaining_time = interval - time_since_last_ask
|
remaining_time = interval - time_since_last_ask
|
||||||
logger.info(f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time/60:.2f}分钟),跳过")
|
logger.info(
|
||||||
|
f"[Expression Reflection] 距离上次提问时间不足,还需等待 {remaining_time:.2f}秒 ({remaining_time / 60:.2f}分钟),跳过"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查是否已经有针对该 Operator 的 Tracker 在运行
|
# 检查是否已经有针对该 Operator 的 Tracker 在运行
|
||||||
|
|
@ -78,10 +82,9 @@ class ExpressionReflector:
|
||||||
# 获取未检查的表达
|
# 获取未检查的表达
|
||||||
try:
|
try:
|
||||||
logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达")
|
logger.info(f"[Expression Reflection] 查询未检查且未拒绝的表达")
|
||||||
expressions = (Expression
|
expressions = (
|
||||||
.select()
|
Expression.select().where((Expression.checked == False) & (Expression.rejected == False)).limit(50)
|
||||||
.where((Expression.checked == False) & (Expression.rejected == False))
|
)
|
||||||
.limit(50))
|
|
||||||
|
|
||||||
expr_list = list(expressions)
|
expr_list = list(expressions)
|
||||||
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
logger.info(f"[Expression Reflection] 找到 {len(expr_list)} 个候选表达")
|
||||||
|
|
@ -91,7 +94,9 @@ class ExpressionReflector:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
target_expr: Expression = random.choice(expr_list)
|
target_expr: Expression = random.choice(expr_list)
|
||||||
logger.info(f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}")
|
logger.info(
|
||||||
|
f"[Expression Reflection] 随机选择了表达 ID: {target_expr.id}, Situation: {target_expr.situation}, Style: {target_expr.style}"
|
||||||
|
)
|
||||||
|
|
||||||
# 生成询问文本
|
# 生成询问文本
|
||||||
ask_text = _generate_ask_text(target_expr)
|
ask_text = _generate_ask_text(target_expr)
|
||||||
|
|
@ -112,11 +117,13 @@ class ExpressionReflector:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
logger.error(f"[Expression Reflection] 检查或提问过程中出错: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -141,6 +148,7 @@ expression_reflector_manager = ExpressionReflectorManager()
|
||||||
async def _check_tracker_exists(operator_config: str) -> bool:
|
async def _check_tracker_exists(operator_config: str) -> bool:
|
||||||
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
"""检查指定 Operator 是否已有活跃的 Tracker"""
|
||||||
from src.express.reflect_tracker import reflect_tracker_manager
|
from src.express.reflect_tracker import reflect_tracker_manager
|
||||||
|
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
chat_stream = None
|
chat_stream = None
|
||||||
|
|
||||||
|
|
@ -240,12 +248,5 @@ async def _send_to_operator(operator_config: str, text: str, expr: Expression):
|
||||||
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
reflect_tracker_manager.add_tracker(stream_id, tracker)
|
||||||
|
|
||||||
# 发送消息
|
# 发送消息
|
||||||
await send_api.text_to_stream(
|
await send_api.text_to_stream(text=text, stream_id=stream_id, typing=True)
|
||||||
text=text,
|
|
||||||
stream_id=stream_id,
|
|
||||||
typing=True
|
|
||||||
)
|
|
||||||
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")
|
logger.info(f"Sent expression reflect query to operator {operator_config} for expr {expr.id}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
logger = get_logger("reflect_tracker")
|
logger = get_logger("reflect_tracker")
|
||||||
|
|
||||||
|
|
||||||
class ReflectTracker:
|
class ReflectTracker:
|
||||||
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
def __init__(self, chat_stream: ChatStream, expression: Expression, created_time: float):
|
||||||
self.chat_stream = chat_stream
|
self.chat_stream = chat_stream
|
||||||
|
|
@ -28,9 +29,7 @@ class ReflectTracker:
|
||||||
self.max_duration = 15 * 60 # 15 minutes
|
self.max_duration = 15 * 60 # 15 minutes
|
||||||
|
|
||||||
# LLM for judging response
|
# LLM for judging response
|
||||||
self.judge_model = LLMRequest(
|
self.judge_model = LLMRequest(model_set=model_config.model_task_config.utils, request_type="reflect.tracker")
|
||||||
model_set=model_config.model_task_config.utils, request_type="reflect.tracker"
|
|
||||||
)
|
|
||||||
|
|
||||||
self._init_prompts()
|
self._init_prompts()
|
||||||
|
|
||||||
|
|
@ -109,7 +108,7 @@ class ReflectTracker:
|
||||||
"reflect_judge_prompt",
|
"reflect_judge_prompt",
|
||||||
situation=self.expression.situation,
|
situation=self.expression.situation,
|
||||||
style=self.expression.style,
|
style=self.expression.style,
|
||||||
context_block=context_block
|
context_block=context_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
|
logger.info(f"ReflectTracker LLM Prompt: {prompt}")
|
||||||
|
|
@ -162,9 +161,13 @@ class ReflectTracker:
|
||||||
self.expression.save()
|
self.expression.save()
|
||||||
|
|
||||||
if has_update:
|
if has_update:
|
||||||
logger.info(f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}")
|
logger.info(
|
||||||
|
f"Expression {self.expression.id} rejected and updated by operator. New situation: {corrected_situation}, New style: {corrected_style}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1.")
|
logger.info(
|
||||||
|
f"Expression {self.expression.id} rejected but no correction provided, marked as rejected=1."
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
elif judgment == "Ignore":
|
elif judgment == "Ignore":
|
||||||
|
|
@ -177,6 +180,7 @@ class ReflectTracker:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# Global manager for trackers
|
# Global manager for trackers
|
||||||
class ReflectTrackerManager:
|
class ReflectTrackerManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|
@ -192,5 +196,5 @@ class ReflectTrackerManager:
|
||||||
if chat_id in self.trackers:
|
if chat_id in self.trackers:
|
||||||
del self.trackers[chat_id]
|
del self.trackers[chat_id]
|
||||||
|
|
||||||
reflect_tracker_manager = ReflectTrackerManager()
|
|
||||||
|
|
||||||
|
reflect_tracker_manager = ReflectTrackerManager()
|
||||||
|
|
|
||||||
|
|
@ -315,7 +315,9 @@ class ChatHistorySummarizer:
|
||||||
before_count = len(self.current_batch.messages)
|
before_count = len(self.current_batch.messages)
|
||||||
self.current_batch.messages.extend(new_messages)
|
self.current_batch.messages.extend(new_messages)
|
||||||
self.current_batch.end_time = current_time
|
self.current_batch.end_time = current_time
|
||||||
logger.info(f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息")
|
logger.info(
|
||||||
|
f"{self.log_prefix} 更新聊天检查批次: {before_count} -> {len(self.current_batch.messages)} 条消息"
|
||||||
|
)
|
||||||
# 更新批次后持久化
|
# 更新批次后持久化
|
||||||
self._persist_topic_cache()
|
self._persist_topic_cache()
|
||||||
else:
|
else:
|
||||||
|
|
@ -361,9 +363,7 @@ class ChatHistorySummarizer:
|
||||||
else:
|
else:
|
||||||
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
time_str = f"{time_since_last_check / 3600:.1f}小时"
|
||||||
|
|
||||||
logger.info(
|
logger.info(f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}")
|
||||||
f"{self.log_prefix} 批次状态检查 | 消息数: {message_count} | 距上次检查: {time_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查“话题检查”触发条件
|
# 检查“话题检查”触发条件
|
||||||
should_check = False
|
should_check = False
|
||||||
|
|
@ -426,7 +426,9 @@ class ChatHistorySummarizer:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 2. 构造编号后的消息字符串和参与者信息
|
# 2. 构造编号后的消息字符串和参与者信息
|
||||||
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = self._build_numbered_messages_for_llm(messages)
|
numbered_lines, index_to_msg_str, index_to_msg_text, index_to_participants = (
|
||||||
|
self._build_numbered_messages_for_llm(messages)
|
||||||
|
)
|
||||||
|
|
||||||
# 3. 调用 LLM 识别话题,并得到 topic -> indices
|
# 3. 调用 LLM 识别话题,并得到 topic -> indices
|
||||||
existing_topics = list(self.topic_cache.keys())
|
existing_topics = list(self.topic_cache.keys())
|
||||||
|
|
@ -588,9 +590,7 @@ class ChatHistorySummarizer:
|
||||||
if not numbered_lines:
|
if not numbered_lines:
|
||||||
return False, {}
|
return False, {}
|
||||||
|
|
||||||
history_topics_block = (
|
history_topics_block = "\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
||||||
"\n".join(f"- {t}" for t in existing_topics) if existing_topics else "(当前无历史话题)"
|
|
||||||
)
|
|
||||||
messages_block = "\n".join(numbered_lines)
|
messages_block = "\n".join(numbered_lines)
|
||||||
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
|
|
@ -607,6 +607,7 @@ class ChatHistorySummarizer:
|
||||||
)
|
)
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
|
logger.info(f"{self.log_prefix} 话题识别LLM Prompt: {prompt}")
|
||||||
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
|
logger.info(f"{self.log_prefix} 话题识别LLM Response: {response}")
|
||||||
|
|
||||||
|
|
@ -895,4 +896,3 @@ class ChatHistorySummarizer:
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,7 @@ class JargonExplainer:
|
||||||
request_type="jargon.explain",
|
request_type="jargon.explain",
|
||||||
)
|
)
|
||||||
|
|
||||||
def match_jargon_from_messages(
|
def match_jargon_from_messages(self, messages: List[Any]) -> List[Dict[str, str]]:
|
||||||
self, messages: List[Any]
|
|
||||||
) -> List[Dict[str, str]]:
|
|
||||||
"""
|
"""
|
||||||
通过直接匹配数据库中的jargon字符串来提取黑话
|
通过直接匹配数据库中的jargon字符串来提取黑话
|
||||||
|
|
||||||
|
|
@ -68,7 +66,9 @@ class JargonExplainer:
|
||||||
if is_bot_message(msg):
|
if is_bot_message(msg):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip()
|
msg_text = (
|
||||||
|
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||||
|
).strip()
|
||||||
if msg_text:
|
if msg_text:
|
||||||
message_texts.append(msg_text)
|
message_texts.append(msg_text)
|
||||||
|
|
||||||
|
|
@ -79,9 +79,7 @@ class JargonExplainer:
|
||||||
combined_text = " ".join(message_texts)
|
combined_text = " ".join(message_texts)
|
||||||
|
|
||||||
# 查询所有有meaning的jargon记录
|
# 查询所有有meaning的jargon记录
|
||||||
query = Jargon.select().where(
|
query = Jargon.select().where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
|
||||||
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 根据all_global配置决定查询逻辑
|
# 根据all_global配置决定查询逻辑
|
||||||
if global_config.jargon.all_global:
|
if global_config.jargon.all_global:
|
||||||
|
|
@ -123,12 +121,12 @@ class JargonExplainer:
|
||||||
pattern = re.escape(content)
|
pattern = re.escape(content)
|
||||||
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
# 使用单词边界或中文字符边界来匹配,避免部分匹配
|
||||||
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
# 对于中文,使用Unicode字符类;对于英文,使用单词边界
|
||||||
if re.search(r'[\u4e00-\u9fff]', content):
|
if re.search(r"[\u4e00-\u9fff]", content):
|
||||||
# 包含中文,使用更宽松的匹配
|
# 包含中文,使用更宽松的匹配
|
||||||
search_pattern = pattern
|
search_pattern = pattern
|
||||||
else:
|
else:
|
||||||
# 纯英文/数字,使用单词边界
|
# 纯英文/数字,使用单词边界
|
||||||
search_pattern = r'\b' + pattern + r'\b'
|
search_pattern = r"\b" + pattern + r"\b"
|
||||||
|
|
||||||
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
if re.search(search_pattern, combined_text, re.IGNORECASE):
|
||||||
# 找到匹配,记录(去重)
|
# 找到匹配,记录(去重)
|
||||||
|
|
@ -147,9 +145,7 @@ class JargonExplainer:
|
||||||
|
|
||||||
return list(matched_jargon.values())
|
return list(matched_jargon.values())
|
||||||
|
|
||||||
async def explain_jargon(
|
async def explain_jargon(self, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||||
self, messages: List[Any], chat_context: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
"""
|
||||||
解释上下文中的黑话
|
解释上下文中的黑话
|
||||||
|
|
||||||
|
|
@ -239,9 +235,7 @@ class JargonExplainer:
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
async def explain_jargon_in_context(
|
async def explain_jargon_in_context(chat_id: str, messages: List[Any], chat_context: str) -> Optional[str]:
|
||||||
chat_id: str, messages: List[Any], chat_context: str
|
|
||||||
) -> Optional[str]:
|
|
||||||
"""
|
"""
|
||||||
解释上下文中的黑话(便捷函数)
|
解释上下文中的黑话(便捷函数)
|
||||||
|
|
||||||
|
|
@ -255,4 +249,3 @@ async def explain_jargon_in_context(
|
||||||
"""
|
"""
|
||||||
explainer = JargonExplainer(chat_id)
|
explainer = JargonExplainer(chat_id)
|
||||||
return await explainer.explain_jargon(messages, chat_context)
|
return await explainer.explain_jargon(messages, chat_context)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,15 +22,13 @@ from src.jargon.jargon_utils import (
|
||||||
contains_bot_self_name,
|
contains_bot_self_name,
|
||||||
parse_chat_id_list,
|
parse_chat_id_list,
|
||||||
chat_id_list_contains,
|
chat_id_list_contains,
|
||||||
update_chat_id_list
|
update_chat_id_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _init_prompt() -> None:
|
def _init_prompt() -> None:
|
||||||
prompt_str = """
|
prompt_str = """
|
||||||
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
**聊天内容,其中的{bot_name}的发言内容是你自己的发言,[msg_id] 是消息ID**
|
||||||
|
|
@ -126,7 +124,6 @@ _init_prompt()
|
||||||
_init_inference_prompts()
|
_init_inference_prompts()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
def _should_infer_meaning(jargon_obj: Jargon) -> bool:
|
||||||
"""
|
"""
|
||||||
判断是否需要进行含义推断
|
判断是否需要进行含义推断
|
||||||
|
|
@ -211,7 +208,9 @@ class JargonMiner:
|
||||||
processed_pairs = set()
|
processed_pairs = set()
|
||||||
|
|
||||||
for idx, msg in enumerate(messages):
|
for idx, msg in enumerate(messages):
|
||||||
msg_text = (getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or "").strip()
|
msg_text = (
|
||||||
|
getattr(msg, "display_message", None) or getattr(msg, "processed_plain_text", None) or ""
|
||||||
|
).strip()
|
||||||
if not msg_text or is_bot_message(msg):
|
if not msg_text or is_bot_message(msg):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -588,7 +587,6 @@ class JargonMiner:
|
||||||
content = entry["content"]
|
content = entry["content"]
|
||||||
raw_content_list = entry["raw_content"] # 已经是列表
|
raw_content_list = entry["raw_content"] # 已经是列表
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询所有content匹配的记录
|
# 查询所有content匹配的记录
|
||||||
query = Jargon.select().where(Jargon.content == content)
|
query = Jargon.select().where(Jargon.content == content)
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from src.chat.utils.utils import parse_platform_accounts
|
||||||
|
|
||||||
logger = get_logger("jargon")
|
logger = get_logger("jargon")
|
||||||
|
|
||||||
|
|
||||||
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
def parse_chat_id_list(chat_id_value: Any) -> List[List[Any]]:
|
||||||
"""
|
"""
|
||||||
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
解析chat_id字段,兼容旧格式(字符串)和新格式(JSON列表)
|
||||||
|
|
@ -168,10 +169,7 @@ def is_bot_message(msg: Any) -> bool:
|
||||||
.strip()
|
.strip()
|
||||||
.lower()
|
.lower()
|
||||||
)
|
)
|
||||||
user_id = (
|
user_id = str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "").strip()
|
||||||
str(getattr(msg, "user_id", "") or getattr(getattr(msg, "user_info", None), "user_id", "") or "")
|
|
||||||
.strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
if not platform or not user_id:
|
if not platform or not user_id:
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
|
|
@ -338,7 +338,9 @@ class LLMRequest:
|
||||||
if e.__cause__:
|
if e.__cause__:
|
||||||
original_error_type = type(e.__cause__).__name__
|
original_error_type = type(e.__cause__).__name__
|
||||||
original_error_msg = str(e.__cause__)
|
original_error_msg = str(e.__cause__)
|
||||||
original_error_info = f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
original_error_info = (
|
||||||
|
f"\n 底层异常类型: {original_error_type}\n 底层异常信息: {original_error_msg}"
|
||||||
|
)
|
||||||
|
|
||||||
retry_remain -= 1
|
retry_remain -= 1
|
||||||
if retry_remain <= 0:
|
if retry_remain <= 0:
|
||||||
|
|
|
||||||
|
|
@ -296,7 +296,6 @@ def _match_jargon_from_text(chat_text: str, chat_id: str) -> List[str]:
|
||||||
if not content:
|
if not content:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
if not global_config.jargon.all_global and not jargon.is_global:
|
if not global_config.jargon.all_global and not jargon.is_global:
|
||||||
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
chat_id_list = parse_chat_id_list(jargon.chat_id)
|
||||||
if not chat_id_list_contains(chat_id_list, chat_id):
|
if not chat_id_list_contains(chat_id_list, chat_id):
|
||||||
|
|
@ -586,9 +585,7 @@ async def _react_agent_solve_question(
|
||||||
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
step["actions"].append({"action_type": "found_answer", "action_params": {"answer": found_answer_content}})
|
||||||
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
step["observations"] = ["从LLM输出内容中检测到found_answer"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(
|
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}")
|
||||||
f"ReAct Agent 第 {iteration + 1} 次迭代 找到关于问题{question}的答案: {found_answer_content}"
|
|
||||||
)
|
|
||||||
return True, found_answer_content, thinking_steps, False
|
return True, found_answer_content, thinking_steps, False
|
||||||
|
|
||||||
if not_enough_info_reason:
|
if not_enough_info_reason:
|
||||||
|
|
@ -1016,9 +1013,7 @@ async def build_memory_retrieval_prompt(
|
||||||
|
|
||||||
if question_results:
|
if question_results:
|
||||||
retrieved_memory = "\n\n".join(question_results)
|
retrieved_memory = "\n\n".join(question_results)
|
||||||
logger.info(
|
logger.info(f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆")
|
||||||
f"记忆检索成功,耗时: {(end_time - start_time):.3f}秒,包含 {len(question_results)} 条记忆"
|
|
||||||
)
|
|
||||||
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
|
||||||
else:
|
else:
|
||||||
logger.debug("所有问题均未找到答案")
|
logger.debug("所有问题均未找到答案")
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,9 @@ async def search_chat_history(
|
||||||
if record.participants:
|
if record.participants:
|
||||||
try:
|
try:
|
||||||
participants_data = (
|
participants_data = (
|
||||||
json.loads(record.participants) if isinstance(record.participants, str) else record.participants
|
json.loads(record.participants)
|
||||||
|
if isinstance(record.participants, str)
|
||||||
|
else record.participants
|
||||||
)
|
)
|
||||||
if isinstance(participants_data, list):
|
if isinstance(participants_data, list):
|
||||||
participants_list = [str(p).lower() for p in participants_data]
|
participants_list = [str(p).lower() for p in participants_data]
|
||||||
|
|
@ -156,9 +158,7 @@ async def search_chat_history(
|
||||||
# 添加关键词
|
# 添加关键词
|
||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
|
||||||
)
|
|
||||||
if isinstance(keywords_data, list) and keywords_data:
|
if isinstance(keywords_data, list) and keywords_data:
|
||||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||||
result_parts.append(f"关键词:{keywords_str}")
|
result_parts.append(f"关键词:{keywords_str}")
|
||||||
|
|
@ -208,9 +208,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
|
||||||
return "未提供有效的记忆ID"
|
return "未提供有效的记忆ID"
|
||||||
|
|
||||||
# 查询记录
|
# 查询记录
|
||||||
query = ChatHistory.select().where(
|
query = ChatHistory.select().where((ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list)))
|
||||||
(ChatHistory.chat_id == chat_id) & (ChatHistory.id.in_(id_list))
|
|
||||||
)
|
|
||||||
records = list(query.order_by(ChatHistory.start_time.desc()))
|
records = list(query.order_by(ChatHistory.start_time.desc()))
|
||||||
|
|
||||||
if not records:
|
if not records:
|
||||||
|
|
@ -256,9 +254,7 @@ async def get_chat_history_detail(chat_id: str, memory_ids: str) -> str:
|
||||||
# 添加关键词
|
# 添加关键词
|
||||||
if record.keywords:
|
if record.keywords:
|
||||||
try:
|
try:
|
||||||
keywords_data = (
|
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
||||||
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
|
|
||||||
)
|
|
||||||
if isinstance(keywords_data, list) and keywords_data:
|
if isinstance(keywords_data, list) and keywords_data:
|
||||||
keywords_str = "、".join([str(k) for k in keywords_data])
|
keywords_str = "、".join([str(k) for k in keywords_data])
|
||||||
result_parts.append(f"关键词:{keywords_str}")
|
result_parts.append(f"关键词:{keywords_str}")
|
||||||
|
|
|
||||||
|
|
@ -150,6 +150,7 @@ class ConfigSection:
|
||||||
order=1
|
order=1
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
title: str # 显示标题
|
title: str # 显示标题
|
||||||
description: Optional[str] = None # 详细描述
|
description: Optional[str] = None # 详细描述
|
||||||
icon: Optional[str] = None # 图标名称
|
icon: Optional[str] = None # 图标名称
|
||||||
|
|
@ -182,6 +183,7 @@ class ConfigTab:
|
||||||
sections=["plugin", "api"]
|
sections=["plugin", "api"]
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
id: str # 标签页 ID
|
id: str # 标签页 ID
|
||||||
title: str # 显示标题
|
title: str # 显示标题
|
||||||
sections: List[str] = field(default_factory=list) # 包含的 section 名称列表
|
sections: List[str] = field(default_factory=list) # 包含的 section 名称列表
|
||||||
|
|
@ -222,6 +224,7 @@ class ConfigLayout:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: str = "auto" # 布局类型: auto, tabs, pages
|
type: str = "auto" # 布局类型: auto, tabs, pages
|
||||||
tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表
|
tabs: List[ConfigTab] = field(default_factory=list) # 标签页列表
|
||||||
|
|
||||||
|
|
@ -234,11 +237,7 @@ class ConfigLayout:
|
||||||
|
|
||||||
|
|
||||||
def section_meta(
|
def section_meta(
|
||||||
title: str,
|
title: str, description: Optional[str] = None, icon: Optional[str] = None, collapsed: bool = False, order: int = 0
|
||||||
description: Optional[str] = None,
|
|
||||||
icon: Optional[str] = None,
|
|
||||||
collapsed: bool = False,
|
|
||||||
order: int = 0
|
|
||||||
) -> Union[str, ConfigSection]:
|
) -> Union[str, ConfigSection]:
|
||||||
"""
|
"""
|
||||||
便捷函数:创建 section 元数据
|
便捷函数:创建 section 元数据
|
||||||
|
|
@ -261,10 +260,4 @@ def section_meta(
|
||||||
"debug": section_meta("调试设置", collapsed=True, order=99),
|
"debug": section_meta("调试设置", collapsed=True, order=99),
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
return ConfigSection(
|
return ConfigSection(title=title, description=description, icon=icon, collapsed=collapsed, order=order)
|
||||||
title=title,
|
|
||||||
description=description,
|
|
||||||
icon=icon,
|
|
||||||
collapsed=collapsed,
|
|
||||||
order=order
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||||
|
|
||||||
class ChatHistoryMessage(BaseModel):
|
class ChatHistoryMessage(BaseModel):
|
||||||
"""聊天历史消息"""
|
"""聊天历史消息"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
type: str # 'user' | 'bot' | 'system'
|
type: str # 'user' | 'bot' | 'system'
|
||||||
content: str
|
content: str
|
||||||
|
|
@ -81,11 +82,7 @@ class ChatHistoryManager:
|
||||||
def clear_history(self) -> int:
|
def clear_history(self) -> int:
|
||||||
"""清空 WebUI 聊天历史记录"""
|
"""清空 WebUI 聊天历史记录"""
|
||||||
try:
|
try:
|
||||||
deleted = (
|
deleted = Messages.delete().where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID).execute()
|
||||||
Messages.delete()
|
|
||||||
.where(Messages.chat_info_group_id == WEBUI_CHAT_GROUP_ID)
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
|
logger.info(f"已清空 {deleted} 条 WebUI 聊天记录")
|
||||||
return deleted
|
return deleted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -135,11 +132,7 @@ chat_manager = ChatConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
def create_message_data(
|
def create_message_data(
|
||||||
content: str,
|
content: str, user_id: str, user_name: str, message_id: Optional[str] = None, is_at_bot: bool = True
|
||||||
user_id: str,
|
|
||||||
user_name: str,
|
|
||||||
message_id: Optional[str] = None,
|
|
||||||
is_at_bot: bool = True
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""创建符合麦麦消息格式的消息数据"""
|
"""创建符合麦麦消息格式的消息数据"""
|
||||||
if message_id is None:
|
if message_id is None:
|
||||||
|
|
@ -163,7 +156,7 @@ def create_message_data(
|
||||||
},
|
},
|
||||||
"additional_config": {
|
"additional_config": {
|
||||||
"at_bot": is_at_bot,
|
"at_bot": is_at_bot,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"message_segment": {
|
"message_segment": {
|
||||||
"type": "seglist",
|
"type": "seglist",
|
||||||
|
|
@ -175,8 +168,8 @@ def create_message_data(
|
||||||
{
|
{
|
||||||
"type": "mention_bot",
|
"type": "mention_bot",
|
||||||
"data": "1.0",
|
"data": "1.0",
|
||||||
}
|
},
|
||||||
]
|
],
|
||||||
},
|
},
|
||||||
"raw_message": content,
|
"raw_message": content,
|
||||||
"processed_plain_text": content,
|
"processed_plain_text": content,
|
||||||
|
|
@ -186,7 +179,7 @@ def create_message_data(
|
||||||
@router.get("/history")
|
@router.get("/history")
|
||||||
async def get_chat_history(
|
async def get_chat_history(
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
user_id: Optional[str] = Query(default=None) # 保留参数兼容性,但不用于过滤
|
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||||
):
|
):
|
||||||
"""获取聊天历史记录
|
"""获取聊天历史记录
|
||||||
|
|
||||||
|
|
@ -236,28 +229,37 @@ async def websocket_chat(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||||
await chat_manager.send_message(session_id, {
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
"type": "session_info",
|
"type": "session_info",
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"user_name": user_name,
|
"user_name": user_name,
|
||||||
"bot_name": global_config.bot.nickname,
|
"bot_name": global_config.bot.nickname,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# 发送历史记录
|
# 发送历史记录
|
||||||
history = chat_history.get_history(50)
|
history = chat_history.get_history(50)
|
||||||
if history:
|
if history:
|
||||||
await chat_manager.send_message(session_id, {
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
"type": "history",
|
"type": "history",
|
||||||
"messages": history,
|
"messages": history,
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# 发送欢迎消息(不保存到历史)
|
# 发送欢迎消息(不保存到历史)
|
||||||
await chat_manager.send_message(session_id, {
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
"type": "system",
|
"type": "system",
|
||||||
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
|
"content": f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!",
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
data = await websocket.receive_json()
|
data = await websocket.receive_json()
|
||||||
|
|
@ -275,7 +277,8 @@ async def websocket_chat(
|
||||||
|
|
||||||
# 广播用户消息给所有连接(包括发送者)
|
# 广播用户消息给所有连接(包括发送者)
|
||||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||||
await chat_manager.broadcast({
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
"type": "user_message",
|
"type": "user_message",
|
||||||
"content": content,
|
"content": content,
|
||||||
"message_id": message_id,
|
"message_id": message_id,
|
||||||
|
|
@ -284,8 +287,9 @@ async def websocket_chat(
|
||||||
"name": current_user_name,
|
"name": current_user_name,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"is_bot": False,
|
"is_bot": False,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
|
|
||||||
# 创建麦麦消息格式
|
# 创建麦麦消息格式
|
||||||
message_data = create_message_data(
|
message_data = create_message_data(
|
||||||
|
|
@ -298,42 +302,55 @@ async def websocket_chat(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 显示正在输入状态
|
# 显示正在输入状态
|
||||||
await chat_manager.broadcast({
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
"type": "typing",
|
"type": "typing",
|
||||||
"is_typing": True,
|
"is_typing": True,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 调用麦麦的消息处理
|
# 调用麦麦的消息处理
|
||||||
await chat_bot.message_process(message_data)
|
await chat_bot.message_process(message_data)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息时出错: {e}")
|
logger.error(f"处理消息时出错: {e}")
|
||||||
await chat_manager.send_message(session_id, {
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
"type": "error",
|
"type": "error",
|
||||||
"content": f"处理消息时出错: {str(e)}",
|
"content": f"处理消息时出错: {str(e)}",
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
await chat_manager.broadcast({
|
await chat_manager.broadcast(
|
||||||
|
{
|
||||||
"type": "typing",
|
"type": "typing",
|
||||||
"is_typing": False,
|
"is_typing": False,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
elif data.get("type") == "ping":
|
elif data.get("type") == "ping":
|
||||||
await chat_manager.send_message(session_id, {
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
"type": "pong",
|
"type": "pong",
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
elif data.get("type") == "update_nickname":
|
elif data.get("type") == "update_nickname":
|
||||||
# 允许用户更新昵称
|
# 允许用户更新昵称
|
||||||
if new_name := data.get("user_name", "").strip():
|
if new_name := data.get("user_name", "").strip():
|
||||||
current_user_name = new_name
|
current_user_name = new_name
|
||||||
await chat_manager.send_message(session_id, {
|
await chat_manager.send_message(
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
"type": "nickname_updated",
|
"type": "nickname_updated",
|
||||||
"user_name": current_user_name,
|
"user_name": current_user_name,
|
||||||
"timestamp": time.time(),
|
"timestamp": time.time(),
|
||||||
})
|
},
|
||||||
|
)
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
except WebSocketDisconnect:
|
||||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
import os
|
import os
|
||||||
import tomlkit
|
import tomlkit
|
||||||
from fastapi import APIRouter, HTTPException, Body
|
from fastapi import APIRouter, HTTPException, Body
|
||||||
from typing import Any
|
from typing import Any, Annotated
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||||
|
|
@ -41,6 +41,12 @@ from src.webui.config_schema import ConfigSchemaGenerator
|
||||||
|
|
||||||
logger = get_logger("webui")
|
logger = get_logger("webui")
|
||||||
|
|
||||||
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
|
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||||
|
SectionBody = Annotated[Any, Body()]
|
||||||
|
RawContentBody = Annotated[str, Body(embed=True)]
|
||||||
|
PathBody = Annotated[dict[str, str], Body()]
|
||||||
|
|
||||||
router = APIRouter(prefix="/config", tags=["config"])
|
router = APIRouter(prefix="/config", tags=["config"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -90,7 +96,7 @@ async def get_bot_config_schema():
|
||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取配置架构失败: {e}")
|
logger.error(f"获取配置架构失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/schema/model")
|
@router.get("/schema/model")
|
||||||
|
|
@ -101,7 +107,7 @@ async def get_model_config_schema():
|
||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取模型配置架构失败: {e}")
|
logger.error(f"获取模型配置架构失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 子配置架构获取接口 =====
|
# ===== 子配置架构获取接口 =====
|
||||||
|
|
@ -174,7 +180,7 @@ async def get_config_section_schema(section_name: str):
|
||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取配置节架构失败: {e}")
|
logger.error(f"获取配置节架构失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 配置读取接口 =====
|
# ===== 配置读取接口 =====
|
||||||
|
|
@ -196,7 +202,7 @@ async def get_bot_config():
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取配置文件失败: {e}")
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/model")
|
@router.get("/model")
|
||||||
|
|
@ -215,21 +221,21 @@ async def get_model_config():
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取配置文件失败: {e}")
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 配置更新接口 =====
|
# ===== 配置更新接口 =====
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot")
|
@router.post("/bot")
|
||||||
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
async def update_bot_config(config_data: ConfigBody):
|
||||||
"""更新麦麦主程序配置"""
|
"""更新麦麦主程序配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
try:
|
try:
|
||||||
Config.from_dict(config_data)
|
Config.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件
|
# 保存配置文件
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
|
|
@ -242,18 +248,18 @@ async def update_bot_config(config_data: dict[str, Any] = Body(...)):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存配置文件失败: {e}")
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model")
|
@router.post("/model")
|
||||||
async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
async def update_model_config(config_data: ConfigBody):
|
||||||
"""更新模型配置"""
|
"""更新模型配置"""
|
||||||
try:
|
try:
|
||||||
# 验证配置数据
|
# 验证配置数据
|
||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
APIAdapterConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件
|
# 保存配置文件
|
||||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||||
|
|
@ -266,14 +272,14 @@ async def update_model_config(config_data: dict[str, Any] = Body(...)):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存配置文件失败: {e}")
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 配置节更新接口 =====
|
# ===== 配置节更新接口 =====
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot/section/{section_name}")
|
@router.post("/bot/section/{section_name}")
|
||||||
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
|
async def update_bot_config_section(section_name: str, section_data: SectionBody):
|
||||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
|
|
@ -304,7 +310,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||||
try:
|
try:
|
||||||
Config.from_dict(config_data)
|
Config.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置(tomlkit.dump 会保留注释)
|
# 保存配置(tomlkit.dump 会保留注释)
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
|
@ -316,7 +322,7 @@ async def update_bot_config_section(section_name: str, section_data: Any = Body(
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新配置节失败: {e}")
|
logger.error(f"更新配置节失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 原始 TOML 文件操作接口 =====
|
# ===== 原始 TOML 文件操作接口 =====
|
||||||
|
|
@ -338,24 +344,24 @@ async def get_bot_config_raw():
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取配置文件失败: {e}")
|
logger.error(f"读取配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bot/raw")
|
@router.post("/bot/raw")
|
||||||
async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
async def update_bot_config_raw(raw_content: RawContentBody):
|
||||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||||
try:
|
try:
|
||||||
# 验证 TOML 格式
|
# 验证 TOML 格式
|
||||||
try:
|
try:
|
||||||
config_data = tomlkit.loads(raw_content)
|
config_data = tomlkit.loads(raw_content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||||
|
|
||||||
# 验证配置数据结构
|
# 验证配置数据结构
|
||||||
try:
|
try:
|
||||||
Config.from_dict(config_data)
|
Config.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置文件
|
# 保存配置文件
|
||||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||||
|
|
@ -368,11 +374,11 @@ async def update_bot_config_raw(raw_content: str = Body(..., embed=True)):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存配置文件失败: {e}")
|
logger.error(f"保存配置文件失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/model/section/{section_name}")
|
@router.post("/model/section/{section_name}")
|
||||||
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
|
async def update_model_config_section(section_name: str, section_data: SectionBody):
|
||||||
"""更新模型配置的指定节(保留注释和格式)"""
|
"""更新模型配置的指定节(保留注释和格式)"""
|
||||||
try:
|
try:
|
||||||
# 读取现有配置
|
# 读取现有配置
|
||||||
|
|
@ -403,7 +409,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||||
try:
|
try:
|
||||||
APIAdapterConfig.from_dict(config_data)
|
APIAdapterConfig.from_dict(config_data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||||
|
|
||||||
# 保存配置(tomlkit.dump 会保留注释)
|
# 保存配置(tomlkit.dump 会保留注释)
|
||||||
with open(config_path, "w", encoding="utf-8") as f:
|
with open(config_path, "w", encoding="utf-8") as f:
|
||||||
|
|
@ -415,7 +421,7 @@ async def update_model_config_section(section_name: str, section_data: Any = Bod
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新配置节失败: {e}")
|
logger.error(f"更新配置节失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
# ===== 适配器配置管理接口 =====
|
# ===== 适配器配置管理接口 =====
|
||||||
|
|
@ -443,7 +449,7 @@ def _to_relative_path(path: str) -> str:
|
||||||
# 尝试获取相对路径
|
# 尝试获取相对路径
|
||||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||||
if not rel_path.startswith('..'):
|
if not rel_path.startswith(".."):
|
||||||
return rel_path
|
return rel_path
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||||
|
|
@ -463,6 +469,7 @@ async def get_adapter_config_path():
|
||||||
return {"success": True, "path": None}
|
return {"success": True, "path": None}
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||||
webui_data = json.load(f)
|
webui_data = json.load(f)
|
||||||
|
|
||||||
|
|
@ -476,6 +483,7 @@ async def get_adapter_config_path():
|
||||||
# 检查文件是否存在并返回最后修改时间
|
# 检查文件是否存在并返回最后修改时间
|
||||||
if os.path.exists(abs_path):
|
if os.path.exists(abs_path):
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
mtime = os.path.getmtime(abs_path)
|
mtime = os.path.getmtime(abs_path)
|
||||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||||
# 返回相对路径(如果可能)
|
# 返回相对路径(如果可能)
|
||||||
|
|
@ -487,11 +495,11 @@ async def get_adapter_config_path():
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取适配器配置路径失败: {e}")
|
logger.error(f"获取适配器配置路径失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config/path")
|
@router.post("/adapter-config/path")
|
||||||
async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
async def save_adapter_config_path(data: PathBody):
|
||||||
"""保存适配器配置文件路径偏好"""
|
"""保存适配器配置文件路径偏好"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
|
|
@ -530,7 +538,7 @@ async def save_adapter_config_path(data: dict[str, str] = Body(...)):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存适配器配置路径失败: {e}")
|
logger.error(f"保存适配器配置路径失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.get("/adapter-config")
|
@router.get("/adapter-config")
|
||||||
|
|
@ -562,11 +570,11 @@ async def get_adapter_config(path: str):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取适配器配置失败: {e}")
|
logger.error(f"读取适配器配置失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
||||||
@router.post("/adapter-config")
|
@router.post("/adapter-config")
|
||||||
async def save_adapter_config(data: dict[str, str] = Body(...)):
|
async def save_adapter_config(data: PathBody):
|
||||||
"""保存适配器配置到指定路径"""
|
"""保存适配器配置到指定路径"""
|
||||||
try:
|
try:
|
||||||
path = data.get("path")
|
path = data.get("path")
|
||||||
|
|
@ -586,10 +594,9 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||||
|
|
||||||
# 验证 TOML 格式
|
# 验证 TOML 格式
|
||||||
try:
|
try:
|
||||||
import tomlkit
|
|
||||||
tomlkit.loads(content)
|
tomlkit.loads(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}")
|
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||||
|
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
dir_path = os.path.dirname(abs_path)
|
dir_path = os.path.dirname(abs_path)
|
||||||
|
|
@ -607,5 +614,4 @@ async def save_adapter_config(data: dict[str, str] = Body(...)):
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存适配器配置失败: {e}")
|
logger.error(f"保存适配器配置失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -117,7 +117,7 @@ class ConfigSchemaGenerator:
|
||||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||||
# 单行文档字符串
|
# 单行文档字符串
|
||||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||||
else:
|
else:
|
||||||
# 多行文档字符串
|
# 多行文档字符串
|
||||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||||
|
|
@ -135,7 +135,7 @@ class ConfigSchemaGenerator:
|
||||||
next_line = lines[i + 1].strip()
|
next_line = lines[i + 1].strip()
|
||||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||||
description_lines.append(next_line.strip('"""').strip("'''").strip())
|
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||||
else:
|
else:
|
||||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||||
description_lines.append(next_line.strip(quote).strip())
|
description_lines.append(next_line.strip(quote).strip())
|
||||||
|
|
@ -199,13 +199,13 @@ class ConfigSchemaGenerator:
|
||||||
return FieldType.ARRAY, None, items
|
return FieldType.ARRAY, None, items
|
||||||
|
|
||||||
# 处理基本类型
|
# 处理基本类型
|
||||||
if field_type is bool or field_type == bool:
|
if field_type is bool:
|
||||||
return FieldType.BOOLEAN, None, None
|
return FieldType.BOOLEAN, None, None
|
||||||
elif field_type is int or field_type == int:
|
elif field_type is int:
|
||||||
return FieldType.INTEGER, None, None
|
return FieldType.INTEGER, None, None
|
||||||
elif field_type is float or field_type == float:
|
elif field_type is float:
|
||||||
return FieldType.NUMBER, None, None
|
return FieldType.NUMBER, None, None
|
||||||
elif field_type is str or field_type == str:
|
elif field_type is str:
|
||||||
return FieldType.STRING, None, None
|
return FieldType.STRING, None, None
|
||||||
elif field_type is dict or origin is dict:
|
elif field_type is dict or origin is dict:
|
||||||
return FieldType.OBJECT, None, None
|
return FieldType.OBJECT, None, None
|
||||||
|
|
|
||||||
|
|
@ -3,20 +3,25 @@
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
|
from fastapi import APIRouter, HTTPException, Header, Query, UploadFile, File, Form
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Annotated
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Emoji
|
from src.common.database.database_model import Emoji
|
||||||
from .token_manager import get_token_manager
|
from .token_manager import get_token_manager
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import base64
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import io
|
import io
|
||||||
|
|
||||||
logger = get_logger("webui.emoji")
|
logger = get_logger("webui.emoji")
|
||||||
|
|
||||||
|
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||||
|
EmojiFile = Annotated[UploadFile, File(description="表情包图片文件")]
|
||||||
|
EmojiFiles = Annotated[List[UploadFile], File(description="多个表情包图片文件")]
|
||||||
|
DescriptionForm = Annotated[str, Form(description="表情包描述")]
|
||||||
|
EmotionForm = Annotated[str, Form(description="情感标签,多个用逗号分隔")]
|
||||||
|
IsRegisteredForm = Annotated[bool, Form(description="是否直接注册")]
|
||||||
|
|
||||||
# 创建路由器
|
# 创建路由器
|
||||||
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
router = APIRouter(prefix="/emoji", tags=["Emoji"])
|
||||||
|
|
||||||
|
|
@ -592,10 +597,10 @@ class EmojiUploadResponse(BaseModel):
|
||||||
|
|
||||||
@router.post("/upload", response_model=EmojiUploadResponse)
|
@router.post("/upload", response_model=EmojiUploadResponse)
|
||||||
async def upload_emoji(
|
async def upload_emoji(
|
||||||
file: UploadFile = File(..., description="表情包图片文件"),
|
file: EmojiFile,
|
||||||
description: str = Form("", description="表情包描述"),
|
description: DescriptionForm = "",
|
||||||
emotion: str = Form("", description="情感标签,多个用逗号分隔"),
|
emotion: EmotionForm = "",
|
||||||
is_registered: bool = Form(True, description="是否直接注册"),
|
is_registered: IsRegisteredForm = True,
|
||||||
authorization: Optional[str] = Header(None),
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -713,9 +718,9 @@ async def upload_emoji(
|
||||||
|
|
||||||
@router.post("/batch/upload")
|
@router.post("/batch/upload")
|
||||||
async def batch_upload_emoji(
|
async def batch_upload_emoji(
|
||||||
files: List[UploadFile] = File(..., description="多个表情包图片文件"),
|
files: EmojiFiles,
|
||||||
emotion: str = Form("", description="情感标签,多个用逗号分隔"),
|
emotion: EmotionForm = "",
|
||||||
is_registered: bool = Form(True, description="是否直接注册"),
|
is_registered: IsRegisteredForm = True,
|
||||||
authorization: Optional[str] = Header(None),
|
authorization: Optional[str] = Header(None),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -749,11 +754,13 @@ async def batch_upload_emoji(
|
||||||
# 验证文件类型
|
# 验证文件类型
|
||||||
if file.content_type not in allowed_types:
|
if file.content_type not in allowed_types:
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append({
|
results["details"].append(
|
||||||
|
{
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"不支持的文件类型: {file.content_type}",
|
"error": f"不支持的文件类型: {file.content_type}",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 读取文件内容
|
# 读取文件内容
|
||||||
|
|
@ -761,11 +768,13 @@ async def batch_upload_emoji(
|
||||||
|
|
||||||
if not file_content:
|
if not file_content:
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append({
|
results["details"].append(
|
||||||
|
{
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "文件内容为空",
|
"error": "文件内容为空",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 验证图片
|
# 验证图片
|
||||||
|
|
@ -774,11 +783,13 @@ async def batch_upload_emoji(
|
||||||
img_format = img.format.lower() if img.format else "png"
|
img_format = img.format.lower() if img.format else "png"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append({
|
results["details"].append(
|
||||||
|
{
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": f"无效的图片: {str(e)}",
|
"error": f"无效的图片: {str(e)}",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 计算哈希
|
# 计算哈希
|
||||||
|
|
@ -787,11 +798,13 @@ async def batch_upload_emoji(
|
||||||
# 检查重复
|
# 检查重复
|
||||||
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
if Emoji.get_or_none(Emoji.emoji_hash == emoji_hash):
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append({
|
results["details"].append(
|
||||||
|
{
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": "已存在相同的表情包",
|
"error": "已存在相同的表情包",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 生成文件名并保存
|
# 生成文件名并保存
|
||||||
|
|
@ -829,19 +842,23 @@ async def batch_upload_emoji(
|
||||||
)
|
)
|
||||||
|
|
||||||
results["uploaded"] += 1
|
results["uploaded"] += 1
|
||||||
results["details"].append({
|
results["details"].append(
|
||||||
|
{
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"success": True,
|
"success": True,
|
||||||
"id": emoji.id,
|
"id": emoji.id,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
results["failed"] += 1
|
results["failed"] += 1
|
||||||
results["details"].append({
|
results["details"].append(
|
||||||
|
{
|
||||||
"filename": file.filename,
|
"filename": file.filename,
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
results["message"] = f"成功上传 {results['uploaded']} 个,失败 {results['failed']} 个"
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -602,9 +602,9 @@ class GitMirrorService:
|
||||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
def run_git_clone():
|
def run_git_clone(clone_cmd=cmd):
|
||||||
return subprocess.run(
|
return subprocess.run(
|
||||||
cmd,
|
clone_cmd,
|
||||||
capture_output=True,
|
capture_output=True,
|
||||||
text=True,
|
text=True,
|
||||||
timeout=300, # 5分钟超时
|
timeout=300, # 5分钟超时
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
"""知识库图谱可视化 API 路由"""
|
"""知识库图谱可视化 API 路由"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from fastapi import APIRouter, Query
|
from fastapi import APIRouter, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
@ -11,6 +12,7 @@ router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||||
|
|
||||||
class KnowledgeNode(BaseModel):
|
class KnowledgeNode(BaseModel):
|
||||||
"""知识节点"""
|
"""知识节点"""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
type: str # 'entity' or 'paragraph'
|
type: str # 'entity' or 'paragraph'
|
||||||
content: str
|
content: str
|
||||||
|
|
@ -19,6 +21,7 @@ class KnowledgeNode(BaseModel):
|
||||||
|
|
||||||
class KnowledgeEdge(BaseModel):
|
class KnowledgeEdge(BaseModel):
|
||||||
"""知识边"""
|
"""知识边"""
|
||||||
|
|
||||||
source: str
|
source: str
|
||||||
target: str
|
target: str
|
||||||
weight: float
|
weight: float
|
||||||
|
|
@ -28,12 +31,14 @@ class KnowledgeEdge(BaseModel):
|
||||||
|
|
||||||
class KnowledgeGraph(BaseModel):
|
class KnowledgeGraph(BaseModel):
|
||||||
"""知识图谱"""
|
"""知识图谱"""
|
||||||
|
|
||||||
nodes: List[KnowledgeNode]
|
nodes: List[KnowledgeNode]
|
||||||
edges: List[KnowledgeEdge]
|
edges: List[KnowledgeEdge]
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeStats(BaseModel):
|
class KnowledgeStats(BaseModel):
|
||||||
"""知识库统计信息"""
|
"""知识库统计信息"""
|
||||||
|
|
||||||
total_nodes: int
|
total_nodes: int
|
||||||
total_edges: int
|
total_edges: int
|
||||||
entity_nodes: int
|
entity_nodes: int
|
||||||
|
|
@ -69,16 +74,11 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||||
content = node_data['content'] if 'content' in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||||
|
|
||||||
nodes.append(KnowledgeNode(
|
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||||
id=node_id,
|
|
||||||
type=node_type,
|
|
||||||
content=content,
|
|
||||||
create_time=create_time
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -93,17 +93,15 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||||
edge_data = graph[source, target]
|
edge_data = graph[source, target]
|
||||||
|
|
||||||
# edge_data 支持 [] 操作符但不支持 .get()
|
# edge_data 支持 [] 操作符但不支持 .get()
|
||||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||||
|
|
||||||
edges.append(KnowledgeEdge(
|
edges.append(
|
||||||
source=source,
|
KnowledgeEdge(
|
||||||
target=target,
|
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||||
weight=weight,
|
)
|
||||||
create_time=create_time,
|
)
|
||||||
update_time=update_time
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -114,7 +112,7 @@ def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||||
@router.get("/graph", response_model=KnowledgeGraph)
|
@router.get("/graph", response_model=KnowledgeGraph)
|
||||||
async def get_knowledge_graph(
|
async def get_knowledge_graph(
|
||||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph")
|
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||||
):
|
):
|
||||||
"""获取知识图谱(限制节点数量)
|
"""获取知识图谱(限制节点数量)
|
||||||
|
|
||||||
|
|
@ -136,9 +134,11 @@ async def get_knowledge_graph(
|
||||||
|
|
||||||
# 按类型过滤节点
|
# 按类型过滤节点
|
||||||
if node_type == "entity":
|
if node_type == "entity":
|
||||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'ent']
|
all_node_list = [
|
||||||
|
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||||
|
]
|
||||||
elif node_type == "paragraph":
|
elif node_type == "paragraph":
|
||||||
all_node_list = [n for n in all_node_list if n in graph and 'type' in graph[n] and graph[n]['type'] == 'pg']
|
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||||
|
|
||||||
# 限制节点数量
|
# 限制节点数量
|
||||||
total_nodes = len(all_node_list)
|
total_nodes = len(all_node_list)
|
||||||
|
|
@ -155,16 +155,11 @@ async def get_knowledge_graph(
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
node_type_val = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||||
content = node_data['content'] if 'content' in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||||
|
|
||||||
nodes.append(KnowledgeNode(
|
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||||
id=node_id,
|
|
||||||
type=node_type_val,
|
|
||||||
content=content,
|
|
||||||
create_time=create_time
|
|
||||||
))
|
|
||||||
node_ids.add(node_id)
|
node_ids.add(node_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||||
|
|
@ -181,17 +176,15 @@ async def get_knowledge_graph(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
edge_data = graph[source, target]
|
edge_data = graph[source, target]
|
||||||
weight = edge_data['weight'] if 'weight' in edge_data else 1.0
|
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||||
create_time = edge_data['create_time'] if 'create_time' in edge_data else None
|
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||||
update_time = edge_data['update_time'] if 'update_time' in edge_data else None
|
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||||
|
|
||||||
edges.append(KnowledgeEdge(
|
edges.append(
|
||||||
source=source,
|
KnowledgeEdge(
|
||||||
target=target,
|
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||||
weight=weight,
|
)
|
||||||
create_time=create_time,
|
)
|
||||||
update_time=update_time
|
|
||||||
))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -215,13 +208,7 @@ async def get_knowledge_stats():
|
||||||
try:
|
try:
|
||||||
kg_manager = _load_kg_manager()
|
kg_manager = _load_kg_manager()
|
||||||
if kg_manager is None or kg_manager.graph is None:
|
if kg_manager is None or kg_manager.graph is None:
|
||||||
return KnowledgeStats(
|
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||||
total_nodes=0,
|
|
||||||
total_edges=0,
|
|
||||||
entity_nodes=0,
|
|
||||||
paragraph_nodes=0,
|
|
||||||
avg_connections=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
graph = kg_manager.graph
|
graph = kg_manager.graph
|
||||||
node_list = graph.get_node_list()
|
node_list = graph.get_node_list()
|
||||||
|
|
@ -236,10 +223,10 @@ async def get_knowledge_stats():
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
node_type = node_data['type'] if 'type' in node_data else 'ent'
|
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||||
if node_type == 'ent':
|
if node_type == "ent":
|
||||||
entity_nodes += 1
|
entity_nodes += 1
|
||||||
elif node_type == 'pg':
|
elif node_type == "pg":
|
||||||
paragraph_nodes += 1
|
paragraph_nodes += 1
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
@ -252,18 +239,12 @@ async def get_knowledge_stats():
|
||||||
total_edges=total_edges,
|
total_edges=total_edges,
|
||||||
entity_nodes=entity_nodes,
|
entity_nodes=entity_nodes,
|
||||||
paragraph_nodes=paragraph_nodes,
|
paragraph_nodes=paragraph_nodes,
|
||||||
avg_connections=round(avg_connections, 2)
|
avg_connections=round(avg_connections, 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||||
return KnowledgeStats(
|
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||||
total_nodes=0,
|
|
||||||
total_edges=0,
|
|
||||||
entity_nodes=0,
|
|
||||||
paragraph_nodes=0,
|
|
||||||
avg_connections=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search", response_model=List[KnowledgeNode])
|
@router.get("/search", response_model=List[KnowledgeNode])
|
||||||
|
|
@ -290,17 +271,12 @@ async def search_knowledge_node(query: str = Query(..., min_length=1)):
|
||||||
for node_id in node_list:
|
for node_id in node_list:
|
||||||
try:
|
try:
|
||||||
node_data = graph[node_id]
|
node_data = graph[node_id]
|
||||||
content = node_data['content'] if 'content' in node_data else node_id
|
content = node_data["content"] if "content" in node_data else node_id
|
||||||
node_type = "entity" if ('type' in node_data and node_data['type'] == 'ent') else "paragraph"
|
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||||
|
|
||||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||||
create_time = node_data['create_time'] if 'create_time' in node_data else None
|
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||||
results.append(KnowledgeNode(
|
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||||
id=node_id,
|
|
||||||
type=node_type,
|
|
||||||
content=content,
|
|
||||||
create_time=create_time
|
|
||||||
))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,11 +50,13 @@ def _parse_openai_response(data: dict) -> list[dict]:
|
||||||
if "data" in data and isinstance(data["data"], list):
|
if "data" in data and isinstance(data["data"], list):
|
||||||
for model in data["data"]:
|
for model in data["data"]:
|
||||||
if isinstance(model, dict) and "id" in model:
|
if isinstance(model, dict) and "id" in model:
|
||||||
models.append({
|
models.append(
|
||||||
|
{
|
||||||
"id": model["id"],
|
"id": model["id"],
|
||||||
"name": model.get("name") or model["id"],
|
"name": model.get("name") or model["id"],
|
||||||
"owned_by": model.get("owned_by", ""),
|
"owned_by": model.get("owned_by", ""),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -72,11 +74,13 @@ def _parse_gemini_response(data: dict) -> list[dict]:
|
||||||
model_id = model["name"]
|
model_id = model["name"]
|
||||||
if model_id.startswith("models/"):
|
if model_id.startswith("models/"):
|
||||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||||
models.append({
|
models.append(
|
||||||
|
{
|
||||||
"id": model_id,
|
"id": model_id,
|
||||||
"name": model.get("displayName") or model_id,
|
"name": model.get("displayName") or model_id,
|
||||||
"owned_by": "google",
|
"owned_by": "google",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -118,25 +122,24 @@ async def _fetch_models_from_provider(
|
||||||
response = await client.get(url, headers=headers, params=params)
|
response = await client.get(url, headers=headers, params=params)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException as e:
|
||||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试")
|
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||||
if e.response.status_code == 401:
|
if e.response.status_code == 401:
|
||||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期")
|
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||||
elif e.response.status_code == 403:
|
elif e.response.status_code == 403:
|
||||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限")
|
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||||
elif e.response.status_code == 404:
|
elif e.response.status_code == 404:
|
||||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表")
|
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=502,
|
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||||
detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
) from e
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取模型列表失败: {e}")
|
logger.error(f"获取模型列表失败: {e}")
|
||||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||||
|
|
||||||
# 根据解析器类型解析响应
|
# 根据解析器类型解析响应
|
||||||
if parser == "openai":
|
if parser == "openai":
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,9 @@ def parse_version(version_str: str) -> tuple[int, int, int]:
|
||||||
"""
|
"""
|
||||||
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
|
# 移除 snapshot、dev、alpha、beta 等后缀(支持 - 和 . 分隔符)
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
|
# 匹配 -snapshot.X, .snapshot, -dev, .dev, -alpha, .alpha, -beta, .beta 等后缀
|
||||||
base_version = re.split(r'[-.](?:snapshot|dev|alpha|beta|rc)', version_str, flags=re.IGNORECASE)[0]
|
base_version = re.split(r"[-.](?:snapshot|dev|alpha|beta|rc)", version_str, flags=re.IGNORECASE)[0]
|
||||||
|
|
||||||
parts = base_version.split(".")
|
parts = base_version.split(".")
|
||||||
if len(parts) < 3:
|
if len(parts) < 3:
|
||||||
|
|
@ -1167,9 +1168,7 @@ class UpdatePluginConfigRequest(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}/schema")
|
@router.get("/config/{plugin_id}/schema")
|
||||||
async def get_plugin_config_schema(
|
async def get_plugin_config_schema(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
获取插件配置 Schema
|
获取插件配置 Schema
|
||||||
|
|
||||||
|
|
@ -1205,7 +1204,7 @@ async def get_plugin_config_schema(
|
||||||
plugin_instance = instance
|
plugin_instance = instance
|
||||||
break
|
break
|
||||||
|
|
||||||
if plugin_instance and hasattr(plugin_instance, 'get_webui_config_schema'):
|
if plugin_instance and hasattr(plugin_instance, "get_webui_config_schema"):
|
||||||
# 从插件实例获取 schema
|
# 从插件实例获取 schema
|
||||||
schema = plugin_instance.get_webui_config_schema()
|
schema = plugin_instance.get_webui_config_schema()
|
||||||
return {"success": True, "schema": schema}
|
return {"success": True, "schema": schema}
|
||||||
|
|
@ -1236,6 +1235,7 @@ async def get_plugin_config_schema(
|
||||||
current_config = {}
|
current_config = {}
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
current_config = tomlkit.load(f)
|
current_config = tomlkit.load(f)
|
||||||
|
|
||||||
|
|
@ -1301,9 +1301,7 @@ async def get_plugin_config_schema(
|
||||||
|
|
||||||
|
|
||||||
@router.get("/config/{plugin_id}")
|
@router.get("/config/{plugin_id}")
|
||||||
async def get_plugin_config(
|
async def get_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
获取插件当前配置值
|
获取插件当前配置值
|
||||||
|
|
||||||
|
|
@ -1344,6 +1342,7 @@ async def get_plugin_config(
|
||||||
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
return {"success": True, "config": {}, "message": "配置文件不存在"}
|
||||||
|
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
with open(config_path, "r", encoding="utf-8") as f:
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
config = tomlkit.load(f)
|
config = tomlkit.load(f)
|
||||||
|
|
||||||
|
|
@ -1358,9 +1357,7 @@ async def get_plugin_config(
|
||||||
|
|
||||||
@router.put("/config/{plugin_id}")
|
@router.put("/config/{plugin_id}")
|
||||||
async def update_plugin_config(
|
async def update_plugin_config(
|
||||||
plugin_id: str,
|
plugin_id: str, request: UpdatePluginConfigRequest, authorization: Optional[str] = Header(None)
|
||||||
request: UpdatePluginConfigRequest,
|
|
||||||
authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
更新插件配置
|
更新插件配置
|
||||||
|
|
@ -1401,6 +1398,7 @@ async def update_plugin_config(
|
||||||
# 备份旧配置
|
# 备份旧配置
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
backup_name = f"config.toml.backup.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
backup_path = plugin_path / backup_name
|
backup_path = plugin_path / backup_name
|
||||||
|
|
@ -1409,6 +1407,7 @@ async def update_plugin_config(
|
||||||
|
|
||||||
# 写入新配置(使用 tomlkit 保留注释)
|
# 写入新配置(使用 tomlkit 保留注释)
|
||||||
import tomlkit
|
import tomlkit
|
||||||
|
|
||||||
# 先读取原配置以保留注释和格式
|
# 先读取原配置以保留注释和格式
|
||||||
existing_doc = tomlkit.document()
|
existing_doc = tomlkit.document()
|
||||||
if config_path.exists():
|
if config_path.exists():
|
||||||
|
|
@ -1422,11 +1421,7 @@ async def update_plugin_config(
|
||||||
|
|
||||||
logger.info(f"已更新插件配置: {plugin_id}")
|
logger.info(f"已更新插件配置: {plugin_id}")
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": "配置已保存", "note": "配置更改将在插件重新加载后生效"}
|
||||||
"success": True,
|
|
||||||
"message": "配置已保存",
|
|
||||||
"note": "配置更改将在插件重新加载后生效"
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -1436,9 +1431,7 @@ async def update_plugin_config(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/reset")
|
@router.post("/config/{plugin_id}/reset")
|
||||||
async def reset_plugin_config(
|
async def reset_plugin_config(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
重置插件配置为默认值
|
重置插件配置为默认值
|
||||||
|
|
||||||
|
|
@ -1481,17 +1474,14 @@ async def reset_plugin_config(
|
||||||
# 备份并删除
|
# 备份并删除
|
||||||
import shutil
|
import shutil
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
backup_name = f"config.toml.reset.{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}"
|
||||||
backup_path = plugin_path / backup_name
|
backup_path = plugin_path / backup_name
|
||||||
shutil.move(config_path, backup_path)
|
shutil.move(config_path, backup_path)
|
||||||
|
|
||||||
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
logger.info(f"已重置插件配置: {plugin_id},备份: {backup_path}")
|
||||||
|
|
||||||
return {
|
return {"success": True, "message": "配置已重置,下次加载插件时将使用默认配置", "backup": str(backup_path)}
|
||||||
"success": True,
|
|
||||||
"message": "配置已重置,下次加载插件时将使用默认配置",
|
|
||||||
"backup": str(backup_path)
|
|
||||||
}
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|
@ -1501,9 +1491,7 @@ async def reset_plugin_config(
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config/{plugin_id}/toggle")
|
@router.post("/config/{plugin_id}/toggle")
|
||||||
async def toggle_plugin(
|
async def toggle_plugin(plugin_id: str, authorization: Optional[str] = Header(None)) -> Dict[str, Any]:
|
||||||
plugin_id: str, authorization: Optional[str] = Header(None)
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
切换插件启用状态
|
切换插件启用状态
|
||||||
|
|
||||||
|
|
@ -1567,7 +1555,7 @@ async def toggle_plugin(
|
||||||
"success": True,
|
"success": True,
|
||||||
"enabled": new_enabled,
|
"enabled": new_enabled,
|
||||||
"message": f"插件已{status}",
|
"message": f"插件已{status}",
|
||||||
"note": "状态更改将在下次加载插件时生效"
|
"note": "状态更改将在下次加载插件时生效",
|
||||||
}
|
}
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
|
|
|
||||||
|
|
@ -91,10 +91,12 @@ class WebUIServer:
|
||||||
|
|
||||||
logger.info("开始导入 knowledge_routes...")
|
logger.info("开始导入 knowledge_routes...")
|
||||||
from src.webui.knowledge_routes import router as knowledge_router
|
from src.webui.knowledge_routes import router as knowledge_router
|
||||||
|
|
||||||
logger.info("knowledge_routes 导入成功")
|
logger.info("knowledge_routes 导入成功")
|
||||||
|
|
||||||
# 导入本地聊天室路由
|
# 导入本地聊天室路由
|
||||||
from src.webui.chat_routes import router as chat_router
|
from src.webui.chat_routes import router as chat_router
|
||||||
|
|
||||||
logger.info("chat_routes 导入成功")
|
logger.info("chat_routes 导入成功")
|
||||||
|
|
||||||
# 注册路由
|
# 注册路由
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue