Merge pull request #1364 from Mai-with-u/dev

Dev
pull/1381/head 0.11.3-beta
SengokuCola 2025-11-19 02:35:49 +08:00 committed by GitHub
commit 461167fa31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 5224 additions and 261 deletions

1
.gitignore vendored
View File

@ -69,7 +69,6 @@ elua.confirmed
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/

View File

@ -44,14 +44,16 @@
## 🔥 更新和安装
**最新版本: v0.11.0** ([更新日志](changelogs/changelog.md))
**最新版本: v0.11.3** ([更新日志](changelogs/changelog.md))
可前往 [Release](https://github.com/MaiM-with-u/MaiBot/releases/) 页面下载最新版本
可前往 [启动器发布页面](https://github.com/MaiM-with-u/mailauncher/releases/)下载最新启动器
**GitHub 分支说明:**
- `main`: 稳定发布版本(推荐)
- `dev`: 开发测试版本(不稳定)
- `classical`: 旧版本(停止维护)
- `classical`: 经典版本(停止维护)
### 最新版本部署教程
- [🚀 最新版本部署教程](https://docs.mai-mai.org/manual/deployment/mmc_deploy_windows.html) - 基于 MaiCore 的新版本部署方式(与旧版本不兼容)

17
bot.py
View File

@ -107,9 +107,6 @@ async def graceful_shutdown(): # sourcery skip: use-named-expression
logger.info("麦麦优雅关闭完成")
# 关闭日志系统,释放文件句柄
shutdown_logging()
except Exception as e:
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
@ -215,6 +212,10 @@ if __name__ == "__main__":
# 创建事件循环
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 初始化 WebSocket 日志推送
from src.common.logger import initialize_ws_handler
initialize_ws_handler(loop)
try:
# 执行初始化和任务调度
@ -241,7 +242,7 @@ if __name__ == "__main__":
# 确保 loop 在任何情况下都尝试关闭(如果存在且未关闭)
if "loop" in locals() and loop and not loop.is_closed():
loop.close()
logger.info("事件循环已关闭")
print("[主程序] 事件循环已关闭")
# 关闭日志系统,释放文件句柄
try:
@ -249,6 +250,8 @@ if __name__ == "__main__":
except Exception as e:
print(f"关闭日志系统时出错: {e}")
# 在程序退出前暂停,让你有机会看到输出
# input("按 Enter 键退出...") # <--- 添加这行
sys.exit(exit_code) # <--- 使用记录的退出码
print("[主程序] 准备退出...")
# 使用 os._exit() 强制退出,避免被阻塞
# 由于已经在 graceful_shutdown() 中完成了所有清理工作,这是安全的
os._exit(exit_code)

View File

@ -1,6 +1,17 @@
# Changelog
## [0.11.2] - 2025-11-15
## [0.11.3] - 2025-11-17
### 功能更改和修复
- 优化记忆提取策略
- 优化黑话提取
- 优化表达方式学习
- 修改readme
- 加入测试版webui
提示:清理旧的记忆数据和表达方式,表现更好
方法:删除数据库中 expression jargon 和 thinking_back 的全部内容
## [0.11.2] - 2025-11-16
### 🌟 主要功能更改
- "海马体Agent"记忆系统上线最新最好的记忆系统默认已接入lpmm
- 添加黑话jargon学习系统

View File

@ -98,8 +98,13 @@ class QAManager:
return result, ppr_node_weights
async def get_knowledge(self, question: str) -> Optional[str]:
"""获取知识"""
async def get_knowledge(self, question: str, limit: int = 5) -> Optional[str]:
"""获取知识
Args:
question: 查询问题
limit: 返回的相关知识条数
"""
# 处理查询
processed_result = await self.process_query(question)
if processed_result is not None:
@ -109,6 +114,8 @@ class QAManager:
logger.debug("知识库查询结果为空,可能是知识库中没有相关内容")
return None
limit = max(1, limit) if isinstance(limit, int) else 5
knowledge = [
(
self.embed_manager.paragraphs_embedding_store.store[res[0]].str,
@ -116,9 +123,18 @@ class QAManager:
)
for res in query_res
]
found_knowledge = "\n".join(
[f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}" for i, k in enumerate(knowledge)]
)
# max_score = max([k[1] for k in knowledge]) if knowledge else None
selected_knowledge = knowledge[:limit]
formatted_knowledge = [
f"{i + 1}条知识:{k[0]}\n 该条知识对于问题的相关性:{k[1]}"
for i, k in enumerate(selected_knowledge)
]
# if max_score is not None:
# formatted_knowledge.insert(0, f"最高相关系数:{max_score}")
found_knowledge = "\n".join(formatted_knowledge)
if len(found_knowledge) > MAX_KNOWLEDGE_LENGTH:
found_knowledge = found_knowledge[:MAX_KNOWLEDGE_LENGTH] + "\n"
return found_knowledge

View File

@ -311,6 +311,8 @@ class Expression(BaseModel):
context = TextField(null=True)
up_content = TextField(null=True)
content_list = TextField(null=True)
count = IntegerField(default=1)
last_active_time = FloatField()
chat_id = TextField(index=True)
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据

View File

@ -19,6 +19,7 @@ PROJECT_ROOT = logger_file.parent.parent.parent.resolve()
# 全局handler实例避免重复创建
_file_handler = None
_console_handler = None
_ws_handler = None
def get_file_handler():
@ -59,6 +60,35 @@ def get_console_handler():
return _console_handler
def get_ws_handler():
"""获取 WebSocket handler 单例"""
global _ws_handler
if _ws_handler is None:
_ws_handler = WebSocketLogHandler()
# WebSocket handler 推送所有级别的日志
_ws_handler.setLevel(logging.DEBUG)
return _ws_handler
def initialize_ws_handler(loop):
"""初始化 WebSocket handler 的事件循环
Args:
loop: asyncio 事件循环
"""
handler = get_ws_handler()
handler.set_loop(loop)
# 为 WebSocket handler 设置 JSON 格式化器(与文件格式相同)
handler.setFormatter(file_formatter)
# 添加到根日志记录器
root_logger = logging.getLogger()
if handler not in root_logger.handlers:
root_logger.addHandler(handler)
print("[日志系统] ✅ WebSocket 日志推送已启用")
class TimestampedFileHandler(logging.Handler):
"""基于时间戳的文件处理器,简单的轮转份数限制"""
@ -145,12 +175,78 @@ class TimestampedFileHandler(logging.Handler):
super().close()
class WebSocketLogHandler(logging.Handler):
"""WebSocket 日志处理器 - 将日志实时推送到前端"""
_log_counter = 0 # 类级别计数器,确保 ID 唯一性
def __init__(self, loop=None):
super().__init__()
self.loop = loop
self._initialized = False
def set_loop(self, loop):
"""设置事件循环"""
self.loop = loop
self._initialized = True
def emit(self, record):
"""发送日志到 WebSocket 客户端"""
if not self._initialized or self.loop is None:
return
try:
# 获取格式化后的消息
# 对于 structlog,formatted message 包含完整的日志信息
formatted_msg = self.format(record) if self.formatter else record.getMessage()
# 如果是 JSON 格式(文件格式化器),解析它
message = formatted_msg
try:
import json
log_dict = json.loads(formatted_msg)
message = log_dict.get('event', formatted_msg)
except (json.JSONDecodeError, ValueError):
# 不是 JSON,直接使用消息
message = formatted_msg
# 生成唯一 ID: 时间戳毫秒 + 自增计数器
WebSocketLogHandler._log_counter += 1
log_id = f"{int(record.created * 1000)}_{WebSocketLogHandler._log_counter}"
# 格式化日志数据
log_data = {
"id": log_id,
"timestamp": datetime.fromtimestamp(record.created).strftime("%Y-%m-%d %H:%M:%S"),
"level": record.levelname,
"module": record.name,
"message": message,
}
# 异步广播日志(不阻塞日志记录)
try:
import asyncio
from src.webui.logs_ws import broadcast_log
asyncio.run_coroutine_threadsafe(
broadcast_log(log_data),
self.loop
)
except Exception:
# WebSocket 推送失败不影响日志记录
pass
except Exception:
# 不要让 WebSocket 错误影响日志系统
self.handleError(record)
# 旧的轮转文件处理器已移除,现在使用基于时间戳的处理器
def close_handlers():
"""安全关闭所有handler"""
global _file_handler, _console_handler
global _file_handler, _console_handler, _ws_handler
if _file_handler:
_file_handler.close()
@ -159,6 +255,10 @@ def close_handlers():
if _console_handler:
_console_handler.close()
_console_handler = None
if _ws_handler:
_ws_handler.close()
_ws_handler = None
def remove_duplicate_handlers(): # sourcery skip: for-append-to-extend, list-comprehension
@ -843,8 +943,8 @@ def start_log_cleanup_task():
def shutdown_logging():
"""优雅关闭日志系统,释放所有文件句柄"""
logger = get_logger("logger")
logger.info("正在关闭日志系统...")
# 先输出到控制台,避免日志系统关闭后无法输出
print("[logger] 正在关闭日志系统...")
# 关闭所有handler
root_logger = logging.getLogger()
@ -865,4 +965,5 @@ def shutdown_logging():
handler.close()
logger_obj.removeHandler(handler)
logger.info("日志系统已关闭")
# 使用 print 而不是 logger因为 logger 已经关闭
print("[logger] 日志系统已关闭")

View File

@ -2,6 +2,7 @@ from fastapi import FastAPI, APIRouter
from fastapi.middleware.cors import CORSMiddleware # 新增导入
from typing import Optional
from uvicorn import Config, Server as UvicornServer
import asyncio
import os
from rich.traceback import install
@ -82,8 +83,17 @@ class Server:
"""安全关闭服务器"""
if self._server:
self._server.should_exit = True
await self._server.shutdown()
self._server = None
try:
# 添加 3 秒超时,避免 shutdown 永久挂起
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
except asyncio.TimeoutError:
# 超时就强制标记为 None让垃圾回收处理
pass
except Exception:
# 忽略其他异常
pass
finally:
self._server = None
def get_app(self) -> FastAPI:
"""获取 FastAPI 实例"""

View File

@ -56,7 +56,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.11.2"
MMC_VERSION = "0.11.3"
def get_key_comment(toml_table, key):

View File

@ -61,6 +61,37 @@ def format_create_date(timestamp: float) -> str:
return "未知时间"
def _compute_weights(population: List[Dict]) -> List[float]:
"""
根据表达的count计算权重范围限定在1~3之间
count越高权重越高但最多为基础权重的3倍
"""
if not population:
return []
counts = []
for item in population:
count = item.get("count", 1)
try:
count_value = float(count)
except (TypeError, ValueError):
count_value = 1.0
counts.append(max(count_value, 0.0))
min_count = min(counts)
max_count = max(counts)
if max_count == min_count:
return [1.0 for _ in counts]
weights = []
for count_value in counts:
# 线性映射到[1,3]区间
normalized = (count_value - min_count) / (max_count - min_count)
weights.append(1.0 + normalized * 2.0) # 1~3
return weights
def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
"""
随机抽样函数
@ -78,15 +109,24 @@ def weighted_sample(population: List[Dict], k: int) -> List[Dict]:
if len(population) <= k:
return population.copy()
# 使用随机抽样
selected = []
selected: List[Dict] = []
population_copy = population.copy()
for _ in range(k):
if not population_copy:
break
# 随机选择一个元素
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
for _ in range(min(k, len(population_copy))):
weights = _compute_weights(population_copy)
total_weight = sum(weights)
if total_weight <= 0:
# 回退到均匀随机
idx = random.randint(0, len(population_copy) - 1)
selected.append(population_copy.pop(idx))
continue
threshold = random.uniform(0, total_weight)
cumulative = 0.0
for idx, weight in enumerate(weights):
cumulative += weight
if threshold <= cumulative:
selected.append(population_copy.pop(idx))
break
return selected

View File

@ -77,6 +77,9 @@ class ExpressionLearner:
self.express_learn_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils, request_type="expression.learner"
)
self.summary_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.utils_small, request_type="expression.summary"
)
self.embedding_model: LLMRequest = LLMRequest(
model_set=model_config.model_task_config.embedding, request_type="expression.embedding"
)
@ -91,8 +94,8 @@ class ExpressionLearner:
_, self.enable_learning, self.learning_intensity = global_config.expression.get_expression_config_for_chat(
self.chat_id
)
self.min_messages_for_learning = 30 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 300 / self.learning_intensity
self.min_messages_for_learning = 15 / self.learning_intensity # 触发学习所需的最少消息数
self.min_learning_interval = 120 / self.learning_intensity
def should_trigger_learning(self) -> bool:
"""
@ -186,25 +189,13 @@ class ExpressionLearner:
context,
up_content,
) in learnt_expressions:
# 查找是否已存在相似表达方式
query = Expression.select().where(
(Expression.chat_id == self.chat_id) & (Expression.situation == situation) & (Expression.style == style)
await self._upsert_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
if query.exists():
# 表达方式完全相同,只更新时间戳
expr_obj = query.get()
expr_obj.last_active_time = current_time
expr_obj.save()
else:
Expression.create(
situation=situation,
style=style,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time, # 手动设置创建日期
context=context,
up_content=up_content,
)
return learnt_expressions
@ -362,6 +353,10 @@ class ExpressionLearner:
logger.error(f"学习表达方式失败,模型生成出错: {e}")
return None
expressions: List[Tuple[str, str]] = self.parse_expression_response(response)
expressions = self._filter_self_reference_styles(expressions)
if not expressions:
logger.info("过滤后没有可用的表达方式style 与机器人名称重复)")
return None
# logger.debug(f"学习{type_str}的response: {response}")
# 对表达方式溯源
@ -433,6 +428,153 @@ class ExpressionLearner:
expressions.append((situation, style))
return expressions
def _filter_self_reference_styles(self, expressions: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
过滤掉style与机器人名称/昵称重复的表达
"""
banned_names = set()
bot_nickname = (global_config.bot.nickname or "").strip()
if bot_nickname:
banned_names.add(bot_nickname)
alias_names = global_config.bot.alias_names or []
for alias in alias_names:
alias = alias.strip()
if alias:
banned_names.add(alias)
banned_casefold = {name.casefold() for name in banned_names if name}
filtered: List[Tuple[str, str]] = []
removed_count = 0
for situation, style in expressions:
normalized_style = (style or "").strip()
if normalized_style and normalized_style.casefold() not in banned_casefold:
filtered.append((situation, style))
else:
removed_count += 1
if removed_count:
logger.debug(f"已过滤 {removed_count} 条style与机器人名称重复的表达方式")
return filtered
async def _upsert_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
expr_obj = (
Expression.select()
.where((Expression.chat_id == self.chat_id) & (Expression.style == style))
.first()
)
if expr_obj:
await self._update_existing_expression(
expr_obj=expr_obj,
situation=situation,
context=context,
up_content=up_content,
current_time=current_time,
)
return
await self._create_expression_record(
situation=situation,
style=style,
context=context,
up_content=up_content,
current_time=current_time,
)
async def _create_expression_record(
self,
situation: str,
style: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = [situation]
formatted_situation = await self._compose_situation_text(content_list, 1, situation)
Expression.create(
situation=formatted_situation,
style=style,
content_list=json.dumps(content_list, ensure_ascii=False),
count=1,
last_active_time=current_time,
chat_id=self.chat_id,
create_date=current_time,
context=context,
up_content=up_content,
)
async def _update_existing_expression(
self,
expr_obj: Expression,
situation: str,
context: str,
up_content: str,
current_time: float,
) -> None:
content_list = self._parse_content_list(expr_obj.content_list)
content_list.append(situation)
expr_obj.content_list = json.dumps(content_list, ensure_ascii=False)
expr_obj.count = (expr_obj.count or 0) + 1
expr_obj.last_active_time = current_time
expr_obj.context = context
expr_obj.up_content = up_content
new_situation = await self._compose_situation_text(
content_list=content_list,
count=expr_obj.count,
fallback=expr_obj.situation,
)
expr_obj.situation = new_situation
expr_obj.save()
def _parse_content_list(self, stored_list: Optional[str]) -> List[str]:
if not stored_list:
return []
try:
data = json.loads(stored_list)
except json.JSONDecodeError:
return []
return [str(item) for item in data if isinstance(item, str)] if isinstance(data, list) else []
async def _compose_situation_text(self, content_list: List[str], count: int, fallback: str = "") -> str:
sanitized = [c.strip() for c in content_list if c.strip()]
summary = await self._summarize_situations(sanitized)
if summary:
return summary
return "/".join(sanitized) if sanitized else fallback
async def _summarize_situations(self, situations: List[str]) -> Optional[str]:
if not situations:
return None
prompt = (
"请阅读以下多个聊天情境描述,并将它们概括成一句简短的话,"
"长度不超过20个字保留共同特点\n"
f"{chr(10).join(f'- {s}' for s in situations[-10:])}\n只输出概括内容。"
)
try:
summary, _ = await self.summary_model.generate_response_async(prompt, temperature=0.2)
summary = summary.strip()
if summary:
return summary
except Exception as e:
logger.error(f"概括表达情境失败: {e}")
return None
def _build_bare_lines(self, messages: List) -> List[Tuple[int, str]]:
"""
为每条消息构建精简文本列表保留到原消息索引的映射

View File

@ -139,6 +139,7 @@ class ExpressionSelector:
"last_active_time": expr.last_active_time,
"source_id": expr.chat_id,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
"count": expr.count if getattr(expr, "count", None) is not None else 1,
}
for expr in style_query
]

View File

@ -23,6 +23,29 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
logger = get_logger("jargon")
def _contains_bot_self_name(content: str) -> bool:
"""
判断词条是否包含机器人的昵称或别名
"""
if not content:
return False
bot_config = getattr(global_config, "bot", None)
if not bot_config:
return False
target = content.strip().lower()
nickname = str(getattr(bot_config, "nickname", "") or "").strip().lower()
alias_names = [
str(alias or "").strip().lower()
for alias in getattr(bot_config, "alias_names", []) or []
]
candidates = [name for name in [nickname, *alias_names] if name]
return any(name in target for name in candidates if target)
def _init_prompt() -> None:
prompt_str = """
**聊天内容其中的SELF是你自己的发言**
@ -251,7 +274,7 @@ class JargonMiner:
self.chat_id = chat_id
self.last_learning_time: float = time.time()
# 频率控制,可按需调整
self.min_messages_for_learning: int = 15
self.min_messages_for_learning: int = 10
self.min_learning_interval: float = 20
self.llm = LLMRequest(
@ -434,7 +457,7 @@ class JargonMiner:
jargon_obj.is_complete = True
jargon_obj.save()
logger.info(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
logger.debug(f"jargon {content} 推断完成: is_jargon={is_jargon}, meaning={jargon_obj.meaning}, last_inference_count={jargon_obj.last_inference_count}, is_complete={jargon_obj.is_complete}")
# 固定输出推断结果,格式化为可读形式
if is_jargon:
@ -442,7 +465,7 @@ class JargonMiner:
meaning = jargon_obj.meaning or "无详细说明"
is_global = jargon_obj.is_global
if is_global:
logger.info(f"[通用黑话]{content}的含义是 {meaning}")
logger.info(f"[黑话]{content}的含义是 {meaning}")
else:
logger.info(f"[{self.stream_name}]{content}的含义是 {meaning}")
else:
@ -545,6 +568,9 @@ class JargonMiner:
raw_content_list = [raw_content_str]
if content and raw_content_list:
if _contains_bot_self_name(content):
logger.debug(f"解析阶段跳过包含机器人昵称/别名的词条: {content}")
continue
entries.append({
"content": content,
"raw_content": raw_content_list

View File

@ -11,10 +11,40 @@ from src.plugin_system.apis import llm_api
from src.common.database.database_model import ThinkingBack
from json_repair import repair_json
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.memory_system.retrieval_tools.query_lpmm_knowledge import query_lpmm_knowledge
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
logger = get_logger("memory_retrieval")
THINKING_BACK_NOT_FOUND_RETENTION_SECONDS = 3600 # 未找到答案记录保留时长
THINKING_BACK_CLEANUP_INTERVAL_SECONDS = 300 # 清理频率
_last_not_found_cleanup_ts: float = 0.0
def _cleanup_stale_not_found_thinking_back() -> None:
"""定期清理过期的未找到答案记录"""
global _last_not_found_cleanup_ts
now = time.time()
if now - _last_not_found_cleanup_ts < THINKING_BACK_CLEANUP_INTERVAL_SECONDS:
return
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
try:
deleted_rows = (
ThinkingBack.delete()
.where(
(ThinkingBack.found_answer == 0) &
(ThinkingBack.update_time < threshold_time)
)
.execute()
)
if deleted_rows:
logger.info(f"清理过期的未找到答案thinking_back记录 {deleted_rows}")
_last_not_found_cleanup_ts = now
except Exception as e:
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
def init_memory_retrieval_prompt():
"""初始化记忆检索相关的 prompt 模板和工具"""
# 首先注册所有工具
@ -34,20 +64,17 @@ def init_memory_retrieval_prompt():
1. 对话中是否提到了过去发生的事情人物事件或信息
2. 是否有需要回忆的内容比如"之前说过""上次""以前"
3. 是否有需要查找历史信息的问题
4. 是否需要查找某人的信息person: 如果对话中提到人名昵称用户ID等需要查询该人物的详细信息
5. 是否有问题可以搜集信息帮助你聊天
6. 对话中是否包含黑话俚语缩写等可能需要查询的概念
4. 是否有问题可以搜集信息帮助你聊天
5. 对话中是否包含黑话俚语缩写等可能需要查询的概念
重要提示
- **每次只能提出一个问题**选择最需要查询的关键问题
- 如果"最近已查询的问题和结果"中已经包含了类似的问题请避免重复生成相同或相似的问题
- 如果"最近已查询的问题和结果"中已经包含了类似的问题并得到了答案请避免重复生成相同或相似的问题不需要重复查询
- 如果之前已经查询过某个问题但未找到答案可以尝试用不同的方式提问或更具体的问题
- 如果之前已经查询过某个问题并找到了答案可以直接参考已有结果不需要重复查询
如果你认为需要从记忆中检索信息来回答
1. 先识别对话中可能需要查询的概念黑话/俚语/缩写/专有名词等关键词放入"concepts"字段
2. 识别对话中提到的人物名称人名昵称等放入"person"字段
3. 然后根据上下文提出**一个**最关键的问题来帮助你回复目标消息放入"questions"字段
1. 识别对话中可能需要查询的概念黑话/俚语/缩写/专有名词等关键词放入"concepts"字段
2. 根据上下文提出**一个**最关键的问题来帮助你回复目标消息放入"questions"字段
问题格式示例
- "xxx在前几天干了什么"
@ -55,17 +82,11 @@ def init_memory_retrieval_prompt():
- "xxxx和xxx的关系是什么"
- "xxx在某个时间点发生了什么"
请输出JSON格式包含三个字段
- "concepts": 需要检索的概念列表字符串数组如果不需要检索概念则输出空数组[]
- "person": 需要查询的人物名称列表字符串数组如果不需要查询人物信息则输出空数组[]
- "questions": 问题数组字符串数组如果不需要检索记忆则输出空数组[]如果需要检索则只输出包含一个问题的数组
输出格式示例需要检索时
```json
{{
"concepts": ["AAA", "BBB", "CCC"],
"person": ["张三", "李四"],
"questions": ["张三在前几天干了什么"]
"concepts": ["AAA", "BBB", "CCC"], #需要检索的概念列表(字符串数组),如果不需要检索概念则输出空数组[]
"questions": ["张三在前几天干了什么"] #问题数组(字符串数组),如果不需要检索记忆则输出空数组[],如果需要检索则只输出包含一个问题的数组
}}
```
@ -73,7 +94,6 @@ def init_memory_retrieval_prompt():
```json
{{
"concepts": [],
"person": [],
"questions": []
}}
```
@ -85,10 +105,8 @@ def init_memory_retrieval_prompt():
# 第二步ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""
你的名字是{bot_name}现在是{time_now}
"""你的名字是{bot_name}。现在是{time_now}
你正在参与聊天你需要搜集信息来回答问题帮助你参与聊天
你需要通过思考(Think)行动(Action)观察(Observation)的循环来回答问题
**重要限制**
- 最大查询轮数{max_iterations}当前第{current_iteration}剩余{remaining_iterations}
@ -101,76 +119,32 @@ def init_memory_retrieval_prompt():
{collected_info}
**执行步骤**
**第一步思考Think**
在思考中分析
- 当前信息是否足够回答问题
- **如果信息足够且能找到明确答案**在思考中直接给出答案格式为found_answer(answer="你的答案内容")
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
- 如果还需要继续查询说明最需要查询什么并输出为纯文本说明
- **如果需要尝试搜集更多信息进一步调用工具进入第二步行动环节
- **如果已有信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
**第二步行动Action**
根据思考结果立即行动
- 如果思考中已给出found_answer 无需调用工具直接结束
- 如果思考中已给出not_enough_info 无需调用工具直接结束
- 如果信息不足且需要继续查询 调用相应工具查询可并行调用多个工具
- 如果涉及过往事件可以使用聊天记录查询工具查询过往事件
- 如果涉及概念可以用jargon查询或根据关键词检索聊天记录
- 如果涉及人物可以使用人物信息查询工具查询人物信息
- 如果不确定查询类别也可以使用lpmm知识库查询
- 如果信息不足且需要继续查询说明最需要查询什么并输出为纯文本说明然后调用相应工具查询可并行调用多个工具
**重要规则**
- **只有在检索到明确有关的信息并得出答案时才使用found_answer**
- **如果信息不足无法确定找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须在思考中给出格式为 found_answer(answer="...") not_enough_info(reason="...")不要调用工具
""",
name="memory_retrieval_react_prompt",
)
# 第二步ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""
你的名字是{bot_name}现在是{time_now}
你正在参与聊天你需要搜集信息来回答问题帮助你参与聊天
你需要通过思考(Think)行动(Action)观察(Observation)的循环来回答问题
**重要限制**
- 最大查询轮数{max_iterations}当前第{current_iteration}剩余{remaining_iterations}
- 必须尽快得出答案避免不必要的查询
- 思考要简短直接切入要点
- 必须严格使用检索到的信息回答问题不要编造信息
当前问题{question}
**执行步骤**
**第一步思考Think**
在思考中分析
- 当前信息是否足够回答问题
- **如果信息足够且能找到明确答案**在思考中直接给出答案格式为found_answer(answer="你的答案内容")
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
- 如果还需要继续查询说明最需要查询什么并输出为纯文本说明
**第二步行动Action**
根据思考结果立即行动
- 如果思考中已给出found_answer 无需调用工具直接结束
- 如果思考中已给出not_enough_info 无需调用工具直接结束
- 如果信息不足且需要继续查询 调用相应工具查询可并行调用多个工具
**重要规则**
- **只有在检索到明确具体的答案时才使用found_answer**
- **如果信息不足无法确定找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须在思考中给出格式为 found_answer(answer="...") not_enough_info(reason="...")不要调用工具
- 答案必须在思考中给出格式为 found_answer(answer="...") not_enough_info(reason="...")
""",
name="memory_retrieval_react_prompt_head",
)
# 额外如果最后一轮迭代ReAct Agent prompt使用function calling要求先思考再行动
Prompt(
"""
你的名字是{bot_name}现在是{time_now}
你正在参与聊天你需要搜集信息来回答问题帮助你参与聊天
**重要限制**
- 你已经经过几轮查询尝试了信息搜集现在你需要总结信息选择回答问题或判断问题无法回答
- 思考要简短直接切入要点
- 必须严格使用检索到的信息回答问题不要编造信息
"""你的名字是{bot_name}。现在是{time_now}
你正在参与聊天你需要根据搜集到的信息判断问题是否可以回答问题
当前问题{question}
已收集的信息
@ -183,6 +157,9 @@ def init_memory_retrieval_prompt():
- **如果信息不足或无法找到答案**在思考中给出not_enough_info(reason="信息不足或无法找到答案的原因")
**重要规则**
- 你已经经过几轮查询尝试了信息搜集现在你需要总结信息选择回答问题或判断问题无法回答
- 必须严格使用检索到的信息回答问题不要编造信息
- 答案必须精简不要过多解释
- **只有在检索到明确具体的答案时才使用found_answer**
- **如果信息不足无法确定找不到相关信息必须使用not_enough_info不要使用found_answer**
- 答案必须给出格式为 found_answer(answer="...") not_enough_info(reason="...")
@ -312,8 +289,7 @@ async def _retrieve_concepts_with_jargon(
results.append("".join(output_parts) if len(output_parts) > 1 else output_parts[0])
logger.info(f"在jargon库中找到匹配精确匹配: {concept},找到{len(jargon_results)}条结果")
else:
# 未找到
results.append(f"未在jargon库中找到'{concept}'的解释")
# 未找到,不返回占位信息,只记录日志
logger.info(f"在jargon库中未找到匹配: {concept}")
if results:
@ -321,47 +297,6 @@ async def _retrieve_concepts_with_jargon(
return ""
async def _retrieve_persons_info(
persons: List[str],
chat_id: str
) -> str:
"""对人物列表进行信息检索
Args:
persons: 人物名称列表
chat_id: 聊天ID
Returns:
str: 检索结果字符串
"""
if not persons:
return ""
from src.memory_system.retrieval_tools.query_person_info import query_person_info
results = []
for person in persons:
person = person.strip()
if not person:
continue
try:
person_info = await query_person_info(person)
if person_info and "未找到" not in person_info:
results.append(f"{person}\n{person_info}")
logger.info(f"查询到人物信息: {person}")
else:
results.append(f"未找到人物'{person}'的信息")
logger.info(f"未找到人物信息: {person}")
except Exception as e:
logger.error(f"查询人物信息失败: {person}, 错误: {e}")
results.append(f"查询人物'{person}'信息时发生错误: {str(e)}")
if results:
return "【人物信息检索结果】\n" + "\n\n".join(results) + "\n"
return ""
async def _react_agent_solve_question(
question: str,
chat_id: str,
@ -408,36 +343,41 @@ async def _react_agent_solve_question(
remaining_iterations = max_iterations - current_iteration
is_final_iteration = current_iteration >= max_iterations
# 构建prompt不再需要工具文本描述
prompt_type = "memory_retrieval_react_prompt"
if is_final_iteration:
prompt_type = "memory_retrieval_react_final_prompt"
# 最后一次迭代使用最终prompt
tool_definitions = []
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: 0最后一次迭代不提供工具调用")
prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_final_prompt",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}")
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
)
else:
# 非最终迭代使用head_prompt
tool_definitions = tool_registry.get_tool_definitions()
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代,问题: {question}|可用工具数量: {len(tool_definitions)}")
prompt = await global_prompt_manager.format_prompt(
prompt_type,
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "暂无信息",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
)
if not is_final_iteration:
head_prompt = await global_prompt_manager.format_prompt(
"memory_retrieval_react_prompt_head",
bot_name=bot_name,
time_now=time_now,
question=question,
collected_info=collected_info if collected_info else "",
current_iteration=current_iteration,
remaining_iterations=remaining_iterations,
max_iterations=max_iterations,
@ -447,7 +387,6 @@ async def _react_agent_solve_question(
_client,
*,
_head_prompt: str = head_prompt,
_prompt: str = prompt,
_conversation_messages: List[Message] = conversation_messages,
) -> List[Message]:
messages: List[Message] = []
@ -455,14 +394,46 @@ async def _react_agent_solve_question(
system_builder = MessageBuilder()
system_builder.set_role(RoleType.System)
system_builder.add_text_content(_head_prompt)
if _prompt.strip():
system_builder.add_text_content(f"\n{_prompt}")
messages.append(system_builder.build())
messages.extend(_conversation_messages)
# for msg in messages:
# print(msg)
# 优化日志展示 - 合并所有消息到一条日志
log_lines = []
for idx, msg in enumerate(messages, 1):
role_name = msg.role.value if hasattr(msg.role, 'value') else str(msg.role)
# 处理内容 - 显示完整内容,不截断
if isinstance(msg.content, str):
full_content = msg.content
content_type = "文本"
elif isinstance(msg.content, list):
text_parts = [item for item in msg.content if isinstance(item, str)]
image_count = len([item for item in msg.content if isinstance(item, tuple)])
full_content = "".join(text_parts) if text_parts else ""
content_type = f"混合({len(text_parts)}段文本, {image_count}张图片)"
else:
full_content = str(msg.content)
content_type = "未知"
# 构建单条消息的日志信息
msg_info = f"\n[消息 {idx}] 角色: {role_name} 内容类型: {content_type}\n========================================"
if full_content:
msg_info += f"\n{full_content}"
if msg.tool_calls:
msg_info += f"\n 工具调用: {len(msg.tool_calls)}"
for tool_call in msg.tool_calls:
msg_info += f"\n - {tool_call}"
if msg.tool_call_id:
msg_info += f"\n 工具调用ID: {msg.tool_call_id}"
log_lines.append(msg_info)
# 合并所有消息为一条日志输出
logger.info(f"消息列表 (共{len(messages)}条):{''.join(log_lines)}")
return messages
@ -472,14 +443,6 @@ async def _react_agent_solve_question(
tool_options=tool_definitions,
request_type="memory.react",
)
else:
logger.info(f"ReAct Agent 第 {iteration + 1} 次Prompt: {prompt}")
success, response, reasoning_content, model_name, tool_calls = await llm_api.generate_with_model_with_tools(
prompt,
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
)
logger.info(f"ReAct Agent 第 {iteration + 1} 次迭代 模型: {model_name} ,调用工具数量: {len(tool_calls) if tool_calls else 0} ,调用工具响应: {response}")
@ -997,6 +960,24 @@ async def _process_single_question(
Optional[str]: 如果找到答案返回格式化的结果字符串否则返回None
"""
logger.info(f"开始处理问题: {question}")
_cleanup_stale_not_found_thinking_back()
question_initial_info = initial_info or ""
# 预先进行一次LPMM知识库查询作为后续ReAct Agent的辅助信息
if global_config.lpmm_knowledge.enable:
try:
lpmm_result = await query_lpmm_knowledge(question, limit=2)
if lpmm_result and lpmm_result.startswith("你从LPMM知识库中找到"):
if question_initial_info:
question_initial_info += "\n"
question_initial_info += f"【LPMM知识库预查询】\n{lpmm_result}"
logger.info(f"LPMM预查询命中问题: {question[:50]}...")
else:
logger.info(f"LPMM预查询未命中或未找到信息问题: {question[:50]}...")
except Exception as e:
logger.error(f"LPMM预查询失败问题: {question[:50]}... 错误: {e}")
# 先检查thinking_back数据库中是否有现成答案
cached_result = _query_thinking_back(chat_id, question)
@ -1005,26 +986,22 @@ async def _process_single_question(
if cached_result:
cached_found_answer, cached_answer = cached_result
# 根据found_answer的值决定是否重新查询
if cached_found_answer: # found_answer == 1 (True)
# found_answer == 120%概率重新查询
if random.random() < 0.2:
if random.random() < 0.5:
should_requery = True
logger.info(f"found_answer=1触发20%概率重新查询,问题: {question[:50]}...")
else: # found_answer == 0 (False)
# found_answer == 040%概率重新查询
if random.random() < 0.4:
should_requery = True
logger.info(f"found_answer=0触发40%概率重新查询,问题: {question[:50]}...")
# 如果不需要重新查询,使用缓存答案
if not should_requery:
if cached_answer:
if not should_requery and cached_answer:
logger.info(f"从thinking_back缓存中获取答案问题: {question[:50]}...")
return f"问题:{question}\n答案:{cached_answer}"
else:
# 缓存中没有答案,需要查询
elif not cached_answer:
should_requery = True
logger.info(f"found_answer=1 但缓存答案为空,重新查询,问题: {question[:50]}...")
else:
# found_answer == 0不使用缓存直接重新查询
should_requery = True
logger.info(f"thinking_back存在但未找到答案忽略缓存重新查询问题: {question[:50]}...")
# 如果没有缓存答案或需要重新查询使用ReAct Agent查询
if not cached_result or should_requery:
@ -1038,7 +1015,7 @@ async def _process_single_question(
chat_id=chat_id,
max_iterations=global_config.memory.max_agent_iterations,
timeout=120.0,
initial_info=initial_info
initial_info=question_initial_info
)
# 存储到数据库(超时时不存储)
@ -1119,10 +1096,9 @@ async def build_memory_retrieval_prompt(
logger.error(f"LLM生成问题失败: {response}")
return ""
# 解析概念列表、人物列表和问题列表
concepts, persons, questions = _parse_questions_json(response)
# 解析概念列表和问题列表
concepts, questions = _parse_questions_json(response)
logger.info(f"解析到 {len(concepts)} 个概念: {concepts}")
logger.info(f"解析到 {len(persons)} 个人物: {persons}")
logger.info(f"解析到 {len(questions)} 个问题: {questions}")
# 对概念进行jargon检索作为初始信息
@ -1136,22 +1112,13 @@ async def build_memory_retrieval_prompt(
else:
logger.info("概念检索未找到任何结果")
# 对人物进行信息检索,添加到初始信息
if persons:
logger.info(f"开始对 {len(persons)} 个人物进行信息检索")
person_info = await _retrieve_persons_info(persons, chat_id)
if person_info:
initial_info += person_info
logger.info(f"人物信息检索完成,结果: {person_info[:200]}...")
else:
logger.info("人物信息检索未找到任何结果")
# 获取缓存的记忆与question时使用相同的时间窗口和数量限制
cached_memories = _get_cached_memories(chat_id, time_window_seconds=300.0)
if not questions:
logger.debug("模型认为不需要检索记忆或解析失败")
# 即使没有当次查询,也返回缓存的记忆、概念检索结果和人物信息检索结果
# 即使没有当次查询,也返回缓存的记忆和概念检索结果
all_results = []
if initial_info:
all_results.append(initial_info.strip())
@ -1161,7 +1128,7 @@ async def build_memory_retrieval_prompt(
if all_results:
retrieved_memory = "\n\n".join(all_results)
end_time = time.time()
logger.info(f"无当次查询,返回缓存记忆、概念检索和人物信息检索结果,耗时: {(end_time - start_time):.3f}")
logger.info(f"无当次查询,返回缓存记忆和概念检索结果,耗时: {(end_time - start_time):.3f}")
return f"你回忆起了以下信息:\n{retrieved_memory}\n如果与回复内容相关,可以参考这些回忆的信息。\n"
else:
return ""
@ -1223,14 +1190,14 @@ async def build_memory_retrieval_prompt(
return ""
def _parse_questions_json(response: str) -> Tuple[List[str], List[str], List[str]]:
"""解析问题JSON返回概念列表、人物列表和问题列表
def _parse_questions_json(response: str) -> Tuple[List[str], List[str]]:
"""解析问题JSON返回概念列表和问题列表
Args:
response: LLM返回的响应
Returns:
Tuple[List[str], List[str], List[str]]: (概念列表, 人物列表, 问题列表)
Tuple[List[str], List[str]]: (概念列表, 问题列表)
"""
try:
# 尝试提取JSON可能包含在```json代码块中
@ -1249,30 +1216,26 @@ def _parse_questions_json(response: str) -> Tuple[List[str], List[str], List[str
# 解析JSON
parsed = json.loads(repaired_json)
# 只支持新格式包含concepts、person和questions的对象
# 只支持新格式包含concepts和questions的对象
if not isinstance(parsed, dict):
logger.warning(f"解析的JSON不是对象格式: {parsed}")
return [], [], []
return [], []
concepts_raw = parsed.get("concepts", [])
persons_raw = parsed.get("person", [])
questions_raw = parsed.get("questions", [])
# 确保是列表
if not isinstance(concepts_raw, list):
concepts_raw = []
if not isinstance(persons_raw, list):
persons_raw = []
if not isinstance(questions_raw, list):
questions_raw = []
# 确保所有元素都是字符串
concepts = [c for c in concepts_raw if isinstance(c, str) and c.strip()]
persons = [p for p in persons_raw if isinstance(p, str) and p.strip()]
questions = [q for q in questions_raw if isinstance(q, str) and q.strip()]
return concepts, persons, questions
return concepts, questions
except Exception as e:
logger.error(f"解析问题JSON失败: {e}, 响应内容: {response[:200]}...")
return [], [], []
return [], []

View File

@ -193,7 +193,7 @@ def register_tool():
"""注册工具"""
register_memory_retrieval_tool(
name="query_chat_history",
description="根据时间或关键词在chat_history表的聊天记录概述库中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
description="根据时间或关键词在聊天记录中查询。可以查询某个时间点发生了什么、某个时间范围内的事件,或根据关键词搜索消息概述。支持两种匹配模式:模糊匹配(默认,只要包含任意一个关键词即匹配)和全匹配(必须包含所有关键词才匹配)",
parameters=[
{
"name": "keyword",

View File

@ -10,7 +10,7 @@ from .tool_registry import register_memory_retrieval_tool
logger = get_logger("memory_retrieval_tools")
async def query_lpmm_knowledge(query: str) -> str:
async def query_lpmm_knowledge(query: str, limit: int = 5) -> str:
"""在LPMM知识库中查询相关信息
Args:
@ -24,6 +24,12 @@ async def query_lpmm_knowledge(query: str) -> str:
if not content:
return "查询关键词为空"
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
if not global_config.lpmm_knowledge.enable:
logger.debug("LPMM知识库未启用")
return "LPMM知识库未启用"
@ -33,7 +39,7 @@ async def query_lpmm_knowledge(query: str) -> str:
logger.debug("LPMM知识库未初始化跳过查询")
return "LPMM知识库未初始化"
knowledge_info = await qa_manager.get_knowledge(content)
knowledge_info = await qa_manager.get_knowledge(content, limit=limit_value)
logger.debug(f"LPMM知识库查询结果: {knowledge_info}")
if knowledge_info:
@ -57,7 +63,13 @@ def register_tool():
"type": "string",
"description": "需要查询的关键词或问题",
"required": True,
}
},
{
"name": "limit",
"type": "integer",
"description": "希望返回的相关知识条数默认为5",
"required": False,
},
],
execute_func=query_lpmm_knowledge,
)

View File

@ -15,6 +15,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
description = "从知识库中搜索相关信息,如果你需要知识,就使用这个工具"
parameters = [
("query", ToolParamType.STRING, "搜索查询关键词", True, None),
("limit", ToolParamType.INTEGER, "希望返回的相关知识条数默认5", False, 5),
]
available_for_llm = global_config.lpmm_knowledge.enable
@ -29,6 +30,12 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
"""
try:
query: str = function_args.get("query") # type: ignore
limit = function_args.get("limit", 5)
try:
limit_value = int(limit)
except (TypeError, ValueError):
limit_value = 5
limit_value = max(1, limit_value)
# threshold = function_args.get("threshold", 0.4)
# 检查LPMM知识库是否启用
@ -38,7 +45,7 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
# 调用知识库搜索
knowledge_info = await qa_manager.get_knowledge(query)
knowledge_info = await qa_manager.get_knowledge(query, limit=limit_value)
logger.debug(f"知识库查询结果: {knowledge_info}")

View File

@ -0,0 +1,366 @@
"""
配置管理API路由
"""
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body
from typing import Any
from src.common.logger import get_logger
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
EmojiConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
MoodConfig,
VoiceConfig,
JargonConfig,
)
from src.config.api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui")
router = APIRouter(prefix="/config", tags=["config"])
# ===== 辅助函数 =====
def _update_dict_preserve_comments(target: Any, source: Any) -> None:
"""
递归合并字典保留 target 中的注释和格式
source 的值更新到 target 仅更新已存在的键
Args:
target: 目标字典tomlkit 对象包含注释
source: 源字典普通 dict list
"""
# 如果 source 是列表,直接替换(数组表没有注释保留的意义)
if isinstance(source, list):
return # 调用者需要直接赋值
# 如果都是字典,递归合并
if isinstance(source, dict) and isinstance(target, dict):
for key, value in source.items():
if key == "version":
continue # 跳过版本号
if key in target:
target_value = target[key]
# 递归处理嵌套字典
if isinstance(value, dict) and isinstance(target_value, dict):
_update_dict_preserve_comments(target_value, value)
else:
# 使用 tomlkit.item 保持类型
try:
target[key] = tomlkit.item(value)
except (TypeError, ValueError):
target[key] = value
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema():
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
schema = ConfigSchemaGenerator.generate_config_schema(Config)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}")
@router.get("/schema/model")
async def get_model_config_schema():
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}")
# ===== 子配置架构获取接口 =====
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str):
"""
获取指定配置节的架构
支持的section_name:
- bot: BotConfig
- personality: PersonalityConfig
- relationship: RelationshipConfig
- chat: ChatConfig
- message_receive: MessageReceiveConfig
- emoji: EmojiConfig
- expression: ExpressionConfig
- keyword_reaction: KeywordReactionConfig
- chinese_typo: ChineseTypoConfig
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- mood: MoodConfig
- voice: VoiceConfig
- jargon: JargonConfig
- model_task_config: ModelTaskConfig
- api_provider: APIProvider
- model_info: ModelInfo
"""
section_map = {
"bot": BotConfig,
"personality": PersonalityConfig,
"relationship": RelationshipConfig,
"chat": ChatConfig,
"message_receive": MessageReceiveConfig,
"emoji": EmojiConfig,
"expression": ExpressionConfig,
"keyword_reaction": KeywordReactionConfig,
"chinese_typo": ChineseTypoConfig,
"response_post_process": ResponsePostProcessConfig,
"response_splitter": ResponseSplitterConfig,
"telemetry": TelemetryConfig,
"experimental": ExperimentalConfig,
"maim_message": MaimMessageConfig,
"lpmm_knowledge": LPMMKnowledgeConfig,
"tool": ToolConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"mood": MoodConfig,
"voice": VoiceConfig,
"jargon": JargonConfig,
"model_task_config": ModelTaskConfig,
"api_provider": APIProvider,
"model_info": ModelInfo,
}
if section_name not in section_map:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
try:
config_class = section_map[section_name]
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}")
# ===== 配置读取接口 =====
@router.get("/bot")
async def get_bot_config():
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
@router.get("/model")
async def get_model_config():
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}")
# ===== 配置更新接口 =====
@router.post("/bot")
async def update_bot_config(config_data: dict[str, Any] = Body(...)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info("麦麦主程序配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
@router.post("/model")
async def update_model_config(config_data: dict[str, Any] = Body(...)):
"""更新模型配置"""
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info("模型配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}")
# ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: Any = Body(...)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组类型(如 platforms, aliases直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")
@router.post("/model/section/{section_name}")
async def update_model_config_section(section_name: str, section_data: Any = Body(...)):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_dict_preserve_comments(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}")
# 保存配置tomlkit.dump 会保留注释)
with open(config_path, "w", encoding="utf-8") as f:
tomlkit.dump(config_data, f)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}")

View File

@ -0,0 +1,336 @@
"""
配置架构生成器 - 自动从配置类生成前端表单架构
"""
import inspect
from dataclasses import fields, MISSING
from typing import Any, get_origin, get_args, Literal, Optional
from enum import Enum
from src.config.config_base import ConfigBase
class FieldType(str, Enum):
"""字段类型枚举"""
STRING = "string"
NUMBER = "number"
INTEGER = "integer"
BOOLEAN = "boolean"
SELECT = "select"
ARRAY = "array"
OBJECT = "object"
TEXTAREA = "textarea"
class FieldSchema:
"""字段架构"""
def __init__(
self,
name: str,
type: FieldType,
label: str,
description: str = "",
default: Any = None,
required: bool = True,
options: Optional[list[str]] = None,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
items: Optional[dict] = None,
properties: Optional[dict] = None,
):
self.name = name
self.type = type
self.label = label
self.description = description
self.default = default
self.required = required
self.options = options
self.min_value = min_value
self.max_value = max_value
self.items = items
self.properties = properties
def to_dict(self) -> dict:
"""转换为字典"""
result = {
"name": self.name,
"type": self.type.value,
"label": self.label,
"description": self.description,
"required": self.required,
}
if self.default is not None:
result["default"] = self.default
if self.options is not None:
result["options"] = self.options
if self.min_value is not None:
result["minValue"] = self.min_value
if self.max_value is not None:
result["maxValue"] = self.max_value
if self.items is not None:
result["items"] = self.items
if self.properties is not None:
result["properties"] = self.properties
return result
class ConfigSchemaGenerator:
"""配置架构生成器"""
@staticmethod
def _extract_field_description(config_class: type, field_name: str) -> str:
"""
从类定义中提取字段的文档字符串描述
Args:
config_class: 配置类
field_name: 字段名
Returns:
str: 字段描述
"""
try:
# 获取源代码
source = inspect.getsource(config_class)
lines = source.split("\n")
# 查找字段定义
field_found = False
description_lines = []
for i, line in enumerate(lines):
# 匹配字段定义行,例如: platform: str
if f"{field_name}:" in line and "=" in line:
field_found = True
# 查找下一行的文档字符串
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
# 单行文档字符串
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip())
else:
# 多行文档字符串
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
for j in range(i + 2, len(lines)):
if quote in lines[j]:
description_lines.append(lines[j].split(quote)[0].strip())
break
description_lines.append(lines[j].strip())
break
elif f"{field_name}:" in line and "=" not in line:
# 没有默认值的字段
field_found = True
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.strip('"""').strip("'''").strip())
else:
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
for j in range(i + 2, len(lines)):
if quote in lines[j]:
description_lines.append(lines[j].split(quote)[0].strip())
break
description_lines.append(lines[j].strip())
break
if field_found and description_lines:
return " ".join(description_lines)
except Exception:
pass
return ""
@staticmethod
def _get_field_type_and_options(field_type: type) -> tuple[FieldType, Optional[list[str]], Optional[dict]]:
"""
获取字段类型和选项
Args:
field_type: 字段类型
Returns:
tuple: (FieldType, options, items)
"""
origin = get_origin(field_type)
args = get_args(field_type)
# 处理 Literal 类型(枚举选项)
if origin is Literal:
return FieldType.SELECT, [str(arg) for arg in args], None
# 处理 list 类型
if origin is list:
item_type = args[0] if args else str
if item_type is str:
items = {"type": "string"}
elif item_type is int:
items = {"type": "integer"}
elif item_type is float:
items = {"type": "number"}
elif item_type is bool:
items = {"type": "boolean"}
elif item_type is dict:
items = {"type": "object"}
else:
items = {"type": "string"}
return FieldType.ARRAY, None, items
# 处理 set 类型(与 list 类似)
if origin is set:
item_type = args[0] if args else str
if item_type is str:
items = {"type": "string"}
else:
items = {"type": "string"}
return FieldType.ARRAY, None, items
# 处理基本类型
if field_type is bool or field_type == bool:
return FieldType.BOOLEAN, None, None
elif field_type is int or field_type == int:
return FieldType.INTEGER, None, None
elif field_type is float or field_type == float:
return FieldType.NUMBER, None, None
elif field_type is str or field_type == str:
return FieldType.STRING, None, None
elif field_type is dict or origin is dict:
return FieldType.OBJECT, None, None
# 默认为字符串
return FieldType.STRING, None, None
@staticmethod
def _format_field_name(name: str) -> str:
"""
格式化字段名为可读的标签
Args:
name: 原始字段名
Returns:
str: 格式化后的标签
"""
# 将下划线替换为空格,并首字母大写
return " ".join(word.capitalize() for word in name.split("_"))
@staticmethod
def generate_schema(config_class: type[ConfigBase], include_nested: bool = True) -> dict:
"""
从配置类生成前端表单架构
Args:
config_class: 配置类必须继承自 ConfigBase
include_nested: 是否包含嵌套的配置对象
Returns:
dict: 前端表单架构
"""
if not issubclass(config_class, ConfigBase):
raise ValueError(f"{config_class.__name__} 必须继承自 ConfigBase")
schema_fields = []
nested_schemas = {}
for field in fields(config_class):
# 跳过私有字段和内部字段
if field.name.startswith("_") or field.name in ["MMC_VERSION"]:
continue
# 提取字段描述
description = ConfigSchemaGenerator._extract_field_description(config_class, field.name)
# 判断是否必填
required = field.default is MISSING and field.default_factory is MISSING
# 获取默认值
default_value = None
if field.default is not MISSING:
default_value = field.default
elif field.default_factory is not MISSING:
try:
default_value = field.default_factory()
except Exception:
default_value = None
# 检查是否为嵌套的 ConfigBase
if isinstance(field.type, type) and issubclass(field.type, ConfigBase):
if include_nested:
# 递归生成嵌套配置的架构
nested_schema = ConfigSchemaGenerator.generate_schema(field.type, include_nested=True)
nested_schemas[field.name] = nested_schema
field_schema = FieldSchema(
name=field.name,
type=FieldType.OBJECT,
label=ConfigSchemaGenerator._format_field_name(field.name),
description=description or field.type.__doc__ or "",
default=default_value,
required=required,
properties=nested_schema,
)
else:
continue
else:
# 获取字段类型和选项
field_type, options, items = ConfigSchemaGenerator._get_field_type_and_options(field.type)
# 特殊处理:长文本使用 textarea
if field_type == FieldType.STRING and field.name in [
"personality",
"reply_style",
"interest",
"plan_style",
"visual_style",
"private_plan_style",
"emotion_style",
"reaction",
"filtration_prompt",
]:
field_type = FieldType.TEXTAREA
field_schema = FieldSchema(
name=field.name,
type=field_type,
label=ConfigSchemaGenerator._format_field_name(field.name),
description=description,
default=default_value,
required=required,
options=options,
items=items,
)
schema_fields.append(field_schema.to_dict())
return {
"className": config_class.__name__,
"classDoc": config_class.__doc__ or "",
"fields": schema_fields,
"nested": nested_schemas if nested_schemas else None,
}
@staticmethod
def generate_config_schema(config_class: type[ConfigBase]) -> dict:
"""
生成完整的配置架构包含所有嵌套的子配置
Args:
config_class: 配置类
Returns:
dict: 完整的配置架构
"""
return ConfigSchemaGenerator.generate_schema(config_class, include_nested=True)

View File

@ -0,0 +1,483 @@
"""表情包管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List
from src.common.logger import get_logger
from src.common.database.database_model import Emoji
from .token_manager import get_token_manager
import json
import time
logger = get_logger("webui.emoji")
# 创建路由器
router = APIRouter(prefix="/emoji", tags=["Emoji"])
class EmojiResponse(BaseModel):
"""表情包响应"""
id: int
full_path: str
format: str
emoji_hash: str
description: str
query_count: int
is_registered: bool
is_banned: bool
emotion: Optional[List[str]] # 解析后的 JSON
record_time: float
register_time: Optional[float]
usage_count: int
last_used_time: Optional[float]
class EmojiListResponse(BaseModel):
"""表情包列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[EmojiResponse]
class EmojiDetailResponse(BaseModel):
"""表情包详情响应"""
success: bool
data: EmojiResponse
class EmojiUpdateRequest(BaseModel):
"""表情包更新请求"""
description: Optional[str] = None
is_registered: Optional[bool] = None
is_banned: Optional[bool] = None
emotion: Optional[List[str]] = None
class EmojiUpdateResponse(BaseModel):
"""表情包更新响应"""
success: bool
message: str
data: Optional[EmojiResponse] = None
class EmojiDeleteResponse(BaseModel):
"""表情包删除响应"""
success: bool
message: str
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def parse_emotion(emotion_str: Optional[str]) -> Optional[List[str]]:
"""解析情感标签 JSON 字符串"""
if not emotion_str:
return None
try:
return json.loads(emotion_str)
except (json.JSONDecodeError, TypeError):
return None
def emoji_to_response(emoji: Emoji) -> EmojiResponse:
"""将 Emoji 模型转换为响应对象"""
return EmojiResponse(
id=emoji.id,
full_path=emoji.full_path,
format=emoji.format,
emoji_hash=emoji.emoji_hash,
description=emoji.description,
query_count=emoji.query_count,
is_registered=emoji.is_registered,
is_banned=emoji.is_banned,
emotion=parse_emotion(emoji.emotion),
record_time=emoji.record_time,
register_time=emoji.register_time,
usage_count=emoji.usage_count,
last_used_time=emoji.last_used_time,
)
@router.get("/list", response_model=EmojiListResponse)
async def get_emoji_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_registered: Optional[bool] = Query(None, description="是否已注册筛选"),
is_banned: Optional[bool] = Query(None, description="是否被禁用筛选"),
format: Optional[str] = Query(None, description="格式筛选"),
authorization: Optional[str] = Header(None)
):
"""
获取表情包列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 description, emoji_hash)
is_registered: 是否已注册筛选
is_banned: 是否被禁用筛选
format: 格式筛选
authorization: Authorization header
Returns:
表情包列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Emoji.select()
# 搜索过滤
if search:
query = query.where(
(Emoji.description.contains(search)) |
(Emoji.emoji_hash.contains(search))
)
# 注册状态过滤
if is_registered is not None:
query = query.where(Emoji.is_registered == is_registered)
# 禁用状态过滤
if is_banned is not None:
query = query.where(Emoji.is_banned == is_banned)
# 格式过滤
if format:
query = query.where(Emoji.format == format)
# 排序:使用次数倒序,然后按记录时间倒序
from peewee import Case
query = query.order_by(
Emoji.usage_count.desc(),
Case(None, [(Emoji.record_time.is_null(), 1)], 0),
Emoji.record_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
emojis = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [emoji_to_response(emoji) for emoji in emojis]
return EmojiListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包列表失败: {str(e)}") from e
@router.get("/{emoji_id}", response_model=EmojiDetailResponse)
async def get_emoji_detail(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
"""
获取表情包详细信息
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
表情包详细信息
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
return EmojiDetailResponse(
success=True,
data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表情包详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表情包详情失败: {str(e)}") from e
@router.patch("/{emoji_id}", response_model=EmojiUpdateResponse)
async def update_emoji(
emoji_id: int,
request: EmojiUpdateRequest,
authorization: Optional[str] = Header(None)
):
"""
增量更新表情包只更新提供的字段
Args:
emoji_id: 表情包ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 处理情感标签(转换为 JSON
if 'emotion' in update_data:
if update_data['emotion'] is None:
update_data['emotion'] = None
else:
update_data['emotion'] = json.dumps(update_data['emotion'], ensure_ascii=False)
# 如果注册状态从 False 变为 True记录注册时间
if 'is_registered' in update_data and update_data['is_registered'] and not emoji.is_registered:
update_data['register_time'] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(emoji, field, value)
emoji.save()
logger.info(f"表情包已更新: ID={emoji_id}, 字段: {list(update_data.keys())}")
return EmojiUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表情包失败: {str(e)}") from e
@router.delete("/{emoji_id}", response_model=EmojiDeleteResponse)
async def delete_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
"""
删除表情包
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 记录删除信息
emoji_hash = emoji.emoji_hash
# 执行删除
emoji.delete_instance()
logger.info(f"表情包已删除: ID={emoji_id}, hash={emoji_hash}")
return EmojiDeleteResponse(
success=True,
message=f"成功删除表情包: {emoji_hash}"
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表情包失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_emoji_stats(
authorization: Optional[str] = Header(None)
):
"""
获取表情包统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Emoji.select().count()
registered = Emoji.select().where(Emoji.is_registered).count()
banned = Emoji.select().where(Emoji.is_banned).count()
# 按格式统计
formats = {}
for emoji in Emoji.select(Emoji.format):
fmt = emoji.format
formats[fmt] = formats.get(fmt, 0) + 1
# 获取最常用的表情包前10
top_used = Emoji.select().order_by(Emoji.usage_count.desc()).limit(10)
top_used_list = [
{
"id": emoji.id,
"emoji_hash": emoji.emoji_hash,
"description": emoji.description,
"usage_count": emoji.usage_count
}
for emoji in top_used
]
return {
"success": True,
"data": {
"total": total,
"registered": registered,
"banned": banned,
"unregistered": total - registered,
"formats": formats,
"top_used": top_used_list
}
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
@router.post("/{emoji_id}/register", response_model=EmojiUpdateResponse)
async def register_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
"""
注册表情包快捷操作
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
if emoji.is_registered:
raise HTTPException(status_code=400, detail="该表情包已经注册")
if emoji.is_banned:
raise HTTPException(status_code=400, detail="该表情包已被禁用,无法注册")
# 注册表情包
emoji.is_registered = True
emoji.register_time = time.time()
emoji.save()
logger.info(f"表情包已注册: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包注册成功",
data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"注册表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"注册表情包失败: {str(e)}") from e
@router.post("/{emoji_id}/ban", response_model=EmojiUpdateResponse)
async def ban_emoji(
emoji_id: int,
authorization: Optional[str] = Header(None)
):
"""
禁用表情包快捷操作
Args:
emoji_id: 表情包ID
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
emoji = Emoji.get_or_none(Emoji.id == emoji_id)
if not emoji:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {emoji_id} 的表情包")
# 禁用表情包(同时取消注册)
emoji.is_banned = True
emoji.is_registered = False
emoji.save()
logger.info(f"表情包已禁用: ID={emoji_id}")
return EmojiUpdateResponse(
success=True,
message="表情包禁用成功",
data=emoji_to_response(emoji)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"禁用表情包失败: {e}")
raise HTTPException(status_code=500, detail=f"禁用表情包失败: {str(e)}") from e

View File

@ -0,0 +1,404 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List
from src.common.logger import get_logger
from src.common.database.database_model import Expression
from .token_manager import get_token_manager
import time
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
context: Optional[str]
up_content: Optional[str]
last_active_time: float
chat_id: str
create_date: Optional[float]
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
context: Optional[str] = None
up_content: Optional[str] = None
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
context: Optional[str] = None
up_content: Optional[str] = None
chat_id: Optional[str] = None
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
return ExpressionResponse(
id=expression.id,
situation=expression.situation,
style=expression.style,
context=expression.context,
up_content=expression.up_content,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
)
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
authorization: Optional[str] = Header(None)
):
"""
获取表达方式列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style, context)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search)) |
(Expression.style.contains(search)) |
(Expression.context.contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0),
Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int,
authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(
success=True,
data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
authorization: Optional[str] = Header(None)
):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
context=request.context,
up_content=request.up_content,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True,
message="表达方式创建成功",
data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
authorization: Optional[str] = Header(None)
):
"""
增量更新表达方式只更新提供的字段
Args:
expression_id: 表达方式ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后活跃时间
update_data['last_active_time'] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int,
authorization: Optional[str] = Header(None)
):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(
success=True,
message=f"成功删除表达方式: {situation}"
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(
authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = Expression.select().where(
(Expression.create_date.is_null(False)) &
(Expression.create_date >= seven_days_ago)
).count()
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10])
}
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e

View File

@ -0,0 +1,731 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
import json
import asyncio
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
_update_progress = callback
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
GITHUB = "github" # GitHub 官方源(兜底)
CUSTOM = "custom" # 自定义镜像源
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
"id": "gh-proxy",
"name": "gh-proxy 镜像",
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None
},
{
"id": "hk-gh-proxy",
"name": "gh-proxy 香港节点",
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None
},
{
"id": "cdn-gh-proxy",
"name": "gh-proxy CDN 节点",
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None
},
{
"id": "edgeone-gh-proxy",
"name": "gh-proxy EdgeOne 节点",
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None
},
{
"id": "meyzh-github",
"name": "Meyzh GitHub 镜像",
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None
},
{
"id": "github",
"name": "GitHub 官方源(兜底)",
"raw_prefix": "https://raw.githubusercontent.com",
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None
}
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
self._init_default_mirrors()
else:
self.mirrors = data["git_mirrors"]
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
else:
logger.info("配置文件不存在,创建默认配置")
self._init_default_mirrors()
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, 'r', encoding='utf-8') as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, 'w', encoding='utf-8') as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
name: str,
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
"raw_prefix": raw_prefix,
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat()
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
name: Optional[str] = None,
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置如果不存在则返回 None
"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
if name is not None:
mirror["name"] = name
if raw_prefix is not None:
mirror["raw_prefix"] = raw_prefix
if clone_prefix is not None:
mirror["clone_prefix"] = clone_prefix
if enabled is not None:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
for i, mirror in enumerate(self.mirrors):
if mirror.get("id") == mirror_id:
self.mirrors.pop(i)
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
return [m["id"] for m in enabled]
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(
self,
max_retries: int = 3,
timeout: int = 30,
config: Optional[GitMirrorConfig] = None
):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间
config: 镜像源配置管理器可选默认创建新实例
"""
self.max_retries = max_retries
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
- version: str - Git 版本号如果已安装
- path: str - Git 可执行文件路径如果已安装
- error: str - 错误信息如果未安装或检测失败
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {
"installed": False,
"error": "系统中未找到 Git请先安装 Git"
}
# 获取 Git 版本
result = subprocess.run(
["git", "--version"],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {
"installed": True,
"version": version,
"path": git_path
}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {
"installed": False,
"error": f"Git 命令执行失败: {result.stderr}"
}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {
"installed": False,
"error": "Git 版本检测超时"
}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {
"installed": False,
"error": f"检测 Git 时发生错误: {str(e)}"
}
async def fetch_raw_file(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
branch: 分支名称
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
- data: str - 文件内容成功时
- error: str - 错误信息失败时
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
if _update_progress:
try:
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
await _update_progress(
stage="loading",
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(
owner, repo, branch, file_path, mirror
)
if result["success"]:
# 成功,推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
stage="loading",
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
async def _fetch_raw_from_mirror(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
logger.debug(f"尝试 #{attempt + 1}: {url}")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except httpx.TimeoutException as e:
last_error = f"请求超时: {e}"
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
async def clone_repository(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
target_path: 目标路径
branch: 分支名称可选
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度浅克隆
Returns:
Dict 包含:
- success: bool - 是否成功
- path: str - 克隆路径成功时
- error: str - 错误信息失败时
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {
"success": False,
"error": f"未找到镜像源: {mirror_id}",
"mirror_used": None,
"attempts": 0
}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(
owner, repo, target_path, branch, depth, mirror
)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {
"success": False,
"error": "所有镜像源克隆均失败",
"mirror_used": None,
"attempts": len(mirrors_to_try)
}
async def _clone_from_mirror(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self,
url: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install"
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone():
return subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=300 # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
"success": True,
"path": str(target_path),
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default"
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {
"success": False,
"error": last_error,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url
}
# 全局服务实例
_git_mirror_service: Optional[GitMirrorService] = None
def get_git_mirror_service() -> GitMirrorService:
"""获取 Git 镜像源服务实例(单例)"""
global _git_mirror_service
if _git_mirror_service is None:
_git_mirror_service = GitMirrorService()
return _git_mirror_service

View File

View File

View File

@ -0,0 +1,138 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set
import json
from pathlib import Path
from src.common.logger import get_logger
logger = get_logger("webui.logs_ws")
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# 从文件末尾开始读取
for line in reversed(lines):
if len(logs) >= limit:
break
try:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
"level": log_entry.get("level", "INFO").upper(),
"module": log_entry.get("logger_name", ""),
"message": log_entry.get("event", ""),
}
logs.append(formatted_log)
log_counter += 1
except (json.JSONDecodeError, KeyError):
continue
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")

View File

@ -31,6 +31,14 @@ def setup_webui(mode: str = "production") -> bool:
def setup_dev_mode() -> bool:
"""设置开发模式 - 仅启用 CORS前端自行启动"""
from src.common.server import get_global_server
from .logs_ws import router as logs_router
# 注册 WebSocket 日志路由(开发模式也需要)
server = get_global_server()
server.register_router(logs_router)
logger.info("✅ WebSocket 日志推送路由已注册")
logger.info("📝 WebUI 开发模式已启用")
logger.info("🌐 请手动启动前端开发服务器: cd webui && npm run dev")
logger.info("💡 前端将运行在 http://localhost:7999")
@ -41,10 +49,23 @@ def setup_production_mode() -> bool:
"""设置生产模式 - 挂载静态文件"""
try:
from src.common.server import get_global_server
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from starlette.responses import FileResponse
from .logs_ws import router as logs_router
import mimetypes
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type('application/javascript', '.js')
mimetypes.add_type('application/javascript', '.mjs')
mimetypes.add_type('text/css', '.css')
mimetypes.add_type('application/json', '.json')
server = get_global_server()
# 注册 WebSocket 日志路由
server.register_router(logs_router)
logger.info("✅ WebSocket 日志推送路由已注册")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
@ -58,14 +79,6 @@ def setup_production_mode() -> bool:
logger.warning("💡 请确认前端已正确构建")
return False
# 挂载静态资源
if (static_path / "assets").exists():
server.app.mount(
"/assets",
StaticFiles(directory=str(static_path / "assets")),
name="assets"
)
# 处理 SPA 路由
@server.app.get("/{full_path:path}")
async def serve_spa(full_path: str):
@ -77,10 +90,12 @@ def setup_production_mode() -> bool:
# 检查文件是否存在
file_path = static_path / full_path
if file_path.is_file():
return FileResponse(file_path)
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
return FileResponse(file_path, media_type=media_type)
# 返回 index.htmlSPA 路由)
return FileResponse(static_path / "index.html")
return FileResponse(static_path / "index.html", media_type="text/html")
host = os.getenv("HOST", "127.0.0.1")
port = os.getenv("PORT", "8000")

View File

@ -0,0 +1,365 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from .token_manager import get_token_manager
import json
import time
logger = get_logger("webui.person")
# 创建路由器
router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
person_name: Optional[str]
name_reason: Optional[str]
platform: str
user_id: str
nickname: Optional[str]
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
memory_points: Optional[str]
know_times: Optional[float]
know_since: Optional[float]
last_know: Optional[float]
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[PersonInfoResponse]
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
memory_points: Optional[str] = None
is_known: Optional[bool] = None
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
def verify_auth_token(authorization: Optional[str]) -> bool:
"""验证认证 Token"""
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token = authorization.replace("Bearer ", "")
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
"""解析群昵称 JSON 字符串"""
if not group_nick_name_str:
return None
try:
return json.loads(group_nick_name_str)
except (json.JSONDecodeError, TypeError):
return None
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
"""将 PersonInfo 模型转换为响应对象"""
return PersonInfoResponse(
id=person.id,
is_known=person.is_known,
person_id=person.person_id,
person_name=person.person_name,
name_reason=person.name_reason,
platform=person.platform,
user_id=person.user_id,
nickname=person.nickname,
group_nick_name=parse_group_nick_name(person.group_nick_name),
memory_points=person.memory_points,
know_times=person.know_times,
know_since=person.know_since,
last_know=person.last_know,
)
@router.get("/list", response_model=PersonListResponse)
async def get_person_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
authorization: Optional[str] = Header(None)
):
"""
获取人物信息列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 person_name, nickname, user_id)
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(authorization)
# 构建查询
query = PersonInfo.select()
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search)) |
(PersonInfo.nickname.contains(search)) |
(PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(
Case(None, [(PersonInfo.last_know.is_null(), 1)], 0),
PersonInfo.last_know.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物列表失败: {str(e)}") from e
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str,
authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(
success=True,
data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物详情失败: {str(e)}") from e
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
authorization: Optional[str] = Header(None)
):
"""
增量更新人物信息只更新提供的字段
Args:
person_id: 人物唯一 ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data['last_know'] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True,
message=f"成功更新 {len(update_data)} 个字段",
data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"更新人物信息失败: {str(e)}") from e
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str,
authorization: Optional[str] = Header(None)
):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
# 执行删除
person.delete_instance()
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(
success=True,
message=f"成功删除人物信息: {person_name}"
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"删除人物信息失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_person_stats(
authorization: Optional[str] = Header(None)
):
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {
"success": True,
"data": {
"total": total,
"known": known,
"unknown": unknown,
"platforms": platforms
}
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e

View File

@ -0,0 +1,127 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from typing import Set, Dict, Any
import json
import asyncio
from src.common.logger import get_logger
logger = get_logger("webui.plugin_progress")
# 创建路由器
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
# 当前加载进度状态
current_progress: Dict[str, Any] = {
"operation": "idle", # idle, fetch, install, uninstall, update
"stage": "idle", # idle, loading, success, error
"progress": 0, # 0-100
"message": "",
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0
}
async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
async def update_progress(
stage: str,
progress: int,
message: str,
operation: str = "fetch",
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
message: 当前消息
operation: 操作类型 (fetch, install, uninstall, update)
error: 错误信息可选
plugin_id: 当前操作的插件 ID
total_plugins: 总插件数
loaded_plugins: 已加载插件数
"""
progress_data = {
"operation": operation,
"stage": stage,
"progress": progress,
"message": message,
"error": error,
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time()
}
await broadcast_progress(progress_data)
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
"""
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接,当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取插件进度 WebSocket 路由器"""
return router

File diff suppressed because it is too large Load Diff

View File

@ -4,12 +4,34 @@ from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
from .expression_routes import router as expression_router
from .emoji_routes import router as emoji_router
from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
logger = get_logger("webui.api")
# 创建路由器
router = APIRouter(prefix="/api/webui", tags=["WebUI"])
# 注册配置管理路由
router.include_router(config_router)
# 注册统计数据路由
router.include_router(statistics_router)
# 注册人物信息管理路由
router.include_router(person_router)
# 注册表达方式管理路由
router.include_router(expression_router)
# 注册表情包管理路由
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
class TokenVerifyRequest(BaseModel):
"""Token 验证请求"""

View File

@ -0,0 +1,329 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any, List
from datetime import datetime, timedelta
from collections import defaultdict
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
online_time: float = Field(0.0, description="在线时间(秒)")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
avg_response_time: float = Field(0.0, description="平均响应时间")
cost_per_hour: float = Field(0.0, description="每小时花费")
tokens_per_hour: float = Field(0.0, description="每小时token数")
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
total_tokens: int
avg_response_time: float
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
tokens: int = 0
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24):
"""
获取仪表盘统计数据
Args:
hours: 统计时间范围小时默认24小时
Returns:
仪表盘数据
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
# 获取摘要数据
summary = await _get_summary_statistics(start_time, now)
# 获取模型统计
model_stats = await _get_model_statistics(start_time)
# 获取小时级时间序列数据
hourly_data = await _get_hourly_statistics(start_time, now)
# 获取日级时间序列数据最近7天
daily_start = now - timedelta(days=7)
daily_data = await _get_daily_statistics(daily_start, now)
# 获取最近活动
recent_activity = await _get_recent_activity(limit=10)
return DashboardData(
summary=summary,
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据"""
summary = StatisticsSummary()
# 查询 LLM 使用记录
llm_records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
total_time_cost = 0.0
time_cost_count = 0
for record in llm_records:
summary.total_requests += 1
summary.total_cost += record.cost or 0.0
summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
total_time_cost += record.time_cost
time_cost_count += 1
# 计算平均响应时间
if time_cost_count > 0:
summary.avg_response_time = total_time_cost / time_cost_count
# 查询在线时间
online_records = list(
OnlineTime.select()
.where(
(OnlineTime.start_timestamp >= start_time) |
(OnlineTime.end_timestamp >= start_time)
)
)
for record in online_records:
start = max(record.start_timestamp, start_time)
end = min(record.end_timestamp, end_time)
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量
messages = list(
Messages.select()
.where(Messages.time >= start_time.timestamp())
.where(Messages.time <= end_time.timestamp())
)
summary.total_messages = len(messages)
# 简单统计:如果 reply_to 不为空,则认为是回复
summary.total_replies = len([m for m in messages if m.reply_to])
# 计算派生指标
if summary.online_time > 0:
online_hours = summary.online_time / 3600.0
summary.cost_per_hour = summary.total_cost / online_hours
summary.tokens_per_hour = summary.total_tokens / online_hours
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据"""
model_data = defaultdict(lambda: {
'request_count': 0,
'total_cost': 0.0,
'total_tokens': 0,
'time_costs': []
})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
)
for record in records:
model_name = record.model_assign_name or record.model_name or "unknown"
model_data[model_name]['request_count'] += 1
model_data[model_name]['total_cost'] += record.cost or 0.0
model_data[model_name]['total_tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
if record.time_cost and record.time_cost > 0:
model_data[model_name]['time_costs'].append(record.time_cost)
# 转换为列表并排序
result = []
for model_name, data in model_data.items():
avg_time = sum(data['time_costs']) / len(data['time_costs']) if data['time_costs'] else 0.0
result.append(ModelStatistics(
model_name=model_name,
request_count=data['request_count'],
total_cost=data['total_cost'],
total_tokens=data['total_tokens'],
avg_response_time=avg_time
))
# 按请求数排序
result.sort(key=lambda x: x.request_count, reverse=True)
return result[:10] # 返回前10个
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据"""
# 创建小时桶
hourly_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
for record in records:
# 获取小时键(去掉分钟和秒)
hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0)
hour_str = hour_key.isoformat()
hourly_buckets[hour_str]['requests'] += 1
hourly_buckets[hour_str]['cost'] += record.cost or 0.0
hourly_buckets[hour_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.isoformat()
data = hourly_buckets.get(hour_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=hour_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
current += timedelta(hours=1)
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据"""
daily_buckets = defaultdict(lambda: {'requests': 0, 'cost': 0.0, 'tokens': 0})
records = list(
LLMUsage.select()
.where(LLMUsage.timestamp >= start_time)
.where(LLMUsage.timestamp <= end_time)
)
for record in records:
# 获取日期键
day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0)
day_str = day_key.isoformat()
daily_buckets[day_str]['requests'] += 1
daily_buckets[day_str]['cost'] += record.cost or 0.0
daily_buckets[day_str]['tokens'] += (record.prompt_tokens or 0) + (record.completion_tokens or 0)
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.isoformat()
data = daily_buckets.get(day_str, {'requests': 0, 'cost': 0.0, 'tokens': 0})
result.append(TimeSeriesData(
timestamp=day_str,
requests=data['requests'],
cost=data['cost'],
tokens=data['tokens']
))
current += timedelta(days=1)
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(
LLMUsage.select()
.order_by(LLMUsage.timestamp.desc())
.limit(limit)
)
activities = []
for record in records:
activities.append({
'timestamp': record.timestamp.isoformat(),
'model': record.model_assign_name or record.model_name,
'request_type': record.request_type,
'tokens': (record.prompt_tokens or 0) + (record.completion_tokens or 0),
'cost': record.cost or 0.0,
'time_cost': record.time_cost or 0.0,
'status': record.status
})
return activities
@router.get("/summary")
async def get_summary(hours: int = 24):
"""
获取统计摘要
Args:
hours: 统计时间范围小时
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
summary = await _get_summary_statistics(start_time, now)
return summary
except Exception as e:
logger.error(f"获取统计摘要失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/models")
async def get_model_stats(hours: int = 24):
"""
获取模型统计
Args:
hours: 统计时间范围小时
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
stats = await _get_model_statistics(start_time)
return stats
except Exception as e:
logger.error(f"获取模型统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@ -1,5 +1,5 @@
[inner]
version = "6.21.4"
version = "6.21.6"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请递增version的值
@ -104,7 +104,7 @@ talk_value_rules = [
include_planner_reasoning = false # 是否将planner推理加入replyer默认关闭不加入
[memory]
max_agent_iterations = 5 # 记忆思考深度最低为1不深入思考
max_agent_iterations = 3 # 记忆思考深度最低为1不深入思考
[jargon]
all_global = true # 是否开启全局黑话模式,注意,此功能关闭后,已经记录的全局黑话不会改变,需要手动删除

View File

@ -1,5 +1,5 @@
[inner]
version = "1.7.7"
version = "1.7.8"
# 配置文件版本号迭代规则同bot_config.toml

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

14
webui/dist/index.html vendored 100644
View File

@ -0,0 +1,14 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" type="image/x-icon" href="/maimai.ico" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>MaiBot Dashboard</title>
<script type="module" crossorigin src="/assets/index-DYT0dd6E.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-BjjI9czp.css">
</head>
<body>
<div id="root"></div>
</body>
</html>

BIN
webui/dist/maimai.ico vendored 100644

Binary file not shown.

After

Width:  |  Height:  |  Size: 66 KiB

1
webui/dist/vite.svg vendored 100644
View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="31.88" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 257"><defs><linearGradient id="IconifyId1813088fe1fbc01fb466" x1="-.828%" x2="57.636%" y1="7.652%" y2="78.411%"><stop offset="0%" stop-color="#41D1FF"></stop><stop offset="100%" stop-color="#BD34FE"></stop></linearGradient><linearGradient id="IconifyId1813088fe1fbc01fb467" x1="43.376%" x2="50.316%" y1="2.242%" y2="89.03%"><stop offset="0%" stop-color="#FFEA83"></stop><stop offset="8.333%" stop-color="#FFDD35"></stop><stop offset="100%" stop-color="#FFA800"></stop></linearGradient></defs><path fill="url(#IconifyId1813088fe1fbc01fb466)" d="M255.153 37.938L134.897 252.976c-2.483 4.44-8.862 4.466-11.382.048L.875 37.958c-2.746-4.814 1.371-10.646 6.827-9.67l120.385 21.517a6.537 6.537 0 0 0 2.322-.004l117.867-21.483c5.438-.991 9.574 4.796 6.877 9.62Z"></path><path fill="url(#IconifyId1813088fe1fbc01fb467)" d="M185.432.063L96.44 17.501a3.268 3.268 0 0 0-2.634 3.014l-5.474 92.456a3.268 3.268 0 0 0 3.997 3.378l24.777-5.718c2.318-.535 4.413 1.507 3.936 3.838l-7.361 36.047c-.495 2.426 1.782 4.5 4.151 3.78l15.304-4.649c2.372-.72 4.652 1.36 4.15 3.788l-11.698 56.621c-.732 3.542 3.979 5.473 5.943 2.437l1.313-2.028l72.516-144.72c1.215-2.423-.88-5.186-3.54-4.672l-25.505 4.922c-2.396.462-4.435-1.77-3.759-4.114l16.646-57.705c.677-2.35-1.37-4.583-3.769-4.113Z"></path></svg>

After

Width:  |  Height:  |  Size: 1.5 KiB