remove webui

r-dev
UnCLAS-Prommer 2026-01-12 18:25:19 +08:00
parent 072939e8e9
commit afb993e481
No known key found for this signature in database
105 changed files with 0 additions and 13434 deletions

View File

@ -1 +0,0 @@
"""WebUI 模块"""

View File

@ -1,938 +0,0 @@
"""麦麦 2025 年度总结 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import (
LLMUsage,
OnlineTime,
Messages,
ChatStreams,
PersonInfo,
Emoji,
Expression,
ActionRecords,
Jargon,
)
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.annual_report")
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ==================== Pydantic 模型定义 ====================
class TimeFootprintData(BaseModel):
"""时光足迹数据"""
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
first_message_time: Optional[str] = Field(None, description="初次消息时间")
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
is_night_owl: bool = Field(False, description="是否是夜猫子")
class SocialNetworkData(BaseModel):
"""社交网络数据"""
total_groups: int = Field(0, description="加入的群组总数")
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
at_count: int = Field(0, description="被@次数")
mentioned_count: int = Field(0, description="被提及次数")
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
longest_companion_days: int = Field(0, description="陪伴天数")
class BrainPowerData(BaseModel):
"""最强大脑数据"""
total_tokens: int = Field(0, description="年度消耗Token总量")
total_cost: float = Field(0.0, description="年度总花费")
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
total_actions: int = Field(0, description="总动作数")
no_reply_count: int = Field(0, description="选择沉默的次数")
avg_interest_value: float = Field(0.0, description="平均兴趣值")
max_interest_value: float = Field(0.0, description="最高兴趣值")
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
max_reasoning_length: int = Field(0, description="最长思考长度")
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
class ExpressionVibeData(BaseModel):
"""个性与表达数据"""
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
checked_expression_count: int = Field(0, description="已检查的表达次数")
total_expressions: int = Field(0, description="表达总数")
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
image_processed_count: int = Field(0, description="处理的图片数量")
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
class AchievementData(BaseModel):
"""趣味成就数据"""
new_jargon_count: int = Field(0, description="新学到的黑话数量")
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
class AnnualReportData(BaseModel):
"""年度报告完整数据"""
year: int = Field(2025, description="报告年份")
bot_name: str = Field("麦麦", description="Bot名称")
generated_at: str = Field(..., description="报告生成时间")
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
achievements: AchievementData = Field(default_factory=AchievementData)
# ==================== 辅助函数 ====================
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
"""获取指定年份的时间戳范围"""
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
return start, end
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
"""获取指定年份的 datetime 范围"""
start = datetime(year, 1, 1, 0, 0, 0)
end = datetime(year, 12, 31, 23, 59, 59)
return start, end
# ==================== 维度一:时光足迹 ====================
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
"""获取时光足迹数据"""
data = TimeFootprintData()
start_ts, end_ts = get_year_time_range(year)
start_dt, end_dt = get_year_datetime_range(year)
try:
# 1. 年度在线时长
online_records = list(
OnlineTime.select().where(
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
)
)
total_seconds = 0
for record in online_records:
try:
start = max(record.start_timestamp, start_dt)
end = min(record.end_timestamp, end_dt)
if end > start:
total_seconds += (end - start).total_seconds()
except Exception:
continue
data.total_online_hours = round(total_seconds / 3600, 2)
# 2. 初次相遇 - 年度第一条消息
first_msg = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.order_by(Messages.time.asc())
.first()
)
if first_msg:
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
content = first_msg.processed_plain_text or first_msg.display_message or ""
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
# 3. 最忙碌的一天
# 使用 SQLite 的 date 函数按日期分组
busiest_query = (
Messages.select(
fn.date(Messages.time, "unixepoch").alias("day"),
fn.COUNT(Messages.id).alias("count"),
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.date(Messages.time, "unixepoch"))
.order_by(fn.COUNT(Messages.id).desc())
.limit(1)
)
busiest_result = list(busiest_query.dicts())
if busiest_result:
data.busiest_day = busiest_result[0].get("day")
data.busiest_day_count = busiest_result[0].get("count", 0)
# 4. 昼夜节律 - 24小时活跃分布
hourly_query = (
Messages.select(
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
fn.COUNT(Messages.id).alias("count"),
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
)
hourly_distribution = [0] * 24
for row in hourly_query.dicts():
try:
hour = int(row.get("hour", 0))
if 0 <= hour < 24:
hourly_distribution[hour] = row.get("count", 0)
except (ValueError, TypeError):
continue
data.hourly_distribution = hourly_distribution
# 5. 深夜食堂 (0-4点)
data.midnight_chat_count = sum(hourly_distribution[0:5])
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
morning_activity = sum(hourly_distribution[6:13])
data.is_night_owl = night_activity > morning_activity
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
return data
# ==================== 维度二:社交网络 ====================
async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 加入的群组总数
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
# 2. 话痨群组 TOP3
top_groups_query = (
Messages.select(
Messages.chat_info_group_id,
Messages.chat_info_group_name,
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.chat_info_group_id.is_null(False))
)
.group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
data.top_groups = [
{
"group_id": row["chat_info_group_id"],
"group_name": row["chat_info_group_name"] or "未知群组",
"message_count": row["count"],
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
}
for row in top_groups_query.dicts()
]
# 3. 互动最多的用户 TOP5过滤 bot 自身)
top_users_query = (
Messages.select(
Messages.user_id,
Messages.user_nickname,
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id.is_null(False))
& (Messages.user_id != bot_qq) # 过滤 bot 自身
)
.group_by(Messages.user_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
data.top_users = [
{
"user_id": row["user_id"],
"user_nickname": row["user_nickname"] or "未知用户",
"message_count": row["count"],
"is_webui": str(row["user_id"]).startswith("webui_"),
}
for row in top_users_query.dicts()
]
# 4. 被@次数
data.at_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_at == True)
)
.count()
)
# 5. 被提及次数
data.mentioned_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_mentioned == True)
)
.count()
)
# 6. 最长情陪伴的用户(过滤 bot 自身)
companion_query = (
ChatStreams.select(
ChatStreams.user_id,
ChatStreams.user_nickname,
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
)
.where(
(ChatStreams.user_id.is_null(False))
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
)
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1)
)
companion_result = list(companion_query.dicts())
if companion_result:
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
duration = companion_result[0].get("duration", 0) or 0
data.longest_companion_days = int(duration / 86400) # 转换为天
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
return data
# ==================== 维度三:最强大脑 ====================
async def get_brain_power(year: int = 2025) -> BrainPowerData:
"""获取最强大脑数据"""
data = BrainPowerData()
start_dt, end_dt = get_year_datetime_range(year)
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 年度消耗 Token 总量和总花费
token_query = LLMUsage.select(
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
result = token_query.dicts().get()
data.total_tokens = int(result.get("total_tokens", 0) or 0)
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
# 2. 最爱用的模型
model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10)
)
model_results = list(model_query.dicts())
if model_results:
data.favorite_model = model_results[0].get("model")
data.favorite_model_count = model_results[0].get("count", 0)
data.model_distribution = [
{
"model": row["model"],
"count": row["count"],
"tokens": row["tokens"],
"cost": round(row["cost"], 4),
}
for row in model_results
]
# 3. 最昂贵的一次思考
expensive_query = (
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.order_by(LLMUsage.cost.desc())
.limit(1)
)
expensive_result = expensive_query.first()
if expensive_result:
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
consumer_query = (
LLMUsage.select(
LLMUsage.user_id,
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (LLMUsage.user_id != "system") # 过滤 system 用户
& (LLMUsage.user_id.is_null(False))
)
.group_by(LLMUsage.user_id)
.order_by(fn.SUM(LLMUsage.cost).desc())
.limit(3)
)
data.top_token_consumers = [
{
"user_id": row["user_id"],
"cost": round(row["cost"], 4),
"tokens": row["tokens"],
}
for row in consumer_query.dicts()
]
# 5. 最喜欢的回复模型 TOP5按模型的回复次数统计只统计 replyer 调用)
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
reply_model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (
LLMUsage.model_assign_name.contains("replyer")
| LLMUsage.model_assign_name.contains("回复")
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
)
)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5)
)
data.top_reply_models = [
{"model": row["model"], "count": row["count"]}
for row in reply_model_query.dicts()
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = ActionRecords.select().where(
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
).count()
no_reply_count = ActionRecords.select().where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply")
).count()
data.total_actions = total_actions
data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
# 6. 情绪波动 (兴趣值)
interest_query = Messages.select(
fn.AVG(Messages.interest_value).alias("avg_interest"),
fn.MAX(Messages.interest_value).alias("max_interest"),
).where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value.is_null(False))
)
interest_result = interest_query.dicts().get()
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
# 找到最高兴趣值的时间
if data.max_interest_value > 0:
max_interest_msg = (
Messages.select(Messages.time)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value == data.max_interest_value)
)
.first()
)
if max_interest_msg:
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
"%Y-%m-%d %H:%M:%S"
)
# 7. 思考深度 (基于 action_reasoning 长度)
reasoning_records = (
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "")
)
)
reasoning_lengths = []
max_len = 0
max_len_time = None
for record in reasoning_records:
if record.action_reasoning:
length = len(record.action_reasoning)
reasoning_lengths.append(length)
if length > max_len:
max_len = length
max_len_time = record.time
if reasoning_lengths:
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
data.max_reasoning_length = max_len
if max_len_time:
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
return data
# ==================== 维度四:个性与表达 ====================
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
bot_qq = str(global_config.bot.qq_account or "")
try:
# 1. 表情包之王 - 使用次数最多的表情包
top_emoji_query = (
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
.where(Emoji.is_registered == True)
.order_by(Emoji.usage_count.desc())
.limit(5)
)
top_emojis = list(top_emoji_query.dicts())
if top_emojis:
data.top_emoji = {
"id": top_emojis[0].get("id"),
"path": top_emojis[0].get("full_path"),
"description": top_emojis[0].get("description"),
"usage_count": top_emojis[0].get("usage_count", 0),
"hash": top_emojis[0].get("emoji_hash"),
}
data.top_emojis = [
{
"id": e.get("id"),
"path": e.get("full_path"),
"description": e.get("description"),
"usage_count": e.get("usage_count", 0),
"hash": e.get("emoji_hash"),
}
for e in top_emojis
]
# 2. 百变麦麦 - 最常用的表达风格
expression_query = (
Expression.select(
Expression.style,
fn.SUM(Expression.count).alias("total_count"),
)
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.group_by(Expression.style)
.order_by(fn.SUM(Expression.count).desc())
.limit(5)
)
data.top_expressions = [
{"style": row["style"], "count": row["total_count"]}
for row in expression_query.dicts()
]
# 3. 被拒绝的表达
data.rejected_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.rejected == True)
)
.count()
)
# 4. 已检查的表达
data.checked_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.checked == True)
)
.count()
)
# 5. 表达总数
data.total_expressions = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.count()
)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [
"reply", "no_reply", "no_reply_until_call", "make_question",
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
]
action_query = (
ActionRecords.select(
ActionRecords.action_name,
fn.COUNT(ActionRecords.id).alias("count"),
)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name.not_in(excluded_actions))
)
.group_by(ActionRecords.action_name)
.order_by(fn.COUNT(ActionRecords.id).desc())
.limit(10)
)
data.action_types = [
{"action": row["action_name"], "count": row["count"]}
for row in action_query.dicts()
]
# 7. 处理的图片数量
data.image_processed_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_picid == True)
)
.count()
)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
import re
def clean_message_content(content: str) -> str:
"""清理消息内容,移除回复引用等标记"""
if not content:
return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r'\[回复<[^>]+>\s*的消息[:][^\]]*\]', '', content)
# 移除 [图片] [表情] 等标记
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
# 移除多余的空白
content = re.sub(r'\s+', ' ', content).strip()
return content
# 使用 user_id 判断是否是 bot 发送的消息
late_night_messages = list(
Messages.select(
Messages.time,
Messages.processed_plain_text,
Messages.display_message,
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id == bot_qq) # bot 发送的消息
)
.order_by(Messages.time.desc())
)
# 筛选出0-6点的消息
late_night_filtered = []
for msg in late_night_messages:
msg_dt = datetime.fromtimestamp(msg.time)
hour = msg_dt.hour
if 0 <= hour < 6: # 0点到6点
raw_content = msg.processed_plain_text or msg.display_message or ""
cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append({
"time": msg.time,
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
})
if len(late_night_filtered) >= 10:
break
if late_night_filtered:
selected = random.choice(late_night_filtered)
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
data.late_night_reply = {
"time": selected["datetime_str"],
"content": content,
}
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
from collections import Counter
import json as json_lib
reply_records = (
ActionRecords.select(ActionRecords.action_data)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
)
reply_contents = []
for record in reply_records:
try:
action_data = record.action_data
if action_data:
content = None
# 尝试解析 JSON 格式
try:
parsed = json_lib.loads(action_data)
if isinstance(parsed, dict):
# 优先使用 reply_text其次使用 content
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (json_lib.JSONDecodeError, TypeError):
pass
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
# 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None:
import ast
try:
parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict):
content = parsed.get("reply_text") or parsed.get("content")
elif isinstance(parsed, str):
content = parsed
except (ValueError, SyntaxError):
# 无法解析,使用原始字符串
content = action_data
# 只统计有意义的回复长度大于2
if content and len(content) > 2:
reply_contents.append(content)
except Exception:
continue
if reply_contents:
content_counter = Counter(reply_contents)
most_common = content_counter.most_common(1)
if most_common:
fav_content, fav_count = most_common[0]
# 截断过长的内容
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
data.favorite_reply = {
"content": display_content,
"count": fav_count,
}
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
return data
# ==================== 维度五:趣味成就 ====================
async def get_achievements(year: int = 2025) -> AchievementData:
"""获取趣味成就数据"""
data = AchievementData()
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
# 2. 代表性黑话示例
jargon_samples = (
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
.where(Jargon.is_jargon == True)
.order_by(Jargon.count.desc())
.limit(5)
)
data.sample_jargons = [
{
"content": j.content,
"meaning": j.meaning,
"count": j.count,
}
for j in jargon_samples
]
# 3. 总消息数
data.total_messages = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.count()
)
# 4. 总回复数 (有 reply_to 的消息)
data.total_replies = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.reply_to.is_null(False))
)
.count()
)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
return data
# ==================== API 路由 ====================
@router.get("/full", response_model=AnnualReportData)
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
"""
获取完整年度报告数据
Args:
year: 报告年份默认2025
Returns:
完整的年度报告数据
"""
try:
from src.config.config import global_config
logger.info(f"开始生成 {year} 年度报告...")
# 获取 bot 名称
bot_name = global_config.bot.nickname or "麦麦"
# 并行获取各维度数据
time_footprint = await get_time_footprint(year)
social_network = await get_social_network(year)
brain_power = await get_brain_power(year)
expression_vibe = await get_expression_vibe(year)
achievements = await get_achievements(year)
report = AnnualReportData(
year=year,
bot_name=bot_name,
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
time_footprint=time_footprint,
social_network=social_network,
brain_power=brain_power,
expression_vibe=expression_vibe,
achievements=achievements,
)
logger.info(f"{year} 年度报告生成完成")
return report
except Exception as e:
logger.error(f"生成年度报告失败: {e}")
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
@router.get("/time-footprint", response_model=TimeFootprintData)
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取时光足迹数据"""
try:
return await get_time_footprint(year)
except Exception as e:
logger.error(f"获取时光足迹数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/social-network", response_model=SocialNetworkData)
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取社交网络数据"""
try:
return await get_social_network(year)
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/brain-power", response_model=BrainPowerData)
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取最强大脑数据"""
try:
return await get_brain_power(year)
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/expression-vibe", response_model=ExpressionVibeData)
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取个性与表达数据"""
try:
return await get_expression_vibe(year)
except Exception as e:
logger.error(f"获取个性与表达数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/achievements", response_model=AchievementData)
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
"""获取趣味成就数据"""
try:
return await get_achievements(year)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@ -1,795 +0,0 @@
"""
WebUI 防爬虫模块
提供爬虫检测和阻止功能保护 WebUI 不被搜索引擎和恶意爬虫访问
"""
import time
import ipaddress
import re
from collections import deque
from typing import Optional
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import PlainTextResponse
from src.common.logger import get_logger
logger = get_logger("webui.anti_crawler")
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
CRAWLER_USER_AGENTS = {
# 搜索引擎爬虫(精确匹配)
"googlebot",
"bingbot",
"baiduspider",
"yandexbot",
"slurp", # Yahoo
"duckduckbot",
"sogou",
"exabot",
"facebot",
"ia_archiver", # Internet Archive
# 通用爬虫(移除过于宽泛的关键词)
"crawler",
"spider",
"scraper",
"wget", # 保留wget因为通常用于自动化脚本
"scrapy", # 保留scrapy因为这是爬虫框架
# 安全扫描工具(这些是明确的扫描工具)
"masscan",
"nmap",
"nikto",
"sqlmap",
# 注意:移除了以下过于宽泛的关键词以避免误报:
# - "bot" (会误匹配GitHub-Robot等)
# - "curl" (正常工具)
# - "python-requests" (正常库)
# - "httpx" (正常库)
# - "aiohttp" (正常库)
}
# 资产测绘工具 User-Agent 标识
ASSET_SCANNER_USER_AGENTS = {
# 知名资产测绘平台
"shodan",
"censys",
"zoomeye",
"fofa",
"quake",
"hunter",
"binaryedge",
"onyphe",
"securitytrails",
"virustotal",
"passivetotal",
# 安全扫描工具
"acunetix",
"appscan",
"burpsuite",
"nessus",
"openvas",
"qualys",
"rapid7",
"tenable",
"veracode",
"zap",
"awvs", # Acunetix Web Vulnerability Scanner
"netsparker",
"skipfish",
"w3af",
"arachni",
# 其他扫描工具
"masscan",
"zmap",
"nmap",
"whatweb",
"wpscan",
"joomscan",
"dnsenum",
"subfinder",
"amass",
"sublist3r",
"theharvester",
}
# 资产测绘工具常用的HTTP头标识
ASSET_SCANNER_HEADERS = {
# 常见的扫描工具自定义头
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
"x-scanner": {"nmap", "masscan", "zmap"},
"x-probe": {"masscan", "zmap"},
# 其他可疑头(移除反向代理标准头)
"x-originating-ip": set(),
"x-remote-ip": set(),
"x-remote-addr": set(),
# 注意:移除了以下反向代理标准头以避免误报:
# - "x-forwarded-proto" (反向代理标准头)
# - "x-real-ip" (反向代理标准头已在_get_client_ip中使用)
}
# 仅检查特定HTTP头中的可疑模式收紧匹配范围
# 只检查这些特定头,不检查所有头
SCANNER_SPECIFIC_HEADERS = {
"x-scan",
"x-scanner",
"x-probe",
"x-originating-ip",
"x-remote-ip",
"x-remote-addr",
}
# 防爬虫模式配置
# false: 禁用
# strict: 严格模式(更严格的检测,更低的频率限制)
# loose: 宽松模式(较宽松的检测,较高的频率限制)
# basic: 基础模式只记录恶意访问不阻止不限制请求数不跟踪IP
# IP白名单配置从配置文件读取逗号分隔
# 支持格式:
# - 精确IP127.0.0.1, 192.168.1.100
# - CIDR格式192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
# - 通配符192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
# - IPv6::1, 2001:db8::/32
def _parse_allowed_ips(ip_string: str) -> list:
"""
解析IP白名单字符串支持精确IPCIDR格式和通配符
Args:
ip_string: 逗号分隔的IP字符串
Returns:
IP白名单列表每个元素可能是
- ipaddress.IPv4Network/IPv6Network对象CIDR格式
- ipaddress.IPv4Address/IPv6Address对象精确IP
- str通配符模式已转换为正则表达式
"""
allowed = []
if not ip_string:
return allowed
for ip_entry in ip_string.split(","):
ip_entry = ip_entry.strip() # 去除空格
if not ip_entry:
continue
# 跳过注释行(以#开头)
if ip_entry.startswith("#"):
continue
# 检查通配符格式(包含*
if "*" in ip_entry:
# 处理通配符
pattern = _convert_wildcard_to_regex(ip_entry)
if pattern:
allowed.append(pattern)
else:
logger.warning(f"无效的通配符IP格式已忽略: {ip_entry}")
continue
try:
# 尝试解析为CIDR格式包含/
if "/" in ip_entry:
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
else:
# 精确IP地址
allowed.append(ipaddress.ip_address(ip_entry))
except (ValueError, AttributeError) as e:
logger.warning(f"无效的IP白名单条目已忽略: {ip_entry} ({e})")
return allowed
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
"""
将通配符IP模式转换为正则表达式
支持的格式
- 192.168.*.* 192.168.*
- 10.*.*.* 10.*
- *.*.*.* *
Args:
wildcard_pattern: 通配符模式字符串
Returns:
正则表达式字符串如果格式无效则返回None
"""
# 去除空格
pattern = wildcard_pattern.strip()
# 处理单个*(匹配所有)
if pattern == "*":
return r".*"
# 处理IPv4通配符格式
# 支持192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
parts = pattern.split(".")
if len(parts) > 4:
return None # IPv4最多4段
# 构建正则表达式
regex_parts = []
for part in parts:
part = part.strip()
if part == "*":
regex_parts.append(r"\d+") # 匹配任意数字
elif part.isdigit():
# 验证数字范围0-255
num = int(part)
if 0 <= num <= 255:
regex_parts.append(re.escape(part))
else:
return None # 无效的数字
else:
return None # 无效的格式
# 如果部分少于4段补充.*
while len(regex_parts) < 4:
regex_parts.append(r"\d+")
# 组合成正则表达式
regex = r"^" + r"\.".join(regex_parts) + r"$"
return regex
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
def _get_anti_crawler_config():
"""获取防爬虫配置"""
from src.config.config import global_config
return {
'mode': global_config.webui.anti_crawler_mode,
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
'trust_xff': global_config.webui.trust_xff
}
# 初始化配置(将在模块加载时执行)
_config = _get_anti_crawler_config()
ANTI_CRAWLER_MODE = _config['mode']
ALLOWED_IPS = _config['allowed_ips']
TRUSTED_PROXIES = _config['trusted_proxies']
TRUST_XFF = _config['trust_xff']
def _get_mode_config(mode: str) -> dict:
"""
根据模式获取配置参数
Args:
mode: 防爬虫模式 (false/strict/loose/basic)
Returns:
配置字典包含所有相关参数
"""
mode = mode.lower()
if mode == "false":
return {
"enabled": False,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
"max_tracked_ips": 0,
"check_user_agent": False,
"check_asset_scanner": False,
"check_rate_limit": False,
"block_on_detect": False, # 不阻止
}
elif mode == "strict":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
"max_tracked_ips": 20000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
elif mode == "loose":
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
"max_tracked_ips": 5000,
"check_user_agent": True,
"check_asset_scanner": True,
"check_rate_limit": True,
"block_on_detect": True, # 阻止恶意访问
}
else: # basic (默认模式)
return {
"enabled": True,
"rate_limit_window": 60,
"rate_limit_max_requests": 1000, # 不限制请求数
"max_tracked_ips": 0, # 不跟踪IP
"check_user_agent": True, # 检测但不阻止
"check_asset_scanner": True, # 检测但不阻止
"check_rate_limit": False, # 不限制请求频率
"block_on_detect": False, # 只记录,不阻止
}
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
"""防爬虫中间件"""
def __init__(self, app, mode: str = "standard"):
"""
初始化防爬虫中间件
Args:
app: FastAPI 应用实例
mode: 防爬虫模式 (false/strict/loose/standard)
"""
super().__init__(app)
self.mode = mode.lower()
# 根据模式获取配置
config = _get_mode_config(self.mode)
self.enabled = config["enabled"]
self.rate_limit_window = config["rate_limit_window"]
self.rate_limit_max_requests = config["rate_limit_max_requests"]
self.max_tracked_ips = config["max_tracked_ips"]
self.check_user_agent = config["check_user_agent"]
self.check_asset_scanner = config["check_asset_scanner"]
self.check_rate_limit = config["check_rate_limit"]
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
# 用于存储每个IP的请求时间戳使用deque提高性能
self.request_times: dict[str, deque] = {}
# 上次清理时间
self.last_cleanup = time.time()
# 将关键词列表转换为集合以提高查找性能
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
"""
检测是否为爬虫 User-Agent
Args:
user_agent: User-Agent 字符串
Returns:
如果是爬虫则返回 True
"""
if not user_agent:
# 没有 User-Agent 的请求记录日志但不直接阻止
# 改为只记录,让频率限制来处理
logger.debug("请求缺少User-Agent")
return False # 不再直接阻止无User-Agent的请求
user_agent_lower = user_agent.lower()
# 使用集合查找提高性能(检查是否包含爬虫关键词)
for crawler_keyword in self.crawler_keywords_set:
if crawler_keyword in user_agent_lower:
return True
return False
def _is_asset_scanner_header(self, request: Request) -> bool:
"""
检测是否为资产测绘工具的HTTP头只检查特定头收紧匹配
Args:
request: 请求对象
Returns:
如果检测到资产测绘工具头则返回 True
"""
# 只检查特定的扫描工具头,不检查所有头
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
header_value_lower = header_value.lower() if header_value else ""
# 检查已知的扫描工具头
if header_name_lower in ASSET_SCANNER_HEADERS:
# 如果该头有特定的工具集合,检查值是否匹配
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
if expected_tools:
for tool in expected_tools:
if tool in header_value_lower:
return True
else:
# 如果没有特定工具集合,只要存在该头就视为可疑
if header_value_lower:
return True
# 只检查特定头中的可疑模式(收紧匹配)
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
# 检查头值中是否包含已知扫描工具名称
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
return True
return False
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
"""
检测资产测绘工具
Args:
request: 请求对象
Returns:
(是否检测到, 检测到的工具名称)
"""
user_agent = request.headers.get("User-Agent")
# 检查 User-Agent使用集合查找提高性能
if user_agent:
user_agent_lower = user_agent.lower()
for scanner_keyword in self.scanner_keywords_set:
if scanner_keyword in user_agent_lower:
return True, scanner_keyword
# 检查HTTP头
if self._is_asset_scanner_header(request):
# 尝试从User-Agent或头中提取工具名称
detected_tool = None
if user_agent:
user_agent_lower = user_agent.lower()
for tool in self.scanner_keywords_set:
if tool in user_agent_lower:
detected_tool = tool
break
# 检查HTTP头中的工具标识只检查特定头
if not detected_tool:
for header_name, header_value in request.headers.items():
header_name_lower = header_name.lower()
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
header_value_lower = (header_value or "").lower()
for tool in self.scanner_keywords_set:
if tool in header_value_lower:
detected_tool = tool
break
if detected_tool:
break
return True, detected_tool or "unknown_scanner"
return False, None
def _check_rate_limit(self, client_ip: str) -> bool:
"""
检查请求频率限制
Args:
client_ip: 客户端IP地址
Returns:
如果超过限制则返回 True需要阻止
"""
# 检查IP白名单
if self._is_ip_allowed(client_ip):
return False
current_time = time.time()
# 定期清理过期的请求记录每5分钟清理一次
if current_time - self.last_cleanup > 300:
self._cleanup_old_requests(current_time)
self.last_cleanup = current_time
# 限制跟踪的IP数量防止内存泄漏
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
# 清理最旧的记录删除最久未访问的IP
self._cleanup_oldest_ips()
# 获取或创建该IP的请求时间deque不使用maxlen避免限流变松
if client_ip not in self.request_times:
self.request_times[client_ip] = deque()
request_times = self.request_times[client_ip]
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
while request_times and current_time - request_times[0] >= self.rate_limit_window:
request_times.popleft()
# 检查是否超过限制
if len(request_times) >= self.rate_limit_max_requests:
return True
# 记录当前请求时间
request_times.append(current_time)
return False
def _cleanup_old_requests(self, current_time: float):
"""清理过期的请求记录只清理当前需要检查的IP不全量遍历"""
# 这个方法现在主要用于定期清理实际清理在_check_rate_limit中按需进行
# 清理最久未访问的IP记录
if len(self.request_times) > self.max_tracked_ips * 0.8:
self._cleanup_oldest_ips()
def _cleanup_oldest_ips(self):
"""清理最久未访问的IP记录全量遍历找真正的oldest"""
if not self.request_times:
return
# 先收集空deque的IP优先删除
empty_ips = []
# 找到最久未访问的IP最旧时间戳
oldest_ip = None
oldest_time = float("inf")
# 全量遍历找真正的oldest超限时性能可接受
for ip, times in self.request_times.items():
if not times:
# 空deque记录待删除
empty_ips.append(ip)
else:
# 找到最旧的时间戳
if times[0] < oldest_time:
oldest_time = times[0]
oldest_ip = ip
# 先删除空deque的IP
for ip in empty_ips:
del self.request_times[ip]
# 如果没有空deque可删除且仍需要清理删除最旧的一个IP
if not empty_ips and oldest_ip:
del self.request_times[oldest_ip]
def _is_trusted_proxy(self, ip: str) -> bool:
"""
检查IP是否在信任的代理列表中
Args:
ip: IP地址字符串
Returns:
如果是信任的代理则返回 True
"""
if not TRUSTED_PROXIES or ip == "unknown":
return False
# 检查代理列表中的每个条目
for trusted_entry in TRUSTED_PROXIES:
# 通配符模式(字符串,正则表达式)
if isinstance(trusted_entry, str):
try:
if re.match(trusted_entry, ip):
return True
except re.error:
continue
# CIDR格式网络对象
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in trusted_entry:
return True
except (ValueError, AttributeError):
continue
# 精确IP地址对象
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == trusted_entry:
return True
except (ValueError, AttributeError):
continue
return False
def _get_client_ip(self, request: Request) -> str:
"""
获取客户端真实IP地址带基本验证和代理信任检查
Args:
request: 请求对象
Returns:
客户端IP地址
"""
# 获取直接连接的客户端IP用于验证代理
direct_client_ip = None
if request.client:
direct_client_ip = request.client.host
# 检查是否信任X-Forwarded-For头
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
use_xff = False
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
use_xff = self._is_trusted_proxy(direct_client_ip)
# 如果信任代理,优先从 X-Forwarded-For 获取
if use_xff:
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# X-Forwarded-For 可能包含多个IP取第一个
ip = forwarded_for.split(",")[0].strip()
# 基本验证IP格式
if self._validate_ip(ip):
return ip
# 从 X-Real-IP 获取(如果信任代理)
if use_xff:
real_ip = request.headers.get("X-Real-IP")
if real_ip:
ip = real_ip.strip()
if self._validate_ip(ip):
return ip
# 使用直接连接的客户端IP
if direct_client_ip and self._validate_ip(direct_client_ip):
return direct_client_ip
return "unknown"
def _validate_ip(self, ip: str) -> bool:
"""
验证IP地址格式
Args:
ip: IP地址字符串
Returns:
如果格式有效则返回 True
"""
try:
ipaddress.ip_address(ip)
return True
except (ValueError, AttributeError):
return False
def _is_ip_allowed(self, ip: str) -> bool:
"""
检查IP是否在白名单中支持精确IPCIDR格式和通配符
Args:
ip: 客户端IP地址
Returns:
如果IP在白名单中则返回 True
"""
if not ALLOWED_IPS or ip == "unknown":
return False
# 检查白名单中的每个条目
for allowed_entry in ALLOWED_IPS:
# 通配符模式(字符串,正则表达式)
if isinstance(allowed_entry, str):
try:
if re.match(allowed_entry, ip):
return True
except re.error:
# 正则表达式错误,跳过
continue
# CIDR格式网络对象
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj in allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
# 精确IP地址对象
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
try:
client_ip_obj = ipaddress.ip_address(ip)
if client_ip_obj == allowed_entry:
return True
except (ValueError, AttributeError):
# IP格式无效跳过
continue
return False
async def dispatch(self, request: Request, call_next):
"""
处理请求
Args:
request: 请求对象
call_next: 下一个中间件或路由处理函数
Returns:
响应对象
"""
# 如果未启用,直接通过
if not self.enabled:
return await call_next(request)
# 允许访问 robots.txt由专门的路由处理
if request.url.path == "/robots.txt":
return await call_next(request)
# 允许访问静态资源CSS、JS、图片等
# 注意:.json 已移除,避免 API 路径绕过防护
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/
static_extensions = {
".css",
".js",
".png",
".jpg",
".jpeg",
".gif",
".svg",
".ico",
".woff",
".woff2",
".ttf",
".eot",
}
static_prefixes = {"/static/", "/assets/", "/dist/"}
# 检查是否是静态资源路径(特定前缀下的静态文件)
path = request.url.path
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
path.endswith(ext) for ext in static_extensions
)
# 也允许根路径下的静态文件(如 /favicon.ico
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
if is_static_path or is_root_static:
return await call_next(request)
# 获取客户端IP只获取一次避免重复调用
client_ip = self._get_client_ip(request)
# 检查IP白名单优先检查白名单IP直接通过
if self._is_ip_allowed(client_ip):
return await call_next(request)
# 获取 User-Agent
user_agent = request.headers.get("User-Agent")
# 检测资产测绘工具(优先检测,因为更危险)
if self.check_asset_scanner:
is_scanner, scanner_name = self._detect_asset_scanner(request)
if is_scanner:
logger.warning(
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
f"User-Agent: {user_agent}, Path: {request.url.path}"
)
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Asset scanning tools are not allowed",
status_code=403,
)
# 检测爬虫 User-Agent
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
# 根据配置决定是否阻止
if self.block_on_detect:
return PlainTextResponse(
"Access Denied: Crawlers are not allowed",
status_code=403,
)
# 检查请求频率限制
if self.check_rate_limit and self._check_rate_limit(client_ip):
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
return PlainTextResponse(
"Too Many Requests: Rate limit exceeded",
status_code=429,
)
# 正常请求,继续处理
return await call_next(request)
def create_robots_txt_response() -> PlainTextResponse:
"""
创建 robots.txt 响应
Returns:
robots.txt 响应对象
"""
robots_content = """User-agent: *
Disallow: /
# 禁止所有爬虫访问
"""
return PlainTextResponse(
content=robots_content,
media_type="text/plain",
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
)

View File

@ -1,301 +0,0 @@
"""
规划器监控API
提供规划器日志数据的查询接口
性能优化
1. 聊天摘要只统计文件数量和最新时间戳不读取文件内容
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/planner", tags=["planner"])
# 规划器日志目录
PLAN_LOG_DIR = Path("logs/plan")
class ChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
plan_count: int
latest_timestamp: float
latest_filename: str
class PlanLogSummary(BaseModel):
"""规划日志摘要"""
chat_id: str
timestamp: float
filename: str
action_count: int
action_types: List[str] # 动作类型列表
total_plan_ms: float
llm_duration_ms: float
reasoning_preview: str
class PlanLogDetail(BaseModel):
"""规划日志详情"""
type: str
chat_id: str
timestamp: float
prompt: str
reasoning: str
raw_output: str
actions: List[Dict]
timing: Dict
extra: Optional[Dict] = None
class PlannerOverview(BaseModel):
"""规划器总览 - 轻量级统计"""
total_chats: int
total_plans: int
chats: List[ChatSummary]
class PaginatedChatLogs(BaseModel):
"""分页的聊天日志列表"""
data: List[PlanLogSummary]
total: int
page: int
page_size: int
chat_id: str
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
return 0
@router.get("/overview", response_model=PlannerOverview)
async def get_planner_overview():
"""
获取规划器总览 - 轻量级接口
只统计文件数量不读取文件内容
"""
if not PLAN_LOG_DIR.exists():
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
chats = []
total_plans = 0
for chat_dir in PLAN_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
plan_count = len(json_files)
total_plans += plan_count
if plan_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ChatSummary(
chat_id=chat_dir.name,
plan_count=plan_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return PlannerOverview(
total_chats=len(chats),
total_plans=total_plans,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
async def get_chat_plan_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
):
"""
获取指定聊天的规划日志列表分页
需要读取文件内容获取摘要信息
支持搜索提示词内容
"""
chat_dir = PLAN_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedChatLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
prompt = data.get('prompt', '')
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
actions = data.get('actions', [])
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
logs.append(PlanLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
action_count=len(actions),
action_types=action_types,
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
reasoning_preview=reasoning[:100] if reasoning else ''
))
except Exception:
# 文件读取失败时使用文件名信息
logs.append(PlanLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
action_count=0,
action_types=[],
total_plan_ms=0,
llm_duration_ms=0,
reasoning_preview='[读取失败]'
))
return PaginatedChatLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
async def get_log_detail(chat_id: str, filename: str):
"""获取规划日志详情 - 按需加载完整内容"""
log_file = PLAN_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return PlanLogDetail(**data)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
# ========== 兼容旧接口 ==========
@router.get("/stats")
async def get_planner_stats():
"""获取规划器统计信息 - 兼容旧接口"""
overview = await get_planner_overview()
# 获取最近10条计划的摘要
recent_plans = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
try:
chat_logs = await get_chat_plan_logs(chat.chat_id, page=1, page_size=2)
recent_plans.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
recent_plans = recent_plans[:10]
return {
"total_chats": overview.total_chats,
"total_plans": overview.total_plans,
"avg_plan_time_ms": 0,
"avg_llm_time_ms": 0,
"recent_plans": recent_plans
}
@router.get("/chats")
async def get_chat_list():
"""获取所有聊天ID列表 - 兼容旧接口"""
overview = await get_planner_overview()
return [chat.chat_id for chat in overview.chats]
@router.get("/all-logs")
async def get_all_logs(
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100)
):
"""获取所有规划日志 - 兼容旧接口"""
if not PLAN_LOG_DIR.exists():
return {"data": [], "total": 0, "page": page, "page_size": page_size}
# 收集所有文件
all_files = []
for chat_dir in PLAN_LOG_DIR.iterdir():
if chat_dir.is_dir():
for log_file in chat_dir.glob("*.json"):
all_files.append((chat_dir.name, log_file))
# 按时间戳排序
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
total = len(all_files)
offset = (page - 1) * page_size
page_files = all_files[offset:offset + page_size]
logs = []
for chat_id, log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
reasoning = data.get('reasoning', '')
logs.append({
"chat_id": data.get('chat_id', chat_id),
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
"filename": log_file.name,
"action_count": len(data.get('actions', [])),
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
"reasoning_preview": reasoning[:100] if reasoning else ''
})
except Exception:
continue
return {"data": logs, "total": total, "page": page, "page_size": page_size}

