pull/1374/head
墨梓柒 2025-11-19 23:35:14 +08:00
parent 2f58605644
commit 44f427dc64
No known key found for this signature in database
GPG Key ID: 4A65B9DBA35F7635
42 changed files with 1742 additions and 2062 deletions

8
bot.py
View File

@ -1,7 +1,6 @@
import asyncio
import hashlib
import os
import sys
import time
import platform
import traceback
@ -30,7 +29,7 @@ else:
raise
# 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式
from src.common.logger import initialize_logging, get_logger, shutdown_logging #noqa
from src.common.logger import initialize_logging, get_logger, shutdown_logging # noqa
initialize_logging()
@ -212,9 +211,10 @@ if __name__ == "__main__":
# 创建事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 初始化 WebSocket 日志推送
from src.common.logger import initialize_ws_handler
initialize_ws_handler(loop)
try:
@ -251,7 +251,7 @@ if __name__ == "__main__":
print(f"关闭日志系统时出错: {e}")
print("[主程序] 准备退出...")
# 使用 os._exit() 强制退出,避免被阻塞
# 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的
os._exit(exit_code)

View File

@ -16,8 +16,6 @@ if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
SECONDS_5_MINUTES = 5 * 60

View File

@ -12,7 +12,6 @@ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
# 设置中文字体
plt.rcParams["font.sans-serif"] = ["SimHei", "Microsoft YaHei", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False

File diff suppressed because it is too large Load Diff

View File

@ -333,7 +333,6 @@ class HeartFChatting:
# 重置连续 no_reply 计数
self.consecutive_no_reply_count = 0
reason = ""
await database_api.store_action_info(
chat_stream=self.chat_stream,

View File

@ -30,9 +30,11 @@ DATA_PATH = os.path.join(ROOT_PATH, "data")
qa_manager = None
inspire_manager = None
def get_qa_manager():
return qa_manager
def lpmm_start_up(): # sourcery skip: extract-duplicate-method
# 检查LPMM知识库是否启用
if global_config.lpmm_knowledge.enable:

View File

@ -128,11 +128,10 @@ class QAManager:
selected_knowledge = knowledge[:limit]
formatted_knowledge = [
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
for i, k in enumerate(selected_knowledge)
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(selected_knowledge)
]
# if max_score is not None:
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
found_knowledge = "\n".join(formatted_knowledge)
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:

View File

@ -226,7 +226,9 @@ class DefaultReplyer:
traceback.print_exc()
return False, llm_response
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@ -1094,10 +1096,10 @@ class DefaultReplyer:
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname
@ -1115,10 +1117,10 @@ class DefaultReplyer:
model_config=model_config.model_task_config.tool_use,
tool_options=[SearchKnowledgeFromLPMMTool.get_tool_definition()],
)
# logger.info(f"工具调用提示词: {prompt}")
# logger.info(f"工具调用: {tool_calls}")
if tool_calls:
result = await self.tool_executor.execute_tool_call(tool_calls[0], SearchKnowledgeFromLPMMTool())
end_time = time.time()

View File

@ -241,7 +241,9 @@ class PrivateReplyer:
return f"{sender_relation}"
async def build_expression_habits(self, chat_history: str, target: str, reply_reason: str = "") -> Tuple[str, List[int]]:
async def build_expression_habits(
self, chat_history: str, target: str, reply_reason: str = ""
) -> Tuple[str, List[int]]:
# sourcery skip: for-append-to-extend
"""构建表达习惯块
@ -1032,10 +1034,10 @@ class PrivateReplyer:
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用跳过获取知识库内容")
return ""
if global_config.lpmm_knowledge.lpmm_mode == "agent":
return ""
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
bot_name = global_config.bot.nickname

View File

@ -106,7 +106,7 @@ class ChatHistorySummarizer:
await self._check_and_package(current_time)
self.last_check_time = current_time
return
logger.info(
f"{self.log_prefix} 开始处理聊天概括,时间窗口: {self.last_check_time:.2f} -> {current_time:.2f}"
)

View File

@ -72,16 +72,16 @@ def get_ws_handler():
def initialize_ws_handler(loop):
"""初始化 WebSocket handler 的事件循环
Args:
loop: asyncio 事件循环
"""
handler = get_ws_handler()
handler.set_loop(loop)
# 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同)
handler.setFormatter(file_formatter)
# 添加到根日志记录器
root_logger = logging.getLogger()
if handler not in root_logger.handlers:
@ -177,43 +177,44 @@ class TimestampedFileHandler(logging.Handler):
class WebSocketLogHandler(logging.Handler):
"""WebSocket 日志处理器 - 将日志实时推送到前端"""
_log_counter = 0 # 类级别计数器,确保 ID 唯一性
def __init__(self, loop=None):
super().__init__()
self.loop = loop
self._initialized = False
def set_loop(self, loop):
"""设置事件循环"""
self.loop = loop
self._initialized = True
def emit(self, record):
"""发送日志到 WebSocket 客户端"""
if not self._initialized or self.loop is None:
return
try:
# 获取格式化后的消息
# 对于 structlog,formatted message 包含完整的日志信息
formatted_msg = self.format(record) if self.formatter else record.getMessage()
# 如果是 JSON 格式(文件格式化器),解析它
message = formatted_msg
try:
import json
log_dict = json.loads(formatted_msg)
message = log_dict.get('event', formatted_msg)
message = log_dict.get("event", formatted_msg)
except (json.JSONDecodeError, ValueError):
# 不是 JSON,直接使用消息
message = formatted_msg
# 生成唯一 ID: 时间戳毫秒 + 自增计数器
WebSocketLogHandler._log_counter += 1
log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}"
# 格式化日志数据
log_data = {
"id": log_id,
@ -222,20 +223,17 @@ class WebSocketLogHandler(logging.Handler):
"module": record.name,
"message": message,
}
# 异步广播日志(不阻塞日志记录)
try:
import asyncio
from src.webui.logs_ws import broadcast_log
asyncio.run_coroutine_threadsafe(
broadcast_log(log_data),
self.loop
)
asyncio.run_coroutine_threadsafe(broadcast_log(log_data), self.loop)
except Exception:
# WebSocket 推送失败不影响日志记录
pass
except Exception:
# 不要让 WebSocket 错误影响日志系统
self.handleError(record)
@ -255,7 +253,7 @@ def close_handlers():
if _console_handler:
_console_handler.close()
_console_handler = None
if _ws_handler:
_ws_handler.close()
_ws_handler = None

View File

@ -647,7 +647,7 @@ class LPMMKnowledgeConfig(ConfigBase):
enable: bool = True
"""是否启用LPMM知识库"""
lpmm_mode: Literal["classic", "agent"] = "classic"
"""LPMM知识库模式可选classic经典模式agent 模式,结合最新的记忆一同使用"""
@ -690,4 +690,4 @@ class JargonConfig(ConfigBase):
"""Jargon配置类"""
all_global: bool = False
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id"""
"""是否将所有新增的jargon项目默认为全局is_global=Truechat_id记录第一次存储时的id"""

View File

@ -467,11 +467,7 @@ class ExpressionLearner:
up_content: str,
current_time: float,
) -> None:
expr_obj = (
Expression.select()
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
.first()
)
expr_obj = Expression.select().where((Expression.chat_id == self.chat_id) & (Expression.style == style)).first()
if expr_obj:
await self._update_existing_expression(

View File

@ -42,8 +42,6 @@ def init_prompt():
Prompt(expression_evaluation_prompt, "expression_evaluation_prompt")
class ExpressionSelector:
def __init__(self):
self.llm_model = LLMRequest(
@ -238,9 +236,9 @@ class ExpressionSelector:
else:
target_message_str = ""
target_message_extra_block = ""
chat_context = f"以下是正在进行的聊天内容:{chat_info}"
# 构建reply_reason块
if reply_reason:
reply_reason_block = f"你的回复理由是:{reply_reason}"
@ -262,9 +260,8 @@ class ExpressionSelector:
# 4. 调用LLM
content, (reasoning_content, model_name, _) = await self.llm_model.generate_response_async(prompt=prompt)
# print(prompt)
if not content:
logger.warning("LLM返回空结果")
return [], []

View File

@ -36,10 +36,7 @@ def _contains_bot_self_name(content: str) -> bool:
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [
str(alias or "").strip().lower()
for alias in getattr(bot_config, "alias_names", []) or []
]
alias_names = [str(alias or "").strip().lower() for alias in getattr(bot_config, "alias_names", []) or []]
candidates = [name for name in [nickname, *alias_names] if name]
@ -149,7 +146,7 @@ async def _enrich_raw_content_if_needed(
) -> List[str]:
"""
检查raw_content是否只包含黑话本身如果是则获取该消息的前三条消息作为原始内容
Args:
content: 黑话内容
raw_content_list: 原始raw_content列表
@ -157,22 +154,22 @@ async def _enrich_raw_content_if_needed(
messages: 当前时间窗口内的消息列表
extraction_start_time: 提取开始时间
extraction_end_time: 提取结束时间
Returns:
处理后的raw_content列表
"""
enriched_list = []
for raw_content in raw_content_list:
# 检查raw_content是否只包含黑话本身去除空白字符后比较
raw_content_clean = raw_content.strip()
content_clean = content.strip()
# 如果raw_content只包含黑话本身可能有一些标点或空白则尝试获取上下文
# 去除所有空白字符后比较,确保只包含黑话本身
raw_content_normalized = raw_content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
content_normalized = content_clean.replace(" ", "").replace("\n", "").replace("\t", "")
if raw_content_normalized == content_normalized:
# 在消息列表中查找只包含该黑话的消息(去除空白后比较)
target_message = None
@ -183,22 +180,20 @@ async def _enrich_raw_content_if_needed(
if msg_content_normalized == content_normalized:
target_message = msg
break
if target_message and target_message.time:
# 获取该消息的前三条消息
try:
previous_messages = get_raw_msg_before_timestamp_with_chat(
chat_id=chat_id,
timestamp=target_message.time,
limit=3
chat_id=chat_id, timestamp=target_message.time, limit=3
)
if previous_messages:
# 将前三条消息和当前消息一起格式化
context_messages = previous_messages + [target_message]
# 按时间排序
context_messages.sort(key=lambda x: x.time or 0)
# 格式化为可读消息
formatted_context, _ = await build_readable_messages_with_list(
context_messages,
@ -206,7 +201,7 @@ async def _enrich_raw_content_if_needed(
timestamp_mode="relative",
truncate=False,
)
if formatted_context.strip():
enriched_list.append(formatted_context.strip())
logger.warning(f"为黑话 {content} 补充了上下文消息")
@ -226,7 +221,7 @@ async def _enrich_raw_content_if_needed(
else:
# raw_content包含更多内容直接使用
enriched_list.append(raw_content)
return enriched_list
@ -240,31 +235,31 @@ def _should_infer_meaning(jargon_obj: Jargon) -> bool:
# 如果已完成所有推断,不再推断
if jargon_obj.is_complete:
return False
count = jargon_obj.count or 0
last_inference = jargon_obj.last_inference_count or 0
# 阈值列表3,6, 10, 20, 40, 60, 100
thresholds = [3,6, 10, 20, 40, 60, 100]
thresholds = [3, 6, 10, 20, 40, 60, 100]
if count < thresholds[0]:
return False
# 如果count没有超过上次判定值不需要判定
if count <= last_inference:
return False
# 找到第一个大于last_inference的阈值
next_threshold = None
for threshold in thresholds:
if threshold > last_inference:
next_threshold = threshold
break
# 如果没有找到下一个阈值说明已经超过100不应该再推断
if next_threshold is None:
return False
# 检查count是否达到或超过这个阈值
return count >= next_threshold
@ -275,13 +270,13 @@ class JargonMiner:
self.last_learning_time: float = time.time()
# 频率控制,可按需调整
self.min_messages_for_learning: int = 10
self.min_learning_interval: float = 20
self.min_learning_interval: float = 20
self.llm = LLMRequest(
model_set=model_config.model_task_config.utils,
request_type="jargon.extract",
)
# 初始化stream_name作为类属性避免重复提取
chat_manager = get_chat_manager()
stream_name = chat_manager.get_stream_name(self.chat_id)
@ -306,17 +301,19 @@ class JargonMiner:
try:
content = jargon_obj.content
raw_content_str = jargon_obj.raw_content or ""
# 解析raw_content列表
raw_content_list = []
if raw_content_str:
try:
raw_content_list = json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
raw_content_list = (
json.loads(raw_content_str) if isinstance(raw_content_str, str) else raw_content_str
)
if not isinstance(raw_content_list, list):
raw_content_list = [raw_content_list] if raw_content_list else []
except (json.JSONDecodeError, TypeError):
raw_content_list = [raw_content_str] if raw_content_str else []
if not raw_content_list:
logger.warning(f"jargon {content} 没有raw_content跳过推断")
return
@ -328,12 +325,12 @@ class JargonMiner:
content=content,
raw_content_list=raw_content_text,
)
response1, _ = await self.llm.generate_response_async(prompt1, temperature=0.3)
if not response1:
logger.warning(f"jargon {content} 推断1失败无响应")
return
# 解析推断1结果
inference1 = None
try:
@ -349,7 +346,7 @@ class JargonMiner:
except Exception as e:
logger.error(f"jargon {content} 推断1解析失败: {e}")
return
# 检查推断1是否表示信息不足无法推断
no_info = inference1.get("no_info", False)
meaning1 = inference1.get("meaning", "").strip()
@ -360,18 +357,17 @@ class JargonMiner:
jargon_obj.save()
return
# 步骤2: 仅基于content推断
prompt2 = await global_prompt_manager.format_prompt(
"jargon_inference_content_only_prompt",
content=content,
)
response2, _ = await self.llm.generate_response_async(prompt2, temperature=0.3)
if not response2:
logger.warning(f"jargon {content} 推断2失败无响应")
return
# 解析推断2结果
inference2 = None
try:
@ -387,13 +383,12 @@ class JargonMiner:
except Exception as e:
logger.error(f"jargon {content} 推断2解析失败: {e}")
return
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
logger.info(f"jargon {content} 推断1提示词: {prompt1}")
logger.info(f"jargon {content} 推断1结果: {response1}")
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 推断2提示词: {prompt2}")
logger.info(f"jargon {content} 推断2结果: {response2}")
@ -404,22 +399,22 @@ class JargonMiner:
logger.debug(f"jargon {content} 推断2结果: {response2}")
logger.debug(f"jargon {content} 推断1提示词: {prompt1}")
logger.debug(f"jargon {content} 推断1结果: {response1}")
# 步骤3: 比较两个推断结果
prompt3 = await global_prompt_manager.format_prompt(
"jargon_compare_inference_prompt",
inference1=json.dumps(inference1, ensure_ascii=False),
inference2=json.dumps(inference2, ensure_ascii=False),
)
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon {content} 比较提示词: {prompt3}")
response3, _ = await self.llm.generate_response_async(prompt3, temperature=0.3)
if not response3:
logger.warning(f"jargon {content} 比较失败:无响应")
return
# 解析比较结果
comparison = None
try:
@ -439,7 +434,7 @@ class JargonMiner:
# 判断是否为黑话
is_similar = comparison.get("is_similar", False)
is_jargon = not is_similar # 如果相似,说明不是黑话;如果有差异,说明是黑话
# 更新数据库记录
jargon_obj.is_jargon = is_jargon
if is_jargon:
@ -448,17 +443,19 @@ class JargonMiner:
else:
# 不是黑话也记录含义使用推断2的结果因为含义明确
jargon_obj.meaning = inference2.get("meaning", "")
# 更新最后一次判定的count值避免重启后重复判定
jargon_obj.last_inference_count = jargon_obj.count or 0
# 如果count>=100标记为完成不再进行推断
if (jargon_obj.count or 0) >= 100:
jargon_obj.is_complete = True
jargon_obj.save()
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
logger.debug(
f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}"
)
# 固定输出推断结果,格式化为可读形式
if is_jargon:
# 是黑话,输出格式:[聊天名]xxx的含义是 xxxxxxxxxxx
@ -471,10 +468,11 @@ class JargonMiner:
else:
# 不是黑话,输出格式:[聊天名]xxx 不是黑话
logger.info(f"[{self.stream_name}]{content} 不是黑话")
except Exception as e:
logger.error(f"jargon推断失败: {e}")
import traceback
traceback.print_exc()
def should_trigger(self) -> bool:
@ -502,7 +500,7 @@ class JargonMiner:
# 记录本次提取的时间窗口,避免重复提取
extraction_start_time = self.last_learning_time
extraction_end_time = time.time()
# 拉取学习窗口内的消息
messages = get_raw_msg_by_timestamp_with_chat_inclusive(
chat_id=self.chat_id,
@ -525,7 +523,7 @@ class JargonMiner:
response, _ = await self.llm.generate_response_async(prompt, temperature=0.2)
if not response:
return
if global_config.debug.show_jargon_prompt:
logger.info(f"jargon提取提示词: {prompt}")
logger.info(f"jargon提取结果: {response}")
@ -555,7 +553,7 @@ class JargonMiner:
continue
content = str(item.get("content", "")).strip()
raw_content_value = item.get("raw_content", "")
# 处理raw_content可能是字符串或列表
raw_content_list = []
if isinstance(raw_content_value, list):
@ -566,15 +564,12 @@ class JargonMiner:
raw_content_str = raw_content_value.strip()
if raw_content_str:
raw_content_list = [raw_content_str]
if content and raw_content_list:
if _contains_bot_self_name(content):
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
entries.append({
"content": content,
"raw_content": raw_content_list
})
entries.append({"content": content, "raw_content": raw_content_list})
except Exception as e:
logger.error(f"解析jargon JSON失败: {e}; 原始: {response}")
return
@ -591,13 +586,13 @@ class JargonMiner:
if content_key not in seen:
seen.add(content_key)
uniq_entries.append(entry)
saved = 0
updated = 0
for entry in uniq_entries:
content = entry["content"]
raw_content_list = entry["raw_content"] # 已经是列表
# 检查并补充raw_content如果只包含黑话本身则获取前三条消息作为上下文
raw_content_list = await _enrich_raw_content_if_needed(
content=content,
@ -607,60 +602,53 @@ class JargonMiner:
extraction_start_time=extraction_start_time,
extraction_end_time=extraction_end_time,
)
try:
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
# 开启all_global无视chat_id查询所有content匹配的记录所有记录都是全局的
query = (
Jargon.select()
.where(Jargon.content == content)
)
query = Jargon.select().where(Jargon.content == content)
else:
# 关闭all_global只查询chat_id匹配的记录不考虑is_global
query = (
Jargon.select()
.where(
(Jargon.chat_id == self.chat_id) &
(Jargon.content == content)
)
)
query = Jargon.select().where((Jargon.chat_id == self.chat_id) & (Jargon.content == content))
if query.exists():
obj = query.get()
try:
obj.count = (obj.count or 0) + 1
except Exception:
obj.count = 1
# 合并raw_content列表读取现有列表追加新值去重
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
existing_raw_content = [obj.raw_content] if obj.raw_content else []
# 合并并去重
merged_list = list(dict.fromkeys(existing_raw_content + raw_content_list))
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
# 开启all_global时确保记录标记为is_global=True
if global_config.jargon.all_global:
obj.is_global = True
# 关闭all_global时保持原有is_global不变不修改
obj.save()
# 检查是否需要推断(达到阈值且超过上次判定值)
if _should_infer_meaning(obj):
# 异步触发推断,不阻塞主流程
# 重新加载对象以确保数据最新
jargon_id = obj.id
asyncio.create_task(self._infer_meaning_by_id(jargon_id))
updated += 1
else:
# 没找到匹配记录,创建新记录
@ -670,13 +658,13 @@ class JargonMiner:
else:
# 关闭all_global新记录is_global=False
is_global_new = False
Jargon.create(
content=content,
raw_content=json.dumps(raw_content_list, ensure_ascii=False),
chat_id=self.chat_id,
is_global=is_global_new,
count=1
count=1,
)
saved += 1
except Exception as e:
@ -688,13 +676,13 @@ class JargonMiner:
# 收集所有提取的jargon内容
jargon_list = [entry["content"] for entry in uniq_entries]
jargon_str = ",".join(jargon_list)
# 输出格式化的结果使用logger.info会自动应用jargon模块的颜色
logger.info(f"[{self.stream_name}]疑似黑话: {jargon_str}")
# 更新为本次提取的结束时间,确保不会重复提取相同的消息窗口
self.last_learning_time = extraction_end_time
if saved or updated:
logger.info(f"jargon写入: 新增 {saved} 条,更新 {updated}chat_id={self.chat_id}")
except Exception as e:
@ -720,15 +708,11 @@ async def extract_and_store_jargon(chat_id: str) -> None:
def search_jargon(
keyword: str,
chat_id: Optional[str] = None,
limit: int = 10,
case_sensitive: bool = False,
fuzzy: bool = True
keyword: str, chat_id: Optional[str] = None, limit: int = 10, case_sensitive: bool = False, fuzzy: bool = True
) -> List[Dict[str, str]]:
"""
搜索jargon支持大小写不敏感和模糊搜索
Args:
keyword: 搜索关键词
chat_id: 可选的聊天ID
@ -737,21 +721,18 @@ def search_jargon(
limit: 返回结果数量限制默认10
case_sensitive: 是否大小写敏感默认False不敏感
fuzzy: 是否模糊搜索默认True使用LIKE匹配
Returns:
List[Dict[str, str]]: 包含content, meaning的字典列表
"""
if not keyword or not keyword.strip():
return []
keyword = keyword.strip()
# 构建查询
query = Jargon.select(
Jargon.content,
Jargon.meaning
)
query = Jargon.select(Jargon.content, Jargon.meaning)
# 构建搜索条件
if case_sensitive:
# 大小写敏感
@ -760,7 +741,7 @@ def search_jargon(
search_condition = Jargon.content.contains(keyword)
else:
# 精确匹配
search_condition = (Jargon.content == keyword)
search_condition = Jargon.content == keyword
else:
# 大小写不敏感
if fuzzy:
@ -768,10 +749,10 @@ def search_jargon(
search_condition = fn.LOWER(Jargon.content).contains(keyword.lower())
else:
# 精确匹配使用LOWER函数
search_condition = (fn.LOWER(Jargon.content) == keyword.lower())
search_condition = fn.LOWER(Jargon.content) == keyword.lower()
query = query.where(search_condition)
# 根据all_global配置决定查询逻辑
if global_config.jargon.all_global:
# 开启all_global所有记录都是全局的查询所有is_global=True的记录无视chat_id
@ -779,35 +760,28 @@ def search_jargon(
else:
# 关闭all_global如果提供了chat_id优先搜索该聊天或global的jargon
if chat_id:
query = query.where(
(Jargon.chat_id == chat_id) | Jargon.is_global
)
query = query.where((Jargon.chat_id == chat_id) | Jargon.is_global)
# 只返回有meaning的记录
query = query.where(
(Jargon.meaning.is_null(False)) & (Jargon.meaning != "")
)
query = query.where((Jargon.meaning.is_null(False)) & (Jargon.meaning != ""))
# 按count降序排序优先返回出现频率高的
query = query.order_by(Jargon.count.desc())
# 限制结果数量
query = query.limit(limit)
# 执行查询并返回结果
results = []
for jargon in query:
results.append({
"content": jargon.content or "",
"meaning": jargon.meaning or ""
})
results.append({"content": jargon.content or "", "meaning": jargon.meaning or ""})
return results
async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: str) -> None:
"""将黑话存入jargon系统
Args:
jargon_keyword: 黑话关键词
answer: 答案内容将概括为raw_content
@ -820,53 +794,52 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
答案{answer}
只输出概括后的内容不要输出其他内容"""
success, summary, _, _ = await llm_api.generate_with_model(
summary_prompt,
model_config=model_config.model_task_config.utils_small,
request_type="memory.summarize_jargon",
)
logger.info(f"概括答案提示: {summary_prompt}")
logger.info(f"概括答案: {summary}")
if not success:
logger.warning(f"概括答案失败,使用原始答案: {summary}")
summary = answer[:100] # 截取前100字符作为备用
raw_content = summary.strip()[:200] # 限制长度
# 检查是否已存在
if global_config.jargon.all_global:
query = Jargon.select().where(Jargon.content == jargon_keyword)
else:
query = Jargon.select().where(
(Jargon.chat_id == chat_id) &
(Jargon.content == jargon_keyword)
)
query = Jargon.select().where((Jargon.chat_id == chat_id) & (Jargon.content == jargon_keyword))
if query.exists():
# 更新现有记录
obj = query.get()
obj.count = (obj.count or 0) + 1
# 合并raw_content列表
existing_raw_content = []
if obj.raw_content:
try:
existing_raw_content = json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
existing_raw_content = (
json.loads(obj.raw_content) if isinstance(obj.raw_content, str) else obj.raw_content
)
if not isinstance(existing_raw_content, list):
existing_raw_content = [existing_raw_content] if existing_raw_content else []
except (json.JSONDecodeError, TypeError):
existing_raw_content = [obj.raw_content] if obj.raw_content else []
# 合并并去重
merged_list = list(dict.fromkeys(existing_raw_content + [raw_content]))
obj.raw_content = json.dumps(merged_list, ensure_ascii=False)
if global_config.jargon.all_global:
obj.is_global = True
obj.save()
logger.info(f"更新jargon记录: {jargon_keyword}")
else:
@ -877,11 +850,9 @@ async def store_jargon_from_answer(jargon_keyword: str, answer: str, chat_id: st
raw_content=json.dumps([raw_content], ensure_ascii=False),
chat_id=chat_id,
is_global=is_global_new,
count=1
count=1,
)
logger.info(f"创建新jargon记录: {jargon_keyword}")
except Exception as e:
logger.error(f"存储jargon失败: {e}")

View File

@ -147,7 +147,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[FunctionDeclar
param_type_value = tool_option_param.param_type.value
if param_type_value == "bool":
param_type_value = "boolean"
return_dict: dict[str, Any] = {
"type": param_type_value,
"description": tool_option_param.description,

View File

@ -122,7 +122,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict[str, Any]
param_type_value = tool_option_param.param_type.value
if param_type_value == "bool":
param_type_value = "boolean"
return_dict: dict[str, Any] = {
"type": param_type_value,
"description": tool_option_param.description,

View File

@ -116,9 +116,7 @@ class MessageBuilder:
构建消息对象
:return: Message对象
"""
if len(self.__content) == 0 and not (
self.__role == RoleType.Assistant and self.__tool_calls
):
if len(self.__content) == 0 and not (self.__role == RoleType.Assistant and self.__tool_calls):
raise ValueError("内容不能为空")
if self.__role == RoleType.Tool and self.__tool_call_id is None:
raise ValueError("Tool角色的工具调用ID不能为空")

View File

@ -166,7 +166,7 @@ class LLMRequest:
time_cost=time.time() - start_time,
)
return content or "", (reasoning_content, model_info.name, tool_calls)
async def generate_response_with_message_async(
self,
message_factory: Callable[[BaseClient], List[Message]],

View File

@ -36,10 +36,10 @@ class MainSystem:
# 使用消息API替代直接的FastAPI实例
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
# 注册 WebUI API 路由
self._register_webui_routes()
# 设置 WebUI开发/生产模式)
self._setup_webui()
@ -47,6 +47,7 @@ class MainSystem:
"""注册 WebUI API 路由"""
try:
from src.webui.routes import router as webui_router
self.server.register_router(webui_router)
logger.info("WebUI API 路由已注册")
except Exception as e:
@ -55,15 +56,17 @@ class MainSystem:
def _setup_webui(self):
"""设置 WebUI根据环境变量决定模式"""
import os
webui_enabled = os.getenv("WEBUI_ENABLED", "false").lower() == "true"
if not webui_enabled:
logger.info("WebUI 已禁用")
return
webui_mode = os.getenv("WEBUI_MODE", "production").lower()
try:
from src.webui.manager import setup_webui
setup_webui(mode=webui_mode)
except Exception as e:
logger.error(f"设置 WebUI 失败: {e}")

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
记忆系统工具函数
包含模糊查找相似度计算等工具函数
"""
import json
import re
from datetime import datetime
@ -14,6 +15,7 @@ from src.common.logger import get_logger
logger = get_logger("memory_utils")
def parse_md_json(json_text: str) -> list[str]:
"""从Markdown格式的内容中提取JSON对象和推理内容"""
json_objects = []
@ -52,14 +54,15 @@ def parse_md_json(json_text: str) -> list[str]:
return json_objects, reasoning_content
def calculate_similarity(text1: str, text2: str) -> float:
"""
计算两个文本的相似度
Args:
text1: 第一个文本
text2: 第二个文本
Returns:
float: 相似度分数 (0-1)
"""
@ -67,16 +70,16 @@ def calculate_similarity(text1: str, text2: str) -> float:
# 预处理文本
text1 = preprocess_text(text1)
text2 = preprocess_text(text2)
# 使用SequenceMatcher计算相似度
similarity = SequenceMatcher(None, text1, text2).ratio()
# 如果其中一个文本包含另一个,提高相似度
if text1 in text2 or text2 in text1:
similarity = max(similarity, 0.8)
return similarity
except Exception as e:
logger.error(f"计算相似度时出错: {e}")
return 0.0
@ -85,31 +88,30 @@ def calculate_similarity(text1: str, text2: str) -> float:
def preprocess_text(text: str) -> str:
"""
预处理文本提高匹配准确性
Args:
text: 原始文本
Returns:
str: 预处理后的文本
"""
try:
# 转换为小写
text = text.lower()
# 移除标点符号和特殊字符
text = re.sub(r'[^\w\s]', '', text)
text = re.sub(r"[^\w\s]", "", text)
# 移除多余空格
text = re.sub(r'\s+', ' ', text).strip()
text = re.sub(r"\s+", " ", text).strip()
return text
except Exception as e:
logger.error(f"预处理文本时出错: {e}")
return text
def parse_datetime_to_timestamp(value: str) -> float:
"""
接受多种常见格式并转换为时间戳
@ -143,25 +145,24 @@ def parse_datetime_to_timestamp(value: str) -> float:
def parse_time_range(time_range: str) -> Tuple[float, float]:
"""
解析时间范围字符串返回开始和结束时间戳
Args:
time_range: 时间范围字符串格式"YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS"
Returns:
Tuple[float, float]: (开始时间戳, 结束时间戳)
"""
if " - " not in time_range:
raise ValueError(f"时间范围格式错误,应为 '开始时间 - 结束时间': {time_range}")
parts = time_range.split(" - ", 1)
if len(parts) != 2:
raise ValueError(f"时间范围格式错误: {time_range}")
start_str = parts[0].strip()
end_str = parts[1].strip()
start_timestamp = parse_datetime_to_timestamp(start_str)
end_timestamp = parse_datetime_to_timestamp(end_str)
return start_timestamp, end_timestamp
return start_timestamp, end_timestamp

View File

@ -17,6 +17,7 @@ from .query_lpmm_knowledge import register_tool as register_lpmm_knowledge
from .query_person_info import register_tool as register_query_person_info
from src.config.config import global_config
def init_all_tools():
"""初始化并注册所有记忆检索工具"""
register_query_jargon()

View File

@ -15,13 +15,10 @@ logger = get_logger("memory_retrieval_tools")
async def query_chat_history(
chat_id: str,
keyword: Optional[str] = None,
time_range: Optional[str] = None,
fuzzy: bool = True
chat_id: str, keyword: Optional[str] = None, time_range: Optional[str] = None, fuzzy: bool = True
) -> str:
"""根据时间或关键词在chat_history表中查询聊天记录概述
Args:
chat_id: 聊天ID
keyword: 关键词可选支持多个关键词可用空格逗号等分隔
@ -31,7 +28,7 @@ async def query_chat_history(
fuzzy: 是否使用模糊匹配模式默认True
- True: 模糊匹配只要包含任意一个关键词即匹配OR关系
- False: 全匹配必须包含所有关键词才匹配AND关系
Returns:
str: 查询结果
"""
@ -39,10 +36,10 @@ async def query_chat_history(
# 检查参数
if not keyword and not time_range:
return "未指定查询参数需要提供keyword或time_range之一"
# 构建查询条件
query = ChatHistory.select().where(ChatHistory.chat_id == chat_id)
# 时间过滤条件
if time_range:
# 判断是时间点还是时间范围
@ -50,79 +47,79 @@ async def query_chat_history(
# 时间范围:查询与时间范围有交集的记录
start_timestamp, end_timestamp = parse_time_range(time_range)
# 交集条件start_time < end_timestamp AND end_time > start_timestamp
time_filter = (
(ChatHistory.start_time < end_timestamp) &
(ChatHistory.end_time > start_timestamp)
)
time_filter = (ChatHistory.start_time < end_timestamp) & (ChatHistory.end_time > start_timestamp)
else:
# 时间点查询包含该时间点的记录start_time <= time_point <= end_time
target_timestamp = parse_datetime_to_timestamp(time_range)
time_filter = (
(ChatHistory.start_time <= target_timestamp) &
(ChatHistory.end_time >= target_timestamp)
)
time_filter = (ChatHistory.start_time <= target_timestamp) & (ChatHistory.end_time >= target_timestamp)
query = query.where(time_filter)
# 执行查询
records = list(query.order_by(ChatHistory.start_time.desc()).limit(50))
# 如果有关键词,进一步过滤
if keyword:
# 解析多个关键词(支持空格、逗号等分隔符)
keywords_list = parse_keywords_string(keyword)
if not keywords_list:
keywords_list = [keyword.strip()] if keyword.strip() else []
# 转换为小写以便匹配
keywords_lower = [kw.lower() for kw in keywords_list if kw.strip()]
if not keywords_lower:
return "关键词为空"
filtered_records = []
for record in records:
# 在theme、keywords、summary、original_text中搜索
theme = (record.theme or "").lower()
summary = (record.summary or "").lower()
original_text = (record.original_text or "").lower()
# 解析record中的keywords JSON
record_keywords_list = []
if record.keywords:
try:
keywords_data = json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
keywords_data = (
json.loads(record.keywords) if isinstance(record.keywords, str) else record.keywords
)
if isinstance(keywords_data, list):
record_keywords_list = [str(k).lower() for k in keywords_data]
except (json.JSONDecodeError, TypeError, ValueError):
pass
# 根据匹配模式检查关键词
matched = False
if fuzzy:
# 模糊匹配只要包含任意一个关键词即匹配OR关系
for kw in keywords_lower:
if (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list)):
if (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
):
matched = True
break
else:
# 全匹配必须包含所有关键词才匹配AND关系
matched = True
for kw in keywords_lower:
kw_matched = (kw in theme or
kw in summary or
kw in original_text or
any(kw in k for k in record_keywords_list))
kw_matched = (
kw in theme
or kw in summary
or kw in original_text
or any(kw in k for k in record_keywords_list)
)
if not kw_matched:
matched = False
break
if matched:
filtered_records.append(record)
if not filtered_records:
keywords_str = "".join(keywords_list)
match_mode = "包含任意一个关键词" if fuzzy else "包含所有关键词"
@ -130,9 +127,9 @@ async def query_chat_history(
return f"未找到{match_mode}'{keywords_str}'且在指定时间范围内的聊天记录概述"
else:
return f"未找到{match_mode}'{keywords_str}'的聊天记录概述"
records = filtered_records
# 如果没有记录(可能是时间范围查询但没有匹配的记录)
if not records:
if time_range:
@ -148,22 +145,23 @@ async def query_chat_history(
record.count = (record.count or 0) + 1
except Exception as update_error:
logger.error(f"更新聊天记录概述计数失败: {update_error}")
# 构建结果文本
results = []
for record in records_to_use: # 最多返回3条记录
result_parts = []
# 添加主题
if record.theme:
result_parts.append(f"主题:{record.theme}")
# 添加时间范围
from datetime import datetime
start_str = datetime.fromtimestamp(record.start_time).strftime("%Y-%m-%d %H:%M:%S")
end_str = datetime.fromtimestamp(record.end_time).strftime("%Y-%m-%d %H:%M:%S")
result_parts.append(f"时间:{start_str} - {end_str}")
# 添加概括优先使用summary如果没有则使用original_text的前200字符
if record.summary:
result_parts.append(f"概括:{record.summary}")
@ -172,18 +170,18 @@ async def query_chat_history(
if len(record.original_text) > 200:
text_preview += "..."
result_parts.append(f"内容:{text_preview}")
results.append("\n".join(result_parts))
if not results:
return "未找到相关聊天记录概述"
response_text = "\n\n---\n\n".join(results)
if len(records) > len(records_to_use):
omitted_count = len(records) - len(records_to_use)
response_text += f"\n\n(还有{omitted_count}条历史记录已省略)"
return response_text
except Exception as e:
logger.error(f"查询聊天历史概述失败: {e}")
return f"查询失败: {str(e)}"
@ -199,20 +197,20 @@ def register_tool():
"name": "keyword",
"type": "string",
"description": "关键词(可选,支持多个关键词,可用空格、逗号、斜杠等分隔,如:'麦麦 百度网盘''麦麦,百度网盘'。用于在主题、关键词、概括、原文中搜索)",
"required": False
"required": False,
},
{
"name": "time_range",
"type": "string",
"description": "时间范围或时间点(可选)。格式:'YYYY-MM-DD HH:MM:SS - YYYY-MM-DD HH:MM:SS'(时间范围,查询与时间范围有交集的记录)或 'YYYY-MM-DD HH:MM:SS'(时间点,查询包含该时间点的记录)",
"required": False
"required": False,
},
{
"name": "fuzzy",
"type": "boolean",
"description": "是否使用模糊匹配模式默认True。True表示模糊匹配只要包含任意一个关键词即匹配OR关系False表示全匹配必须包含所有关键词才匹配AND关系",
"required": False
}
"required": False,
},
],
execute_func=query_chat_history
execute_func=query_chat_history,
)

View File

@ -73,5 +73,3 @@ def register_tool():
],
execute_func=query_lpmm_knowledge,
)

View File

@ -14,23 +14,25 @@ logger = get_logger("memory_retrieval_tools")
def _format_group_nick_names(group_nick_name_field) -> str:
"""格式化群昵称信息
Args:
group_nick_name_field: 群昵称字段可能是字符串JSON或None
Returns:
str: 格式化后的群昵称信息字符串
"""
if not group_nick_name_field:
return ""
try:
# 解析JSON格式的群昵称列表
group_nick_names_data = json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
group_nick_names_data = (
json.loads(group_nick_name_field) if isinstance(group_nick_name_field, str) else group_nick_name_field
)
if not isinstance(group_nick_names_data, list) or not group_nick_names_data:
return ""
# 格式化群昵称列表
group_nick_list = []
for item in group_nick_names_data:
@ -41,7 +43,7 @@ def _format_group_nick_names(group_nick_name_field) -> str:
elif isinstance(item, str):
# 兼容旧格式(如果存在)
group_nick_list.append(f" - {item}")
if group_nick_list:
return "群昵称:\n" + "\n".join(group_nick_list)
return ""
@ -58,10 +60,10 @@ def _format_group_nick_names(group_nick_name_field) -> str:
async def query_person_info(person_name: str) -> str:
"""根据person_name查询用户信息使用模糊查询
Args:
person_name: 用户名称person_name字段
Returns:
str: 查询结果包含用户的所有信息
"""
@ -69,37 +71,35 @@ async def query_person_info(person_name: str) -> str:
person_name = str(person_name).strip()
if not person_name:
return "用户名称为空"
# 构建查询条件(使用模糊查询)
query = PersonInfo.select().where(
PersonInfo.person_name.contains(person_name)
)
query = PersonInfo.select().where(PersonInfo.person_name.contains(person_name))
# 执行查询
records = list(query.limit(20)) # 最多返回20条记录
if not records:
return f"未找到模糊匹配'{person_name}'的用户信息"
# 区分精确匹配和模糊匹配的结果
exact_matches = []
fuzzy_matches = []
for record in records:
# 检查是否是精确匹配
if record.person_name and record.person_name.strip() == person_name:
exact_matches.append(record)
else:
fuzzy_matches.append(record)
# 构建结果文本
results = []
# 先处理精确匹配的结果
for record in exact_matches:
result_parts = []
result_parts.append("【精确匹配】") # 标注为精确匹配
# 基本信息
if record.person_name:
result_parts.append(f"用户名称:{record.person_name}")
@ -111,19 +111,19 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"平台:{record.platform}")
if record.user_id:
result_parts.append(f"平台用户ID{record.user_id}")
# 群昵称信息
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
if group_nick_name_str:
result_parts.append(group_nick_name_str)
# 名称设定原因
if record.name_reason:
result_parts.append(f"名称设定原因:{record.name_reason}")
# 认识状态
result_parts.append(f"是否已认识:{'' if record.is_known else ''}")
# 时间信息
if record.know_since:
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
@ -133,11 +133,15 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"最后认识时间:{last_know_str}")
if record.know_times:
result_parts.append(f"认识次数:{int(record.know_times)}")
# 记忆点memory_points
if record.memory_points:
try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight
memory_list = []
@ -151,7 +155,7 @@ async def query_person_info(person_name: str) -> str:
memory_list.append(f" - [{category}] {content} (权重: {weight})")
else:
memory_list.append(f" - {memory_point}")
if memory_list:
result_parts.append("记忆点:\n" + "\n".join(memory_list))
except (json.JSONDecodeError, TypeError, ValueError) as e:
@ -161,14 +165,14 @@ async def query_person_info(person_name: str) -> str:
if len(str(record.memory_points)) > 200:
memory_preview += "..."
result_parts.append(f"记忆点(原始数据):{memory_preview}")
results.append("\n".join(result_parts))
# 再处理模糊匹配的结果
for record in fuzzy_matches:
result_parts = []
result_parts.append("【模糊匹配】") # 标注为模糊匹配
# 基本信息
if record.person_name:
result_parts.append(f"用户名称:{record.person_name}")
@ -180,19 +184,19 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"平台:{record.platform}")
if record.user_id:
result_parts.append(f"平台用户ID{record.user_id}")
# 群昵称信息
group_nick_name_str = _format_group_nick_names(getattr(record, "group_nick_name", None))
if group_nick_name_str:
result_parts.append(group_nick_name_str)
# 名称设定原因
if record.name_reason:
result_parts.append(f"名称设定原因:{record.name_reason}")
# 认识状态
result_parts.append(f"是否已认识:{'' if record.is_known else ''}")
# 时间信息
if record.know_since:
know_since_str = datetime.fromtimestamp(record.know_since).strftime("%Y-%m-%d %H:%M:%S")
@ -202,11 +206,15 @@ async def query_person_info(person_name: str) -> str:
result_parts.append(f"最后认识时间:{last_know_str}")
if record.know_times:
result_parts.append(f"认识次数:{int(record.know_times)}")
# 记忆点memory_points
if record.memory_points:
try:
memory_points_data = json.loads(record.memory_points) if isinstance(record.memory_points, str) else record.memory_points
memory_points_data = (
json.loads(record.memory_points)
if isinstance(record.memory_points, str)
else record.memory_points
)
if isinstance(memory_points_data, list) and memory_points_data:
# 解析记忆点格式category:content:weight
memory_list = []
@ -220,7 +228,7 @@ async def query_person_info(person_name: str) -> str:
memory_list.append(f" - [{category}] {content} (权重: {weight})")
else:
memory_list.append(f" - {memory_point}")
if memory_list:
result_parts.append("记忆点:\n" + "\n".join(memory_list))
except (json.JSONDecodeError, TypeError, ValueError) as e:
@ -230,20 +238,20 @@ async def query_person_info(person_name: str) -> str:
if len(str(record.memory_points)) > 200:
memory_preview += "..."
result_parts.append(f"记忆点(原始数据):{memory_preview}")
results.append("\n".join(result_parts))
# 组合所有结果
if not results:
return f"未找到匹配'{person_name}'的用户信息"
response_text = "\n\n---\n\n".join(results)
# 添加统计信息
total_count = len(records)
exact_count = len(exact_matches)
fuzzy_count = len(fuzzy_matches)
# 显示精确匹配和模糊匹配的统计
if exact_count > 0 or fuzzy_count > 0:
stats_parts = []
@ -257,13 +265,13 @@ async def query_person_info(person_name: str) -> str:
response_text = f"找到 {total_count} 条匹配的用户信息:\n\n{response_text}"
else:
response_text = f"找到用户信息:\n\n{response_text}"
# 如果结果数量达到限制,添加提示
if total_count >= 20:
response_text += "\n\n(已显示前20条结果可能还有更多匹配记录)"
return response_text
except Exception as e:
logger.error(f"查询用户信息失败: {e}")
return f"查询失败: {str(e)}"
@ -275,13 +283,7 @@ def register_tool():
name="query_person_info",
description="根据查询某个用户的所有信息。名称、昵称、平台、用户ID、qq号、群昵称等",
parameters=[
{
"name": "person_name",
"type": "string",
"description": "用户名称,用于查询用户信息",
"required": True
}
{"name": "person_name", "type": "string", "description": "用户名称,用于查询用户信息", "required": True}
],
execute_func=query_person_info
execute_func=query_person_info,
)

View File

@ -47,10 +47,10 @@ class MemoryRetrievalTool:
async def execute(self, **kwargs) -> str:
"""执行工具"""
return await self.execute_func(**kwargs)
def get_tool_definition(self) -> Dict[str, Any]:
"""获取工具定义用于LLM function calling
Returns:
Dict[str, Any]: 工具定义字典格式与BaseTool一致
格式: {"name": str, "description": str, "parameters": List[Tuple]}
@ -58,14 +58,14 @@ class MemoryRetrievalTool:
# 转换参数格式为元组列表格式与BaseTool一致
# 格式: [("param_name", ToolParamType, "description", required, enum_values)]
param_tuples = []
for param in self.parameters:
param_name = param.get("name", "")
param_type_str = param.get("type", "string").lower()
param_desc = param.get("description", "")
is_required = param.get("required", False)
enum_values = param.get("enum", None)
# 转换类型字符串到ToolParamType
type_mapping = {
"string": ToolParamType.STRING,
@ -76,18 +76,14 @@ class MemoryRetrievalTool:
"bool": ToolParamType.BOOLEAN,
}
param_type = type_mapping.get(param_type_str, ToolParamType.STRING)
# 构建参数元组
param_tuple = (param_name, param_type, param_desc, is_required, enum_values)
param_tuples.append(param_tuple)
# 构建工具定义格式与BaseTool.get_tool_definition()一致
tool_def = {
"name": self.name,
"description": self.description,
"parameters": param_tuples
}
tool_def = {"name": self.name, "description": self.description, "parameters": param_tuples}
return tool_def
@ -126,10 +122,10 @@ class MemoryRetrievalToolRegistry:
action_types.append("final_answer")
action_types.append("no_answer")
return "".join([f'"{at}"' for at in action_types])
def get_tool_definitions(self) -> List[Dict[str, Any]]:
"""获取所有工具的定义列表用于LLM function calling
Returns:
List[Dict[str, Any]]: 工具定义列表每个元素是一个工具定义字典
"""

View File

@ -162,7 +162,12 @@ def levenshtein_distance(s1: str, s2: str) -> int:
class Person:
@classmethod
def register_person(
cls, platform: str, user_id: str, nickname: str, group_id: Optional[str] = None, group_nick_name: Optional[str] = None
cls,
platform: str,
user_id: str,
nickname: str,
group_id: Optional[str] = None,
group_nick_name: Optional[str] = None,
):
"""
注册新用户的类方法
@ -727,7 +732,7 @@ person_info_manager = PersonInfoManager()
async def store_person_memory_from_answer(person_name: str, memory_content: str, chat_id: str) -> None:
"""将人物信息存入person_info的memory_points
Args:
person_name: 人物名称
memory_content: 记忆内容
@ -739,13 +744,13 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if not chat_stream:
logger.warning(f"无法获取chat_stream for chat_id: {chat_id}")
return
platform = chat_stream.platform
# 尝试从person_name查找person_id
# 首先尝试通过person_name查找
person_id = get_person_id_by_person_name(person_name)
if not person_id:
# 如果通过person_name找不到尝试从chat_stream获取user_info
if chat_stream.user_info:
@ -754,25 +759,25 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
else:
logger.warning(f"无法确定person_id for person_name: {person_name}, chat_id: {chat_id}")
return
# 创建或获取Person对象
person = Person(person_id=person_id)
if not person.is_known:
logger.warning(f"用户 {person_name} (person_id: {person_id}) 尚未认识,无法存储记忆")
return
# 确定记忆分类可以根据memory_content判断这里使用通用分类
category = "其他" # 默认分类,可以根据需要调整
# 记忆点格式category:content:weight
weight = "1.0" # 默认权重
memory_point = f"{category}:{memory_content}:{weight}"
# 添加到memory_points
if not person.memory_points:
person.memory_points = []
# 检查是否已存在相似的记忆点(避免重复)
is_duplicate = False
for existing_point in person.memory_points:
@ -781,16 +786,20 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if len(parts) >= 2:
existing_content = parts[1].strip()
# 简单相似度检查(如果内容相同或非常相似,则跳过)
if existing_content == memory_content or memory_content in existing_content or existing_content in memory_content:
if (
existing_content == memory_content
or memory_content in existing_content
or existing_content in memory_content
):
is_duplicate = True
break
if not is_duplicate:
person.memory_points.append(memory_point)
person.sync_to_database()
logger.info(f"成功添加记忆点到 {person_name} (person_id: {person_id}): {memory_point}")
else:
logger.debug(f"记忆点已存在,跳过: {memory_point}")
except Exception as e:
logger.error(f"存储人物记忆失败: {e}")

View File

@ -124,7 +124,6 @@ class ToolExecutor:
response, (reasoning_content, model_name, tool_calls) = await self.llm_model.generate_response_async(
prompt=prompt, tools=tools, raise_when_empty=False
)
# 执行工具调用
tool_results, used_tools = await self.execute_tool_calls(tool_calls)

View File

@ -51,7 +51,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典保留 target 中的注释和格式
source 的值更新到 target 仅更新已存在的键
Args:
target: 目标字典tomlkit 对象包含注释
source: 源字典普通 dict list
@ -59,7 +59,7 @@ def _update_dict_preserve_comments(target: Any, source: Any) -> None:
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():

View File

@ -1,4 +1,5 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from fastapi.responses import FileResponse
from pydantic import BaseModel
@ -18,6 +19,7 @@ router = APIRouter(prefix="/emoji", tags=["Emoji"])
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
@ -35,6 +37,7 @@ class EmojiResponse(BaseModel):
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
@ -44,12 +47,14 @@ class EmojiListResponse(BaseModel):
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
@ -58,6 +63,7 @@ class EmojiUpdateRequest(BaseModel):
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
@ -65,6 +71,7 @@ class EmojiUpdateResponse(BaseModel):
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
@ -73,13 +80,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
@ -120,11 +127,11 @@ async def get_emoji_list(
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表情包列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
@ -133,61 +140,51 @@ async def get_emoji_list(
is_banned: 是否被禁用筛选
format: 格式筛选
authorization: Authorization header
Returns:
表情包列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Emoji.select()
# 搜索过滤
if search:
query = query.where(
(Emoji.description.contains(search)) |
(Emoji.emoji_hash.contains(search))
)
query = query.where((Emoji.description.contains(search)) | (Emoji.emoji_hash.contains(search)))
# 注册状态过滤
if is_registered is not None:
query = query.where(Emoji.is_registered == is_registered)
# 禁用状态过滤
if is_banned is not None:
query = query.where(Emoji.is_banned == is_banned)
# 格式过滤
if format:
query = query.where(Emoji.format == format)
# 排序:使用次数倒序,然后按记录时间倒序
from peewee import Case
query = query.order_by(
Emoji.usage_count.desc(),
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
Emoji.record_time.desc()
Emoji.usage_count.desc(), Case(None, [(Emoji.record_time.is_null(), 1)], 0), Emoji.record_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
emojis = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis]
return EmojiListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return EmojiListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
@ -196,33 +193,27 @@ async def get_emoji_list(
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def get_emoji_detail(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
获取表情包详细信息
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
表情包详细信息
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(
success=True,
data=emoji_to_response(emoji)
)
return EmojiDetailResponse(success=True, data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
@ -231,61 +222,55 @@ async def get_emoji_detail(
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_emoji(emoji_id: int, request: EmojiUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新表情包只更新提供的字段
Args:
emoji_id: 表情包ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 处理情感标签(转换为 JSON
if 'emotion' in update_data:
if update_data['emotion'] is None:
update_data['emotion'] = None
if "emotion" in update_data:
if update_data["emotion"] is None:
update_data["emotion"] = None
else:
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
update_data["emotion"] = json.dumps(update_data["emotion"], ensure_ascii=False)
# 如果注册状态从 False 变为 True记录注册时间
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
update_data['register_time'] = time.time()
if "is_registered" in update_data and update_data["is_registered"] and not emoji.is_registered:
update_data["register_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(emoji, field, value)
emoji.save()
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
return EmojiUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=emoji_to_response(emoji)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
@ -294,41 +279,35 @@ async def update_emoji(
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def delete_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
删除表情包
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 记录删除信息
emoji_hash = emoji.emoji_hash
# 执行删除
emoji.delete_instance()
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
return EmojiDeleteResponse(
success=True,
message=f"成功删除表情包: {emoji_hash}"
)
return EmojiDeleteResponse(success=True, message=f"成功删除表情包: {emoji_hash}")
except HTTPException:
raise
except Exception as e:
@ -337,31 +316,29 @@ async def delete_emoji(
@router.get("/stats/summary")
async def get_emoji_stats(
authorization: Optional[str] = Header(None)
):
async def get_emoji_stats(authorization: Optional[str] = Header(None)):
"""
获取表情包统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Emoji.select().count()
registered = Emoji.select().where(Emoji.is_registered).count()
banned = Emoji.select().where(Emoji.is_banned).count()
# 按格式统计
formats = {}
for emoji in Emoji.select(Emoji.format):
fmt = emoji.format
formats[fmt] = formats.get(fmt, 0) + 1
# 获取最常用的表情包前10
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
top_used_list = [
@ -369,11 +346,11 @@ async def get_emoji_stats(
"id": emoji.id,
"emoji_hash": emoji.emoji_hash,
"description": emoji.description,
"usage_count": emoji.usage_count
"usage_count": emoji.usage_count,
}
for emoji in top_used
]
return {
"success": True,
"data": {
@ -382,10 +359,10 @@ async def get_emoji_stats(
"banned": banned,
"unregistered": total - registered,
"formats": formats,
"top_used": top_used_list
}
"top_used": top_used_list,
},
}
except HTTPException:
raise
except Exception as e:
@ -394,47 +371,40 @@ async def get_emoji_stats(
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def register_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
注册表情包快捷操作
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
if emoji.is_banned:
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
# 注册表情包
emoji.is_registered = True
emoji.register_time = time.time()
emoji.save()
logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包注册成功",
data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(success=True, message="表情包注册成功", data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
@ -443,41 +413,34 @@ async def register_emoji(
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
async def ban_emoji(emoji_id: int, authorization: Optional[str] = Header(None)):
"""
禁用表情包快捷操作
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 禁用表情包(同时取消注册)
emoji.is_banned = True
emoji.is_registered = False
emoji.save()
logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包禁用成功",
data=emoji_to_response(emoji)
)
return EmojiUpdateResponse(success=True, message="表情包禁用成功", data=emoji_to_response(emoji))
except HTTPException:
raise
except Exception as e:
@ -489,16 +452,16 @@ async def ban_emoji(
async def get_emoji_thumbnail(
emoji_id: int,
token: Optional[str] = Query(None, description="访问令牌"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表情包缩略图
Args:
emoji_id: 表情包ID
token: 访问令牌通过 query parameter
authorization: Authorization header
Returns:
表情包图片文件
"""
@ -511,37 +474,32 @@ async def get_emoji_thumbnail(
else:
# 如果没有 query token则验证 Authorization header
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 检查文件是否存在
if not os.path.exists(emoji.full_path):
raise HTTPException(status_code=404, detail="表情包文件不存在")
# 根据格式设置 MIME 类型
mime_types = {
'png': 'image/png',
'jpg': 'image/jpeg',
'jpeg': 'image/jpeg',
'gif': 'image/gif',
'webp': 'image/webp',
'bmp': 'image/bmp'
"png": "image/png",
"jpg": "image/jpeg",
"jpeg": "image/jpeg",
"gif": "image/gif",
"webp": "image/webp",
"bmp": "image/bmp",
}
media_type = mime_types.get(emoji.format.lower(), 'application/octet-stream')
return FileResponse(
path=emoji.full_path,
media_type=media_type,
filename=f"{emoji.emoji_hash}.{emoji.format}"
)
media_type = mime_types.get(emoji.format.lower(), "application/octet-stream")
return FileResponse(path=emoji.full_path, media_type=media_type, filename=f"{emoji.emoji_hash}.{emoji.format}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包缩略图失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e

View File

@ -1,4 +1,5 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List
@ -15,6 +16,7 @@ router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
@ -27,6 +29,7 @@ class ExpressionResponse(BaseModel):
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
@ -36,12 +39,14 @@ class ExpressionListResponse(BaseModel):
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
context: Optional[str] = None
@ -51,6 +56,7 @@ class ExpressionCreateRequest(BaseModel):
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
context: Optional[str] = None
@ -60,6 +66,7 @@ class ExpressionUpdateRequest(BaseModel):
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
@ -67,12 +74,14 @@ class ExpressionUpdateResponse(BaseModel):
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
@ -82,13 +91,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
@ -112,64 +121,58 @@ async def get_expression_list(
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style, context)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search)) |
(Expression.style.contains(search)) |
(Expression.context.contains(search))
(Expression.situation.contains(search))
| (Expression.style.contains(search))
| (Expression.context.contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
Expression.last_active_time.desc()
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
@ -178,33 +181,27 @@ async def get_expression_list(
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int,
authorization: Optional[str] = Header(None)
):
async def get_expression_detail(expression_id: int, authorization: Optional[str] = Header(None)):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(
success=True,
data=expression_to_response(expression)
)
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
@ -213,25 +210,22 @@ async def get_expression_detail(
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
authorization: Optional[str] = Header(None)
):
async def create_expression(request: ExpressionCreateRequest, authorization: Optional[str] = Header(None)):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
@ -242,15 +236,13 @@ async def create_expression(
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True,
message="表达方式创建成功",
data=expression_to_response(expression)
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
@ -260,52 +252,48 @@ async def create_expression(
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
authorization: Optional[str] = Header(None)
expression_id: int, request: ExpressionUpdateRequest, authorization: Optional[str] = Header(None)
):
"""
增量更新表达方式只更新提供的字段
Args:
expression_id: 表达方式ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data['last_active_time'] = time.time()
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=expression_to_response(expression)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
@ -314,41 +302,35 @@ async def update_expression(
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int,
authorization: Optional[str] = Header(None)
):
async def delete_expression(expression_id: int, authorization: Optional[str] = Header(None)):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(
success=True,
message=f"成功删除表达方式: {situation}"
)
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
@ -357,46 +339,45 @@ async def delete_expression(
@router.get("/stats/summary")
async def get_expression_stats(
authorization: Optional[str] = Header(None)
):
async def get_expression_stats(authorization: Optional[str] = Header(None)):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = Expression.select().where(
(Expression.create_date.is_null(False)) &
(Expression.create_date >= seven_days_ago)
).count()
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
}
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:

View File

@ -1,4 +1,5 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
@ -15,6 +16,7 @@ logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
@ -23,6 +25,7 @@ def set_update_progress_callback(callback):
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
@ -34,10 +37,10 @@ class MirrorType(str, Enum):
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
@ -47,7 +50,7 @@ class GitMirrorConfig:
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None
"created_at": None,
},
{
"id": "hk-gh-proxy",
@ -56,7 +59,7 @@ class GitMirrorConfig:
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None
"created_at": None,
},
{
"id": "cdn-gh-proxy",
@ -65,7 +68,7 @@ class GitMirrorConfig:
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
@ -74,7 +77,7 @@ class GitMirrorConfig:
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None
"created_at": None,
},
{
"id": "meyzh-github",
@ -83,7 +86,7 @@ class GitMirrorConfig:
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None
"created_at": None,
},
{
"id": "github",
@ -92,23 +95,23 @@ class GitMirrorConfig:
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None
}
"created_at": None,
},
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
@ -122,59 +125,59 @@ class GitMirrorConfig:
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, 'w', encoding='utf-8') as f:
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
@ -182,26 +185,26 @@ class GitMirrorConfig:
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
@ -209,15 +212,15 @@ class GitMirrorConfig:
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat()
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
@ -225,11 +228,11 @@ class GitMirrorConfig:
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置如果不存在则返回 None
"""
@ -245,19 +248,19 @@ class GitMirrorConfig:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
@ -267,9 +270,9 @@ class GitMirrorConfig:
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
@ -278,16 +281,11 @@ class GitMirrorConfig:
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(
self,
max_retries: int = 3,
timeout: int = 30,
config: Optional[GitMirrorConfig] = None
):
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间
@ -297,16 +295,16 @@ class GitMirrorService:
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
@ -316,54 +314,33 @@ class GitMirrorService:
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {
"installed": False,
"error": "系统中未找到 Git请先安装 Git"
}
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(
["git", "--version"],
capture_output=True,
text=True,
timeout=5
)
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {
"installed": True,
"version": version,
"path": git_path
}
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {
"installed": False,
"error": f"Git 命令执行失败: {result.stderr}"
}
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {
"installed": False,
"error": "Git 版本检测超时"
}
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {
"installed": False,
"error": f"检测 Git 时发生错误: {str(e)}"
}
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
owner: str,
@ -371,11 +348,11 @@ class GitMirrorService:
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
@ -383,7 +360,7 @@ class GitMirrorService:
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
@ -393,29 +370,24 @@ class GitMirrorService:
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
@ -427,15 +399,13 @@ class GitMirrorService:
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(
owner, repo, branch, file_path, mirror
)
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
if _update_progress:
@ -445,15 +415,15 @@ class GitMirrorService:
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
@ -461,39 +431,29 @@ class GitMirrorService:
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror: Dict[str, Any]
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
@ -501,14 +461,14 @@ class GitMirrorService:
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
@ -519,15 +479,9 @@ class GitMirrorService:
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
owner: str,
@ -536,11 +490,11 @@ class GitMirrorService:
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
@ -549,7 +503,7 @@ class GitMirrorService:
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度浅克隆
Returns:
Dict 包含:
- success: bool - 是否成功
@ -559,44 +513,32 @@ class GitMirrorService:
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(
owner, repo, target_path, branch, depth, mirror
)
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源克隆均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
owner: str,
@ -604,52 +546,47 @@ class GitMirrorService:
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any]
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self,
url: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror_type: str
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
@ -657,24 +594,24 @@ class GitMirrorService:
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install"
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone():
return subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
@ -683,40 +620,34 @@ class GitMirrorService:
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default"
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例

View File

@ -1,4 +1,5 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
import json
@ -14,30 +15,30 @@ active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
@ -49,7 +50,9 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
@ -64,7 +67,7 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@ -72,35 +75,35 @@ def load_recent_logs(limit: int = 100) -> list[dict]:
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
@ -111,19 +114,19 @@ async def websocket_logs(websocket: WebSocket):
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
@ -131,7 +134,7 @@ async def broadcast_log(log_data: dict):
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)

View File

@ -1,4 +1,5 @@
"""WebUI 管理器 - 处理开发/生产环境的 WebUI 启动"""
import os
from pathlib import Path
from src.common.logger import get_logger
@ -10,10 +11,10 @@ logger = get_logger("webui")
def setup_webui(mode: str = "production") -> bool:
"""
设置 WebUI
Args:
mode: 运行模式"development" "production"
Returns:
bool: 是否成功设置
"""
@ -22,7 +23,7 @@ def setup_webui(mode: str = "production") -> bool:
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
if mode == "development":
return setup_dev_mode()
else:
@ -33,12 +34,12 @@ def setup_dev_mode() -> bool:
"""设置开发模式 - 仅启用 CORS前端自行启动"""
from src.common.server import get_global_server
from .logs_ws import router as logs_router
# 注册 WebSocket 日志路由(开发模式也需要)
server = get_global_server()
server.register_router(logs_router)
logger.info("✅ WebSocket 日志推送路由已注册")
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
logger.info("💡 前端将运行在 http://localhost:7999")
@ -52,33 +53,33 @@ def setup_production_mode() -> bool:
from starlette.responses import FileResponse
from .logs_ws import router as logs_router
import mimetypes
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('application/javascript', '.mjs')
mimetypes.add_type('text/css', '.css')
mimetypes.add_type('application/json', '.json')
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
server = get_global_server()
# 注册 WebSocket 日志路由
server.register_router(logs_router)
logger.info("✅ WebSocket 日志推送路由已注册")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return False
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return False
# 处理 SPA 路由
@server.app.get("/{full_path:path}")
async def serve_spa(full_path: str):
@ -86,23 +87,23 @@ def setup_production_mode() -> bool:
# API 路由不处理
if full_path.startswith("api/"):
return None
# 检查文件是否存在
file_path = static_path / full_path
if file_path.is_file():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
# 返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html", media_type="text/html")
host = os.getenv("HOST", "127.0.0.1")
port = os.getenv("PORT", "8000")
logger.info("✅ WebUI 生产模式已挂载")
logger.info(f"🌐 访问 http://{host}:{port} 查看 WebUI")
return True
except Exception as e:
logger.error(f"挂载 WebUI 静态文件失败: {e}")
return False

View File

@ -1,4 +1,5 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
@ -16,6 +17,7 @@ router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
@ -33,6 +35,7 @@ class PersonInfoResponse(BaseModel):
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
@ -42,12 +45,14 @@ class PersonListResponse(BaseModel):
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
@ -57,6 +62,7 @@ class PersonUpdateRequest(BaseModel):
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
@ -64,6 +70,7 @@ class PersonUpdateResponse(BaseModel):
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
@ -72,13 +79,13 @@ def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
@ -118,11 +125,11 @@ async def get_person_list(
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
authorization: Optional[str] = Header(None)
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
@ -130,58 +137,50 @@ async def get_person_list(
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = PersonInfo.select()
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search)) |
(PersonInfo.nickname.contains(search)) |
(PersonInfo.user_id.contains(search))
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
PersonInfo.last_know.desc()
)
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
@ -190,33 +189,27 @@ async def get_person_list(
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str,
authorization: Optional[str] = Header(None)
):
async def get_person_detail(person_id: str, authorization: Optional[str] = Header(None)):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(
success=True,
data=person_to_response(person)
)
return PersonDetailResponse(success=True, data=person_to_response(person))
except HTTPException:
raise
except Exception as e:
@ -225,53 +218,47 @@ async def get_person_detail(
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_person(person_id: str, request: PersonUpdateRequest, authorization: Optional[str] = Header(None)):
"""
增量更新人物信息只更新提供的字段
Args:
person_id: 人物唯一 ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data['last_know'] = time.time()
update_data["last_know"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=person_to_response(person)
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
@ -280,41 +267,35 @@ async def update_person(
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str,
authorization: Optional[str] = Header(None)
):
async def delete_person(person_id: str, authorization: Optional[str] = Header(None)):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
# 执行删除
person.delete_instance()
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(
success=True,
message=f"成功删除人物信息: {person_name}"
)
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
except HTTPException:
raise
except Exception as e:
@ -323,41 +304,31 @@ async def delete_person(
@router.get("/stats/summary")
async def get_person_stats(
authorization: Optional[str] = Header(None)
):
async def get_person_stats(authorization: Optional[str] = Header(None)):
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {
"success": True,
"data": {
"total": total,
"known": known,
"unknown": unknown,
"platforms": platforms
}
}
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
except HTTPException:
raise
except Exception as e:

View File

@ -1,4 +1,5 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
import json
@ -22,7 +23,7 @@ current_progress: Dict[str, Any] = {
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0
"loaded_plugins": 0,
}
@ -30,20 +31,20 @@ async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
@ -57,10 +58,10 @@ async def update_progress(
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0
loaded_plugins: int = 0,
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
@ -80,9 +81,9 @@ async def update_progress(
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time()
"timestamp": asyncio.get_event_loop().time(),
}
await broadcast_progress(progress_data)
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
@ -90,30 +91,30 @@ async def update_progress(
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")

File diff suppressed because it is too large Load Diff

View File

@ -3,6 +3,7 @@
提供系统重启状态查询等功能
"""
import os
import sys
import time
@ -19,12 +20,14 @@ _start_time = time.time()
class RestartResponse(BaseModel):
"""重启响应"""
success: bool
message: str
class StatusResponse(BaseModel):
"""状态响应"""
running: bool
uptime: float
version: str
@ -35,74 +38,60 @@ class StatusResponse(BaseModel):
async def restart_maibot():
"""
重启麦麦主程序
使用 os.execv 重启当前进程配置更改将在重启后生效
注意此操作会使麦麦暂时离线
"""
try:
# 记录重启操作
print(f"[{datetime.now()}] WebUI 触发重启操作")
# 使用 os.execv 重启当前进程
# 这会替换当前进程,保持相同的 PID
python = sys.executable
args = [python] + sys.argv
# 返回成功响应(实际上这个响应可能不会发送,因为进程会立即重启)
# 但我们仍然返回它以保持 API 一致性
os.execv(python, args)
return RestartResponse(
success=True,
message="麦麦正在重启中..."
)
return RestartResponse(success=True, message="麦麦正在重启中...")
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"重启失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status():
"""
获取麦麦运行状态
返回麦麦的运行状态运行时长和版本信息
"""
try:
uptime = time.time() - _start_time
# 尝试获取版本信息(需要根据实际情况调整)
version = MMC_VERSION # 可以从配置或常量中读取
return StatusResponse(
running=True,
uptime=uptime,
version=version,
start_time=datetime.fromtimestamp(_start_time).isoformat()
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
)
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"获取状态失败: {str(e)}"
) from e
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
# 可选:添加更多系统控制功能
@router.post("/reload-config")
async def reload_config():
"""
热重载配置不重启进程
仅重新加载配置文件某些配置可能需要重启才能生效
此功能需要在主程序中实现配置热重载逻辑
"""
# 这里需要调用主程序的配置重载函数
# 示例await app_instance.reload_config()
return {
"success": True,
"message": "配置重载功能待实现"
}
return {"success": True, "message": "配置重载功能待实现"}

View File

@ -1,4 +1,5 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header
from pydantic import BaseModel, Field
from typing import Optional
@ -38,28 +39,33 @@ router.include_router(system_router)
class TokenVerifyRequest(BaseModel):
"""Token 验证请求"""
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
"""Token 验证响应"""
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
class TokenUpdateRequest(BaseModel):
"""Token 更新请求"""
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
"""Token 更新响应"""
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
"""Token 重新生成响应"""
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
@ -67,18 +73,21 @@ class TokenRegenerateResponse(BaseModel):
class FirstSetupStatusResponse(BaseModel):
"""首次配置状态响应"""
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
"""完成配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
"""重置配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
@ -93,44 +102,35 @@ async def health_check():
async def verify_token(request: TokenVerifyRequest):
"""
验证访问令牌
Args:
request: 包含 token 的验证请求
Returns:
验证结果
"""
try:
token_manager = get_token_manager()
is_valid = token_manager.verify_token(request.token)
if is_valid:
return TokenVerifyResponse(
valid=True,
message="Token 验证成功"
)
return TokenVerifyResponse(valid=True, message="Token 验证成功")
else:
return TokenVerifyResponse(
valid=False,
message="Token 无效或已过期"
)
return TokenVerifyResponse(valid=False, message="Token 无效或已过期")
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@router.post("/auth/update", response_model=TokenUpdateResponse)
async def update_token(
request: TokenUpdateRequest,
authorization: Optional[str] = Header(None)
):
async def update_token(request: TokenUpdateRequest, authorization: Optional[str] = Header(None)):
"""
更新访问令牌需要当前有效的 token
Args:
request: 包含新 token 的更新请求
authorization: Authorization header (Bearer token)
Returns:
更新结果
"""
@ -138,20 +138,17 @@ async def update_token(
# 验证当前 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 更新 token
success, message = token_manager.update_token(request.new_token)
return TokenUpdateResponse(
success=success,
message=message
)
return TokenUpdateResponse(success=success, message=message)
except HTTPException:
raise
except Exception as e:
@ -163,10 +160,10 @@ async def update_token(
async def regenerate_token(authorization: Optional[str] = Header(None)):
"""
重新生成访问令牌需要当前有效的 token
Args:
authorization: Authorization header (Bearer token)
Returns:
新生成的 token
"""
@ -174,21 +171,17 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
# 验证当前 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 重新生成 token
new_token = token_manager.regenerate_token()
return TokenRegenerateResponse(
success=True,
token=new_token,
message="Token 已重新生成"
)
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
except HTTPException:
raise
except Exception as e:
@ -200,10 +193,10 @@ async def regenerate_token(authorization: Optional[str] = Header(None)):
async def get_setup_status(authorization: Optional[str] = Header(None)):
"""
获取首次配置状态
Args:
authorization: Authorization header (Bearer token)
Returns:
首次配置状态
"""
@ -211,20 +204,17 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
# 验证 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 检查是否为首次配置
is_first = token_manager.is_first_setup()
return FirstSetupStatusResponse(
is_first_setup=is_first,
message="首次配置" if is_first else "已完成配置"
)
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
except HTTPException:
raise
except Exception as e:
@ -236,10 +226,10 @@ async def get_setup_status(authorization: Optional[str] = Header(None)):
async def complete_setup(authorization: Optional[str] = Header(None)):
"""
标记首次配置完成
Args:
authorization: Authorization header (Bearer token)
Returns:
完成结果
"""
@ -247,20 +237,17 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
# 验证 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 标记配置完成
success = token_manager.mark_setup_completed()
return CompleteSetupResponse(
success=success,
message="配置已完成" if success else "标记失败"
)
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
except HTTPException:
raise
except Exception as e:
@ -272,10 +259,10 @@ async def complete_setup(authorization: Optional[str] = Header(None)):
async def reset_setup(authorization: Optional[str] = Header(None)):
"""
重置首次配置状态允许重新进入配置向导
Args:
authorization: Authorization header (Bearer token)
Returns:
重置结果
"""
@ -283,20 +270,17 @@ async def reset_setup(authorization: Optional[str] = Header(None)):
# 验证 token
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
current_token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 重置配置状态
success = token_manager.reset_setup_status()
return ResetSetupResponse(
success=success,
message="配置状态已重置" if success else "重置失败"
)
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
except HTTPException:
raise
except Exception as e:

View File

@ -1,4 +1,5 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any, List
@ -15,6 +16,7 @@ router = APIRouter(prefix="/statistics", tags=["statistics"])
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
@ -28,6 +30,7 @@ class StatisticsSummary(BaseModel):
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
@ -37,6 +40,7 @@ class ModelStatistics(BaseModel):
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
@ -45,6 +49,7 @@ class TimeSeriesData(BaseModel):
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
@ -56,39 +61,39 @@ class DashboardData(BaseModel):
async def get_dashboard_data(hours: int = 24):
"""
获取仪表盘统计数据
Args:
hours: 统计时间范围小时默认24小时
Returns:
仪表盘数据
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
# 获取摘要数据
summary = await _get_summary_statistics(start_time, now)
# 获取模型统计
model_stats = await _get_model_statistics(start_time)
# 获取小时级时间序列数据
hourly_data = await _get_hourly_statistics(start_time, now)
# 获取日级时间序列数据最近7天
daily_start = now - timedelta(days=7)
daily_data = await _get_daily_statistics(daily_start, now)
# 获取最近活动
recent_activity = await _get_recent_activity(limit=10)
return DashboardData(
summary=summary,
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity
recent_activity=recent_activity,
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
@ -98,100 +103,84 @@ async def get_dashboard_data(hours: int = 24):
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据"""
summary = StatisticsSummary()
# 查询 LLM 使用记录
llm_records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
total_time_cost = 0.0
time_cost_count = 0
for record in llm_records:
summary.total_requests += 1
summary.total_cost += record.cost or 0.0
summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
total_time_cost += record.time_cost
time_cost_count += 1
# 计算平均响应时间
if time_cost_count > 0:
summary.avg_response_time = total_time_cost / time_cost_count
# 查询在线时间
online_records = list(
OnlineTime.select()
.where(
(OnlineTime.start_timestamp >= start_time) |
(OnlineTime.end_timestamp >= start_time)
)
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
for record in online_records:
start = max(record.start_timestamp, start_time)
end = min(record.end_timestamp, end_time)
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量
messages = list(
Messages.select()
.where(Messages.time >= start_time.timestamp())
.where(Messages.time <= end_time.timestamp())
Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp())
)
summary.total_messages = len(messages)
# 简单统计:如果 reply_to 不为空,则认为是回复
summary.total_replies = len([m for m in messages if m.reply_to])
# 计算派生指标
if summary.online_time > 0:
online_hours = summary.online_time / 3600.0
summary.cost_per_hour = summary.total_cost / online_hours
summary.tokens_per_hour = summary.total_tokens / online_hours
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据"""
model_data = defaultdict(lambda: {
'request_count': 0,
'total_cost': 0.0,
'total_tokens': 0,
'time_costs': []
})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
)
model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []})
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time))
for record in records:
model_name = record.model_assign_name or record.model_name or "unknown"
model_data[model_name]['request_count'] += 1
model_data[model_name]['total_cost'] += record.cost or 0.0
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
model_data[model_name]["request_count"] += 1
model_data[model_name]["total_cost"] += record.cost or 0.0
model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
model_data[model_name]['time_costs'].append(record.time_cost)
model_data[model_name]["time_costs"].append(record.time_cost)
# 转换为列表并排序
result = []
for model_name, data in model_data.items():
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
result.append(ModelStatistics(
model_name=model_name,
request_count=data['request_count'],
total_cost=data['total_cost'],
total_tokens=data['total_tokens'],
avg_response_time=avg_time
))
avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0
result.append(
ModelStatistics(
model_name=model_name,
request_count=data["request_count"],
total_cost=data["total_cost"],
total_tokens=data["total_tokens"],
avg_response_time=avg_time,
)
)
# 按请求数排序
result.sort(key=lambda x: x.request_count, reverse=True)
return result[:10] # 返回前10个
@ -200,96 +189,80 @@ async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据"""
# 创建小时桶
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
for record in records:
# 获取小时键(去掉分钟和秒)
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
hour_str = hour_key.isoformat()
hourly_buckets[hour_str]['requests'] += 1
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
hourly_buckets[hour_str]["requests"] += 1
hourly_buckets[hour_str]["cost"] += record.cost or 0.0
hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.isoformat()
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=hour_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0})
result.append(
TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
)
current += timedelta(hours=1)
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据"""
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0})
records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time))
for record in records:
# 获取日期键
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
day_str = day_key.isoformat()
daily_buckets[day_str]['requests'] += 1
daily_buckets[day_str]['cost'] += record.cost or 0.0
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
daily_buckets[day_str]["requests"] += 1
daily_buckets[day_str]["cost"] += record.cost or 0.0
daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.isoformat()
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=day_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0})
result.append(
TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"])
)
current += timedelta(days=1)
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(
LLMUsage.select()
.order_by(LLMUsage.timestamp.desc())
.limit(limit)
)
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
activities = []
for record in records:
activities.append({
'timestamp': record.timestamp.isoformat(),
'model': record.model_assign_name or record.model_name,
'request_type': record.request_type,
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
'cost': record.cost or 0.0,
'time_cost': record.time_cost or 0.0,
'status': record.status
})
activities.append(
{
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
}
)
return activities
@ -297,7 +270,7 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
async def get_summary(hours: int = 24):
"""
获取统计摘要
Args:
hours: 统计时间范围小时
"""
@ -315,7 +288,7 @@ async def get_summary(hours: int = 24):
async def get_model_stats(hours: int = 24):
"""
获取模型统计
Args:
hours: 统计时间范围小时
"""

View File

@ -19,7 +19,7 @@ class TokenManager:
def __init__(self, config_path: Optional[Path] = None):
"""
初始化 Token 管理器
Args:
config_path: 配置文件路径默认为项目根目录的 data/webui.json
"""
@ -27,10 +27,10 @@ class TokenManager:
# 获取项目根目录 (src/webui -> src -> 根目录)
project_root = Path(__file__).parent.parent.parent
config_path = project_root / "data" / "webui.json"
self.config_path = config_path
self.config_path.parent.mkdir(parents=True, exist_ok=True)
# 确保配置文件存在并包含有效的 token
self._ensure_config()
@ -75,22 +75,23 @@ class TokenManager:
"""生成新的 64 位随机 token"""
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
token = secrets.token_hex(32)
config = {
"access_token": token,
"created_at": self._get_current_timestamp(),
"updated_at": self._get_current_timestamp(),
"first_setup_completed": False # 标记首次配置未完成
"first_setup_completed": False, # 标记首次配置未完成
}
self._save_config(config)
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
return token
def _get_current_timestamp(self) -> str:
"""获取当前时间戳字符串"""
from datetime import datetime
return datetime.now().isoformat()
def get_token(self) -> str:
@ -101,38 +102,38 @@ class TokenManager:
def verify_token(self, token: str) -> bool:
"""
验证 token 是否有效
Args:
token: 待验证的 token
Returns:
bool: token 是否有效
"""
if not token:
return False
current_token = self.get_token()
if not current_token:
logger.error("系统中没有有效的 token")
return False
# 使用 secrets.compare_digest 防止时序攻击
is_valid = secrets.compare_digest(token, current_token)
if is_valid:
logger.debug("Token 验证成功")
else:
logger.warning("Token 验证失败")
return is_valid
def update_token(self, new_token: str) -> tuple[bool, str]:
"""
更新 token
Args:
new_token: 新的 token (最少 10 必须包含大小写字母和特殊符号)
Returns:
tuple[bool, str]: (是否更新成功, 错误消息)
"""
@ -141,17 +142,17 @@ class TokenManager:
if not is_valid:
logger.error(f"Token 格式无效: {error_msg}")
return False, error_msg
try:
config = self._load_config()
old_token = config.get("access_token", "")[:8]
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
self._save_config(config)
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
return True, "Token 更新成功"
except Exception as e:
logger.error(f"更新 Token 失败: {e}")
@ -160,7 +161,7 @@ class TokenManager:
def regenerate_token(self) -> str:
"""
重新生成 token
Returns:
str: 新生成的 token
"""
@ -170,20 +171,20 @@ class TokenManager:
def _validate_token_format(self, token: str) -> bool:
"""
验证 token 格式是否正确旧的 64 位十六进制验证保留用于系统生成的 token
Args:
token: 待验证的 token
Returns:
bool: 格式是否正确
"""
if not token or not isinstance(token, str):
return False
# 必须是 64 位十六进制字符串
if len(token) != 64:
return False
# 验证是否为有效的十六进制字符串
try:
int(token, 16)
@ -194,48 +195,48 @@ class TokenManager:
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
"""
验证自定义 token 格式
要求:
- 最少 10
- 包含大写字母
- 包含小写字母
- 包含特殊符号
Args:
token: 待验证的 token
Returns:
tuple[bool, str]: (是否有效, 错误消息)
"""
if not token or not isinstance(token, str):
return False, "Token 不能为空"
# 检查长度
if len(token) < 10:
return False, "Token 长度至少为 10 位"
# 检查是否包含大写字母
has_upper = any(c.isupper() for c in token)
if not has_upper:
return False, "Token 必须包含大写字母"
# 检查是否包含小写字母
has_lower = any(c.islower() for c in token)
if not has_lower:
return False, "Token 必须包含小写字母"
# 检查是否包含特殊符号
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
has_special = any(c in special_chars for c in token)
if not has_special:
return False, f"Token 必须包含特殊符号 ({special_chars})"
return True, "Token 格式正确"
def is_first_setup(self) -> bool:
"""
检查是否为首次配置
Returns:
bool: 是否为首次配置
"""
@ -245,7 +246,7 @@ class TokenManager:
def mark_setup_completed(self) -> bool:
"""
标记首次配置已完成
Returns:
bool: 是否标记成功
"""
@ -263,7 +264,7 @@ class TokenManager:
def reset_setup_status(self) -> bool:
"""
重置首次配置状态允许重新进入配置向导
Returns:
bool: 是否重置成功
"""