View File

@ -1,269 +0,0 @@
"""
回复器监控API
提供回复器日志数据的查询接口
性能优化
1. 聊天摘要只统计文件数量和最新时间戳不读取文件内容
2. 日志列表使用文件名解析时间戳只在需要时读取完整内容
3. 详情按需加载
"""
import json
from pathlib import Path
from typing import List, Dict, Optional
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel
router = APIRouter(prefix="/api/replier", tags=["replier"])
# 回复器日志目录
REPLY_LOG_DIR = Path("logs/reply")
class ReplierChatSummary(BaseModel):
"""聊天摘要 - 轻量级,不读取文件内容"""
chat_id: str
reply_count: int
latest_timestamp: float
latest_filename: str
class ReplyLogSummary(BaseModel):
"""回复日志摘要"""
chat_id: str
timestamp: float
filename: str
model: str
success: bool
llm_ms: float
overall_ms: float
output_preview: str
class ReplyLogDetail(BaseModel):
"""回复日志详情"""
type: str
chat_id: str
timestamp: float
prompt: str
output: str
processed_output: List[str]
model: str
reasoning: str
think_level: int
timing: Dict
error: Optional[str] = None
success: bool
class ReplierOverview(BaseModel):
"""回复器总览 - 轻量级统计"""
total_chats: int
total_replies: int
chats: List[ReplierChatSummary]
class PaginatedReplyLogs(BaseModel):
"""分页的回复日志列表"""
data: List[ReplyLogSummary]
total: int
page: int
page_size: int
chat_id: str
def parse_timestamp_from_filename(filename: str) -> float:
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
try:
timestamp_str = filename.split('_')[0]
# 时间戳是毫秒级,需要转换为秒
return float(timestamp_str) / 1000
except (ValueError, IndexError):
return 0
@router.get("/overview", response_model=ReplierOverview)
async def get_replier_overview():
"""
获取回复器总览 - 轻量级接口
只统计文件数量不读取文件内容
"""
if not REPLY_LOG_DIR.exists():
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
chats = []
total_replies = 0
for chat_dir in REPLY_LOG_DIR.iterdir():
if not chat_dir.is_dir():
continue
# 只统计json文件数量
json_files = list(chat_dir.glob("*.json"))
reply_count = len(json_files)
total_replies += reply_count
if reply_count == 0:
continue
# 从文件名获取最新时间戳
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
chats.append(ReplierChatSummary(
chat_id=chat_dir.name,
reply_count=reply_count,
latest_timestamp=latest_timestamp,
latest_filename=latest_file.name
))
# 按最新时间戳排序
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
return ReplierOverview(
total_chats=len(chats),
total_replies=total_replies,
chats=chats
)
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
async def get_chat_reply_logs(
chat_id: str,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
):
"""
获取指定聊天的回复日志列表分页
需要读取文件内容获取摘要信息
支持搜索提示词内容
"""
chat_dir = REPLY_LOG_DIR / chat_id
if not chat_dir.exists():
return PaginatedReplyLogs(
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
)
# 先获取所有文件并按时间戳排序
json_files = list(chat_dir.glob("*.json"))
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
# 如果有搜索关键词,需要过滤文件
if search:
search_lower = search.lower()
filtered_files = []
for log_file in json_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
prompt = data.get('prompt', '')
if search_lower in prompt.lower():
filtered_files.append(log_file)
except Exception:
continue
json_files = filtered_files
total = len(json_files)
# 分页 - 只读取当前页的文件
offset = (page - 1) * page_size
page_files = json_files[offset:offset + page_size]
logs = []
for log_file in page_files:
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
output = data.get('output', '')
logs.append(ReplyLogSummary(
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
filename=log_file.name,
model=data.get('model', ''),
success=data.get('success', True),
llm_ms=data.get('timing', {}).get('llm_ms', 0),
overall_ms=data.get('timing', {}).get('overall_ms', 0),
output_preview=output[:100] if output else ''
))
except Exception:
# 文件读取失败时使用文件名信息
logs.append(ReplyLogSummary(
chat_id=chat_id,
timestamp=parse_timestamp_from_filename(log_file.name),
filename=log_file.name,
model='',
success=False,
llm_ms=0,
overall_ms=0,
output_preview='[读取失败]'
))
return PaginatedReplyLogs(
data=logs,
total=total,
page=page,
page_size=page_size,
chat_id=chat_id
)
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
async def get_reply_log_detail(chat_id: str, filename: str):
"""获取回复日志详情 - 按需加载完整内容"""
log_file = REPLY_LOG_DIR / chat_id / filename
if not log_file.exists():
raise HTTPException(status_code=404, detail="日志文件不存在")
try:
with open(log_file, 'r', encoding='utf-8') as f:
data = json.load(f)
return ReplyLogDetail(
type=data.get('type', 'reply'),
chat_id=data.get('chat_id', chat_id),
timestamp=data.get('timestamp', 0),
prompt=data.get('prompt', ''),
output=data.get('output', ''),
processed_output=data.get('processed_output', []),
model=data.get('model', ''),
reasoning=data.get('reasoning', ''),
think_level=data.get('think_level', 0),
timing=data.get('timing', {}),
error=data.get('error'),
success=data.get('success', True)
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
# ========== 兼容接口 ==========
@router.get("/stats")
async def get_replier_stats():
"""获取回复器统计信息"""
overview = await get_replier_overview()
# 获取最近10条回复的摘要
recent_replies = []
for chat in overview.chats[:5]: # 从最近5个聊天中获取
try:
chat_logs = await get_chat_reply_logs(chat.chat_id, page=1, page_size=2)
recent_replies.extend(chat_logs.data)
except Exception:
continue
# 按时间排序取前10
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
recent_replies = recent_replies[:10]
return {
"total_chats": overview.total_chats,
"total_replies": overview.total_replies,
"recent_replies": recent_replies
}
@router.get("/chats")
async def get_replier_chat_list():
"""获取所有聊天ID列表"""
overview = await get_replier_overview()
return [chat.chat_id for chat in overview.chats]

View File

@ -1,183 +0,0 @@
"""
WebUI 认证模块
提供统一的认证依赖支持 Cookie Header 两种方式
"""
from typing import Optional
from fastapi import HTTPException, Cookie, Header, Response, Request
from src.common.logger import get_logger
from src.config.config import global_config
from .token_manager import get_token_manager
logger = get_logger("webui.auth")
# Cookie 配置
COOKIE_NAME = "maibot_session"
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
def _is_secure_environment() -> bool:
"""
检测是否应该启用安全 CookieHTTPS
Returns:
bool: 如果应该使用 secure cookie 则返回 True
"""
# 从配置读取
if global_config.webui.secure_cookie:
logger.info("配置中启用了 secure_cookie")
return True
# 检查是否是生产环境
if global_config.webui.mode == "production":
logger.info("WebUI运行在生产模式启用 secure cookie")
return True
# 默认:开发环境不启用(因为通常是 HTTP
logger.debug("WebUI运行在开发模式禁用 secure cookie")
return False
def get_current_token(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> str:
"""
获取当前请求的 token优先从 Cookie 获取其次从 Header 获取
Args:
request: FastAPI Request 对象
maibot_session: Cookie 中的 token
authorization: Authorization Header (Bearer token)
Returns:
验证通过的 token
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return token
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
"""
设置认证 Cookie
Args:
response: FastAPI Response 对象
token: 要设置的 token
request: FastAPI Request 对象可选用于检测协议
"""
# 根据环境和实际请求协议决定安全设置
is_secure = _is_secure_environment()
# 如果提供了 request检测实际使用的协议
if request:
# 检查 X-Forwarded-Proto header代理/负载均衡器)
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
if forwarded_proto:
is_https = forwarded_proto == "https"
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
else:
# 检查 request.url.scheme
is_https = request.url.scheme == "https"
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
# 如果是 HTTP 连接,强制禁用 secure 标志
if not is_https and is_secure:
logger.warning("=" * 80)
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
logger.warning("=" * 80)
is_secure = False
# 设置 Cookie
response.set_cookie(
key=COOKIE_NAME,
value=token,
max_age=COOKIE_MAX_AGE,
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
secure=is_secure, # 根据实际协议决定
path="/", # 确保 Cookie 在所有路径下可用
)
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
logger.debug(f"完整 token 前缀: {token[:20]}...")
def clear_auth_cookie(response: Response) -> None:
"""
清除认证 Cookie
Args:
response: FastAPI Response 对象
"""
# 保持与 set_auth_cookie 相同的安全设置
is_secure = _is_secure_environment()
response.delete_cookie(
key=COOKIE_NAME,
httponly=True,
samesite="strict" if is_secure else "lax",
secure=is_secure,
path="/",
)
logger.debug("已清除认证 Cookie")
def verify_auth_token_from_cookie_or_header(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""
验证认证 Token支持从 Cookie Header 获取
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
验证成功返回 True
Raises:
HTTPException: 认证失败时抛出 401 错误
"""
token = None
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
# 其次从 Header 获取(兼容旧版本)
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
if not token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
# 验证 token
token_manager = get_token_manager()
if not token_manager.verify_token(token):
raise HTTPException(status_code=401, detail="Token 无效或已过期")
return True

View File

@ -1,782 +0,0 @@
"""本地聊天室路由 - WebUI 与麦麦直接对话
支持两种模式
1. WebUI 模式使用 WebUI 平台独立身份聊天
2. 虚拟身份模式使用真实平台用户的身份在虚拟群聊中与麦麦对话
"""
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.chat")
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# WebUI 聊天的虚拟群组 ID
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
WEBUI_CHAT_PLATFORM = "webui"
# 虚拟身份模式的群 ID 前缀
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
# 固定的 WebUI 用户 ID 前缀
WEBUI_USER_ID_PREFIX = "webui_user_"
class VirtualIdentityConfig(BaseModel):
"""虚拟身份配置"""
enabled: bool = False # 是否启用虚拟身份模式
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
person_id: Optional[str] = None # PersonInfo 的 person_id
user_id: Optional[str] = None # 原始平台用户 ID
user_nickname: Optional[str] = None # 用户昵称
group_id: Optional[str] = None # 虚拟群 ID自动生成或用户指定
group_name: Optional[str] = None # 虚拟群名(用户自定义)
class ChatHistoryMessage(BaseModel):
"""聊天历史消息"""
id: str
type: str # 'user' | 'bot' | 'system'
content: str
timestamp: float
sender_name: str
sender_id: Optional[str] = None
is_bot: bool = False
class ChatHistoryManager:
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
def __init__(self, max_messages: int = 200):
self.max_messages = max_messages
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
"""将数据库消息转换为前端格式
Args:
msg: 数据库消息对象
group_id: ID用于判断是否是虚拟群
"""
# 判断是否是机器人消息
user_id = msg.user_id or ""
# 对于虚拟群,通过比较机器人 QQ 账号来判断
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
# 虚拟群user_id 等于机器人 QQ 账号的是机器人消息
bot_qq = str(global_config.bot.qq_account)
is_bot = user_id == bot_qq
else:
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
return {
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"timestamp": msg.time,
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
}
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""从数据库获取最近的历史记录
Args:
limit: 获取的消息数量
group_id: ID默认为 WEBUI_CHAT_GROUP_ID
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
# 查询指定群的消息,按时间排序
messages = (
Messages.select()
.where(Messages.chat_info_group_id == target_group_id)
.order_by(Messages.time.desc())
.limit(limit)
)
# 转换为列表并反转(使最旧的消息在前)
# 传递 group_id 以便正确判断虚拟群中的机器人消息
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
result.reverse()
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
return result
except Exception as e:
logger.error(f"从数据库加载聊天记录失败: {e}")
return []
def clear_history(self, group_id: Optional[str] = None) -> int:
"""清空聊天历史记录
Args:
group_id: ID默认清空 WebUI 默认聊天室
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
logger.error(f"清空聊天记录失败: {e}")
return 0
# 全局聊天历史管理器
chat_history = ChatHistoryManager()
# 存储 WebSocket 连接
class ChatConnectionManager:
"""聊天连接管理器"""
def __init__(self):
self.active_connections: Dict[str, WebSocket] = {}
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
await websocket.accept()
self.active_connections[session_id] = websocket
self.user_sessions[user_id] = session_id
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
def disconnect(self, session_id: str, user_id: str):
if session_id in self.active_connections:
del self.active_connections[session_id]
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict):
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict):
"""广播消息给所有连接"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
chat_manager = ChatConnectionManager()
def create_message_data(
content: str,
user_id: str,
user_name: str,
message_id: Optional[str] = None,
is_at_bot: bool = True,
virtual_config: Optional[VirtualIdentityConfig] = None,
) -> Dict[str, Any]:
"""创建符合麦麦消息格式的消息数据
Args:
content: 消息内容
user_id: 用户 ID
user_name: 用户昵称
message_id: 消息 ID可选自动生成
is_at_bot: 是否 @ 机器人
virtual_config: 虚拟身份配置可选启用后使用真实平台身份
"""
if message_id is None:
message_id = str(uuid.uuid4())
# 确定使用的平台、群信息和用户信息
if virtual_config and virtual_config.enabled:
# 虚拟身份模式:使用真实平台身份
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
group_name = virtual_config.group_name or "WebUI虚拟群聊"
actual_user_id = virtual_config.user_id or user_id
actual_user_name = virtual_config.user_nickname or user_name
else:
# 标准 WebUI 模式
platform = WEBUI_CHAT_PLATFORM
group_id = WEBUI_CHAT_GROUP_ID
group_name = "WebUI本地聊天室"
actual_user_id = user_id
actual_user_name = user_name
return {
"message_info": {
"platform": platform,
"message_id": message_id,
"time": time.time(),
"group_info": {
"group_id": group_id,
"group_name": group_name,
"platform": platform,
},
"user_info": {
"user_id": actual_user_id,
"user_nickname": actual_user_name,
"user_cardname": actual_user_name,
"platform": platform,
},
"additional_config": {
"at_bot": is_at_bot,
},
},
"message_segment": {
"type": "seglist",
"data": [
{
"type": "text",
"data": content,
},
{
"type": "mention_bot",
"data": "1.0",
},
],
},
"raw_message": content,
"processed_plain_text": content,
}
@router.get("/history")
async def get_chat_history(
limit: int = Query(default=50, ge=1, le=200),
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
_auth: bool = Depends(require_auth),
):
"""获取聊天历史记录
所有 WebUI 用户共享同一个聊天室因此返回所有历史记录
如果指定了 group_id则获取该虚拟群的历史记录
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
history = chat_history.get_history(limit, target_group_id)
return {
"success": True,
"messages": history,
"total": len(history),
}
@router.get("/platforms")
async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""获取可用平台列表
PersonInfo 表中获取所有已知的平台
"""
try:
from peewee import fn
# 查询所有不同的平台
platforms = (
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
.group_by(PersonInfo.platform)
.order_by(fn.COUNT(PersonInfo.id).desc())
)
result = []
for p in platforms:
if p.platform: # 排除空平台
result.append({"platform": p.platform, "count": p.count})
return {"success": True, "platforms": result}
except Exception as e:
logger.error(f"获取平台列表失败: {e}")
return {"success": False, "error": str(e), "platforms": []}
@router.get("/persons")
async def get_persons_by_platform(
platform: str = Query(..., description="平台名称"),
search: Optional[str] = Query(default=None, description="搜索关键词"),
limit: int = Query(default=50, ge=1, le=200),
_auth: bool = Depends(require_auth),
):
"""获取指定平台的用户列表
Args:
platform: 平台名称 qq, discord
search: 搜索关键词匹配昵称用户名user_id
limit: 返回数量限制
"""
try:
# 构建查询
query = PersonInfo.select().where(PersonInfo.platform == platform)
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 按最后交互时间排序,优先显示活跃用户
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
query = query.limit(limit)
result = []
for person in query:
result.append(
{
"person_id": person.person_id,
"user_id": person.user_id,
"person_name": person.person_name,
"nickname": person.nickname,
"is_known": person.is_known,
"platform": person.platform,
"display_name": person.person_name or person.nickname or person.user_id,
}
)
return {"success": True, "persons": result, "total": len(result)}
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
return {"success": False, "error": str(e), "persons": []}
@router.delete("/history")
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
"""清空聊天历史记录
Args:
group_id: 可选指定要清空的群 ID默认清空 WebUI 默认聊天室
"""
deleted = chat_history.clear_history(group_id)
return {
"success": True,
"message": f"已清空 {deleted} 条聊天记录",
}
@router.websocket("/ws")
async def websocket_chat(
websocket: WebSocket,
user_id: Optional[str] = Query(default=None),
user_name: Optional[str] = Query(default="WebUI用户"),
platform: Optional[str] = Query(default=None),
person_id: Optional[str] = Query(default=None),
group_name: Optional[str] = Query(default=None),
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
token: Optional[str] = Query(default=None), # 认证 token
):
"""WebSocket 聊天端点
Args:
user_id: 用户唯一标识由前端生成并持久化
user_name: 用户显示昵称可修改
platform: 虚拟身份模式的平台可选
person_id: 虚拟身份模式的用户 person_id可选
group_name: 虚拟身份模式的群名可选
group_id: 虚拟身份模式的群 ID可选由前端生成并持久化
token: 认证 token可选也可从 Cookie 获取
虚拟身份模式可通过 URL 参数直接配置或通过消息中的 set_virtual_identity 配置
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/api/chat/ws?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("聊天 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
# 生成会话 ID每次连接都是新的
session_id = str(uuid.uuid4())
# 如果没有提供 user_id生成一个新的
if not user_id:
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
# 确保 user_id 有正确的前缀
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
# 当前会话的虚拟身份配置(可通过消息动态更新)
current_virtual_config: Optional[VirtualIdentityConfig] = None
# 如果 URL 参数中提供了虚拟身份信息,自动配置
if platform and person_id:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
# 使用前端传递的 group_id如果没有则生成一个稳定的
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
logger.info(
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
)
except Exception as e:
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
await chat_manager.connect(websocket, session_id, user_id)
try:
# 构建会话信息
session_info_data = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
"user_name": user_name,
"bot_name": global_config.bot.nickname,
}
# 如果有虚拟身份配置,添加到会话信息中
if current_virtual_config and current_virtual_config.enabled:
session_info_data["virtual_mode"] = True
session_info_data["group_id"] = current_virtual_config.group_id
session_info_data["virtual_identity"] = {
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_name": current_virtual_config.group_name,
}
# 发送会话信息(包含用户 ID前端需要保存
await chat_manager.send_message(session_id, session_info_data)
# 发送历史记录(根据模式选择不同的群)
if current_virtual_config and current_virtual_config.enabled:
history = chat_history.get_history(50, current_virtual_config.group_id)
else:
history = chat_history.get_history(50)
if history:
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": history,
},
)
# 发送欢迎消息(不保存到历史)
if current_virtual_config and current_virtual_config.enabled:
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
else:
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": welcome_msg,
"timestamp": time.time(),
},
)
while True:
data = await websocket.receive_json()
if data.get("type") == "message":
content = data.get("content", "").strip()
if not content:
continue
# 用户可以更新昵称
current_user_name = data.get("user_name", user_name)
message_id = str(uuid.uuid4())
timestamp = time.time()
# 确定发送者信息(根据是否使用虚拟身份)
if current_virtual_config and current_virtual_config.enabled:
sender_name = current_virtual_config.user_nickname or current_user_name
sender_user_id = current_virtual_config.user_id or user_id
else:
sender_name = current_user_name
sender_user_id = user_id
# 广播用户消息给所有连接(包括发送者)
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
await chat_manager.broadcast(
{
"type": "user_message",
"content": content,
"message_id": message_id,
"timestamp": timestamp,
"sender": {
"name": sender_name,
"user_id": sender_user_id,
"is_bot": False,
},
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
}
)
# 创建麦麦消息格式
message_data = create_message_data(
content=content,
user_id=user_id,
user_name=current_user_name,
message_id=message_id,
is_at_bot=True,
virtual_config=current_virtual_config,
)
try:
# 显示正在输入状态
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": True,
}
)
# 调用麦麦的消息处理
await chat_bot.message_process(message_data)
except Exception as e:
logger.error(f"处理消息时出错: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"处理消息时出错: {str(e)}",
"timestamp": time.time(),
},
)
finally:
await chat_manager.broadcast(
{
"type": "typing",
"is_typing": False,
}
)
elif data.get("type") == "ping":
await chat_manager.send_message(
session_id,
{
"type": "pong",
"timestamp": time.time(),
},
)
elif data.get("type") == "update_nickname":
# 允许用户更新昵称
if new_name := data.get("user_name", "").strip():
current_user_name = new_name
await chat_manager.send_message(
session_id,
{
"type": "nickname_updated",
"user_name": current_user_name,
"timestamp": time.time(),
},
)
elif data.get("type") == "set_virtual_identity":
# 设置或更新虚拟身份配置
virtual_data = data.get("config", {})
if virtual_data.get("enabled"):
# 验证必要字段
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
"timestamp": time.time(),
},
)
continue
# 获取用户信息
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
if not person:
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"找不到用户: {virtual_data.get('person_id')}",
"timestamp": time.time(),
},
)
continue
# 生成虚拟群 ID
custom_group_id = virtual_data.get("group_id")
if custom_group_id:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
else:
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
current_virtual_config = VirtualIdentityConfig(
enabled=True,
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
group_id=group_id,
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
)
# 发送虚拟身份已激活的消息
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {
"enabled": True,
"platform": current_virtual_config.platform,
"user_id": current_virtual_config.user_id,
"user_nickname": current_virtual_config.user_nickname,
"group_id": current_virtual_config.group_id,
"group_name": current_virtual_config.group_name,
},
"timestamp": time.time(),
},
)
# 加载虚拟群的历史记录
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": virtual_history,
"group_id": current_virtual_config.group_id,
},
)
# 发送系统消息
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
"timestamp": time.time(),
},
)
except Exception as e:
logger.error(f"设置虚拟身份失败: {e}")
await chat_manager.send_message(
session_id,
{
"type": "error",
"content": f"设置虚拟身份失败: {str(e)}",
"timestamp": time.time(),
},
)
else:
# 禁用虚拟身份模式
current_virtual_config = None
await chat_manager.send_message(
session_id,
{
"type": "virtual_identity_set",
"config": {"enabled": False},
"timestamp": time.time(),
},
)
# 重新加载默认聊天室历史
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
await chat_manager.send_message(
session_id,
{
"type": "history",
"messages": default_history,
"group_id": WEBUI_CHAT_GROUP_ID,
},
)
await chat_manager.send_message(
session_id,
{
"type": "system",
"content": "已切换回 WebUI 独立用户模式",
"timestamp": time.time(),
},
)
except WebSocketDisconnect:
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
except Exception as e:
logger.error(f"WebSocket 错误: {e}")
finally:
chat_manager.disconnect(session_id, user_id)
@router.get("/info")
async def get_chat_info(_auth: bool = Depends(require_auth)):
"""获取聊天室信息"""
return {
"bot_name": global_config.bot.nickname,
"platform": WEBUI_CHAT_PLATFORM,
"group_id": WEBUI_CHAT_GROUP_ID,
"active_sessions": len(chat_manager.active_connections),
}
def get_webui_chat_broadcaster() -> tuple:
"""获取 WebUI 聊天广播器,供外部模块使用
Returns:
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
"""
return (chat_manager, WEBUI_CHAT_PLATFORM)

View File

@ -1,597 +0,0 @@
"""
配置管理API路由
"""
import os
import tomlkit
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
from typing import Any, Annotated, Optional
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
from src.config.official_configs import (
BotConfig,
PersonalityConfig,
RelationshipConfig,
ChatConfig,
MessageReceiveConfig,
EmojiConfig,
ExpressionConfig,
KeywordReactionConfig,
ChineseTypoConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
TelemetryConfig,
ExperimentalConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
ToolConfig,
MemoryConfig,
DebugConfig,
VoiceConfig,
)
from src.config.api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
)
from src.webui.config_schema import ConfigSchemaGenerator
logger = get_logger("webui")
# 模块级别的类型别名(解决 B008 ruff 错误)
ConfigBody = Annotated[dict[str, Any], Body()]
SectionBody = Annotated[Any, Body()]
RawContentBody = Annotated[str, Body(embed=True)]
PathBody = Annotated[dict[str, str], Body()]
router = APIRouter(prefix="/config", tags=["config"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# ===== 架构获取接口 =====
@router.get("/schema/bot")
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置架构"""
try:
# Config 类包含所有子配置
schema = ConfigSchemaGenerator.generate_config_schema(Config)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
@router.get("/schema/model")
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
"""获取模型配置架构(包含提供商和模型任务配置)"""
try:
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取模型配置架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
# ===== 子配置架构获取接口 =====
@router.get("/schema/section/{section_name}")
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
"""
获取指定配置节的架构
支持的section_name:
- bot: BotConfig
- personality: PersonalityConfig
- relationship: RelationshipConfig
- chat: ChatConfig
- message_receive: MessageReceiveConfig
- emoji: EmojiConfig
- expression: ExpressionConfig
- keyword_reaction: KeywordReactionConfig
- chinese_typo: ChineseTypoConfig
- response_post_process: ResponsePostProcessConfig
- response_splitter: ResponseSplitterConfig
- telemetry: TelemetryConfig
- experimental: ExperimentalConfig
- maim_message: MaimMessageConfig
- lpmm_knowledge: LPMMKnowledgeConfig
- tool: ToolConfig
- memory: MemoryConfig
- debug: DebugConfig
- voice: VoiceConfig
- jargon: JargonConfig
- model_task_config: ModelTaskConfig
- api_provider: APIProvider
- model_info: ModelInfo
"""
section_map = {
"bot": BotConfig,
"personality": PersonalityConfig,
"relationship": RelationshipConfig,
"chat": ChatConfig,
"message_receive": MessageReceiveConfig,
"emoji": EmojiConfig,
"expression": ExpressionConfig,
"keyword_reaction": KeywordReactionConfig,
"chinese_typo": ChineseTypoConfig,
"response_post_process": ResponsePostProcessConfig,
"response_splitter": ResponseSplitterConfig,
"telemetry": TelemetryConfig,
"experimental": ExperimentalConfig,
"maim_message": MaimMessageConfig,
"lpmm_knowledge": LPMMKnowledgeConfig,
"tool": ToolConfig,
"memory": MemoryConfig,
"debug": DebugConfig,
"voice": VoiceConfig,
"model_task_config": ModelTaskConfig,
"api_provider": APIProvider,
"model_info": ModelInfo,
}
if section_name not in section_map:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
try:
config_class = section_map[section_name]
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
return {"success": True, "schema": schema}
except Exception as e:
logger.error(f"获取配置节架构失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
# ===== 配置读取接口 =====
@router.get("/bot")
async def get_bot_config(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.get("/model")
async def get_model_config(_auth: bool = Depends(require_auth)):
"""获取模型配置(包含提供商和模型任务配置)"""
try:
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
return {"success": True, "config": config_data}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
# ===== 配置更新接口 =====
@router.post("/bot")
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置"""
try:
# 验证配置数据
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
save_toml_with_format(config_data, config_path)
logger.info("麦麦主程序配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model")
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
"""更新模型配置"""
try:
# 验证配置数据
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件(自动保留注释和格式)
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
save_toml_with_format(config_data, config_path)
logger.info("模型配置已更新")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
# ===== 配置节更新接口 =====
@router.post("/bot/section/{section_name}")
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组类型(如 platforms, aliases直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 原始 TOML 文件操作接口 =====
@router.get("/bot/raw")
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
"""获取麦麦主程序配置的原始 TOML 内容"""
try:
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
raw_content = f.read()
return {"success": True, "content": raw_content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
@router.post("/bot/raw")
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
try:
# 验证 TOML 格式
try:
config_data = tomlkit.loads(raw_content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 验证配置数据结构
try:
Config.from_dict(config_data)
except Exception as e:
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置文件
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
with open(config_path, "w", encoding="utf-8") as f:
f.write(raw_content)
logger.info("麦麦主程序配置已更新(原始模式)")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
@router.post("/model/section/{section_name}")
async def update_model_config_section(
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
):
"""更新模型配置的指定节(保留注释和格式)"""
try:
# 读取现有配置
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 更新指定节
if section_name not in config_data:
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
# 使用递归合并保留注释(对于字典类型)
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
if isinstance(section_data, list):
# 列表直接替换
config_data[section_name] = section_data
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
# 字典递归合并
_update_toml_doc(config_data[section_name], section_data)
else:
# 其他类型直接替换
config_data[section_name] = section_data
# 验证完整配置
try:
APIAdapterConfig.from_dict(config_data)
except Exception as e:
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
# 特殊处理:如果是更新 api_providers检查是否有模型引用了已删除的provider
if section_name == "api_providers" and "api_provider" in str(e):
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
models = config_data.get("models", [])
orphaned_models = [
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
]
if orphaned_models:
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
raise HTTPException(status_code=400, detail=error_msg) from e
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
# 保存配置(格式化数组为多行,保留注释)
save_toml_with_format(config_data, config_path)
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"更新配置节失败: {e}")
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
# ===== 适配器配置管理接口 =====
def _normalize_adapter_path(path: str) -> str:
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
if not path:
return path
# 如果已经是绝对路径,直接返回
if os.path.isabs(path):
return path
# 相对路径,转换为相对于项目根目录的绝对路径
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
def _to_relative_path(path: str) -> str:
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
if not path or not os.path.isabs(path):
return path
try:
# 尝试获取相对路径
rel_path = os.path.relpath(path, PROJECT_ROOT)
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
if not rel_path.startswith(".."):
return rel_path
except (ValueError, TypeError):
# 在 Windows 上如果路径在不同驱动器relpath 会抛出 ValueError
pass
# 无法转换为相对路径,返回绝对路径
return path
@router.get("/adapter-config/path")
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
"""获取保存的适配器配置文件路径"""
try:
# 从 data/webui.json 读取路径偏好
webui_data_path = os.path.join("data", "webui.json")
if not os.path.exists(webui_data_path):
return {"success": True, "path": None}
import json
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
adapter_config_path = webui_data.get("adapter_config_path")
if not adapter_config_path:
return {"success": True, "path": None}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(adapter_config_path)
# 检查文件是否存在并返回最后修改时间
if os.path.exists(abs_path):
import datetime
mtime = os.path.getmtime(abs_path)
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
# 返回相对路径(如果可能)
display_path = _to_relative_path(abs_path)
return {"success": True, "path": display_path, "lastModified": last_modified}
else:
# 文件不存在,返回原路径
return {"success": True, "path": adapter_config_path, "lastModified": None}
except Exception as e:
logger.error(f"获取适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
@router.post("/adapter-config/path")
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置文件路径偏好"""
try:
path = data.get("path")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
# 保存到 data/webui.json
webui_data_path = os.path.join("data", "webui.json")
import json
# 读取现有数据
if os.path.exists(webui_data_path):
with open(webui_data_path, "r", encoding="utf-8") as f:
webui_data = json.load(f)
else:
webui_data = {}
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 尝试转换为相对路径保存(如果文件在项目目录内)
save_path = _to_relative_path(abs_path)
# 更新路径
webui_data["adapter_config_path"] = save_path
# 保存
os.makedirs("data", exist_ok=True)
with open(webui_data_path, "w", encoding="utf-8") as f:
json.dump(webui_data, f, ensure_ascii=False, indent=2)
logger.info(f"适配器配置路径已保存: {save_path}(绝对路径: {abs_path}")
return {"success": True, "message": "路径已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置路径失败: {e}")
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
@router.get("/adapter-config")
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
"""从指定路径读取适配器配置文件"""
try:
if not path:
raise HTTPException(status_code=400, detail="路径参数不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件是否存在
if not os.path.exists(abs_path):
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 读取文件内容
with open(abs_path, "r", encoding="utf-8") as f:
content = f.read()
logger.info(f"已读取适配器配置: {path} (绝对路径: {abs_path})")
return {"success": True, "content": content}
except HTTPException:
raise
except Exception as e:
logger.error(f"读取适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
@router.post("/adapter-config")
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
"""保存适配器配置到指定路径"""
try:
path = data.get("path")
content = data.get("content")
if not path:
raise HTTPException(status_code=400, detail="路径不能为空")
if content is None:
raise HTTPException(status_code=400, detail="配置内容不能为空")
# 将路径规范化为绝对路径
abs_path = _normalize_adapter_path(path)
# 检查文件扩展名
if not abs_path.endswith(".toml"):
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
# 验证 TOML 格式
try:
tomlkit.loads(content)
except Exception as e:
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
# 确保目录存在
dir_path = os.path.dirname(abs_path)
if dir_path:
os.makedirs(dir_path, exist_ok=True)
# 保存文件
with open(abs_path, "w", encoding="utf-8") as f:
f.write(content)
logger.info(f"适配器配置已保存: {path} (绝对路径: {abs_path})")
return {"success": True, "message": "配置已保存"}
except HTTPException:
raise
except Exception as e:
logger.error(f"保存适配器配置失败: {e}")
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e

View File

@ -1,335 +0,0 @@
"""
配置架构生成器 - 自动从配置类生成前端表单架构
"""
import inspect
from dataclasses import fields, MISSING
from typing import Any, get_origin, get_args, Literal, Optional
from enum import Enum
from src.config.config_base import ConfigBase
class FieldType(str, Enum):
"""字段类型枚举"""
STRING = "string"
NUMBER = "number"
INTEGER = "integer"
BOOLEAN = "boolean"
SELECT = "select"
ARRAY = "array"
OBJECT = "object"
TEXTAREA = "textarea"
class FieldSchema:
"""字段架构"""
def __init__(
self,
name: str,
type: FieldType,
label: str,
description: str = "",
default: Any = None,
required: bool = True,
options: Optional[list[str]] = None,
min_value: Optional[float] = None,
max_value: Optional[float] = None,
items: Optional[dict] = None,
properties: Optional[dict] = None,
):
self.name = name
self.type = type
self.label = label
self.description = description
self.default = default
self.required = required
self.options = options
self.min_value = min_value
self.max_value = max_value
self.items = items
self.properties = properties
def to_dict(self) -> dict:
"""转换为字典"""
result = {
"name": self.name,
"type": self.type.value,
"label": self.label,
"description": self.description,
"required": self.required,
}
if self.default is not None:
result["default"] = self.default
if self.options is not None:
result["options"] = self.options
if self.min_value is not None:
result["minValue"] = self.min_value
if self.max_value is not None:
result["maxValue"] = self.max_value
if self.items is not None:
result["items"] = self.items
if self.properties is not None:
result["properties"] = self.properties
return result
class ConfigSchemaGenerator:
"""配置架构生成器"""
@staticmethod
def _extract_field_description(config_class: type, field_name: str) -> str:
"""
从类定义中提取字段的文档字符串描述
Args:
config_class: 配置类
field_name: 字段名
Returns:
str: 字段描述
"""
try:
# 获取源代码
source = inspect.getsource(config_class)
lines = source.split("\n")
# 查找字段定义
field_found = False
description_lines = []
for i, line in enumerate(lines):
# 匹配字段定义行,例如: platform: str
if f"{field_name}:" in line and "=" in line:
field_found = True
# 查找下一行的文档字符串
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
# 单行文档字符串
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else:
# 多行文档字符串
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
for j in range(i + 2, len(lines)):
if quote in lines[j]:
description_lines.append(lines[j].split(quote)[0].strip())
break
description_lines.append(lines[j].strip())
break
elif f"{field_name}:" in line and "=" not in line:
# 没有默认值的字段
field_found = True
if i + 1 < len(lines):
next_line = lines[i + 1].strip()
if next_line.startswith('"""') or next_line.startswith("'''"):
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
else:
quote = '"""' if next_line.startswith('"""') else "'''"
description_lines.append(next_line.strip(quote).strip())
for j in range(i + 2, len(lines)):
if quote in lines[j]:
description_lines.append(lines[j].split(quote)[0].strip())
break
description_lines.append(lines[j].strip())
break
if field_found and description_lines:
return " ".join(description_lines)
except Exception:
pass
return ""
@staticmethod
def _get_field_type_and_options(field_type: type) -> tuple[FieldType, Optional[list[str]], Optional[dict]]:
"""
获取字段类型和选项
Args:
field_type: 字段类型
Returns:
tuple: (FieldType, options, items)
"""
origin = get_origin(field_type)
args = get_args(field_type)
# 处理 Literal 类型(枚举选项)
if origin is Literal:
return FieldType.SELECT, [str(arg) for arg in args], None
# 处理 list 类型
if origin is list:
item_type = args[0] if args else str
if item_type is str:
items = {"type": "string"}
elif item_type is int:
items = {"type": "integer"}
elif item_type is float:
items = {"type": "number"}
elif item_type is bool:
items = {"type": "boolean"}
elif item_type is dict:
items = {"type": "object"}
else:
items = {"type": "string"}
return FieldType.ARRAY, None, items
# 处理 set 类型(与 list 类似)
if origin is set:
item_type = args[0] if args else str
if item_type is str:
items = {"type": "string"}
else:
items = {"type": "string"}
return FieldType.ARRAY, None, items
# 处理基本类型
if field_type is bool:
return FieldType.BOOLEAN, None, None
elif field_type is int:
return FieldType.INTEGER, None, None
elif field_type is float:
return FieldType.NUMBER, None, None
elif field_type is str:
return FieldType.STRING, None, None
elif field_type is dict or origin is dict:
return FieldType.OBJECT, None, None
# 默认为字符串
return FieldType.STRING, None, None
@staticmethod
def _format_field_name(name: str) -> str:
"""
格式化字段名为可读的标签
Args:
name: 原始字段名
Returns:
str: 格式化后的标签
"""
# 将下划线替换为空格,并首字母大写
return " ".join(word.capitalize() for word in name.split("_"))
@staticmethod
def generate_schema(config_class: type[ConfigBase], include_nested: bool = True) -> dict:
"""
从配置类生成前端表单架构
Args:
config_class: 配置类必须继承自 ConfigBase
include_nested: 是否包含嵌套的配置对象
Returns:
dict: 前端表单架构
"""
if not issubclass(config_class, ConfigBase):
raise ValueError(f"{config_class.__name__} 必须继承自 ConfigBase")
schema_fields = []
nested_schemas = {}
for field in fields(config_class):
# 跳过私有字段和内部字段
if field.name.startswith("_") or field.name in ["MMC_VERSION"]:
continue
# 提取字段描述
description = ConfigSchemaGenerator._extract_field_description(config_class, field.name)
# 判断是否必填
required = field.default is MISSING and field.default_factory is MISSING
# 获取默认值
default_value = None
if field.default is not MISSING:
default_value = field.default
elif field.default_factory is not MISSING:
try:
default_value = field.default_factory()
except Exception:
default_value = None
# 检查是否为嵌套的 ConfigBase
if isinstance(field.type, type) and issubclass(field.type, ConfigBase):
if include_nested:
# 递归生成嵌套配置的架构
nested_schema = ConfigSchemaGenerator.generate_schema(field.type, include_nested=True)
nested_schemas[field.name] = nested_schema
field_schema = FieldSchema(
name=field.name,
type=FieldType.OBJECT,
label=ConfigSchemaGenerator._format_field_name(field.name),
description=description or field.type.__doc__ or "",
default=default_value,
required=required,
properties=nested_schema,
)
else:
continue
else:
# 获取字段类型和选项
field_type, options, items = ConfigSchemaGenerator._get_field_type_and_options(field.type)
# 特殊处理:长文本使用 textarea
if field_type == FieldType.STRING and field.name in [
"personality",
"reply_style",
"interest",
"plan_style",
"visual_style",
"private_plan_style",
"reaction",
"filtration_prompt",
]:
field_type = FieldType.TEXTAREA
field_schema = FieldSchema(
name=field.name,
type=field_type,
label=ConfigSchemaGenerator._format_field_name(field.name),
description=description,
default=default_value,
required=required,
options=options,
items=items,
)
schema_fields.append(field_schema.to_dict())
return {
"className": config_class.__name__,
"classDoc": config_class.__doc__ or "",
"fields": schema_fields,
"nested": nested_schemas if nested_schemas else None,
}
@staticmethod
def generate_config_schema(config_class: type[ConfigBase]) -> dict:
"""
生成完整的配置架构包含所有嵌套的子配置
Args:
config_class: 配置类
Returns:
dict: 完整的配置架构
"""
return ConfigSchemaGenerator.generate_schema(config_class, include_nested=True)

File diff suppressed because it is too large Load Diff

View File

@ -1,790 +0,0 @@
"""表达方式管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from .auth import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
# 创建路由器
router = APIRouter(prefix="/expression", tags=["Expression"])
class ExpressionResponse(BaseModel):
"""表达方式响应"""
id: int
situation: str
style: str
last_active_time: float
chat_id: str
create_date: Optional[float]
checked: bool
rejected: bool
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
class ExpressionListResponse(BaseModel):
"""表达方式列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
class ExpressionDetailResponse(BaseModel):
"""表达方式详情响应"""
success: bool
data: ExpressionResponse
class ExpressionCreateRequest(BaseModel):
"""表达方式创建请求"""
situation: str
style: str
chat_id: str
class ExpressionUpdateRequest(BaseModel):
"""表达方式更新请求"""
situation: Optional[str] = None
style: Optional[str] = None
chat_id: Optional[str] = None
checked: Optional[bool] = None
rejected: Optional[bool] = None
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
class ExpressionUpdateResponse(BaseModel):
"""表达方式更新响应"""
success: bool
message: str
data: Optional[ExpressionResponse] = None
class ExpressionDeleteResponse(BaseModel):
"""表达方式删除响应"""
success: bool
message: str
class ExpressionCreateResponse(BaseModel):
"""表达方式创建响应"""
success: bool
message: str
data: ExpressionResponse
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
return ExpressionResponse(
id=expression.id,
situation=expression.situation,
style=expression.style,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by,
)
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
# 优先使用群聊名称,否则使用用户昵称
if chat_stream.group_name:
return chat_stream.group_name
elif chat_stream.user_nickname:
return chat_stream.user_nickname
return chat_id # 找不到时返回原始ID
except Exception:
return chat_id
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
for cs in chat_streams:
if cs.group_name:
result[cs.stream_id] = cs.group_name
elif cs.user_nickname:
result[cs.stream_id] = cs.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
class ChatInfo(BaseModel):
"""聊天信息"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool
data: List[ChatInfo]
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取所有聊天列表用于下拉选择
Args:
authorization: Authorization header
Returns:
聊天列表
"""
try:
verify_auth_token(maibot_session, authorization)
chat_list = []
for cs in ChatStreams.select():
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
chat_list.append(
ChatInfo(
chat_id=cs.stream_id,
chat_name=chat_name,
platform=cs.platform,
is_group=bool(cs.group_id),
)
)
# 按名称排序
chat_list.sort(key=lambda x: x.chat_name)
return ChatListResponse(success=True, data=chat_list)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/list", response_model=ExpressionListResponse)
async def get_expression_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取表达方式列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 situation, style)
chat_id: 聊天ID筛选
authorization: Authorization header
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Expression.select()
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search))
| (Expression.style.contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
async def get_expression_detail(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式详细信息
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
表达方式详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取表达方式详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
@router.post("/", response_model=ExpressionCreateResponse)
async def create_expression(
request: ExpressionCreateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
创建新的表达方式
Args:
request: 创建请求
authorization: Authorization header
Returns:
创建结果
"""
try:
verify_auth_token(maibot_session, authorization)
current_time = time.time()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
return ExpressionCreateResponse(
success=True, message="表达方式创建成功", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"创建表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
async def update_expression(
expression_id: int,
request: ExpressionUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新表达方式只更新提供的字段
Args:
expression_id: 表达方式ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and expression.checked:
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表"
)
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
# 移除 require_unchecked它不是数据库字段
update_data.pop('require_unchecked', None)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 如果更新了 checked 或 rejected标记为用户修改
if 'checked' in update_data or 'rejected' in update_data:
update_data['modified_by'] = 'user'
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
return ExpressionUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
async def delete_expression(
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除表达方式
Args:
expression_id: 表达方式ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 记录删除信息
situation = expression.situation
# 执行删除
expression.delete_instance()
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int]
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
async def batch_delete_expressions(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除表达方式
Args:
request: 包含要删除的ID列表的请求
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.ids:
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
if not_found_ids:
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
logger.info(f"批量删除了 {deleted_count} 个表达方式")
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除表达方式失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_expression_stats(
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取表达方式统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
return {
"success": True,
"data": {
"total": total,
"recent_7days": recent,
"chat_count": len(chat_stats),
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
},
}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
# ============ 审核相关接口 ============
class ReviewStatsResponse(BaseModel):
"""审核统计响应"""
total: int
unchecked: int
passed: int
rejected: int
ai_checked: int
user_checked: int
@router.get("/review/stats", response_model=ReviewStatsResponse)
async def get_review_stats(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None)
):
"""
获取审核统计数据
Returns:
审核统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
unchecked = Expression.select().where(Expression.checked == False).count()
passed = Expression.select().where(
(Expression.checked == True) & (Expression.rejected == False)
).count()
rejected = Expression.select().where(
(Expression.checked == True) & (Expression.rejected == True)
).count()
ai_checked = Expression.select().where(Expression.modified_by == 'ai').count()
user_checked = Expression.select().where(Expression.modified_by == 'user').count()
return ReviewStatsResponse(
total=total,
unchecked=unchecked,
passed=passed,
rejected=rejected,
ai_checked=ai_checked,
user_checked=user_checked
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
class ReviewListResponse(BaseModel):
"""审核列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[ExpressionResponse]
@router.get("/review/list", response_model=ReviewListResponse)
async def get_review_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取待审核/已审核的表达方式列表
Args:
page: 页码
page_size: 每页数量
filter_type: 筛选类型 (unchecked/passed/rejected/all)
search: 搜索关键词
chat_id: 聊天ID筛选
Returns:
表达方式列表
"""
try:
verify_auth_token(maibot_session, authorization)
query = Expression.select()
# 根据筛选类型过滤
if filter_type == "unchecked":
query = query.where(Expression.checked == False)
elif filter_type == "passed":
query = query.where((Expression.checked == True) & (Expression.rejected == False))
elif filter_type == "rejected":
query = query.where((Expression.checked == True) & (Expression.rejected == True))
# all 不需要额外过滤
# 搜索过滤
if search:
query = query.where(
(Expression.situation.contains(search)) | (Expression.style.contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
# 排序:创建时间倒序
from peewee import Case
query = query.order_by(
Case(None, [(Expression.create_date.is_null(), 1)], 0),
Expression.create_date.desc()
)
total = query.count()
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
return ReviewListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=[expression_to_response(expr) for expr in expressions]
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取审核列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
class BatchReviewItem(BaseModel):
"""批量审核项"""
id: int
rejected: bool
require_unchecked: bool = True # 默认要求未检查状态
class BatchReviewRequest(BaseModel):
"""批量审核请求"""
items: List[BatchReviewItem]
class BatchReviewResultItem(BaseModel):
"""批量审核结果项"""
id: int
success: bool
message: str
class BatchReviewResponse(BaseModel):
"""批量审核响应"""
success: bool
total: int
succeeded: int
failed: int
results: List[BatchReviewResultItem]
@router.post("/review/batch", response_model=BatchReviewResponse)
async def batch_review_expressions(
request: BatchReviewRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量审核表达方式
Args:
request: 批量审核请求
Returns:
批量审核结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.items:
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
results = []
succeeded = 0
failed = 0
for item in request.items:
try:
expression = Expression.get_or_none(Expression.id == item.id)
if not expression:
results.append(BatchReviewResultItem(
id=item.id,
success=False,
message=f"未找到 ID 为 {item.id} 的表达方式"
))
failed += 1
continue
# 冲突检测
if item.require_unchecked and expression.checked:
results.append(BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查"
))
failed += 1
continue
# 更新状态
expression.checked = True
expression.rejected = item.rejected
expression.modified_by = 'user'
expression.last_active_time = time.time()
expression.save()
results.append(BatchReviewResultItem(
id=item.id,
success=True,
message="通过" if not item.rejected else "拒绝"
))
succeeded += 1
except Exception as e:
results.append(BatchReviewResultItem(
id=item.id,
success=False,
message=str(e)
))
failed += 1
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
return BatchReviewResponse(
success=True,
total=len(request.items),
succeeded=succeeded,
failed=failed,
results=results
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量审核失败: {e}")
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e

View File

@ -1,662 +0,0 @@
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
from typing import Optional, List, Dict, Any
from enum import Enum
import httpx
import json
import asyncio
import subprocess
import shutil
from pathlib import Path
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("webui.git_mirror")
# 导入进度更新函数(避免循环导入)
_update_progress = None
def set_update_progress_callback(callback):
"""设置进度更新回调函数"""
global _update_progress
_update_progress = callback
class MirrorType(str, Enum):
"""镜像源类型"""
GH_PROXY = "gh-proxy" # gh-proxy 主节点
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
GITHUB = "github" # GitHub 官方源(兜底)
CUSTOM = "custom" # 自定义镜像源
class GitMirrorConfig:
"""Git 镜像源配置管理"""
# 配置文件路径
CONFIG_FILE = Path("data/webui.json")
# 默认镜像源配置
DEFAULT_MIRRORS = [
{
"id": "gh-proxy",
"name": "gh-proxy 镜像",
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://gh-proxy.org/https://github.com",
"enabled": True,
"priority": 1,
"created_at": None,
},
{
"id": "hk-gh-proxy",
"name": "gh-proxy 香港节点",
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 2,
"created_at": None,
},
{
"id": "cdn-gh-proxy",
"name": "gh-proxy CDN 节点",
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 3,
"created_at": None,
},
{
"id": "edgeone-gh-proxy",
"name": "gh-proxy EdgeOne 节点",
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
"enabled": True,
"priority": 4,
"created_at": None,
},
{
"id": "meyzh-github",
"name": "Meyzh GitHub 镜像",
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
"clone_prefix": "https://meyzh.github.io/https://github.com",
"enabled": True,
"priority": 5,
"created_at": None,
},
{
"id": "github",
"name": "GitHub 官方源(兜底)",
"raw_prefix": "https://raw.githubusercontent.com",
"clone_prefix": "https://github.com",
"enabled": True,
"priority": 999,
"created_at": None,
},
]
def __init__(self):
"""初始化配置管理器"""
self.config_file = self.CONFIG_FILE
self.mirrors: List[Dict[str, Any]] = []
self._load_config()
def _load_config(self) -> None:
"""加载配置文件"""
try:
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
data = json.load(f)
# 检查是否有镜像源配置
if "git_mirrors" not in data or not data["git_mirrors"]:
logger.info("配置文件中未找到镜像源配置,使用默认配置")
self._init_default_mirrors()
else:
self.mirrors = data["git_mirrors"]
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
else:
logger.info("配置文件不存在,创建默认配置")
self._init_default_mirrors()
except Exception as e:
logger.error(f"加载配置文件失败: {e}")
self._init_default_mirrors()
def _init_default_mirrors(self) -> None:
"""初始化默认镜像源"""
current_time = datetime.now().isoformat()
self.mirrors = []
for mirror in self.DEFAULT_MIRRORS:
mirror_copy = mirror.copy()
mirror_copy["created_at"] = current_time
self.mirrors.append(mirror_copy)
self._save_config()
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
def _save_config(self) -> None:
"""保存配置到文件"""
try:
# 确保目录存在
self.config_file.parent.mkdir(parents=True, exist_ok=True)
# 读取现有配置
existing_data = {}
if self.config_file.exists():
with open(self.config_file, "r", encoding="utf-8") as f:
existing_data = json.load(f)
# 更新镜像源配置
existing_data["git_mirrors"] = self.mirrors
# 写入文件
with open(self.config_file, "w", encoding="utf-8") as f:
json.dump(existing_data, f, indent=2, ensure_ascii=False)
logger.debug(f"配置已保存到 {self.config_file}")
except Exception as e:
logger.error(f"保存配置文件失败: {e}")
def get_all_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有镜像源"""
return self.mirrors.copy()
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
"""获取所有启用的镜像源,按优先级排序"""
enabled = [m for m in self.mirrors if m.get("enabled", False)]
return sorted(enabled, key=lambda x: x.get("priority", 999))
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
"""根据 ID 获取镜像源"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
return mirror.copy()
return None
def add_mirror(
self,
mirror_id: str,
name: str,
raw_prefix: str,
clone_prefix: str,
enabled: bool = True,
priority: Optional[int] = None,
) -> Dict[str, Any]:
"""
添加新的镜像源
Returns:
添加的镜像源配置
Raises:
ValueError: 如果镜像源 ID 已存在
"""
# 检查 ID 是否已存在
if self.get_mirror_by_id(mirror_id):
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
# 如果未指定优先级,使用最大优先级 + 1
if priority is None:
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
priority = max_priority + 1
new_mirror = {
"id": mirror_id,
"name": name,
"raw_prefix": raw_prefix,
"clone_prefix": clone_prefix,
"enabled": enabled,
"priority": priority,
"created_at": datetime.now().isoformat(),
}
self.mirrors.append(new_mirror)
self._save_config()
logger.info(f"已添加镜像源: {mirror_id} - {name}")
return new_mirror.copy()
def update_mirror(
self,
mirror_id: str,
name: Optional[str] = None,
raw_prefix: Optional[str] = None,
clone_prefix: Optional[str] = None,
enabled: Optional[bool] = None,
priority: Optional[int] = None,
) -> Optional[Dict[str, Any]]:
"""
更新镜像源配置
Returns:
更新后的镜像源配置如果不存在则返回 None
"""
for mirror in self.mirrors:
if mirror.get("id") == mirror_id:
if name is not None:
mirror["name"] = name
if raw_prefix is not None:
mirror["raw_prefix"] = raw_prefix
if clone_prefix is not None:
mirror["clone_prefix"] = clone_prefix
if enabled is not None:
mirror["enabled"] = enabled
if priority is not None:
mirror["priority"] = priority
mirror["updated_at"] = datetime.now().isoformat()
self._save_config()
logger.info(f"已更新镜像源: {mirror_id}")
return mirror.copy()
return None
def delete_mirror(self, mirror_id: str) -> bool:
"""
删除镜像源
Returns:
True 如果删除成功False 如果镜像源不存在
"""
for i, mirror in enumerate(self.mirrors):
if mirror.get("id") == mirror_id:
self.mirrors.pop(i)
self._save_config()
logger.info(f"已删除镜像源: {mirror_id}")
return True
return False
def get_default_priority_list(self) -> List[str]:
"""获取默认优先级列表(仅启用的镜像源 ID"""
enabled = self.get_enabled_mirrors()
return [m["id"] for m in enabled]
class GitMirrorService:
"""Git 镜像源服务"""
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
"""
初始化 Git 镜像源服务
Args:
max_retries: 最大重试次数
timeout: 请求超时时间
config: 镜像源配置管理器可选默认创建新实例
"""
self.max_retries = max_retries
self.timeout = timeout
self.config = config or GitMirrorConfig()
logger.info(f"Git镜像源服务初始化完成已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
def get_mirror_config(self) -> GitMirrorConfig:
"""获取镜像源配置管理器"""
return self.config
@staticmethod
def check_git_installed() -> Dict[str, Any]:
"""
检查本机是否安装了 Git
Returns:
Dict 包含:
- installed: bool - 是否已安装 Git
- version: str - Git 版本号如果已安装
- path: str - Git 可执行文件路径如果已安装
- error: str - 错误信息如果未安装或检测失败
"""
import subprocess
import shutil
try:
# 查找 git 可执行文件路径
git_path = shutil.which("git")
if not git_path:
logger.warning("未找到 Git 可执行文件")
return {"installed": False, "error": "系统中未找到 Git请先安装 Git"}
# 获取 Git 版本
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
if result.returncode == 0:
version = result.stdout.strip()
logger.info(f"检测到 Git: {version} at {git_path}")
return {"installed": True, "version": version, "path": git_path}
else:
logger.warning(f"Git 命令执行失败: {result.stderr}")
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
except subprocess.TimeoutExpired:
logger.error("Git 版本检测超时")
return {"installed": False, "error": "Git 版本检测超时"}
except Exception as e:
logger.error(f"检测 Git 时发生错误: {e}")
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
async def fetch_raw_file(
self,
owner: str,
repo: str,
branch: str,
file_path: str,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
) -> Dict[str, Any]:
"""
获取 GitHub 仓库的 Raw 文件内容
Args:
owner: 仓库所有者
repo: 仓库名称
branch: 分支名称
file_path: 文件路径
mirror_id: 指定的镜像源 ID
custom_url: 自定义完整 URL如果提供将忽略其他参数
Returns:
Dict 包含:
- success: bool - 是否成功
- data: str - 文件内容成功时
- error: str - 错误信息失败时
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
if custom_url:
# 使用自定义 URL
return await self._fetch_with_url(custom_url, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
total_mirrors = len(mirrors_to_try)
# 依次尝试每个镜像源
for index, mirror in enumerate(mirrors_to_try, 1):
# 推送进度:正在尝试第 N 个镜像源
if _update_progress:
try:
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
await _update_progress(
stage="loading",
progress=progress,
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
if result["success"]:
# 成功,推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=70,
message=f"成功从 {mirror['name']} 获取数据",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
return result
# 失败,记录日志并推送失败信息
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
if _update_progress and index < total_mirrors:
try:
await _update_progress(
stage="loading",
progress=30 + int(index / total_mirrors * 40),
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
total_plugins=0,
loaded_plugins=0,
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _fetch_raw_from_mirror(
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
) -> Dict[str, Any]:
"""从指定镜像源获取文件"""
# 构建 URL
raw_prefix = mirror["raw_prefix"]
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
return await self._fetch_with_url(url, mirror["id"])
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
"""使用指定 URL 获取文件,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
logger.debug(f"尝试 #{attempt + 1}: {url}")
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.get(url)
response.raise_for_status()
logger.info(f"成功获取文件: {url}")
return {
"success": True,
"data": response.text,
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
}
except httpx.HTTPStatusError as e:
last_error = f"HTTP {e.response.status_code}: {e}"
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except httpx.TimeoutException as e:
last_error = f"请求超时: {e}"
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
async def clone_repository(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str] = None,
mirror_id: Optional[str] = None,
custom_url: Optional[str] = None,
depth: Optional[int] = None,
) -> Dict[str, Any]:
"""
克隆 GitHub 仓库
Args:
owner: 仓库所有者
repo: 仓库名称
target_path: 目标路径
branch: 分支名称可选
mirror_id: 指定的镜像源 ID
custom_url: 自定义克隆 URL
depth: 克隆深度浅克隆
Returns:
Dict 包含:
- success: bool - 是否成功
- path: str - 克隆路径成功时
- error: str - 错误信息失败时
- mirror_used: str - 使用的镜像源
- attempts: int - 尝试次数
"""
logger.info(f"开始克隆仓库: {owner}/{repo}{target_path}")
if custom_url:
# 使用自定义 URL
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
# 确定要使用的镜像源列表
if mirror_id:
# 使用指定的镜像源
mirror = self.config.get_mirror_by_id(mirror_id)
if not mirror:
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
mirrors_to_try = [mirror]
else:
# 使用所有启用的镜像源
mirrors_to_try = self.config.get_enabled_mirrors()
# 依次尝试每个镜像源
for mirror in mirrors_to_try:
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
if result["success"]:
return result
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
# 所有镜像源都失败
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
async def _clone_from_mirror(
self,
owner: str,
repo: str,
target_path: Path,
branch: Optional[str],
depth: Optional[int],
mirror: Dict[str, Any],
) -> Dict[str, Any]:
"""从指定镜像源克隆仓库"""
# 构建克隆 URL
clone_prefix = mirror["clone_prefix"]
url = f"{clone_prefix}/{owner}/{repo}.git"
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
async def _clone_with_url(
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
) -> Dict[str, Any]:
"""使用指定 URL 克隆仓库,支持重试"""
attempts = 0
last_error = None
for attempt in range(self.max_retries):
attempts += 1
try:
# 确保目标路径不存在
if target_path.exists():
logger.warning(f"目标路径已存在,删除: {target_path}")
shutil.rmtree(target_path, ignore_errors=True)
# 构建 git clone 命令
cmd = ["git", "clone"]
# 添加分支参数
if branch:
cmd.extend(["-b", branch])
# 添加深度参数(浅克隆)
if depth:
cmd.extend(["--depth", str(depth)])
# 添加 URL 和目标路径
cmd.extend([url, str(target_path)])
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
# 推送进度
if _update_progress:
try:
await _update_progress(
stage="loading",
progress=20 + attempt * 10,
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
operation="install",
)
except Exception as e:
logger.warning(f"推送进度失败: {e}")
# 执行 git clone在线程池中运行以避免阻塞
loop = asyncio.get_event_loop()
def run_git_clone(clone_cmd=cmd):
return subprocess.run(
clone_cmd,
capture_output=True,
text=True,
timeout=300, # 5分钟超时
)
process = await loop.run_in_executor(None, run_git_clone)
if process.returncode == 0:
logger.info(f"成功克隆仓库: {url} -> {target_path}")
return {
"success": True,
"path": str(target_path),
"mirror_used": mirror_type,
"attempts": attempts,
"url": url,
"branch": branch or "default",
}
else:
last_error = f"Git 克隆失败: {process.stderr}"
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
except subprocess.TimeoutExpired:
last_error = "克隆超时(超过 5 分钟)"
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
except FileNotFoundError:
last_error = "Git 未安装或不在 PATH 中"
logger.error(f"Git 未找到: {last_error}")
break # Git 不存在,不需要重试
except Exception as e:
last_error = f"未知错误: {e}"
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
# 清理可能的部分克隆
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
# 全局服务实例
_git_mirror_service: Optional[GitMirrorService] = None
def get_git_mirror_service() -> GitMirrorService:
"""获取 Git 镜像源服务实例(单例)"""
global _git_mirror_service
if _git_mirror_service is None:
_git_mirror_service = GitMirrorService()
return _git_mirror_service

View File

@ -1,532 +0,0 @@
"""黑话(俚语)管理路由"""
import json
from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams
logger = get_logger("webui.jargon")
router = APIRouter(prefix="/jargon", tags=["Jargon"])
# ==================== 辅助函数 ====================
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
"""
解析 chat_id 字段提取所有 stream_id
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
"""
if not chat_id_str:
return []
try:
# 尝试解析为 JSON
parsed = json.loads(chat_id_str)
if isinstance(parsed, list):
# 格式: [["stream_id", user_id], ...]
stream_ids = []
for item in parsed:
if isinstance(item, list) and len(item) >= 1:
stream_ids.append(str(item[0]))
return stream_ids
else:
# 其他格式,返回原始字符串
return [chat_id_str]
except (json.JSONDecodeError, TypeError):
# 不是有效的 JSON可能是直接的 stream_id
return [chat_id_str]
def get_display_name_for_chat_id(chat_id_str: str) -> str:
"""
获取 chat_id 的显示名称
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
"""
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
if not stream_ids:
return chat_id_str
# 查询所有 stream_id 对应的名称
names = []
for stream_id in stream_ids:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream and chat_stream.group_name:
names.append(chat_stream.group_name)
else:
# 如果没找到,显示截断的 stream_id
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
return ", ".join(names) if names else chat_id_str
# ==================== 请求/响应模型 ====================
class JargonResponse(BaseModel):
"""黑话信息响应"""
id: int
content: str
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: str
stream_id: Optional[str] = None # 解析后的 stream_id用于前端编辑时匹配
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
is_global: bool = False
count: int = 0
is_jargon: Optional[bool] = None
is_complete: bool = False
inference_with_context: Optional[str] = None
inference_content_only: Optional[str] = None
class JargonListResponse(BaseModel):
"""黑话列表响应"""
success: bool = True
total: int
page: int
page_size: int
data: List[JargonResponse]
class JargonDetailResponse(BaseModel):
"""黑话详情响应"""
success: bool = True
data: JargonResponse
class JargonCreateRequest(BaseModel):
"""黑话创建请求"""
content: str = Field(..., description="黑话内容")
raw_content: Optional[str] = Field(None, description="原始内容")
meaning: Optional[str] = Field(None, description="含义")
chat_id: str = Field(..., description="聊天ID")
is_global: bool = Field(False, description="是否全局")
class JargonUpdateRequest(BaseModel):
"""黑话更新请求"""
content: Optional[str] = None
raw_content: Optional[str] = None
meaning: Optional[str] = None
chat_id: Optional[str] = None
is_global: Optional[bool] = None
is_jargon: Optional[bool] = None
class JargonCreateResponse(BaseModel):
"""黑话创建响应"""
success: bool = True
message: str
data: JargonResponse
class JargonUpdateResponse(BaseModel):
"""黑话更新响应"""
success: bool = True
message: str
data: Optional[JargonResponse] = None
class JargonDeleteResponse(BaseModel):
"""黑话删除响应"""
success: bool = True
message: str
deleted_count: int = 0
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
ids: List[int] = Field(..., description="要删除的黑话ID列表")
class JargonStatsResponse(BaseModel):
"""黑话统计响应"""
success: bool = True
data: dict
class ChatInfoResponse(BaseModel):
"""聊天信息响应"""
chat_id: str
chat_name: str
platform: Optional[str] = None
is_group: bool = False
class ChatListResponse(BaseModel):
"""聊天列表响应"""
success: bool = True
data: List[ChatInfoResponse]
# ==================== 工具函数 ====================
def jargon_to_dict(jargon: Jargon) -> dict:
"""将 Jargon ORM 对象转换为字典"""
# 解析 chat_id 获取显示名称和 stream_id
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
stream_id = stream_ids[0] if stream_ids else None
return {
"id": jargon.id,
"content": jargon.content,
"raw_content": jargon.raw_content,
"meaning": jargon.meaning,
"chat_id": jargon.chat_id,
"stream_id": stream_id,
"chat_name": chat_name,
"is_global": jargon.is_global,
"count": jargon.count,
"is_jargon": jargon.is_jargon,
"is_complete": jargon.is_complete,
"inference_with_context": jargon.inference_with_context,
"inference_content_only": jargon.inference_content_only,
}
# ==================== API 端点 ====================
@router.get("/list", response_model=JargonListResponse)
async def get_jargon_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
):
"""获取黑话列表"""
try:
# 构建查询
query = Jargon.select()
# 搜索过滤
if search:
query = query.where(
(Jargon.content.contains(search))
| (Jargon.meaning.contains(search))
| (Jargon.raw_content.contains(search))
)
# 按聊天ID筛选使用 contains 匹配,因为 chat_id 是 JSON 格式)
if chat_id:
# 从传入的 chat_id 中解析出 stream_id
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
# 使用第一个 stream_id 进行模糊匹配
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
else:
# 如果无法解析,使用精确匹配
query = query.where(Jargon.chat_id == chat_id)
# 按是否是黑话筛选
if is_jargon is not None:
query = query.where(Jargon.is_jargon == is_jargon)
# 按是否全局筛选
if is_global is not None:
query = query.where(Jargon.is_global == is_global)
# 获取总数
total = query.count()
# 分页和排序(按使用次数降序)
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
query = query.paginate(page, page_size)
# 转换为响应格式
data = [jargon_to_dict(j) for j in query]
return JargonListResponse(
success=True,
total=total,
page=page,
page_size=page_size,
data=data,
)
except Exception as e:
logger.error(f"获取黑话列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
@router.get("/chats", response_model=ChatListResponse)
async def get_chat_list():
"""获取所有有黑话记录的聊天列表"""
try:
# 获取所有不同的 chat_id
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
# 用于按 stream_id 去重
seen_stream_ids: set[str] = set()
for chat_id in chat_id_list:
stream_ids = parse_chat_id_to_stream_ids(chat_id)
if stream_ids:
seen_stream_ids.add(stream_ids[0])
result = []
for stream_id in seen_stream_ids:
# 尝试从 ChatStreams 表获取聊天名称
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
if chat_stream:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id方便筛选匹配
chat_name=chat_stream.group_name or stream_id,
platform=chat_stream.platform,
is_group=True,
)
)
else:
result.append(
ChatInfoResponse(
chat_id=stream_id, # 使用 stream_id
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
platform=None,
is_group=False,
)
)
return ChatListResponse(success=True, data=result)
except Exception as e:
logger.error(f"获取聊天列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
@router.get("/stats/summary", response_model=JargonStatsResponse)
async def get_jargon_stats():
"""获取黑话统计数据"""
try:
# 总数量
total = Jargon.select().count()
# 已确认是黑话的数量
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
# 已确认不是黑话的数量
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
# 未判定的数量
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
# 全局黑话数量
global_count = Jargon.select().where(Jargon.is_global).count()
# 已完成推断的数量
complete_count = Jargon.select().where(Jargon.is_complete).count()
# 关联的聊天数量
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
# 按聊天统计 TOP 5
top_chats = (
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
.group_by(Jargon.chat_id)
.order_by(fn.COUNT(Jargon.id).desc())
.limit(5)
)
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
return JargonStatsResponse(
success=True,
data={
"total": total,
"confirmed_jargon": confirmed_jargon,
"confirmed_not_jargon": confirmed_not_jargon,
"pending": pending,
"global_count": global_count,
"complete_count": complete_count,
"chat_count": chat_count,
"top_chats": top_chats_dict,
},
)
except Exception as e:
logger.error(f"获取黑话统计失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
async def get_jargon_detail(jargon_id: int):
"""获取黑话详情"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
except HTTPException:
raise
except Exception as e:
logger.error(f"获取黑话详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
@router.post("/", response_model=JargonCreateResponse)
async def create_jargon(request: JargonCreateRequest):
"""创建黑话"""
try:
# 检查是否已存在相同内容的黑话
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
if existing:
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
# 创建黑话
jargon = Jargon.create(
content=request.content,
raw_content=request.raw_content,
meaning=request.meaning,
chat_id=request.chat_id,
is_global=request.is_global,
count=0,
is_jargon=None,
is_complete=False,
)
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
return JargonCreateResponse(
success=True,
message="创建成功",
data=jargon_to_dict(jargon),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"创建黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
"""更新黑话(增量更新)"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
# 增量更新字段
update_data = request.model_dump(exclude_unset=True)
if update_data:
for field, value in update_data.items():
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
setattr(jargon, field, value)
jargon.save()
logger.info(f"更新黑话成功: id={jargon_id}")
return JargonUpdateResponse(
success=True,
message="更新成功",
data=jargon_to_dict(jargon),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"更新黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
async def delete_jargon(jargon_id: int):
"""删除黑话"""
try:
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
if not jargon:
raise HTTPException(status_code=404, detail="黑话不存在")
content = jargon.content
jargon.delete_instance()
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
return JargonDeleteResponse(
success=True,
message="删除成功",
deleted_count=1,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"删除黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
@router.post("/batch/delete", response_model=JargonDeleteResponse)
async def batch_delete_jargons(request: BatchDeleteRequest):
"""批量删除黑话"""
try:
if not request.ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
return JargonDeleteResponse(
success=True,
message=f"成功删除 {deleted_count} 条黑话",
deleted_count=deleted_count,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量删除黑话失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
async def batch_set_jargon_status(
ids: Annotated[List[int], Query(description="黑话ID列表")],
is_jargon: Annotated[bool, Query(description="是否是黑话")],
):
"""批量设置黑话状态"""
try:
if not ids:
raise HTTPException(status_code=400, detail="ID列表不能为空")
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录is_jargon={is_jargon}")
return JargonUpdateResponse(
success=True,
message=f"成功更新 {updated_count} 条黑话状态",
)
except HTTPException:
raise
except Exception as e:
logger.error(f"批量更新黑话状态失败: {e}")
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e

View File

@ -1,298 +0,0 @@
"""知识库图谱可视化 API 路由"""
from typing import List, Optional
from fastapi import APIRouter, Query, Depends, Cookie, Header
from pydantic import BaseModel
import logging
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class KnowledgeNode(BaseModel):
"""知识节点"""
id: str
type: str # 'entity' or 'paragraph'
content: str
create_time: Optional[float] = None
class KnowledgeEdge(BaseModel):
"""知识边"""
source: str
target: str
weight: float
create_time: Optional[float] = None
update_time: Optional[float] = None
class KnowledgeGraph(BaseModel):
"""知识图谱"""
nodes: List[KnowledgeNode]
edges: List[KnowledgeEdge]
class KnowledgeStats(BaseModel):
"""知识库统计信息"""
total_nodes: int
total_edges: int
entity_nodes: int
paragraph_nodes: int
avg_connections: float
def _load_kg_manager():
"""延迟加载 KGManager"""
try:
from src.chat.knowledge.kg_manager import KGManager
kg_manager = KGManager()
kg_manager.load_from_file()
return kg_manager
except Exception as e:
logger.error(f"加载 KGManager 失败: {e}")
return None
def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
"""将 DiGraph 转换为 JSON 格式"""
if kg_manager is None or kg_manager.graph is None:
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
nodes = []
edges = []
# 转换节点
node_list = graph.get_node_list()
for node_id in node_list:
try:
node_data = graph[node_id]
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 转换边
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
try:
# edge_tuple 是 (source, target) 元组
source, target = edge_tuple[0], edge_tuple[1]
# 通过 graph[source, target] 获取边的属性数据
edge_data = graph[source, target]
# edge_data 支持 [] 操作符但不支持 .get()
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
return KnowledgeGraph(nodes=nodes, edges=edges)
@router.get("/graph", response_model=KnowledgeGraph)
async def get_knowledge_graph(
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
_auth: bool = Depends(require_auth),
):
"""获取知识图谱(限制节点数量)
Args:
limit: 返回的最大节点数,默认 100,最大 10000
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
Returns:
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None:
logger.warning("KGManager 未初始化,返回空图谱")
return KnowledgeGraph(nodes=[], edges=[])
graph = kg_manager.graph
all_node_list = graph.get_node_list()
# 按类型过滤节点
if node_type == "entity":
all_node_list = [
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
]
elif node_type == "paragraph":
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
# 限制节点数量
total_nodes = len(all_node_list)
if len(all_node_list) > limit:
node_list = all_node_list[:limit]
else:
node_list = all_node_list
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
# 转换节点
nodes = []
node_ids = set()
for node_id in node_list:
try:
node_data = graph[node_id]
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
content = node_data["content"] if "content" in node_data else node_id
create_time = node_data["create_time"] if "create_time" in node_data else None
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
node_ids.add(node_id)
except Exception as e:
logger.warning(f"跳过节点 {node_id}: {e}")
continue
# 只获取涉及当前节点集的边(保证图的完整性)
edges = []
edge_list = graph.get_edge_list()
for edge_tuple in edge_list:
try:
source, target = edge_tuple[0], edge_tuple[1]
# 只包含两端都在当前节点集中的边
if source not in node_ids or target not in node_ids:
continue
edge_data = graph[source, target]
weight = edge_data["weight"] if "weight" in edge_data else 1.0
create_time = edge_data["create_time"] if "create_time" in edge_data else None
update_time = edge_data["update_time"] if "update_time" in edge_data else None
edges.append(
KnowledgeEdge(
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
)
)
except Exception as e:
logger.warning(f"跳过边 {edge_tuple}: {e}")
continue
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
return graph_data
except Exception as e:
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
return KnowledgeGraph(nodes=[], edges=[])
@router.get("/stats", response_model=KnowledgeStats)
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
"""获取知识库统计信息
Returns:
KnowledgeStats: 统计信息
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
graph = kg_manager.graph
node_list = graph.get_node_list()
edge_list = graph.get_edge_list()
total_nodes = len(node_list)
total_edges = len(edge_list)
# 统计节点类型
entity_nodes = 0
paragraph_nodes = 0
for node_id in node_list:
try:
node_data = graph[node_id]
node_type = node_data["type"] if "type" in node_data else "ent"
if node_type == "ent":
entity_nodes += 1
elif node_type == "pg":
paragraph_nodes += 1
except Exception:
continue
# 计算平均连接数
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
return KnowledgeStats(
total_nodes=total_nodes,
total_edges=total_edges,
entity_nodes=entity_nodes,
paragraph_nodes=paragraph_nodes,
avg_connections=round(avg_connections, 2),
)
except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True)
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
@router.get("/search", response_model=List[KnowledgeNode])
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
"""搜索知识节点
Args:
query: 搜索关键词
Returns:
List[KnowledgeNode]: 匹配的节点列表
"""
try:
kg_manager = _load_kg_manager()
if kg_manager is None or kg_manager.graph is None:
return []
graph = kg_manager.graph
node_list = graph.get_node_list()
results = []
query_lower = query.lower()
# 在节点内容中搜索
for node_id in node_list:
try:
node_data = graph[node_id]
content = node_data["content"] if "content" in node_data else node_id
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
if query_lower in content.lower() or query_lower in node_id.lower():
create_time = node_data["create_time"] if "create_time" in node_data else None
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
except Exception:
continue
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
return results[:50] # 限制返回数量
except Exception as e:
logger.error(f"搜索节点失败: {e}", exc_info=True)
return []

View File

@ -1,177 +0,0 @@
"""WebSocket 日志推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Optional
import json
from pathlib import Path
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.logs_ws")
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
def load_recent_logs(limit: int = 100) -> list[dict]:
"""从日志文件中加载最近的日志
Args:
limit: 返回的最大日志条数
Returns:
日志列表
"""
logs = []
log_dir = Path("logs")
if not log_dir.exists():
return logs
# 获取所有日志文件,按修改时间排序
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
# 用于生成唯一 ID 的计数器
log_counter = 0
# 从最新的文件开始读取
for log_file in log_files:
if len(logs) >= limit:
break
try:
with open(log_file, "r", encoding="utf-8") as f:
lines = f.readlines()
# 从文件末尾开始读取
for line in reversed(lines):
if len(logs) >= limit:
break
try:
log_entry = json.loads(line.strip())
# 转换为前端期望的格式
# 使用时间戳 + 计数器生成唯一 ID
timestamp_id = (
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
)
formatted_log = {
"id": f"{timestamp_id}_{log_counter}",
"timestamp": log_entry.get("timestamp", ""),
"level": log_entry.get("level", "INFO").upper(),
"module": log_entry.get("logger_name", ""),
"message": log_entry.get("event", ""),
}
logs.append(formatted_log)
log_counter += 1
except (json.JSONDecodeError, KeyError):
continue
except Exception as e:
logger.error(f"读取日志文件失败 {log_file}: {e}")
continue
# 反转列表,使其按时间顺序排列(旧到新)
return list(reversed(logs))
@router.websocket("/ws/logs")
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 日志推送端点
客户端连接后会持续接收服务器端的日志消息
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/logs?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
# 连接建立后,立即发送历史日志
try:
recent_logs = load_recent_logs(limit=100)
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
for log_entry in recent_logs:
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
except Exception as e:
logger.error(f"发送历史日志失败: {e}")
try:
# 保持连接,等待客户端消息或断开
while True:
# 接收客户端消息(用于心跳或控制指令)
data = await websocket.receive_text()
# 可以处理客户端的控制消息,例如:
# - "ping" -> 心跳检测
# - {"filter": "ERROR"} -> 设置日志级别过滤
if data == "ping":
await websocket.send_text("pong")
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
async def broadcast_log(log_data: dict):
"""广播日志到所有连接的 WebSocket 客户端
Args:
log_data: 日志数据字典
"""
if not active_connections:
return
# 格式化为 JSON
message = json.dumps(log_data, ensure_ascii=False)
# 记录需要断开的连接
disconnected = set()
# 广播到所有客户端
for connection in active_connections:
try:
await connection.send_text(message)
except Exception:
# 发送失败,标记为断开
disconnected.add(connection)
# 清理断开的连接
if disconnected:
active_connections.difference_update(disconnected)
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")

View File

@ -1,383 +0,0 @@
"""
模型列表获取API路由
提供从各个 AI 厂商 API 获取可用模型列表的代理接口
"""
import os
import httpx
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
from typing import Optional
import tomlkit
from src.common.logger import get_logger
from src.config.config import CONFIG_DIR
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui")
router = APIRouter(prefix="/models", tags=["models"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
# 模型获取器配置
MODEL_FETCHER_CONFIG = {
# OpenAI 兼容格式的提供商
"openai": {
"endpoint": "/models",
"parser": "openai",
},
# Gemini 格式
"gemini": {
"endpoint": "/models",
"parser": "gemini",
},
}
def _normalize_url(url: str) -> str:
"""规范化 URL去掉尾部斜杠"""
if not url:
return ""
return url.rstrip("/")
def _parse_openai_response(data: dict) -> list[dict]:
"""
解析 OpenAI 格式的模型列表响应
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
"""
models = []
if "data" in data and isinstance(data["data"], list):
for model in data["data"]:
if isinstance(model, dict) and "id" in model:
models.append(
{
"id": model["id"],
"name": model.get("name") or model["id"],
"owned_by": model.get("owned_by", ""),
}
)
return models
def _parse_gemini_response(data: dict) -> list[dict]:
"""
解析 Gemini 格式的模型列表响应
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
"""
models = []
if "models" in data and isinstance(data["models"], list):
for model in data["models"]:
if isinstance(model, dict) and "name" in model:
# Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分
model_id = model["name"]
if model_id.startswith("models/"):
model_id = model_id[7:] # 去掉 "models/" 前缀
models.append(
{
"id": model_id,
"name": model.get("displayName") or model_id,
"owned_by": "google",
}
)
return models
async def _fetch_models_from_provider(
base_url: str,
api_key: str,
endpoint: str,
parser: str,
client_type: str = "openai",
) -> list[dict]:
"""
从提供商 API 获取模型列表
Args:
base_url: 提供商的基础 URL
api_key: API 密钥
endpoint: 获取模型列表的端点
parser: 响应解析器类型 ('openai' | 'gemini')
client_type: 客户端类型 ('openai' | 'gemini')
Returns:
模型列表
"""
url = f"{_normalize_url(base_url)}{endpoint}"
# 根据客户端类型设置请求头
headers = {}
params = {}
if client_type == "gemini":
# Gemini 使用 URL 参数传递 API Key
params["key"] = api_key
else:
# OpenAI 兼容格式使用 Authorization 头
headers["Authorization"] = f"Bearer {api_key}"
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(url, headers=headers, params=params)
response.raise_for_status()
data = response.json()
except httpx.TimeoutException as e:
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
except httpx.HTTPStatusError as e:
# 注意:使用 502 Bad Gateway 而不是原始的 401/403
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
if e.response.status_code == 401:
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
elif e.response.status_code == 403:
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
elif e.response.status_code == 404:
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
else:
raise HTTPException(
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
) from e
except Exception as e:
logger.error(f"获取模型列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
# 根据解析器类型解析响应
if parser == "openai":
return _parse_openai_response(data)
elif parser == "gemini":
return _parse_gemini_response(data)
else:
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
def _get_provider_config(provider_name: str) -> Optional[dict]:
"""
model_config.toml 获取指定提供商的配置
Args:
provider_name: 提供商名称
Returns:
提供商配置如果未找到则返回 None
"""
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config_data = tomlkit.load(f)
providers = config_data.get("api_providers", [])
for provider in providers:
if provider.get("name") == provider_name:
return dict(provider)
return None
except Exception as e:
logger.error(f"读取提供商配置失败: {e}")
return None
@router.get("/list")
async def get_provider_models(
provider_name: str = Query(..., description="提供商名称"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
_auth: bool = Depends(require_auth),
):
"""
获取指定提供商的可用模型列表
通过提供商名称查找配置然后请求对应的模型列表端点
"""
# 获取提供商配置
provider_config = _get_provider_config(provider_name)
if not provider_config:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider_config.get("base_url")
api_key = provider_config.get("api_key")
client_type = provider_config.get("client_type", "openai")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
if not api_key:
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
# 获取模型列表
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"provider": provider_name,
"count": len(models),
}
@router.get("/list-by-url")
async def get_models_by_url(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: str = Query(..., description="API Key"),
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
endpoint: str = Query("/models", description="获取模型列表的端点"),
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
_auth: bool = Depends(require_auth),
):
"""
通过 URL 直接获取模型列表用于自定义提供商
"""
models = await _fetch_models_from_provider(
base_url=base_url,
api_key=api_key,
endpoint=endpoint,
parser=parser,
client_type=client_type,
)
return {
"success": True,
"models": models,
"count": len(models),
}
@router.get("/test-connection")
async def test_provider_connection(
base_url: str = Query(..., description="提供商的基础 URL"),
api_key: Optional[str] = Query(None, description="API Key可选用于验证 Key 有效性)"),
_auth: bool = Depends(require_auth),
):
"""
测试提供商连接状态
分两步测试
1. 网络连通性测试 base_url 发送请求检查是否能连接
2. API Key 验证可选如果提供了 api_key尝试获取模型列表验证 Key 是否有效
返回
- network_ok: 网络是否连通
- api_key_valid: API Key 是否有效仅在提供 api_key 时返回
- latency_ms: 响应延迟毫秒
- error: 错误信息如果有
"""
import time
base_url = _normalize_url(base_url)
if not base_url:
raise HTTPException(status_code=400, detail="base_url 不能为空")
result = {
"network_ok": False,
"api_key_valid": None,
"latency_ms": None,
"error": None,
"http_status": None,
}
# 第一步:测试网络连通性
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
# 尝试 GET 请求 base_url不需要 API Key
response = await client.get(base_url)
latency = (time.time() - start_time) * 1000
result["network_ok"] = True
result["latency_ms"] = round(latency, 2)
result["http_status"] = response.status_code
except httpx.ConnectError as e:
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
return result
except httpx.TimeoutException:
result["error"] = "连接超时:服务器响应时间过长"
return result
except httpx.RequestError as e:
result["error"] = f"请求错误:{str(e)}"
return result
except Exception as e:
result["error"] = f"未知错误:{str(e)}"
return result
# 第二步:如果提供了 API Key验证其有效性
if api_key:
try:
start_time = time.time()
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}
# 尝试获取模型列表
models_url = f"{base_url}/models"
response = await client.get(models_url, headers=headers)
if response.status_code == 200:
result["api_key_valid"] = True
elif response.status_code in (401, 403):
result["api_key_valid"] = False
result["error"] = "API Key 无效或已过期"
else:
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
result["api_key_valid"] = None
except Exception as e:
# API Key 验证失败不影响网络连通性结果
logger.warning(f"API Key 验证失败: {e}")
result["api_key_valid"] = None
return result
@router.post("/test-connection-by-name")
async def test_provider_connection_by_name(
provider_name: str = Query(..., description="提供商名称"),
_auth: bool = Depends(require_auth),
):
"""
通过提供商名称测试连接从配置文件读取信息
"""
# 读取配置文件
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
if not os.path.exists(model_config_path):
raise HTTPException(status_code=404, detail="配置文件不存在")
with open(model_config_path, "r", encoding="utf-8") as f:
config = tomlkit.load(f)
# 查找提供商
providers = config.get("api_providers", [])
provider = None
for p in providers:
if p.get("name") == provider_name:
provider = p
break
if not provider:
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
base_url = provider.get("base_url", "")
api_key = provider.get("api_key", "")
if not base_url:
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
# 调用测试接口
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)

View File

@ -1,416 +0,0 @@
"""人物信息管理 API 路由"""
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from .auth import verify_auth_token_from_cookie_or_header
import json
import time
logger = get_logger("webui.person")
# 创建路由器
router = APIRouter(prefix="/person", tags=["Person"])
class PersonInfoResponse(BaseModel):
"""人物信息响应"""
id: int
is_known: bool
person_id: str
person_name: Optional[str]
name_reason: Optional[str]
platform: str
user_id: str
nickname: Optional[str]
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
memory_points: Optional[str]
know_times: Optional[float]
know_since: Optional[float]
last_know: Optional[float]
class PersonListResponse(BaseModel):
"""人物列表响应"""
success: bool
total: int
page: int
page_size: int
data: List[PersonInfoResponse]
class PersonDetailResponse(BaseModel):
"""人物详情响应"""
success: bool
data: PersonInfoResponse
class PersonUpdateRequest(BaseModel):
"""人物信息更新请求"""
person_name: Optional[str] = None
name_reason: Optional[str] = None
nickname: Optional[str] = None
memory_points: Optional[str] = None
is_known: Optional[bool] = None
class PersonUpdateResponse(BaseModel):
"""人物信息更新响应"""
success: bool
message: str
data: Optional[PersonInfoResponse] = None
class PersonDeleteResponse(BaseModel):
"""人物删除响应"""
success: bool
message: str
class BatchDeleteRequest(BaseModel):
"""批量删除请求"""
person_ids: List[str]
class BatchDeleteResponse(BaseModel):
"""批量删除响应"""
success: bool
message: str
deleted_count: int
failed_count: int
failed_ids: List[str] = []
def verify_auth_token(
maibot_session: Optional[str] = None,
authorization: Optional[str] = None,
) -> bool:
"""验证认证 Token支持 Cookie 和 Header"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
"""解析群昵称 JSON 字符串"""
if not group_nick_name_str:
return None
try:
return json.loads(group_nick_name_str)
except (json.JSONDecodeError, TypeError):
return None
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
"""将 PersonInfo 模型转换为响应对象"""
return PersonInfoResponse(
id=person.id,
is_known=person.is_known,
person_id=person.person_id,
person_name=person.person_name,
name_reason=person.name_reason,
platform=person.platform,
user_id=person.user_id,
nickname=person.nickname,
group_nick_name=parse_group_nick_name(person.group_nick_name),
memory_points=person.memory_points,
know_times=person.know_times,
know_since=person.know_since,
last_know=person.last_know,
)
@router.get("/list", response_model=PersonListResponse)
async def get_person_list(
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
platform: Optional[str] = Query(None, description="平台筛选"),
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取人物信息列表
Args:
page: 页码 ( 1 开始)
page_size: 每页数量 (1-100)
search: 搜索关键词 (匹配 person_name, nickname, user_id)
is_known: 是否已认识筛选
platform: 平台筛选
authorization: Authorization header
Returns:
人物信息列表
"""
try:
verify_auth_token(maibot_session, authorization)
# 构建查询
query = PersonInfo.select()
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物列表失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物列表失败: {str(e)}") from e
@router.get("/{person_id}", response_model=PersonDetailResponse)
async def get_person_detail(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
获取人物详细信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
人物详细信息
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
return PersonDetailResponse(success=True, data=person_to_response(person))
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取人物详情失败: {e}")
raise HTTPException(status_code=500, detail=f"获取人物详情失败: {str(e)}") from e
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
async def update_person(
person_id: str,
request: PersonUpdateRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
增量更新人物信息只更新提供的字段
Args:
person_id: 人物唯一 ID
request: 更新请求只包含需要更新的字段
authorization: Authorization header
Returns:
更新结果
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 只更新提供的字段
update_data = request.model_dump(exclude_unset=True)
if not update_data:
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data["last_know"] = time.time()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
return PersonUpdateResponse(
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"更新人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"更新人物信息失败: {str(e)}") from e
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
async def delete_person(
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
):
"""
删除人物信息
Args:
person_id: 人物唯一 ID
authorization: Authorization header
Returns:
删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
# 执行删除
person.delete_instance()
logger.info(f"人物信息已删除: {person_id} ({person_name})")
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
except HTTPException:
raise
except Exception as e:
logger.exception(f"删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"删除人物信息失败: {str(e)}") from e
@router.get("/stats/summary")
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
"""
获取人物信息统计数据
Args:
authorization: Authorization header
Returns:
统计数据
"""
try:
verify_auth_token(maibot_session, authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
except HTTPException:
raise
except Exception as e:
logger.exception(f"获取统计数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
@router.post("/batch/delete", response_model=BatchDeleteResponse)
async def batch_delete_persons(
request: BatchDeleteRequest,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
批量删除人物信息
Args:
request: 包含person_ids列表的请求
authorization: Authorization header
Returns:
批量删除结果
"""
try:
verify_auth_token(maibot_session, authorization)
if not request.person_ids:
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
deleted_count = 0
failed_count = 0
failed_ids = []
for person_id in request.person_ids:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
person.delete_instance()
deleted_count += 1
logger.info(f"批量删除: {person_id}")
else:
failed_count += 1
failed_ids.append(person_id)
except Exception as e:
logger.error(f"删除 {person_id} 失败: {e}")
failed_count += 1
failed_ids.append(person_id)
message = f"成功删除 {deleted_count} 个人物"
if failed_count > 0:
message += f"{failed_count} 个失败"
return BatchDeleteResponse(
success=True,
message=message,
deleted_count=deleted_count,
failed_count=failed_count,
failed_ids=failed_ids,
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"批量删除人物信息失败: {e}")
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e

View File

@ -1,164 +0,0 @@
"""WebSocket 插件加载进度推送模块"""
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
from typing import Set, Dict, Any, Optional
import json
import asyncio
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
from src.webui.ws_auth import verify_ws_token
logger = get_logger("webui.plugin_progress")
# 创建路由器
router = APIRouter()
# 全局 WebSocket 连接池
active_connections: Set[WebSocket] = set()
# 当前加载进度状态
current_progress: Dict[str, Any] = {
"operation": "idle", # idle, fetch, install, uninstall, update
"stage": "idle", # idle, loading, success, error
"progress": 0, # 0-100
"message": "",
"error": None,
"plugin_id": None, # 当前操作的插件 ID
"total_plugins": 0,
"loaded_plugins": 0,
}
async def broadcast_progress(progress_data: Dict[str, Any]):
"""广播进度更新到所有连接的客户端"""
global current_progress
current_progress = progress_data.copy()
if not active_connections:
return
message = json.dumps(progress_data, ensure_ascii=False)
disconnected = set()
for websocket in active_connections:
try:
await websocket.send_text(message)
except Exception as e:
logger.error(f"发送进度更新失败: {e}")
disconnected.add(websocket)
# 移除断开的连接
for websocket in disconnected:
active_connections.discard(websocket)
async def update_progress(
stage: str,
progress: int,
message: str,
operation: str = "fetch",
error: str = None,
plugin_id: str = None,
total_plugins: int = 0,
loaded_plugins: int = 0,
):
"""更新并广播进度
Args:
stage: 阶段 (idle, loading, success, error)
progress: 进度百分比 (0-100)
message: 当前消息
operation: 操作类型 (fetch, install, uninstall, update)
error: 错误信息可选
plugin_id: 当前操作的插件 ID
total_plugins: 总插件数
loaded_plugins: 已加载插件数
"""
progress_data = {
"operation": operation,
"stage": stage,
"progress": progress,
"message": message,
"error": error,
"plugin_id": plugin_id,
"total_plugins": total_plugins,
"loaded_plugins": loaded_plugins,
"timestamp": asyncio.get_event_loop().time(),
}
await broadcast_progress(progress_data)
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
@router.websocket("/ws/plugin-progress")
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
"""WebSocket 插件加载进度推送端点
客户端连接后会立即收到当前进度状态
支持三种认证方式按优先级
1. query 参数 token推荐通过 /api/webui/ws-token 获取临时 token
2. Cookie 中的 maibot_session
3. 直接使用 session token兼容
示例ws://host/ws/plugin-progress?token=xxx
"""
is_authenticated = False
# 方式 1: 尝试验证临时 WebSocket token推荐方式
if token and verify_ws_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
# 方式 2: 尝试从 Cookie 获取 session token
if not is_authenticated:
cookie_token = websocket.cookies.get("maibot_session")
if cookie_token:
token_manager = get_token_manager()
if token_manager.verify_token(cookie_token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
# 方式 3: 尝试直接验证 query 参数作为 session token兼容旧方式
if not is_authenticated and token:
token_manager = get_token_manager()
if token_manager.verify_token(token):
is_authenticated = True
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
if not is_authenticated:
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
await websocket.close(code=4001, reason="认证失败,请重新登录")
return
await websocket.accept()
active_connections.add(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
try:
# 发送当前进度状态
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
# 保持连接并处理客户端消息
while True:
try:
data = await websocket.receive_text()
# 处理客户端心跳
if data == "ping":
await websocket.send_text("pong")
except Exception as e:
logger.error(f"处理客户端消息时出错: {e}")
break
except WebSocketDisconnect:
active_connections.discard(websocket)
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
except Exception as e:
logger.error(f"❌ WebSocket 错误: {e}")
active_connections.discard(websocket)
def get_progress_router() -> APIRouter:
"""获取插件进度 WebSocket 路由器"""
return router

File diff suppressed because it is too large Load Diff

View File

@ -1,245 +0,0 @@
"""
WebUI 请求频率限制模块
防止暴力破解和 API 滥用
"""
import time
from collections import defaultdict
from typing import Dict, Tuple, Optional
from fastapi import Request, HTTPException
from src.common.logger import get_logger
logger = get_logger("webui.rate_limiter")
class RateLimiter:
"""
简单的内存请求频率限制器
使用滑动窗口算法实现
"""
def __init__(self):
# 存储格式: {key: [(timestamp, count), ...]}
self._requests: Dict[str, list] = defaultdict(list)
# 被封禁的 IP: {ip: unblock_timestamp}
self._blocked: Dict[str, float] = {}
def _get_client_ip(self, request: Request) -> str:
"""获取客户端 IP 地址"""
# 检查代理头
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
# 取第一个 IP最原始的客户端
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
# 直接连接的客户端
if request.client:
return request.client.host
return "unknown"
def _cleanup_old_requests(self, key: str, window_seconds: int):
"""清理过期的请求记录"""
now = time.time()
cutoff = now - window_seconds
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
def _cleanup_expired_blocks(self):
"""清理过期的封禁"""
now = time.time()
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
for ip in expired:
del self._blocked[ip]
logger.info(f"🔓 IP {ip} 封禁已解除")
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
"""
检查 IP 是否被封禁
Returns:
(是否被封禁, 剩余封禁秒数)
"""
self._cleanup_expired_blocks()
ip = self._get_client_ip(request)
if ip in self._blocked:
remaining = int(self._blocked[ip] - time.time())
return True, max(0, remaining)
return False, None
def check_rate_limit(
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
) -> Tuple[bool, int]:
"""
检查请求是否超过频率限制
Args:
request: FastAPI Request 对象
max_requests: 窗口期内允许的最大请求数
window_seconds: 窗口时间
key_suffix: 键后缀用于区分不同的限制规则
Returns:
(是否允许, 剩余请求数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:{key_suffix}" if key_suffix else ip
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前窗口内的请求数
current_count = sum(count for _, count in self._requests[key])
if current_count >= max_requests:
return False, 0
# 记录新请求
now = time.time()
self._requests[key].append((now, 1))
remaining = max_requests - current_count - 1
return True, remaining
def block_ip(self, request: Request, duration_seconds: int):
"""
封禁 IP
Args:
request: FastAPI Request 对象
duration_seconds: 封禁时长
"""
ip = self._get_client_ip(request)
self._blocked[ip] = time.time() + duration_seconds
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds}")
def record_failed_attempt(
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
) -> Tuple[bool, int]:
"""
记录失败尝试如登录失败
如果在窗口期内失败次数过多自动封禁 IP
Args:
request: FastAPI Request 对象
max_failures: 允许的最大失败次数
window_seconds: 统计窗口
block_duration: 封禁时长
Returns:
(是否被封禁, 剩余尝试次数)
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
# 清理过期记录
self._cleanup_old_requests(key, window_seconds)
# 计算当前失败次数
current_failures = sum(count for _, count in self._requests[key])
# 记录本次失败
now = time.time()
self._requests[key].append((now, 1))
current_failures += 1
remaining = max_failures - current_failures
# 检查是否需要封禁
if current_failures >= max_failures:
self.block_ip(request, block_duration)
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
return True, 0
if current_failures >= max_failures - 2:
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures}")
return False, max(0, remaining)
def reset_failures(self, request: Request):
"""
重置失败计数认证成功后调用
"""
ip = self._get_client_ip(request)
key = f"{ip}:auth_failures"
if key in self._requests:
del self._requests[key]
# 全局单例
_rate_limiter: Optional[RateLimiter] = None
def get_rate_limiter() -> RateLimiter:
"""获取 RateLimiter 单例"""
global _rate_limiter
if _rate_limiter is None:
_rate_limiter = RateLimiter()
return _rate_limiter
async def check_auth_rate_limit(request: Request):
"""
认证接口的频率限制依赖
规则
- 每个 IP 每分钟最多 10 次认证请求
- 连续失败 5 次后封禁 10 分钟
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, remaining = limiter.check_rate_limit(
request,
max_requests=10, # 每分钟 10 次
window_seconds=60,
key_suffix="auth",
)
if not allowed:
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
async def check_api_rate_limit(request: Request):
"""
普通 API 的频率限制依赖
规则每个 IP 每分钟最多 100 次请求
"""
limiter = get_rate_limiter()
# 检查是否被封禁
blocked, remaining_block = limiter.is_blocked(request)
if blocked:
raise HTTPException(
status_code=429,
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
headers={"Retry-After": str(remaining_block)},
)
# 检查频率限制
allowed, _ = limiter.check_rate_limit(
request,
max_requests=100, # 每分钟 100 次
window_seconds=60,
key_suffix="api",
)
if not allowed:
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})

View File

@ -1,113 +0,0 @@
"""
系统控制路由
提供系统重启状态查询等功能
"""
import os
import time
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel
from src.config.config import MMC_VERSION
from src.common.logger import get_logger
from src.webui.auth import verify_auth_token_from_cookie_or_header
router = APIRouter(prefix="/system", tags=["system"])
logger = get_logger("webui_system")
# 记录启动时间
_start_time = time.time()
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class RestartResponse(BaseModel):
"""重启响应"""
success: bool
message: str
class StatusResponse(BaseModel):
"""状态响应"""
running: bool
uptime: float
version: str
start_time: str
@router.post("/restart", response_model=RestartResponse)
async def restart_maibot(_auth: bool = Depends(require_auth)):
"""
重启麦麦主程序
请求重启当前进程配置更改将在重启后生效
注意此操作会使麦麦暂时离线
"""
import asyncio
try:
# 记录重启操作
logger.info("WebUI 触发重启操作")
# 定义延迟重启的异步任务
async def delayed_restart():
await asyncio.sleep(0.5) # 延迟0.5秒,确保响应已发送
# 使用 os._exit(42) 退出当前进程,配合外部 runner 脚本进行重启
# 42 是约定的重启状态码
logger.info("WebUI 请求重启,退出代码 42")
os._exit(42)
# 创建后台任务执行重启
asyncio.create_task(delayed_restart())
# 立即返回成功响应
return RestartResponse(success=True, message="麦麦正在重启中...")
except Exception as e:
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
@router.get("/status", response_model=StatusResponse)
async def get_maibot_status(_auth: bool = Depends(require_auth)):
"""
获取麦麦运行状态
返回麦麦的运行状态运行时长和版本信息
"""
try:
uptime = time.time() - _start_time
# 尝试获取版本信息(需要根据实际情况调整)
version = MMC_VERSION # 可以从配置或常量中读取
return StatusResponse(
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
# 可选:添加更多系统控制功能
@router.post("/reload-config")
async def reload_config(_auth: bool = Depends(require_auth)):
"""
热重载配置不重启进程
仅重新加载配置文件某些配置可能需要重启才能生效
此功能需要在主程序中实现配置热重载逻辑
"""
# 这里需要调用主程序的配置重载函数
# 示例await app_instance.reload_config()
return {"success": True, "message": "配置重载功能待实现"}

View File

@ -1,456 +0,0 @@
"""WebUI API 路由"""
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
from pydantic import BaseModel, Field
from typing import Optional
from src.common.logger import get_logger
from .token_manager import get_token_manager
from .auth import set_auth_cookie, clear_auth_cookie
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
from .config_routes import router as config_router
from .statistics_routes import router as statistics_router
from .person_routes import router as person_router
from .expression_routes import router as expression_router
from .jargon_routes import router as jargon_router
from .emoji_routes import router as emoji_router
from .plugin_routes import router as plugin_router
from .plugin_progress_ws import get_progress_router
from .routers.system import router as system_router
from .model_routes import router as model_router
from .ws_auth import router as ws_auth_router
from .annual_report_routes import router as annual_report_router
logger = get_logger("webui.api")
# 创建路由器
router = APIRouter(prefix="/api/webui", tags=["WebUI"])
# 注册配置管理路由
router.include_router(config_router)
# 注册统计数据路由
router.include_router(statistics_router)
# 注册人物信息管理路由
router.include_router(person_router)
# 注册表达方式管理路由
router.include_router(expression_router)
# 注册黑话管理路由
router.include_router(jargon_router)
# 注册表情包管理路由
router.include_router(emoji_router)
# 注册插件管理路由
router.include_router(plugin_router)
# 注册插件进度 WebSocket 路由
router.include_router(get_progress_router())
# 注册系统控制路由
router.include_router(system_router)
# 注册模型列表获取路由
router.include_router(model_router)
# 注册 WebSocket 认证路由
router.include_router(ws_auth_router)
# 注册年度报告路由
router.include_router(annual_report_router)
class TokenVerifyRequest(BaseModel):
"""Token 验证请求"""
token: str = Field(..., description="访问令牌")
class TokenVerifyResponse(BaseModel):
"""Token 验证响应"""
valid: bool = Field(..., description="Token 是否有效")
message: str = Field(..., description="验证结果消息")
is_first_setup: bool = Field(False, description="是否为首次设置")
class TokenUpdateRequest(BaseModel):
"""Token 更新请求"""
new_token: str = Field(..., description="新的访问令牌", min_length=10)
class TokenUpdateResponse(BaseModel):
"""Token 更新响应"""
success: bool = Field(..., description="是否更新成功")
message: str = Field(..., description="更新结果消息")
class TokenRegenerateResponse(BaseModel):
"""Token 重新生成响应"""
success: bool = Field(..., description="是否生成成功")
token: str = Field(..., description="新生成的令牌")
message: str = Field(..., description="生成结果消息")
class FirstSetupStatusResponse(BaseModel):
"""首次配置状态响应"""
is_first_setup: bool = Field(..., description="是否为首次配置")
message: str = Field(..., description="状态消息")
class CompleteSetupResponse(BaseModel):
"""完成配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
class ResetSetupResponse(BaseModel):
"""重置配置响应"""
success: bool = Field(..., description="是否成功")
message: str = Field(..., description="结果消息")
@router.get("/health")
async def health_check():
"""健康检查"""
return {"status": "healthy", "service": "MaiBot WebUI"}
@router.post("/auth/verify", response_model=TokenVerifyResponse)
async def verify_token(
request_body: TokenVerifyRequest,
request: Request,
response: Response,
_rate_limit: None = Depends(check_auth_rate_limit),
):
"""
验证访问令牌验证成功后设置 HttpOnly Cookie
Args:
request_body: 包含 token 的验证请求
request: FastAPI Request 对象用于获取客户端 IP
response: FastAPI Response 对象
Returns:
验证结果包含首次配置状态
"""
try:
token_manager = get_token_manager()
rate_limiter = get_rate_limiter()
is_valid = token_manager.verify_token(request_body.token)
if is_valid:
# 认证成功,重置失败计数
rate_limiter.reset_failures(request)
# 设置 HttpOnly Cookie传入 request 以检测协议)
set_auth_cookie(response, request_body.token, request)
# 同时返回首次配置状态,避免额外请求
is_first_setup = token_manager.is_first_setup()
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
else:
# 记录失败尝试
blocked, remaining = rate_limiter.record_failed_attempt(
request,
max_failures=5, # 5 次失败
window_seconds=300, # 5 分钟窗口
block_duration=600, # 封禁 10 分钟
)
if blocked:
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
message = "Token 无效或已过期"
if remaining <= 2:
message += f"(剩余 {remaining} 次尝试机会)"
return TokenVerifyResponse(valid=False, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 验证失败: {e}")
raise HTTPException(status_code=500, detail="Token 验证失败") from e
@router.post("/auth/logout")
async def logout(response: Response):
"""
登出并清除认证 Cookie
Args:
response: FastAPI Response 对象
Returns:
登出结果
"""
clear_auth_cookie(response)
return {"success": True, "message": "已成功登出"}
@router.get("/auth/check")
async def check_auth_status(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
检查当前认证状态用于前端判断是否已登录
Returns:
认证状态
"""
try:
token = None
# 记录请求信息用于调试
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
# 优先从 Cookie 获取
if maibot_session:
token = maibot_session
logger.debug("使用 Cookie 中的 token")
# 其次从 Header 获取
elif authorization and authorization.startswith("Bearer "):
token = authorization.replace("Bearer ", "")
logger.debug("使用 Header 中的 token")
if not token:
logger.debug("未找到 token返回未认证")
return {"authenticated": False}
token_manager = get_token_manager()
is_valid = token_manager.verify_token(token)
logger.debug(f"Token 验证结果: {is_valid}")
if is_valid:
return {"authenticated": True}
else:
return {"authenticated": False}
except Exception as e:
logger.error(f"认证检查失败: {e}", exc_info=True)
return {"authenticated": False}
@router.post("/auth/update", response_model=TokenUpdateResponse)
async def update_token(
request: TokenUpdateRequest,
response: Response,
req: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
更新访问令牌需要当前有效的 token
Args:
request: 包含新 token 的更新请求
response: FastAPI Response 对象
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
更新结果
"""
try:
# 验证当前 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 更新 token
success, message = token_manager.update_token(request.new_token)
# 如果更新成功,清除 Cookie要求用户重新登录
if success:
clear_auth_cookie(response)
return TokenUpdateResponse(success=success, message=message)
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 更新失败: {e}")
raise HTTPException(status_code=500, detail="Token 更新失败") from e
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
async def regenerate_token(
response: Response,
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
重新生成访问令牌需要当前有效的 token
Args:
response: FastAPI Response 对象
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
新生成的 token
"""
try:
# 验证当前 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="当前 Token 无效")
# 重新生成 token
new_token = token_manager.regenerate_token()
# 清除 Cookie要求用户重新登录
clear_auth_cookie(response)
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
except HTTPException:
raise
except Exception as e:
logger.error(f"Token 重新生成失败: {e}")
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
async def get_setup_status(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取首次配置状态
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
首次配置状态
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 检查是否为首次配置
is_first = token_manager.is_first_setup()
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
except HTTPException:
raise
except Exception as e:
logger.error(f"获取配置状态失败: {e}")
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
@router.post("/setup/complete", response_model=CompleteSetupResponse)
async def complete_setup(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
标记首次配置完成
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
完成结果
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 标记配置完成
success = token_manager.mark_setup_completed()
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"标记配置完成失败: {e}")
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
@router.post("/setup/reset", response_model=ResetSetupResponse)
async def reset_setup(
request: Request,
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
重置首次配置状态允许重新进入配置向导
Args:
maibot_session: Cookie 中的 token
authorization: Authorization header (Bearer token)
Returns:
重置结果
"""
try:
# 验证 token优先 Cookie其次 Header
current_token = None
if maibot_session:
current_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
current_token = authorization.replace("Bearer ", "")
if not current_token:
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
token_manager = get_token_manager()
if not token_manager.verify_token(current_token):
raise HTTPException(status_code=401, detail="Token 无效")
# 重置配置状态
success = token_manager.reset_setup_status()
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
except HTTPException:
raise
except Exception as e:
logger.error(f"重置配置状态失败: {e}")
raise HTTPException(status_code=500, detail="重置配置状态失败") from e

View File

@ -1,319 +0,0 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.auth import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
router = APIRouter(prefix="/statistics", tags=["statistics"])
def require_auth(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
) -> bool:
"""认证依赖:验证用户是否已登录"""
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
class StatisticsSummary(BaseModel):
"""统计数据摘要"""
total_requests: int = Field(0, description="总请求数")
total_cost: float = Field(0.0, description="总花费")
total_tokens: int = Field(0, description="总token数")
online_time: float = Field(0.0, description="在线时间(秒)")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
avg_response_time: float = Field(0.0, description="平均响应时间")
cost_per_hour: float = Field(0.0, description="每小时花费")
tokens_per_hour: float = Field(0.0, description="每小时token数")
class ModelStatistics(BaseModel):
"""模型统计"""
model_name: str
request_count: int
total_cost: float
total_tokens: int
avg_response_time: float
class TimeSeriesData(BaseModel):
"""时间序列数据"""
timestamp: str
requests: int = 0
cost: float = 0.0
tokens: int = 0
class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]
@router.get("/dashboard", response_model=DashboardData)
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取仪表盘统计数据
Args:
hours: 统计时间范围小时默认24小时
Returns:
仪表盘数据
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
# 获取摘要数据
summary = await _get_summary_statistics(start_time, now)
# 获取模型统计
model_stats = await _get_model_statistics(start_time)
# 获取小时级时间序列数据
hourly_data = await _get_hourly_statistics(start_time, now)
# 获取日级时间序列数据最近7天
daily_start = now - timedelta(days=7)
daily_data = await _get_daily_statistics(daily_start, now)
# 获取最近活动
recent_activity = await _get_recent_activity(limit=10)
return DashboardData(
summary=summary,
model_stats=model_stats,
hourly_data=hourly_data,
daily_data=daily_data,
recent_activity=recent_activity,
)
except Exception as e:
logger.error(f"获取仪表盘数据失败: {e}")
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据(优化:使用数据库聚合)"""
summary = StatisticsSummary()
# 使用聚合查询替代全量加载
query = LLMUsage.select(
fn.COUNT(LLMUsage.id).alias("total_requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
result = query.dicts().get()
summary.total_requests = result["total_requests"]
summary.total_cost = result["total_cost"]
summary.total_tokens = result["total_tokens"]
summary.avg_response_time = result["avg_response_time"] or 0.0
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
online_records = list(
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
for record in online_records:
start = max(record.start_timestamp, start_time)
end = min(record.end_timestamp, end_time)
if end > start:
summary.online_time += (end - start).total_seconds()
# 查询消息数量 - 使用聚合优化
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
)
summary.total_messages = messages_query.scalar() or 0
# 统计回复数量
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp())
& (Messages.time <= end_time.timestamp())
& (Messages.reply_to.is_null(False))
)
summary.total_replies = replies_query.scalar() or 0
# 计算派生指标
if summary.online_time > 0:
online_hours = summary.online_time / 3600.0
summary.cost_per_hour = summary.total_cost / online_hours
summary.tokens_per_hour = summary.total_tokens / online_hours
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
# 使用GROUP BY聚合避免全量加载
query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
fn.COUNT(LLMUsage.id).alias("request_count"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
)
.where(LLMUsage.timestamp >= start_time)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10) # 只取前10个
)
result = []
for row in query.dicts():
result.append(
ModelStatistics(
model_name=row["model_name"],
request_count=row["request_count"],
total_cost=row["total_cost"],
total_tokens=row["total_tokens"],
avg_response_time=row["avg_response_time"] or 0.0,
)
)
return result
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取小时级统计数据(优化:使用数据库聚合)"""
# SQLite的日期时间函数进行小时分组
# 使用strftime将timestamp格式化为小时级别
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
)
# 转换为字典以快速查找
data_dict = {row["hour"]: row for row in query.dicts()}
# 填充所有小时(包括没有数据的)
result = []
current = start_time.replace(minute=0, second=0, microsecond=0)
while current <= end_time:
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
if hour_str in data_dict:
row = data_dict[hour_str]
result.append(
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
current += timedelta(hours=1)
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
"""获取日级统计数据(优化:使用数据库聚合)"""
# 使用strftime按日期分组
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
)
# 转换为字典
data_dict = {row["day"]: row for row in query.dicts()}
# 填充所有天
result = []
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
while current <= end_time:
day_str = current.strftime("%Y-%m-%dT00:00:00")
if day_str in data_dict:
row = data_dict[day_str]
result.append(
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
)
else:
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
current += timedelta(days=1)
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"""获取最近活动"""
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
activities = []
for record in records:
activities.append(
{
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
}
)
return activities
@router.get("/summary")
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取统计摘要
Args:
hours: 统计时间范围小时
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
summary = await _get_summary_statistics(start_time, now)
return summary
except Exception as e:
logger.error(f"获取统计摘要失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e
@router.get("/models")
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
"""
获取模型统计
Args:
hours: 统计时间范围小时
"""
try:
now = datetime.now()
start_time = now - timedelta(hours=hours)
stats = await _get_model_statistics(start_time)
return stats
except Exception as e:
logger.error(f"获取模型统计失败: {e}")
raise HTTPException(status_code=500, detail=str(e)) from e

View File

@ -1,309 +0,0 @@
"""
WebUI Token 管理模块
负责生成保存验证和更新访问令牌
"""
import json
import secrets
from pathlib import Path
from typing import Optional
from src.common.logger import get_logger
logger = get_logger("webui")
class TokenManager:
"""Token 管理器"""
def __init__(self, config_path: Optional[Path] = None):
"""
初始化 Token 管理器
Args:
config_path: 配置文件路径默认为项目根目录的 data/webui.json
"""
if config_path is None:
# 获取项目根目录 (src/webui -> src -> 根目录)
project_root = Path(__file__).parent.parent.parent
config_path = project_root / "data" / "webui.json"
self.config_path = config_path
self.config_path.parent.mkdir(parents=True, exist_ok=True)
# 确保配置文件存在并包含有效的 token
self._ensure_config()
def _ensure_config(self):
"""确保配置文件存在且包含有效的 token"""
if not self.config_path.exists():
logger.info(f"WebUI 配置文件不存在,正在创建: {self.config_path}")
self._create_new_token()
else:
# 验证配置文件格式
try:
config = self._load_config()
if not config.get("access_token"):
logger.warning("WebUI 配置文件中缺少 access_token正在重新生成")
self._create_new_token()
else:
logger.info(f"WebUI Token 已加载: {config['access_token'][:8]}...")
except Exception as e:
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
self._create_new_token()
def _load_config(self) -> dict:
"""加载配置文件"""
try:
with open(self.config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载 WebUI 配置失败: {e}")
return {}
def _save_config(self, config: dict):
"""保存配置文件"""
try:
with open(self.config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
logger.info(f"WebUI 配置已保存到: {self.config_path}")
except Exception as e:
logger.error(f"保存 WebUI 配置失败: {e}")
raise
def _create_new_token(self) -> str:
"""生成新的 64 位随机 token"""
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
token = secrets.token_hex(32)
config = {
"access_token": token,
"created_at": self._get_current_timestamp(),
"updated_at": self._get_current_timestamp(),
"first_setup_completed": False, # 标记首次配置未完成
}
self._save_config(config)
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
return token
def _get_current_timestamp(self) -> str:
"""获取当前时间戳字符串"""
from datetime import datetime
return datetime.now().isoformat()
def get_token(self) -> str:
"""获取当前有效的 token"""
config = self._load_config()
return config.get("access_token", "")
def verify_token(self, token: str) -> bool:
"""
验证 token 是否有效
Args:
token: 待验证的 token
Returns:
bool: token 是否有效
"""
if not token:
return False
current_token = self.get_token()
if not current_token:
logger.error("系统中没有有效的 token")
return False
# 使用 secrets.compare_digest 防止时序攻击
is_valid = secrets.compare_digest(token, current_token)
if is_valid:
logger.debug("Token 验证成功")
else:
logger.warning("Token 验证失败")
return is_valid
def update_token(self, new_token: str) -> tuple[bool, str]:
"""
更新 token
Args:
new_token: 新的 token (最少 10 必须包含大小写字母和特殊符号)
Returns:
tuple[bool, str]: (是否更新成功, 错误消息)
"""
# 验证新 token 格式
is_valid, error_msg = self._validate_custom_token(new_token)
if not is_valid:
logger.error(f"Token 格式无效: {error_msg}")
return False, error_msg
try:
config = self._load_config()
old_token = config.get("access_token", "")[:8]
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
self._save_config(config)
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
return True, "Token 更新成功"
except Exception as e:
logger.error(f"更新 Token 失败: {e}")
return False, f"更新失败: {str(e)}"
def regenerate_token(self) -> str:
"""
重新生成 token保留 first_setup_completed 状态
Returns:
str: 新生成的 token
"""
logger.info("正在重新生成 WebUI Token...")
# 生成新的 64 位十六进制字符串
new_token = secrets.token_hex(32)
# 加载现有配置,保留 first_setup_completed 状态
config = self._load_config()
old_token = config.get("access_token", "")[:8] if config.get("access_token") else ""
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True表示已完成配置
config["access_token"] = new_token
config["updated_at"] = self._get_current_timestamp()
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
self._save_config(config)
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
return new_token
def _validate_token_format(self, token: str) -> bool:
"""
验证 token 格式是否正确旧的 64 位十六进制验证保留用于系统生成的 token
Args:
token: 待验证的 token
Returns:
bool: 格式是否正确
"""
if not token or not isinstance(token, str):
return False
# 必须是 64 位十六进制字符串
if len(token) != 64:
return False
# 验证是否为有效的十六进制字符串
try:
int(token, 16)
return True
except ValueError:
return False
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
"""
验证自定义 token 格式
要求:
- 最少 10
- 包含大写字母
- 包含小写字母
- 包含特殊符号
Args:
token: 待验证的 token
Returns:
tuple[bool, str]: (是否有效, 错误消息)
"""
if not token or not isinstance(token, str):
return False, "Token 不能为空"
# 检查长度
if len(token) < 10:
return False, "Token 长度至少为 10 位"
# 检查是否包含大写字母
has_upper = any(c.isupper() for c in token)
if not has_upper:
return False, "Token 必须包含大写字母"
# 检查是否包含小写字母
has_lower = any(c.islower() for c in token)
if not has_lower:
return False, "Token 必须包含小写字母"
# 检查是否包含特殊符号
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
has_special = any(c in special_chars for c in token)
if not has_special:
return False, f"Token 必须包含特殊符号 ({special_chars})"
return True, "Token 格式正确"
def is_first_setup(self) -> bool:
"""
检查是否为首次配置
Returns:
bool: 是否为首次配置
"""
config = self._load_config()
return not config.get("first_setup_completed", False)
def mark_setup_completed(self) -> bool:
"""
标记首次配置已完成
Returns:
bool: 是否标记成功
"""
try:
config = self._load_config()
config["first_setup_completed"] = True
config["setup_completed_at"] = self._get_current_timestamp()
self._save_config(config)
logger.info("首次配置已标记为完成")
return True
except Exception as e:
logger.error(f"标记首次配置完成失败: {e}")
return False
def reset_setup_status(self) -> bool:
"""
重置首次配置状态允许重新进入配置向导
Returns:
bool: 是否重置成功
"""
try:
config = self._load_config()
config["first_setup_completed"] = False
if "setup_completed_at" in config:
del config["setup_completed_at"]
self._save_config(config)
logger.info("首次配置状态已重置")
return True
except Exception as e:
logger.error(f"重置首次配置状态失败: {e}")
return False
# 全局单例
_token_manager_instance: Optional[TokenManager] = None
def get_token_manager() -> TokenManager:
"""获取 TokenManager 单例"""
global _token_manager_instance
if _token_manager_instance is None:
_token_manager_instance = TokenManager()
return _token_manager_instance

View File

@ -1,295 +0,0 @@
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
import asyncio
import mimetypes
from pathlib import Path
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from uvicorn import Config, Server as UvicornServer
from src.common.logger import get_logger
logger = get_logger("webui_server")
class WebUIServer:
"""独立的 WebUI 服务器"""
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
self.host = host
self.port = port
self.app = FastAPI(title="MaiBot WebUI")
self._server = None
# 配置防爬虫中间件需要在CORS之前注册
self._setup_anti_crawler()
# 配置 CORS支持开发环境跨域请求
self._setup_cors()
# 显示 Access Token
self._show_access_token()
# 重要:先注册 API 路由,再设置静态文件
self._register_api_routes()
self._setup_static_files()
# 注册robots.txt路由
self._setup_robots_txt()
def _setup_cors(self):
"""配置 CORS 中间件"""
# 开发环境需要允许前端开发服务器的跨域请求
self.app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173", # Vite 开发服务器
"http://127.0.0.1:5173",
"http://localhost:7999", # 前端开发服务器备用端口
"http://127.0.0.1:7999",
"http://localhost:8001", # 生产环境
"http://127.0.0.1:8001",
],
allow_credentials=True, # 允许携带 Cookie
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
allow_headers=[
"Content-Type",
"Authorization",
"Accept",
"Origin",
"X-Requested-With",
], # 明确指定允许的头
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
)
logger.debug("✅ CORS 中间件已配置")
def _show_access_token(self):
"""显示 WebUI Access Token"""
try:
from src.webui.token_manager import get_token_manager
token_manager = get_token_manager()
current_token = token_manager.get_token()
logger.info(f"🔑 WebUI Access Token: {current_token}")
logger.info("💡 请使用此 Token 登录 WebUI")
except Exception as e:
logger.error(f"❌ 获取 Access Token 失败: {e}")
def _setup_static_files(self):
"""设置静态文件服务"""
# 确保正确的 MIME 类型映射
mimetypes.init()
mimetypes.add_type("application/javascript", ".js")
mimetypes.add_type("application/javascript", ".mjs")
mimetypes.add_type("text/css", ".css")
mimetypes.add_type("application/json", ".json")
base_dir = Path(__file__).parent.parent.parent
static_path = base_dir / "webui" / "dist"
if not static_path.exists():
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
logger.warning("💡 请先构建前端: cd webui && npm run build")
return
if not (static_path / "index.html").exists():
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
logger.warning("💡 请确认前端已正确构建")
return
# 处理 SPA 路由 - 注意:这个路由优先级最低
@self.app.get("/{full_path:path}", include_in_schema=False)
async def serve_spa(full_path: str):
"""服务单页应用 - 只处理非 API 请求"""
# 如果是根路径,直接返回 index.html
if not full_path or full_path == "/":
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 检查是否是静态文件
file_path = static_path / full_path
if file_path.is_file() and file_path.exists():
# 自动检测 MIME 类型
media_type = mimetypes.guess_type(str(file_path))[0]
response = FileResponse(file_path, media_type=media_type)
# HTML 文件添加防索引头
if str(file_path).endswith(".html"):
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
# 其他路径返回 index.htmlSPA 路由)
response = FileResponse(static_path / "index.html", media_type="text/html")
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
return response
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
def _setup_anti_crawler(self):
"""配置防爬虫中间件"""
try:
from src.webui.anti_crawler import AntiCrawlerMiddleware
from src.config.config import global_config
# 从配置读取防爬虫模式
anti_crawler_mode = global_config.webui.anti_crawler_mode
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
# 我们需要在CORS之前注册这样防爬虫检查会在CORS之前执行
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
except Exception as e:
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
def _setup_robots_txt(self):
"""设置robots.txt路由"""
try:
from src.webui.anti_crawler import create_robots_txt_response
@self.app.get("/robots.txt", include_in_schema=False)
async def robots_txt():
"""返回robots.txt禁止所有爬虫"""
return create_robots_txt_response()
logger.debug("✅ robots.txt 路由已注册")
except Exception as e:
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
def _register_api_routes(self):
"""注册所有 WebUI API 路由"""
try:
# 导入所有 WebUI 路由
from src.webui.routes import router as webui_router
from src.webui.logs_ws import router as logs_router
from src.webui.knowledge_routes import router as knowledge_router
# 导入本地聊天室路由
from src.webui.chat_routes import router as chat_router
# 导入规划器监控路由
from src.webui.api.planner import router as planner_router
# 导入回复器监控路由
from src.webui.api.replier import router as replier_router
# 注册路由
self.app.include_router(webui_router)
self.app.include_router(logs_router)
self.app.include_router(knowledge_router)
self.app.include_router(chat_router)
self.app.include_router(planner_router)
self.app.include_router(replier_router)
logger.info("✅ WebUI API 路由已注册")
except Exception as e:
logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True)
async def start(self):
"""启动服务器"""
# 预先检查端口是否可用
if not self._check_port_available():
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
logger.error(error_msg)
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
config = Config(
app=self.app,
host=self.host,
port=self.port,
log_config=None,
access_log=False,
)
self._server = UvicornServer(config=config)
logger.info("🌐 WebUI 服务器启动中...")
# 根据地址类型显示正确的访问地址
if ':' in self.host:
# IPv6 地址需要用方括号包裹
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
if self.host == "::":
logger.info(f"💡 IPv6 本机访问: http://[::1]:{self.port}")
logger.info(f"💡 IPv4 本机访问: http://127.0.0.1:{self.port}")
elif self.host == "::1":
logger.info("💡 仅支持 IPv6 本地访问")
else:
# IPv4 地址
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
if self.host == "0.0.0.0":
logger.info(f"💡 本机访问: http://localhost:{self.port} 或 http://127.0.0.1:{self.port}")
try:
await self._server.serve()
except OSError as e:
# 处理端口绑定相关的错误
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
else:
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
raise
except Exception as e:
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
raise
def _check_port_available(self) -> bool:
"""检查端口是否可用(支持 IPv4 和 IPv6"""
import socket
# 判断使用 IPv4 还是 IPv6
if ':' in self.host:
# IPv6 地址
family = socket.AF_INET6
test_host = self.host if self.host != "::" else "::1"
else:
# IPv4 地址
family = socket.AF_INET
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
try:
with socket.socket(family, socket.SOCK_STREAM) as s:
s.settimeout(1)
# 尝试绑定端口
s.bind((test_host, self.port))
return True
except OSError:
return False
async def shutdown(self):
"""关闭服务器"""
if self._server:
logger.info("正在关闭 WebUI 服务器...")
self._server.should_exit = True
try:
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
logger.info("✅ WebUI 服务器已关闭")
except asyncio.TimeoutError:
logger.warning("⚠️ WebUI 服务器关闭超时")
except Exception as e:
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
finally:
self._server = None
# 全局 WebUI 服务器实例
_webui_server = None
def get_webui_server() -> WebUIServer:
"""获取全局 WebUI 服务器实例"""
global _webui_server
if _webui_server is None:
# 从环境变量读取
import os
host = os.getenv("WEBUI_HOST", "127.0.0.1")
port = int(os.getenv("WEBUI_PORT", "8001"))
_webui_server = WebUIServer(host=host, port=port)
return _webui_server

View File

@ -1,114 +0,0 @@
"""WebSocket 认证模块
提供所有 WebSocket 端点统一使用的临时 token 认证机制
临时 token 有效期 60 且只能使用一次用于解决 WebSocket 握手时 Cookie 不可用的问题
"""
from fastapi import APIRouter, Cookie, Header
from typing import Optional
import secrets
import time
from src.common.logger import get_logger
from src.webui.token_manager import get_token_manager
logger = get_logger("webui.ws_auth")
router = APIRouter()
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
_WS_TOKEN_EXPIRE_SECONDS = 60
def _cleanup_expired_ws_tokens():
"""清理过期的临时 token"""
now = time.time()
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
for t in expired:
del _ws_temp_tokens[t]
def generate_ws_token(session_token: str) -> str:
"""生成 WebSocket 临时 token
Args:
session_token: 原始的 session token
Returns:
临时 token 字符串
"""
_cleanup_expired_ws_tokens()
temp_token = secrets.token_urlsafe(32)
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
return temp_token
def verify_ws_token(temp_token: str) -> bool:
"""验证并消费 WebSocket 临时 token一次性使用
Args:
temp_token: 临时 token
Returns:
验证是否通过
"""
_cleanup_expired_ws_tokens()
if temp_token not in _ws_temp_tokens:
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
return False
expire_time, session_token = _ws_temp_tokens[temp_token]
if time.time() > expire_time:
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
return False
# 验证原始 session token 仍然有效
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
del _ws_temp_tokens[temp_token]
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
return False
# 消费 token一次性使用
del _ws_temp_tokens[temp_token]
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
return True
@router.get("/ws-token")
async def get_ws_token(
maibot_session: Optional[str] = Cookie(None),
authorization: Optional[str] = Header(None),
):
"""
获取 WebSocket 连接用的临时 token
此端点验证当前会话的 Cookie Authorization header
然后返回一个临时 token 用于 WebSocket 握手认证
临时 token 有效期 60 且只能使用一次
注意在未认证时返回 200 状态码但 success=False避免前端因 401 刷新页面
"""
# 获取当前 session token
session_token = None
if maibot_session:
session_token = maibot_session
elif authorization and authorization.startswith("Bearer "):
session_token = authorization.replace("Bearer ", "")
if not session_token:
# 返回 200 但 success=False避免前端因 401 刷新页面
# 这在登录页面是正常情况,不应该触发错误处理
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
# 验证 session token
token_manager = get_token_manager()
if not token_manager.verify_token(session_token):
# 同样返回 200 但 success=False避免前端刷新
logger.debug("ws-token 请求:认证已过期")
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
# 生成临时 WebSocket token
ws_token = generate_ws_token(session_token)
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Some files were not shown because too many files have changed in this diff Show More