mirror of https://github.com/Mai-with-u/MaiBot.git
remove webui
parent
072939e8e9
commit
afb993e481
|
|
@ -1 +0,0 @@
|
|||
"""WebUI 模块"""
|
||||
|
|
@ -1,938 +0,0 @@
|
|||
"""麦麦 2025 年度总结 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import (
|
||||
LLMUsage,
|
||||
OnlineTime,
|
||||
Messages,
|
||||
ChatStreams,
|
||||
PersonInfo,
|
||||
Emoji,
|
||||
Expression,
|
||||
ActionRecords,
|
||||
Jargon,
|
||||
)
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.annual_report")
|
||||
|
||||
router = APIRouter(prefix="/annual-report", tags=["annual-report"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ==================== Pydantic 模型定义 ====================
|
||||
|
||||
|
||||
class TimeFootprintData(BaseModel):
|
||||
"""时光足迹数据"""
|
||||
|
||||
total_online_hours: float = Field(0.0, description="年度在线总时长(小时)")
|
||||
first_message_time: Optional[str] = Field(None, description="初次消息时间")
|
||||
first_message_user: Optional[str] = Field(None, description="初次消息用户昵称")
|
||||
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
|
||||
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
|
||||
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
|
||||
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
|
||||
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
|
||||
is_night_owl: bool = Field(False, description="是否是夜猫子")
|
||||
|
||||
|
||||
class SocialNetworkData(BaseModel):
|
||||
"""社交网络数据"""
|
||||
|
||||
total_groups: int = Field(0, description="加入的群组总数")
|
||||
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
|
||||
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
|
||||
at_count: int = Field(0, description="被@次数")
|
||||
mentioned_count: int = Field(0, description="被提及次数")
|
||||
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
|
||||
longest_companion_days: int = Field(0, description="陪伴天数")
|
||||
|
||||
|
||||
class BrainPowerData(BaseModel):
|
||||
"""最强大脑数据"""
|
||||
|
||||
total_tokens: int = Field(0, description="年度消耗Token总量")
|
||||
total_cost: float = Field(0.0, description="年度总花费")
|
||||
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
|
||||
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
|
||||
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
|
||||
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
|
||||
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
|
||||
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
|
||||
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
|
||||
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
|
||||
total_actions: int = Field(0, description="总动作数")
|
||||
no_reply_count: int = Field(0, description="选择沉默的次数")
|
||||
avg_interest_value: float = Field(0.0, description="平均兴趣值")
|
||||
max_interest_value: float = Field(0.0, description="最高兴趣值")
|
||||
max_interest_time: Optional[str] = Field(None, description="最高兴趣值时间")
|
||||
avg_reasoning_length: float = Field(0.0, description="平均思考长度")
|
||||
max_reasoning_length: int = Field(0, description="最长思考长度")
|
||||
max_reasoning_time: Optional[str] = Field(None, description="最长思考的时间")
|
||||
|
||||
|
||||
class ExpressionVibeData(BaseModel):
|
||||
"""个性与表达数据"""
|
||||
|
||||
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
|
||||
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
|
||||
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
|
||||
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
|
||||
checked_expression_count: int = Field(0, description="已检查的表达次数")
|
||||
total_expressions: int = Field(0, description="表达总数")
|
||||
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
|
||||
image_processed_count: int = Field(0, description="处理的图片数量")
|
||||
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
|
||||
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
|
||||
|
||||
|
||||
class AchievementData(BaseModel):
|
||||
"""趣味成就数据"""
|
||||
|
||||
new_jargon_count: int = Field(0, description="新学到的黑话数量")
|
||||
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
|
||||
|
||||
class AnnualReportData(BaseModel):
|
||||
"""年度报告完整数据"""
|
||||
|
||||
year: int = Field(2025, description="报告年份")
|
||||
bot_name: str = Field("麦麦", description="Bot名称")
|
||||
generated_at: str = Field(..., description="报告生成时间")
|
||||
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
|
||||
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
|
||||
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
|
||||
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
|
||||
achievements: AchievementData = Field(default_factory=AchievementData)
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def get_year_time_range(year: int = 2025) -> tuple[float, float]:
|
||||
"""获取指定年份的时间戳范围"""
|
||||
start = datetime(year, 1, 1, 0, 0, 0).timestamp()
|
||||
end = datetime(year, 12, 31, 23, 59, 59).timestamp()
|
||||
return start, end
|
||||
|
||||
|
||||
def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
|
||||
"""获取指定年份的 datetime 范围"""
|
||||
start = datetime(year, 1, 1, 0, 0, 0)
|
||||
end = datetime(year, 12, 31, 23, 59, 59)
|
||||
return start, end
|
||||
|
||||
|
||||
# ==================== 维度一:时光足迹 ====================
|
||||
|
||||
|
||||
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
||||
"""获取时光足迹数据"""
|
||||
data = TimeFootprintData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
start_dt, end_dt = get_year_datetime_range(year)
|
||||
|
||||
try:
|
||||
# 1. 年度在线时长
|
||||
online_records = list(
|
||||
OnlineTime.select().where(
|
||||
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
|
||||
)
|
||||
)
|
||||
total_seconds = 0
|
||||
for record in online_records:
|
||||
try:
|
||||
start = max(record.start_timestamp, start_dt)
|
||||
end = min(record.end_timestamp, end_dt)
|
||||
if end > start:
|
||||
total_seconds += (end - start).total_seconds()
|
||||
except Exception:
|
||||
continue
|
||||
data.total_online_hours = round(total_seconds / 3600, 2)
|
||||
|
||||
# 2. 初次相遇 - 年度第一条消息
|
||||
first_msg = (
|
||||
Messages.select()
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.order_by(Messages.time.asc())
|
||||
.first()
|
||||
)
|
||||
if first_msg:
|
||||
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
|
||||
content = first_msg.processed_plain_text or first_msg.display_message or ""
|
||||
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
|
||||
|
||||
# 3. 最忙碌的一天
|
||||
# 使用 SQLite 的 date 函数按日期分组
|
||||
busiest_query = (
|
||||
Messages.select(
|
||||
fn.date(Messages.time, "unixepoch").alias("day"),
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.group_by(fn.date(Messages.time, "unixepoch"))
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
.limit(1)
|
||||
)
|
||||
busiest_result = list(busiest_query.dicts())
|
||||
if busiest_result:
|
||||
data.busiest_day = busiest_result[0].get("day")
|
||||
data.busiest_day_count = busiest_result[0].get("count", 0)
|
||||
|
||||
# 4. 昼夜节律 - 24小时活跃分布
|
||||
hourly_query = (
|
||||
Messages.select(
|
||||
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
|
||||
)
|
||||
hourly_distribution = [0] * 24
|
||||
for row in hourly_query.dicts():
|
||||
try:
|
||||
hour = int(row.get("hour", 0))
|
||||
if 0 <= hour < 24:
|
||||
hourly_distribution[hour] = row.get("count", 0)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
data.hourly_distribution = hourly_distribution
|
||||
|
||||
# 5. 深夜食堂 (0-4点)
|
||||
data.midnight_chat_count = sum(hourly_distribution[0:5])
|
||||
|
||||
# 6. 判断是否夜猫子 (22点-4点活跃度 vs 6点-12点)
|
||||
night_activity = sum(hourly_distribution[22:24]) + sum(hourly_distribution[0:5])
|
||||
morning_activity = sum(hourly_distribution[6:13])
|
||||
data.is_night_owl = night_activity > morning_activity
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取时光足迹数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度二:社交网络 ====================
|
||||
|
||||
|
||||
async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||
"""获取社交网络数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
data = SocialNetworkData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于过滤
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
try:
|
||||
# 1. 加入的群组总数
|
||||
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
|
||||
|
||||
# 2. 话痨群组 TOP3
|
||||
top_groups_query = (
|
||||
Messages.select(
|
||||
Messages.chat_info_group_id,
|
||||
Messages.chat_info_group_name,
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.chat_info_group_id.is_null(False))
|
||||
)
|
||||
.group_by(Messages.chat_info_group_id)
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_groups = [
|
||||
{
|
||||
"group_id": row["chat_info_group_id"],
|
||||
"group_name": row["chat_info_group_name"] or "未知群组",
|
||||
"message_count": row["count"],
|
||||
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
|
||||
}
|
||||
for row in top_groups_query.dicts()
|
||||
]
|
||||
|
||||
# 3. 互动最多的用户 TOP5(过滤 bot 自身)
|
||||
top_users_query = (
|
||||
Messages.select(
|
||||
Messages.user_id,
|
||||
Messages.user_nickname,
|
||||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.user_id.is_null(False))
|
||||
& (Messages.user_id != bot_qq) # 过滤 bot 自身
|
||||
)
|
||||
.group_by(Messages.user_id)
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_users = [
|
||||
{
|
||||
"user_id": row["user_id"],
|
||||
"user_nickname": row["user_nickname"] or "未知用户",
|
||||
"message_count": row["count"],
|
||||
"is_webui": str(row["user_id"]).startswith("webui_"),
|
||||
}
|
||||
for row in top_users_query.dicts()
|
||||
]
|
||||
|
||||
# 4. 被@次数
|
||||
data.at_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_at == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 5. 被提及次数
|
||||
data.mentioned_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_mentioned == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 6. 最长情陪伴的用户(过滤 bot 自身)
|
||||
companion_query = (
|
||||
ChatStreams.select(
|
||||
ChatStreams.user_id,
|
||||
ChatStreams.user_nickname,
|
||||
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
|
||||
)
|
||||
.where(
|
||||
(ChatStreams.user_id.is_null(False))
|
||||
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
|
||||
)
|
||||
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
|
||||
.limit(1)
|
||||
)
|
||||
companion_result = list(companion_query.dicts())
|
||||
if companion_result:
|
||||
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
|
||||
duration = companion_result[0].get("duration", 0) or 0
|
||||
data.longest_companion_days = int(duration / 86400) # 转换为天
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取社交网络数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度三:最强大脑 ====================
|
||||
|
||||
|
||||
async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
||||
"""获取最强大脑数据"""
|
||||
data = BrainPowerData()
|
||||
start_dt, end_dt = get_year_datetime_range(year)
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
try:
|
||||
# 1. 年度消耗 Token 总量和总花费
|
||||
token_query = LLMUsage.select(
|
||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
||||
result = token_query.dicts().get()
|
||||
data.total_tokens = int(result.get("total_tokens", 0) or 0)
|
||||
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
|
||||
|
||||
# 2. 最爱用的模型
|
||||
model_query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
|
||||
fn.COUNT(LLMUsage.id).alias("count"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
model_results = list(model_query.dicts())
|
||||
if model_results:
|
||||
data.favorite_model = model_results[0].get("model")
|
||||
data.favorite_model_count = model_results[0].get("count", 0)
|
||||
data.model_distribution = [
|
||||
{
|
||||
"model": row["model"],
|
||||
"count": row["count"],
|
||||
"tokens": row["tokens"],
|
||||
"cost": round(row["cost"], 4),
|
||||
}
|
||||
for row in model_results
|
||||
]
|
||||
|
||||
# 3. 最昂贵的一次思考
|
||||
expensive_query = (
|
||||
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
|
||||
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
||||
.order_by(LLMUsage.cost.desc())
|
||||
.limit(1)
|
||||
)
|
||||
expensive_result = expensive_query.first()
|
||||
if expensive_result:
|
||||
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
|
||||
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
|
||||
consumer_query = (
|
||||
LLMUsage.select(
|
||||
LLMUsage.user_id,
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_dt)
|
||||
& (LLMUsage.timestamp <= end_dt)
|
||||
& (LLMUsage.user_id != "system") # 过滤 system 用户
|
||||
& (LLMUsage.user_id.is_null(False))
|
||||
)
|
||||
.group_by(LLMUsage.user_id)
|
||||
.order_by(fn.SUM(LLMUsage.cost).desc())
|
||||
.limit(3)
|
||||
)
|
||||
data.top_token_consumers = [
|
||||
{
|
||||
"user_id": row["user_id"],
|
||||
"cost": round(row["cost"], 4),
|
||||
"tokens": row["tokens"],
|
||||
}
|
||||
for row in consumer_query.dicts()
|
||||
]
|
||||
|
||||
# 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用)
|
||||
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
|
||||
reply_model_query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
|
||||
fn.COUNT(LLMUsage.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(LLMUsage.timestamp >= start_dt)
|
||||
& (LLMUsage.timestamp <= end_dt)
|
||||
& (
|
||||
LLMUsage.model_assign_name.contains("replyer")
|
||||
| LLMUsage.model_assign_name.contains("回复")
|
||||
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
|
||||
)
|
||||
)
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_reply_models = [
|
||||
{"model": row["model"], "count": row["count"]}
|
||||
for row in reply_model_query.dicts()
|
||||
]
|
||||
|
||||
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
|
||||
total_actions = ActionRecords.select().where(
|
||||
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
|
||||
).count()
|
||||
no_reply_count = ActionRecords.select().where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "no_reply")
|
||||
).count()
|
||||
data.total_actions = total_actions
|
||||
data.no_reply_count = no_reply_count
|
||||
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
|
||||
|
||||
# 6. 情绪波动 (兴趣值)
|
||||
interest_query = Messages.select(
|
||||
fn.AVG(Messages.interest_value).alias("avg_interest"),
|
||||
fn.MAX(Messages.interest_value).alias("max_interest"),
|
||||
).where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.interest_value.is_null(False))
|
||||
)
|
||||
interest_result = interest_query.dicts().get()
|
||||
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
|
||||
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
|
||||
|
||||
# 找到最高兴趣值的时间
|
||||
if data.max_interest_value > 0:
|
||||
max_interest_msg = (
|
||||
Messages.select(Messages.time)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.interest_value == data.max_interest_value)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if max_interest_msg:
|
||||
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
|
||||
# 7. 思考深度 (基于 action_reasoning 长度)
|
||||
reasoning_records = (
|
||||
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_reasoning.is_null(False))
|
||||
& (ActionRecords.action_reasoning != "")
|
||||
)
|
||||
)
|
||||
reasoning_lengths = []
|
||||
max_len = 0
|
||||
max_len_time = None
|
||||
for record in reasoning_records:
|
||||
if record.action_reasoning:
|
||||
length = len(record.action_reasoning)
|
||||
reasoning_lengths.append(length)
|
||||
if length > max_len:
|
||||
max_len = length
|
||||
max_len_time = record.time
|
||||
|
||||
if reasoning_lengths:
|
||||
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
|
||||
data.max_reasoning_length = max_len
|
||||
if max_len_time:
|
||||
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取最强大脑数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度四:个性与表达 ====================
|
||||
|
||||
|
||||
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||
"""获取个性与表达数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
data = ExpressionVibeData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
try:
|
||||
# 1. 表情包之王 - 使用次数最多的表情包
|
||||
top_emoji_query = (
|
||||
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
|
||||
.where(Emoji.is_registered == True)
|
||||
.order_by(Emoji.usage_count.desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_emojis = list(top_emoji_query.dicts())
|
||||
if top_emojis:
|
||||
data.top_emoji = {
|
||||
"id": top_emojis[0].get("id"),
|
||||
"path": top_emojis[0].get("full_path"),
|
||||
"description": top_emojis[0].get("description"),
|
||||
"usage_count": top_emojis[0].get("usage_count", 0),
|
||||
"hash": top_emojis[0].get("emoji_hash"),
|
||||
}
|
||||
data.top_emojis = [
|
||||
{
|
||||
"id": e.get("id"),
|
||||
"path": e.get("full_path"),
|
||||
"description": e.get("description"),
|
||||
"usage_count": e.get("usage_count", 0),
|
||||
"hash": e.get("emoji_hash"),
|
||||
}
|
||||
for e in top_emojis
|
||||
]
|
||||
|
||||
# 2. 百变麦麦 - 最常用的表达风格
|
||||
expression_query = (
|
||||
Expression.select(
|
||||
Expression.style,
|
||||
fn.SUM(Expression.count).alias("total_count"),
|
||||
)
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
)
|
||||
.group_by(Expression.style)
|
||||
.order_by(fn.SUM(Expression.count).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_expressions = [
|
||||
{"style": row["style"], "count": row["total_count"]}
|
||||
for row in expression_query.dicts()
|
||||
]
|
||||
|
||||
# 3. 被拒绝的表达
|
||||
data.rejected_expression_count = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
& (Expression.rejected == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 4. 已检查的表达
|
||||
data.checked_expression_count = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
& (Expression.checked == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 5. 表达总数
|
||||
data.total_expressions = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 6. 动作类型分布 (过滤无意义的动作)
|
||||
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
|
||||
excluded_actions = [
|
||||
"reply", "no_reply", "no_reply_until_call", "make_question",
|
||||
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
|
||||
]
|
||||
action_query = (
|
||||
ActionRecords.select(
|
||||
ActionRecords.action_name,
|
||||
fn.COUNT(ActionRecords.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name.not_in(excluded_actions))
|
||||
)
|
||||
.group_by(ActionRecords.action_name)
|
||||
.order_by(fn.COUNT(ActionRecords.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
data.action_types = [
|
||||
{"action": row["action_name"], "count": row["count"]}
|
||||
for row in action_query.dicts()
|
||||
]
|
||||
|
||||
# 7. 处理的图片数量
|
||||
data.image_processed_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_picid == True)
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
|
||||
import random
|
||||
import re
|
||||
|
||||
def clean_message_content(content: str) -> str:
|
||||
"""清理消息内容,移除回复引用等标记"""
|
||||
if not content:
|
||||
return ""
|
||||
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
|
||||
content = re.sub(r'\[回复<[^>]+>\s*的消息[::][^\]]*\]', '', content)
|
||||
# 移除 [图片] [表情] 等标记
|
||||
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
|
||||
# 移除多余的空白
|
||||
content = re.sub(r'\s+', ' ', content).strip()
|
||||
return content
|
||||
|
||||
# 使用 user_id 判断是否是 bot 发送的消息
|
||||
late_night_messages = list(
|
||||
Messages.select(
|
||||
Messages.time,
|
||||
Messages.processed_plain_text,
|
||||
Messages.display_message,
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.user_id == bot_qq) # bot 发送的消息
|
||||
)
|
||||
.order_by(Messages.time.desc())
|
||||
)
|
||||
# 筛选出0-6点的消息
|
||||
late_night_filtered = []
|
||||
for msg in late_night_messages:
|
||||
msg_dt = datetime.fromtimestamp(msg.time)
|
||||
hour = msg_dt.hour
|
||||
if 0 <= hour < 6: # 0点到6点
|
||||
raw_content = msg.processed_plain_text or msg.display_message or ""
|
||||
cleaned_content = clean_message_content(raw_content)
|
||||
# 只保留有意义的内容
|
||||
if cleaned_content and len(cleaned_content) > 2:
|
||||
late_night_filtered.append({
|
||||
"time": msg.time,
|
||||
"hour": hour,
|
||||
"minute": msg_dt.minute,
|
||||
"content": cleaned_content,
|
||||
"datetime_str": msg_dt.strftime("%H:%M"),
|
||||
})
|
||||
if len(late_night_filtered) >= 10:
|
||||
break
|
||||
|
||||
if late_night_filtered:
|
||||
selected = random.choice(late_night_filtered)
|
||||
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
|
||||
data.late_night_reply = {
|
||||
"time": selected["datetime_str"],
|
||||
"content": content,
|
||||
}
|
||||
|
||||
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
|
||||
from collections import Counter
|
||||
import json as json_lib
|
||||
|
||||
reply_records = (
|
||||
ActionRecords.select(ActionRecords.action_data)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "reply")
|
||||
& (ActionRecords.action_data.is_null(False))
|
||||
& (ActionRecords.action_data != "")
|
||||
)
|
||||
)
|
||||
|
||||
reply_contents = []
|
||||
for record in reply_records:
|
||||
try:
|
||||
action_data = record.action_data
|
||||
if action_data:
|
||||
content = None
|
||||
# 尝试解析 JSON 格式
|
||||
try:
|
||||
parsed = json_lib.loads(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
# 优先使用 reply_text,其次使用 content
|
||||
content = parsed.get("reply_text") or parsed.get("content")
|
||||
elif isinstance(parsed, str):
|
||||
content = parsed
|
||||
except (json_lib.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
|
||||
# 例如: "{'reply_text': '墨白灵不知道哦'}"
|
||||
if content is None:
|
||||
import ast
|
||||
try:
|
||||
parsed = ast.literal_eval(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
content = parsed.get("reply_text") or parsed.get("content")
|
||||
elif isinstance(parsed, str):
|
||||
content = parsed
|
||||
except (ValueError, SyntaxError):
|
||||
# 无法解析,使用原始字符串
|
||||
content = action_data
|
||||
|
||||
# 只统计有意义的回复(长度大于2)
|
||||
if content and len(content) > 2:
|
||||
reply_contents.append(content)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if reply_contents:
|
||||
content_counter = Counter(reply_contents)
|
||||
most_common = content_counter.most_common(1)
|
||||
if most_common:
|
||||
fav_content, fav_count = most_common[0]
|
||||
# 截断过长的内容
|
||||
display_content = fav_content[:50] + "..." if len(fav_content) > 50 else fav_content
|
||||
data.favorite_reply = {
|
||||
"content": display_content,
|
||||
"count": fav_count,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取个性与表达数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== 维度五:趣味成就 ====================
|
||||
|
||||
|
||||
async def get_achievements(year: int = 2025) -> AchievementData:
|
||||
"""获取趣味成就数据"""
|
||||
data = AchievementData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
try:
|
||||
# 1. 新学到的黑话数量
|
||||
# Jargon 表没有时间字段,统计全部已确认的黑话
|
||||
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
|
||||
|
||||
# 2. 代表性黑话示例
|
||||
jargon_samples = (
|
||||
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
|
||||
.where(Jargon.is_jargon == True)
|
||||
.order_by(Jargon.count.desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.sample_jargons = [
|
||||
{
|
||||
"content": j.content,
|
||||
"meaning": j.meaning,
|
||||
"count": j.count,
|
||||
}
|
||||
for j in jargon_samples
|
||||
]
|
||||
|
||||
# 3. 总消息数
|
||||
data.total_messages = (
|
||||
Messages.select()
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 4. 总回复数 (有 reply_to 的消息)
|
||||
data.total_replies = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.reply_to.is_null(False))
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取趣味成就数据失败: {e}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
# ==================== API 路由 ====================
|
||||
|
||||
|
||||
@router.get("/full", response_model=AnnualReportData)
|
||||
async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取完整年度报告数据
|
||||
|
||||
Args:
|
||||
year: 报告年份,默认2025
|
||||
|
||||
Returns:
|
||||
完整的年度报告数据
|
||||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
logger.info(f"开始生成 {year} 年度报告...")
|
||||
|
||||
# 获取 bot 名称
|
||||
bot_name = global_config.bot.nickname or "麦麦"
|
||||
|
||||
# 并行获取各维度数据
|
||||
time_footprint = await get_time_footprint(year)
|
||||
social_network = await get_social_network(year)
|
||||
brain_power = await get_brain_power(year)
|
||||
expression_vibe = await get_expression_vibe(year)
|
||||
achievements = await get_achievements(year)
|
||||
|
||||
report = AnnualReportData(
|
||||
year=year,
|
||||
bot_name=bot_name,
|
||||
generated_at=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||||
time_footprint=time_footprint,
|
||||
social_network=social_network,
|
||||
brain_power=brain_power,
|
||||
expression_vibe=expression_vibe,
|
||||
achievements=achievements,
|
||||
)
|
||||
|
||||
logger.info(f"{year} 年度报告生成完成")
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成年度报告失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"生成年度报告失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/time-footprint", response_model=TimeFootprintData)
|
||||
async def get_time_footprint_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取时光足迹数据"""
|
||||
try:
|
||||
return await get_time_footprint(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取时光足迹数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/social-network", response_model=SocialNetworkData)
|
||||
async def get_social_network_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取社交网络数据"""
|
||||
try:
|
||||
return await get_social_network(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取社交网络数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/brain-power", response_model=BrainPowerData)
|
||||
async def get_brain_power_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取最强大脑数据"""
|
||||
try:
|
||||
return await get_brain_power(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取最强大脑数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/expression-vibe", response_model=ExpressionVibeData)
|
||||
async def get_expression_vibe_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取个性与表达数据"""
|
||||
try:
|
||||
return await get_expression_vibe(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取个性与表达数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/achievements", response_model=AchievementData)
|
||||
async def get_achievements_api(year: int = 2025, _auth: bool = Depends(require_auth)):
|
||||
"""获取趣味成就数据"""
|
||||
try:
|
||||
return await get_achievements(year)
|
||||
except Exception as e:
|
||||
logger.error(f"获取趣味成就数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
|
@ -1,795 +0,0 @@
|
|||
"""
|
||||
WebUI 防爬虫模块
|
||||
提供爬虫检测和阻止功能,保护 WebUI 不被搜索引擎和恶意爬虫访问
|
||||
"""
|
||||
|
||||
import time
|
||||
import ipaddress
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Optional
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import PlainTextResponse
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.anti_crawler")
|
||||
|
||||
# 常见爬虫 User-Agent 列表(使用更精确的关键词,避免误报)
|
||||
CRAWLER_USER_AGENTS = {
|
||||
# 搜索引擎爬虫(精确匹配)
|
||||
"googlebot",
|
||||
"bingbot",
|
||||
"baiduspider",
|
||||
"yandexbot",
|
||||
"slurp", # Yahoo
|
||||
"duckduckbot",
|
||||
"sogou",
|
||||
"exabot",
|
||||
"facebot",
|
||||
"ia_archiver", # Internet Archive
|
||||
# 通用爬虫(移除过于宽泛的关键词)
|
||||
"crawler",
|
||||
"spider",
|
||||
"scraper",
|
||||
"wget", # 保留wget,因为通常用于自动化脚本
|
||||
"scrapy", # 保留scrapy,因为这是爬虫框架
|
||||
# 安全扫描工具(这些是明确的扫描工具)
|
||||
"masscan",
|
||||
"nmap",
|
||||
"nikto",
|
||||
"sqlmap",
|
||||
# 注意:移除了以下过于宽泛的关键词以避免误报:
|
||||
# - "bot" (会误匹配GitHub-Robot等)
|
||||
# - "curl" (正常工具)
|
||||
# - "python-requests" (正常库)
|
||||
# - "httpx" (正常库)
|
||||
# - "aiohttp" (正常库)
|
||||
}
|
||||
|
||||
# 资产测绘工具 User-Agent 标识
|
||||
ASSET_SCANNER_USER_AGENTS = {
|
||||
# 知名资产测绘平台
|
||||
"shodan",
|
||||
"censys",
|
||||
"zoomeye",
|
||||
"fofa",
|
||||
"quake",
|
||||
"hunter",
|
||||
"binaryedge",
|
||||
"onyphe",
|
||||
"securitytrails",
|
||||
"virustotal",
|
||||
"passivetotal",
|
||||
# 安全扫描工具
|
||||
"acunetix",
|
||||
"appscan",
|
||||
"burpsuite",
|
||||
"nessus",
|
||||
"openvas",
|
||||
"qualys",
|
||||
"rapid7",
|
||||
"tenable",
|
||||
"veracode",
|
||||
"zap",
|
||||
"awvs", # Acunetix Web Vulnerability Scanner
|
||||
"netsparker",
|
||||
"skipfish",
|
||||
"w3af",
|
||||
"arachni",
|
||||
# 其他扫描工具
|
||||
"masscan",
|
||||
"zmap",
|
||||
"nmap",
|
||||
"whatweb",
|
||||
"wpscan",
|
||||
"joomscan",
|
||||
"dnsenum",
|
||||
"subfinder",
|
||||
"amass",
|
||||
"sublist3r",
|
||||
"theharvester",
|
||||
}
|
||||
|
||||
# 资产测绘工具常用的HTTP头标识
|
||||
ASSET_SCANNER_HEADERS = {
|
||||
# 常见的扫描工具自定义头
|
||||
"x-scan": {"shodan", "censys", "zoomeye", "fofa"},
|
||||
"x-scanner": {"nmap", "masscan", "zmap"},
|
||||
"x-probe": {"masscan", "zmap"},
|
||||
# 其他可疑头(移除反向代理标准头)
|
||||
"x-originating-ip": set(),
|
||||
"x-remote-ip": set(),
|
||||
"x-remote-addr": set(),
|
||||
# 注意:移除了以下反向代理标准头以避免误报:
|
||||
# - "x-forwarded-proto" (反向代理标准头)
|
||||
# - "x-real-ip" (反向代理标准头,已在_get_client_ip中使用)
|
||||
}
|
||||
|
||||
# 仅检查特定HTTP头中的可疑模式(收紧匹配范围)
|
||||
# 只检查这些特定头,不检查所有头
|
||||
SCANNER_SPECIFIC_HEADERS = {
|
||||
"x-scan",
|
||||
"x-scanner",
|
||||
"x-probe",
|
||||
"x-originating-ip",
|
||||
"x-remote-ip",
|
||||
"x-remote-addr",
|
||||
}
|
||||
|
||||
# 防爬虫模式配置
|
||||
# false: 禁用
|
||||
# strict: 严格模式(更严格的检测,更低的频率限制)
|
||||
# loose: 宽松模式(较宽松的检测,较高的频率限制)
|
||||
# basic: 基础模式(只记录恶意访问,不阻止,不限制请求数,不跟踪IP)
|
||||
|
||||
# IP白名单配置(从配置文件读取,逗号分隔)
|
||||
# 支持格式:
|
||||
# - 精确IP:127.0.0.1, 192.168.1.100
|
||||
# - CIDR格式:192.168.1.0/24, 172.17.0.0/16 (适用于Docker网络)
|
||||
# - 通配符:192.168.*.*, 10.*.*.*, *.*.*.* (匹配所有)
|
||||
# - IPv6:::1, 2001:db8::/32
|
||||
def _parse_allowed_ips(ip_string: str) -> list:
|
||||
"""
|
||||
解析IP白名单字符串,支持精确IP、CIDR格式和通配符
|
||||
|
||||
Args:
|
||||
ip_string: 逗号分隔的IP字符串
|
||||
|
||||
Returns:
|
||||
IP白名单列表,每个元素可能是:
|
||||
- ipaddress.IPv4Network/IPv6Network对象(CIDR格式)
|
||||
- ipaddress.IPv4Address/IPv6Address对象(精确IP)
|
||||
- str(通配符模式,已转换为正则表达式)
|
||||
"""
|
||||
allowed = []
|
||||
if not ip_string:
|
||||
return allowed
|
||||
|
||||
for ip_entry in ip_string.split(","):
|
||||
ip_entry = ip_entry.strip() # 去除空格
|
||||
if not ip_entry:
|
||||
continue
|
||||
|
||||
# 跳过注释行(以#开头)
|
||||
if ip_entry.startswith("#"):
|
||||
continue
|
||||
|
||||
# 检查通配符格式(包含*)
|
||||
if "*" in ip_entry:
|
||||
# 处理通配符
|
||||
pattern = _convert_wildcard_to_regex(ip_entry)
|
||||
if pattern:
|
||||
allowed.append(pattern)
|
||||
else:
|
||||
logger.warning(f"无效的通配符IP格式,已忽略: {ip_entry}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# 尝试解析为CIDR格式(包含/)
|
||||
if "/" in ip_entry:
|
||||
allowed.append(ipaddress.ip_network(ip_entry, strict=False))
|
||||
else:
|
||||
# 精确IP地址
|
||||
allowed.append(ipaddress.ip_address(ip_entry))
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(f"无效的IP白名单条目,已忽略: {ip_entry} ({e})")
|
||||
|
||||
return allowed
|
||||
|
||||
|
||||
def _convert_wildcard_to_regex(wildcard_pattern: str) -> Optional[str]:
|
||||
"""
|
||||
将通配符IP模式转换为正则表达式
|
||||
|
||||
支持的格式:
|
||||
- 192.168.*.* 或 192.168.*
|
||||
- 10.*.*.* 或 10.*
|
||||
- *.*.*.* 或 *
|
||||
|
||||
Args:
|
||||
wildcard_pattern: 通配符模式字符串
|
||||
|
||||
Returns:
|
||||
正则表达式字符串,如果格式无效则返回None
|
||||
"""
|
||||
# 去除空格
|
||||
pattern = wildcard_pattern.strip()
|
||||
|
||||
# 处理单个*(匹配所有)
|
||||
if pattern == "*":
|
||||
return r".*"
|
||||
|
||||
# 处理IPv4通配符格式
|
||||
# 支持:192.168.*.*, 192.168.*, 10.*.*.*, 10.* 等
|
||||
parts = pattern.split(".")
|
||||
|
||||
if len(parts) > 4:
|
||||
return None # IPv4最多4段
|
||||
|
||||
# 构建正则表达式
|
||||
regex_parts = []
|
||||
for part in parts:
|
||||
part = part.strip()
|
||||
if part == "*":
|
||||
regex_parts.append(r"\d+") # 匹配任意数字
|
||||
elif part.isdigit():
|
||||
# 验证数字范围(0-255)
|
||||
num = int(part)
|
||||
if 0 <= num <= 255:
|
||||
regex_parts.append(re.escape(part))
|
||||
else:
|
||||
return None # 无效的数字
|
||||
else:
|
||||
return None # 无效的格式
|
||||
|
||||
# 如果部分少于4段,补充.*
|
||||
while len(regex_parts) < 4:
|
||||
regex_parts.append(r"\d+")
|
||||
|
||||
# 组合成正则表达式
|
||||
regex = r"^" + r"\.".join(regex_parts) + r"$"
|
||||
return regex
|
||||
|
||||
|
||||
# 从配置读取防爬虫设置(延迟导入避免循环依赖)
|
||||
def _get_anti_crawler_config():
|
||||
"""获取防爬虫配置"""
|
||||
from src.config.config import global_config
|
||||
return {
|
||||
'mode': global_config.webui.anti_crawler_mode,
|
||||
'allowed_ips': _parse_allowed_ips(global_config.webui.allowed_ips),
|
||||
'trusted_proxies': _parse_allowed_ips(global_config.webui.trusted_proxies),
|
||||
'trust_xff': global_config.webui.trust_xff
|
||||
}
|
||||
|
||||
# 初始化配置(将在模块加载时执行)
|
||||
_config = _get_anti_crawler_config()
|
||||
ANTI_CRAWLER_MODE = _config['mode']
|
||||
ALLOWED_IPS = _config['allowed_ips']
|
||||
TRUSTED_PROXIES = _config['trusted_proxies']
|
||||
TRUST_XFF = _config['trust_xff']
|
||||
|
||||
|
||||
def _get_mode_config(mode: str) -> dict:
|
||||
"""
|
||||
根据模式获取配置参数
|
||||
|
||||
Args:
|
||||
mode: 防爬虫模式 (false/strict/loose/basic)
|
||||
|
||||
Returns:
|
||||
配置字典,包含所有相关参数
|
||||
"""
|
||||
mode = mode.lower()
|
||||
|
||||
if mode == "false":
|
||||
return {
|
||||
"enabled": False,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 禁用时设置很高的值
|
||||
"max_tracked_ips": 0,
|
||||
"check_user_agent": False,
|
||||
"check_asset_scanner": False,
|
||||
"check_rate_limit": False,
|
||||
"block_on_detect": False, # 不阻止
|
||||
}
|
||||
elif mode == "strict":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 15, # 严格模式:更低的请求数
|
||||
"max_tracked_ips": 20000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
elif mode == "loose":
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 60, # 宽松模式:更高的请求数
|
||||
"max_tracked_ips": 5000,
|
||||
"check_user_agent": True,
|
||||
"check_asset_scanner": True,
|
||||
"check_rate_limit": True,
|
||||
"block_on_detect": True, # 阻止恶意访问
|
||||
}
|
||||
else: # basic (默认模式)
|
||||
return {
|
||||
"enabled": True,
|
||||
"rate_limit_window": 60,
|
||||
"rate_limit_max_requests": 1000, # 不限制请求数
|
||||
"max_tracked_ips": 0, # 不跟踪IP
|
||||
"check_user_agent": True, # 检测但不阻止
|
||||
"check_asset_scanner": True, # 检测但不阻止
|
||||
"check_rate_limit": False, # 不限制请求频率
|
||||
"block_on_detect": False, # 只记录,不阻止
|
||||
}
|
||||
|
||||
|
||||
class AntiCrawlerMiddleware(BaseHTTPMiddleware):
|
||||
"""防爬虫中间件"""
|
||||
|
||||
def __init__(self, app, mode: str = "standard"):
|
||||
"""
|
||||
初始化防爬虫中间件
|
||||
|
||||
Args:
|
||||
app: FastAPI 应用实例
|
||||
mode: 防爬虫模式 (false/strict/loose/standard)
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.mode = mode.lower()
|
||||
# 根据模式获取配置
|
||||
config = _get_mode_config(self.mode)
|
||||
self.enabled = config["enabled"]
|
||||
self.rate_limit_window = config["rate_limit_window"]
|
||||
self.rate_limit_max_requests = config["rate_limit_max_requests"]
|
||||
self.max_tracked_ips = config["max_tracked_ips"]
|
||||
self.check_user_agent = config["check_user_agent"]
|
||||
self.check_asset_scanner = config["check_asset_scanner"]
|
||||
self.check_rate_limit = config["check_rate_limit"]
|
||||
self.block_on_detect = config["block_on_detect"] # 是否阻止检测到的恶意访问
|
||||
|
||||
# 用于存储每个IP的请求时间戳(使用deque提高性能)
|
||||
self.request_times: dict[str, deque] = {}
|
||||
# 上次清理时间
|
||||
self.last_cleanup = time.time()
|
||||
# 将关键词列表转换为集合以提高查找性能
|
||||
self.crawler_keywords_set = set(CRAWLER_USER_AGENTS)
|
||||
self.scanner_keywords_set = set(ASSET_SCANNER_USER_AGENTS)
|
||||
|
||||
def _is_crawler_user_agent(self, user_agent: Optional[str]) -> bool:
|
||||
"""
|
||||
检测是否为爬虫 User-Agent
|
||||
|
||||
Args:
|
||||
user_agent: User-Agent 字符串
|
||||
|
||||
Returns:
|
||||
如果是爬虫则返回 True
|
||||
"""
|
||||
if not user_agent:
|
||||
# 没有 User-Agent 的请求记录日志但不直接阻止
|
||||
# 改为只记录,让频率限制来处理
|
||||
logger.debug("请求缺少User-Agent")
|
||||
return False # 不再直接阻止无User-Agent的请求
|
||||
|
||||
user_agent_lower = user_agent.lower()
|
||||
|
||||
# 使用集合查找提高性能(检查是否包含爬虫关键词)
|
||||
for crawler_keyword in self.crawler_keywords_set:
|
||||
if crawler_keyword in user_agent_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _is_asset_scanner_header(self, request: Request) -> bool:
|
||||
"""
|
||||
检测是否为资产测绘工具的HTTP头(只检查特定头,收紧匹配)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
如果检测到资产测绘工具头则返回 True
|
||||
"""
|
||||
# 只检查特定的扫描工具头,不检查所有头
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
header_value_lower = header_value.lower() if header_value else ""
|
||||
|
||||
# 检查已知的扫描工具头
|
||||
if header_name_lower in ASSET_SCANNER_HEADERS:
|
||||
# 如果该头有特定的工具集合,检查值是否匹配
|
||||
expected_tools = ASSET_SCANNER_HEADERS[header_name_lower]
|
||||
if expected_tools:
|
||||
for tool in expected_tools:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
else:
|
||||
# 如果没有特定工具集合,只要存在该头就视为可疑
|
||||
if header_value_lower:
|
||||
return True
|
||||
|
||||
# 只检查特定头中的可疑模式(收紧匹配)
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
# 检查头值中是否包含已知扫描工具名称
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _detect_asset_scanner(self, request: Request) -> tuple[bool, Optional[str]]:
|
||||
"""
|
||||
检测资产测绘工具
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
(是否检测到, 检测到的工具名称)
|
||||
"""
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检查 User-Agent(使用集合查找提高性能)
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for scanner_keyword in self.scanner_keywords_set:
|
||||
if scanner_keyword in user_agent_lower:
|
||||
return True, scanner_keyword
|
||||
|
||||
# 检查HTTP头
|
||||
if self._is_asset_scanner_header(request):
|
||||
# 尝试从User-Agent或头中提取工具名称
|
||||
detected_tool = None
|
||||
if user_agent:
|
||||
user_agent_lower = user_agent.lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in user_agent_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
|
||||
# 检查HTTP头中的工具标识(只检查特定头)
|
||||
if not detected_tool:
|
||||
for header_name, header_value in request.headers.items():
|
||||
header_name_lower = header_name.lower()
|
||||
if header_name_lower in SCANNER_SPECIFIC_HEADERS:
|
||||
header_value_lower = (header_value or "").lower()
|
||||
for tool in self.scanner_keywords_set:
|
||||
if tool in header_value_lower:
|
||||
detected_tool = tool
|
||||
break
|
||||
if detected_tool:
|
||||
break
|
||||
|
||||
return True, detected_tool or "unknown_scanner"
|
||||
|
||||
return False, None
|
||||
|
||||
def _check_rate_limit(self, client_ip: str) -> bool:
|
||||
"""
|
||||
检查请求频率限制
|
||||
|
||||
Args:
|
||||
client_ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果超过限制则返回 True(需要阻止)
|
||||
"""
|
||||
# 检查IP白名单
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return False
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 定期清理过期的请求记录(每5分钟清理一次)
|
||||
if current_time - self.last_cleanup > 300:
|
||||
self._cleanup_old_requests(current_time)
|
||||
self.last_cleanup = current_time
|
||||
|
||||
# 限制跟踪的IP数量,防止内存泄漏
|
||||
if self.max_tracked_ips > 0 and len(self.request_times) > self.max_tracked_ips:
|
||||
# 清理最旧的记录(删除最久未访问的IP)
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
# 获取或创建该IP的请求时间deque(不使用maxlen,避免限流变松)
|
||||
if client_ip not in self.request_times:
|
||||
self.request_times[client_ip] = deque()
|
||||
|
||||
request_times = self.request_times[client_ip]
|
||||
|
||||
# 移除时间窗口外的请求记录(从左侧弹出过期记录)
|
||||
while request_times and current_time - request_times[0] >= self.rate_limit_window:
|
||||
request_times.popleft()
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(request_times) >= self.rate_limit_max_requests:
|
||||
return True
|
||||
|
||||
# 记录当前请求时间
|
||||
request_times.append(current_time)
|
||||
return False
|
||||
|
||||
def _cleanup_old_requests(self, current_time: float):
|
||||
"""清理过期的请求记录(只清理当前需要检查的IP,不全量遍历)"""
|
||||
# 这个方法现在主要用于定期清理,实际清理在_check_rate_limit中按需进行
|
||||
# 清理最久未访问的IP记录
|
||||
if len(self.request_times) > self.max_tracked_ips * 0.8:
|
||||
self._cleanup_oldest_ips()
|
||||
|
||||
def _cleanup_oldest_ips(self):
|
||||
"""清理最久未访问的IP记录(全量遍历找真正的oldest)"""
|
||||
if not self.request_times:
|
||||
return
|
||||
|
||||
# 先收集空deque的IP(优先删除)
|
||||
empty_ips = []
|
||||
# 找到最久未访问的IP(最旧时间戳)
|
||||
oldest_ip = None
|
||||
oldest_time = float("inf")
|
||||
|
||||
# 全量遍历找真正的oldest(超限时性能可接受)
|
||||
for ip, times in self.request_times.items():
|
||||
if not times:
|
||||
# 空deque,记录待删除
|
||||
empty_ips.append(ip)
|
||||
else:
|
||||
# 找到最旧的时间戳
|
||||
if times[0] < oldest_time:
|
||||
oldest_time = times[0]
|
||||
oldest_ip = ip
|
||||
|
||||
# 先删除空deque的IP
|
||||
for ip in empty_ips:
|
||||
del self.request_times[ip]
|
||||
|
||||
# 如果没有空deque可删除,且仍需要清理,删除最旧的一个IP
|
||||
if not empty_ips and oldest_ip:
|
||||
del self.request_times[oldest_ip]
|
||||
|
||||
def _is_trusted_proxy(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在信任的代理列表中
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果是信任的代理则返回 True
|
||||
"""
|
||||
if not TRUSTED_PROXIES or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查代理列表中的每个条目
|
||||
for trusted_entry in TRUSTED_PROXIES:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(trusted_entry, str):
|
||||
try:
|
||||
if re.match(trusted_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(trusted_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == trusted_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""
|
||||
获取客户端真实IP地址(带基本验证和代理信任检查)
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
|
||||
Returns:
|
||||
客户端IP地址
|
||||
"""
|
||||
# 获取直接连接的客户端IP(用于验证代理)
|
||||
direct_client_ip = None
|
||||
if request.client:
|
||||
direct_client_ip = request.client.host
|
||||
|
||||
# 检查是否信任X-Forwarded-For头
|
||||
# TRUST_XFF 只表示"启用代理解析能力",但仍要求直连 IP 在 TRUSTED_PROXIES 中
|
||||
use_xff = False
|
||||
if TRUST_XFF and TRUSTED_PROXIES and direct_client_ip:
|
||||
# 只有在启用 TRUST_XFF 且直连 IP 在信任列表中时,才信任 XFF
|
||||
use_xff = self._is_trusted_proxy(direct_client_ip)
|
||||
|
||||
# 如果信任代理,优先从 X-Forwarded-For 获取
|
||||
if use_xff:
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# X-Forwarded-For 可能包含多个IP,取第一个
|
||||
ip = forwarded_for.split(",")[0].strip()
|
||||
# 基本验证IP格式
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 从 X-Real-IP 获取(如果信任代理)
|
||||
if use_xff:
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
ip = real_ip.strip()
|
||||
if self._validate_ip(ip):
|
||||
return ip
|
||||
|
||||
# 使用直接连接的客户端IP
|
||||
if direct_client_ip and self._validate_ip(direct_client_ip):
|
||||
return direct_client_ip
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _validate_ip(self, ip: str) -> bool:
|
||||
"""
|
||||
验证IP地址格式
|
||||
|
||||
Args:
|
||||
ip: IP地址字符串
|
||||
|
||||
Returns:
|
||||
如果格式有效则返回 True
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
return False
|
||||
|
||||
def _is_ip_allowed(self, ip: str) -> bool:
|
||||
"""
|
||||
检查IP是否在白名单中(支持精确IP、CIDR格式和通配符)
|
||||
|
||||
Args:
|
||||
ip: 客户端IP地址
|
||||
|
||||
Returns:
|
||||
如果IP在白名单中则返回 True
|
||||
"""
|
||||
if not ALLOWED_IPS or ip == "unknown":
|
||||
return False
|
||||
|
||||
# 检查白名单中的每个条目
|
||||
for allowed_entry in ALLOWED_IPS:
|
||||
# 通配符模式(字符串,正则表达式)
|
||||
if isinstance(allowed_entry, str):
|
||||
try:
|
||||
if re.match(allowed_entry, ip):
|
||||
return True
|
||||
except re.error:
|
||||
# 正则表达式错误,跳过
|
||||
continue
|
||||
# CIDR格式(网络对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Network, ipaddress.IPv6Network)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj in allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
# 精确IP(地址对象)
|
||||
elif isinstance(allowed_entry, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
|
||||
try:
|
||||
client_ip_obj = ipaddress.ip_address(ip)
|
||||
if client_ip_obj == allowed_entry:
|
||||
return True
|
||||
except (ValueError, AttributeError):
|
||||
# IP格式无效,跳过
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
"""
|
||||
处理请求
|
||||
|
||||
Args:
|
||||
request: 请求对象
|
||||
call_next: 下一个中间件或路由处理函数
|
||||
|
||||
Returns:
|
||||
响应对象
|
||||
"""
|
||||
# 如果未启用,直接通过
|
||||
if not self.enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问 robots.txt(由专门的路由处理)
|
||||
if request.url.path == "/robots.txt":
|
||||
return await call_next(request)
|
||||
|
||||
# 允许访问静态资源(CSS、JS、图片等)
|
||||
# 注意:.json 已移除,避免 API 路径绕过防护
|
||||
# 静态资源只在特定前缀下放行(/static/、/assets/、/dist/)
|
||||
static_extensions = {
|
||||
".css",
|
||||
".js",
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".svg",
|
||||
".ico",
|
||||
".woff",
|
||||
".woff2",
|
||||
".ttf",
|
||||
".eot",
|
||||
}
|
||||
static_prefixes = {"/static/", "/assets/", "/dist/"}
|
||||
|
||||
# 检查是否是静态资源路径(特定前缀下的静态文件)
|
||||
path = request.url.path
|
||||
is_static_path = any(path.startswith(prefix) for prefix in static_prefixes) and any(
|
||||
path.endswith(ext) for ext in static_extensions
|
||||
)
|
||||
|
||||
# 也允许根路径下的静态文件(如 /favicon.ico)
|
||||
is_root_static = path.count("/") == 1 and any(path.endswith(ext) for ext in static_extensions)
|
||||
|
||||
if is_static_path or is_root_static:
|
||||
return await call_next(request)
|
||||
|
||||
# 获取客户端IP(只获取一次,避免重复调用)
|
||||
client_ip = self._get_client_ip(request)
|
||||
|
||||
# 检查IP白名单(优先检查,白名单IP直接通过)
|
||||
if self._is_ip_allowed(client_ip):
|
||||
return await call_next(request)
|
||||
|
||||
# 获取 User-Agent
|
||||
user_agent = request.headers.get("User-Agent")
|
||||
|
||||
# 检测资产测绘工具(优先检测,因为更危险)
|
||||
if self.check_asset_scanner:
|
||||
is_scanner, scanner_name = self._detect_asset_scanner(request)
|
||||
if is_scanner:
|
||||
logger.warning(
|
||||
f"🚫 检测到资产测绘工具请求 - IP: {client_ip}, 工具: {scanner_name}, "
|
||||
f"User-Agent: {user_agent}, Path: {request.url.path}"
|
||||
)
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Asset scanning tools are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检测爬虫 User-Agent
|
||||
if self.check_user_agent and self._is_crawler_user_agent(user_agent):
|
||||
logger.warning(f"🚫 检测到爬虫请求 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
# 根据配置决定是否阻止
|
||||
if self.block_on_detect:
|
||||
return PlainTextResponse(
|
||||
"Access Denied: Crawlers are not allowed",
|
||||
status_code=403,
|
||||
)
|
||||
|
||||
# 检查请求频率限制
|
||||
if self.check_rate_limit and self._check_rate_limit(client_ip):
|
||||
logger.warning(f"🚫 请求频率过高 - IP: {client_ip}, User-Agent: {user_agent}, Path: {request.url.path}")
|
||||
return PlainTextResponse(
|
||||
"Too Many Requests: Rate limit exceeded",
|
||||
status_code=429,
|
||||
)
|
||||
|
||||
# 正常请求,继续处理
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def create_robots_txt_response() -> PlainTextResponse:
|
||||
"""
|
||||
创建 robots.txt 响应
|
||||
|
||||
Returns:
|
||||
robots.txt 响应对象
|
||||
"""
|
||||
robots_content = """User-agent: *
|
||||
Disallow: /
|
||||
|
||||
# 禁止所有爬虫访问
|
||||
"""
|
||||
return PlainTextResponse(
|
||||
content=robots_content,
|
||||
media_type="text/plain",
|
||||
headers={"Cache-Control": "public, max-age=86400"}, # 缓存24小时
|
||||
)
|
||||
|
|
@ -1,301 +0,0 @@
|
|||
"""
|
||||
规划器监控API
|
||||
提供规划器日志数据的查询接口
|
||||
|
||||
性能优化:
|
||||
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/planner", tags=["planner"])
|
||||
|
||||
# 规划器日志目录
|
||||
PLAN_LOG_DIR = Path("logs/plan")
|
||||
|
||||
|
||||
class ChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
chat_id: str
|
||||
plan_count: int
|
||||
latest_timestamp: float
|
||||
latest_filename: str
|
||||
|
||||
|
||||
class PlanLogSummary(BaseModel):
|
||||
"""规划日志摘要"""
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
action_count: int
|
||||
action_types: List[str] # 动作类型列表
|
||||
total_plan_ms: float
|
||||
llm_duration_ms: float
|
||||
reasoning_preview: str
|
||||
|
||||
|
||||
class PlanLogDetail(BaseModel):
|
||||
"""规划日志详情"""
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
prompt: str
|
||||
reasoning: str
|
||||
raw_output: str
|
||||
actions: List[Dict]
|
||||
timing: Dict
|
||||
extra: Optional[Dict] = None
|
||||
|
||||
|
||||
class PlannerOverview(BaseModel):
|
||||
"""规划器总览 - 轻量级统计"""
|
||||
total_chats: int
|
||||
total_plans: int
|
||||
chats: List[ChatSummary]
|
||||
|
||||
|
||||
class PaginatedChatLogs(BaseModel):
|
||||
"""分页的聊天日志列表"""
|
||||
data: List[PlanLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
chat_id: str
|
||||
|
||||
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
@router.get("/overview", response_model=PlannerOverview)
|
||||
async def get_planner_overview():
|
||||
"""
|
||||
获取规划器总览 - 轻量级接口
|
||||
只统计文件数量,不读取文件内容
|
||||
"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return PlannerOverview(total_chats=0, total_plans=0, chats=[])
|
||||
|
||||
chats = []
|
||||
total_plans = 0
|
||||
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
plan_count = len(json_files)
|
||||
total_plans += plan_count
|
||||
|
||||
if plan_count == 0:
|
||||
continue
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
plan_count=plan_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return PlannerOverview(
|
||||
total_chats=len(chats),
|
||||
total_plans=total_plans,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedChatLogs)
|
||||
async def get_chat_plan_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
):
|
||||
"""
|
||||
获取指定聊天的规划日志列表(分页)
|
||||
需要读取文件内容获取摘要信息
|
||||
支持搜索提示词内容
|
||||
"""
|
||||
chat_dir = PLAN_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedChatLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
actions = data.get('actions', [])
|
||||
action_types = [a.get('action_type', '') for a in actions if a.get('action_type')]
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
action_count=len(actions),
|
||||
action_types=action_types,
|
||||
total_plan_ms=data.get('timing', {}).get('total_plan_ms', 0),
|
||||
llm_duration_ms=data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
reasoning_preview=reasoning[:100] if reasoning else ''
|
||||
))
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(PlanLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
action_count=0,
|
||||
action_types=[],
|
||||
total_plan_ms=0,
|
||||
llm_duration_ms=0,
|
||||
reasoning_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedChatLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=PlanLogDetail)
|
||||
async def get_log_detail(chat_id: str, filename: str):
|
||||
"""获取规划日志详情 - 按需加载完整内容"""
|
||||
log_file = PLAN_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return PlanLogDetail(**data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
||||
|
||||
# ========== 兼容旧接口 ==========
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_planner_stats():
|
||||
"""获取规划器统计信息 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
|
||||
# 获取最近10条计划的摘要
|
||||
recent_plans = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
try:
|
||||
chat_logs = await get_chat_plan_logs(chat.chat_id, page=1, page_size=2)
|
||||
recent_plans.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_plans.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_plans = recent_plans[:10]
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_plans": overview.total_plans,
|
||||
"avg_plan_time_ms": 0,
|
||||
"avg_llm_time_ms": 0,
|
||||
"recent_plans": recent_plans
|
||||
}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def get_chat_list():
|
||||
"""获取所有聊天ID列表 - 兼容旧接口"""
|
||||
overview = await get_planner_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
|
||||
|
||||
@router.get("/all-logs")
|
||||
async def get_all_logs(
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100)
|
||||
):
|
||||
"""获取所有规划日志 - 兼容旧接口"""
|
||||
if not PLAN_LOG_DIR.exists():
|
||||
return {"data": [], "total": 0, "page": page, "page_size": page_size}
|
||||
|
||||
# 收集所有文件
|
||||
all_files = []
|
||||
for chat_dir in PLAN_LOG_DIR.iterdir():
|
||||
if chat_dir.is_dir():
|
||||
for log_file in chat_dir.glob("*.json"):
|
||||
all_files.append((chat_dir.name, log_file))
|
||||
|
||||
# 按时间戳排序
|
||||
all_files.sort(key=lambda x: parse_timestamp_from_filename(x[1].name), reverse=True)
|
||||
|
||||
total = len(all_files)
|
||||
offset = (page - 1) * page_size
|
||||
page_files = all_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for chat_id, log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
reasoning = data.get('reasoning', '')
|
||||
logs.append({
|
||||
"chat_id": data.get('chat_id', chat_id),
|
||||
"timestamp": data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
"filename": log_file.name,
|
||||
"action_count": len(data.get('actions', [])),
|
||||
"total_plan_ms": data.get('timing', {}).get('total_plan_ms', 0),
|
||||
"llm_duration_ms": data.get('timing', {}).get('llm_duration_ms', 0),
|
||||
"reasoning_preview": reasoning[:100] if reasoning else ''
|
||||
})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return {"data": logs, "total": total, "page": page, "page_size": page_size}
|
||||
|
|
@ -1,269 +0,0 @@
|
|||
"""
|
||||
回复器监控API
|
||||
提供回复器日志数据的查询接口
|
||||
|
||||
性能优化:
|
||||
1. 聊天摘要只统计文件数量和最新时间戳,不读取文件内容
|
||||
2. 日志列表使用文件名解析时间戳,只在需要时读取完整内容
|
||||
3. 详情按需加载
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(prefix="/api/replier", tags=["replier"])
|
||||
|
||||
# 回复器日志目录
|
||||
REPLY_LOG_DIR = Path("logs/reply")
|
||||
|
||||
|
||||
class ReplierChatSummary(BaseModel):
|
||||
"""聊天摘要 - 轻量级,不读取文件内容"""
|
||||
chat_id: str
|
||||
reply_count: int
|
||||
latest_timestamp: float
|
||||
latest_filename: str
|
||||
|
||||
|
||||
class ReplyLogSummary(BaseModel):
|
||||
"""回复日志摘要"""
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
filename: str
|
||||
model: str
|
||||
success: bool
|
||||
llm_ms: float
|
||||
overall_ms: float
|
||||
output_preview: str
|
||||
|
||||
|
||||
class ReplyLogDetail(BaseModel):
|
||||
"""回复日志详情"""
|
||||
type: str
|
||||
chat_id: str
|
||||
timestamp: float
|
||||
prompt: str
|
||||
output: str
|
||||
processed_output: List[str]
|
||||
model: str
|
||||
reasoning: str
|
||||
think_level: int
|
||||
timing: Dict
|
||||
error: Optional[str] = None
|
||||
success: bool
|
||||
|
||||
|
||||
class ReplierOverview(BaseModel):
|
||||
"""回复器总览 - 轻量级统计"""
|
||||
total_chats: int
|
||||
total_replies: int
|
||||
chats: List[ReplierChatSummary]
|
||||
|
||||
|
||||
class PaginatedReplyLogs(BaseModel):
|
||||
"""分页的回复日志列表"""
|
||||
data: List[ReplyLogSummary]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
chat_id: str
|
||||
|
||||
|
||||
def parse_timestamp_from_filename(filename: str) -> float:
|
||||
"""从文件名解析时间戳: 1766497488220_af92bdb1.json -> 1766497488.220"""
|
||||
try:
|
||||
timestamp_str = filename.split('_')[0]
|
||||
# 时间戳是毫秒级,需要转换为秒
|
||||
return float(timestamp_str) / 1000
|
||||
except (ValueError, IndexError):
|
||||
return 0
|
||||
|
||||
|
||||
@router.get("/overview", response_model=ReplierOverview)
|
||||
async def get_replier_overview():
|
||||
"""
|
||||
获取回复器总览 - 轻量级接口
|
||||
只统计文件数量,不读取文件内容
|
||||
"""
|
||||
if not REPLY_LOG_DIR.exists():
|
||||
return ReplierOverview(total_chats=0, total_replies=0, chats=[])
|
||||
|
||||
chats = []
|
||||
total_replies = 0
|
||||
|
||||
for chat_dir in REPLY_LOG_DIR.iterdir():
|
||||
if not chat_dir.is_dir():
|
||||
continue
|
||||
|
||||
# 只统计json文件数量
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
reply_count = len(json_files)
|
||||
total_replies += reply_count
|
||||
|
||||
if reply_count == 0:
|
||||
continue
|
||||
|
||||
# 从文件名获取最新时间戳
|
||||
latest_file = max(json_files, key=lambda f: parse_timestamp_from_filename(f.name))
|
||||
latest_timestamp = parse_timestamp_from_filename(latest_file.name)
|
||||
|
||||
chats.append(ReplierChatSummary(
|
||||
chat_id=chat_dir.name,
|
||||
reply_count=reply_count,
|
||||
latest_timestamp=latest_timestamp,
|
||||
latest_filename=latest_file.name
|
||||
))
|
||||
|
||||
# 按最新时间戳排序
|
||||
chats.sort(key=lambda x: x.latest_timestamp, reverse=True)
|
||||
|
||||
return ReplierOverview(
|
||||
total_chats=len(chats),
|
||||
total_replies=total_replies,
|
||||
chats=chats
|
||||
)
|
||||
|
||||
|
||||
@router.get("/chat/{chat_id}/logs", response_model=PaginatedReplyLogs)
|
||||
async def get_chat_reply_logs(
|
||||
chat_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(20, ge=1, le=100),
|
||||
search: Optional[str] = Query(None, description="搜索关键词,匹配提示词内容")
|
||||
):
|
||||
"""
|
||||
获取指定聊天的回复日志列表(分页)
|
||||
需要读取文件内容获取摘要信息
|
||||
支持搜索提示词内容
|
||||
"""
|
||||
chat_dir = REPLY_LOG_DIR / chat_id
|
||||
if not chat_dir.exists():
|
||||
return PaginatedReplyLogs(
|
||||
data=[], total=0, page=page, page_size=page_size, chat_id=chat_id
|
||||
)
|
||||
|
||||
# 先获取所有文件并按时间戳排序
|
||||
json_files = list(chat_dir.glob("*.json"))
|
||||
json_files.sort(key=lambda f: parse_timestamp_from_filename(f.name), reverse=True)
|
||||
|
||||
# 如果有搜索关键词,需要过滤文件
|
||||
if search:
|
||||
search_lower = search.lower()
|
||||
filtered_files = []
|
||||
for log_file in json_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
prompt = data.get('prompt', '')
|
||||
if search_lower in prompt.lower():
|
||||
filtered_files.append(log_file)
|
||||
except Exception:
|
||||
continue
|
||||
json_files = filtered_files
|
||||
|
||||
total = len(json_files)
|
||||
|
||||
# 分页 - 只读取当前页的文件
|
||||
offset = (page - 1) * page_size
|
||||
page_files = json_files[offset:offset + page_size]
|
||||
|
||||
logs = []
|
||||
for log_file in page_files:
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
output = data.get('output', '')
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', parse_timestamp_from_filename(log_file.name)),
|
||||
filename=log_file.name,
|
||||
model=data.get('model', ''),
|
||||
success=data.get('success', True),
|
||||
llm_ms=data.get('timing', {}).get('llm_ms', 0),
|
||||
overall_ms=data.get('timing', {}).get('overall_ms', 0),
|
||||
output_preview=output[:100] if output else ''
|
||||
))
|
||||
except Exception:
|
||||
# 文件读取失败时使用文件名信息
|
||||
logs.append(ReplyLogSummary(
|
||||
chat_id=chat_id,
|
||||
timestamp=parse_timestamp_from_filename(log_file.name),
|
||||
filename=log_file.name,
|
||||
model='',
|
||||
success=False,
|
||||
llm_ms=0,
|
||||
overall_ms=0,
|
||||
output_preview='[读取失败]'
|
||||
))
|
||||
|
||||
return PaginatedReplyLogs(
|
||||
data=logs,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
chat_id=chat_id
|
||||
)
|
||||
|
||||
|
||||
@router.get("/log/{chat_id}/{filename}", response_model=ReplyLogDetail)
|
||||
async def get_reply_log_detail(chat_id: str, filename: str):
|
||||
"""获取回复日志详情 - 按需加载完整内容"""
|
||||
log_file = REPLY_LOG_DIR / chat_id / filename
|
||||
if not log_file.exists():
|
||||
raise HTTPException(status_code=404, detail="日志文件不存在")
|
||||
|
||||
try:
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return ReplyLogDetail(
|
||||
type=data.get('type', 'reply'),
|
||||
chat_id=data.get('chat_id', chat_id),
|
||||
timestamp=data.get('timestamp', 0),
|
||||
prompt=data.get('prompt', ''),
|
||||
output=data.get('output', ''),
|
||||
processed_output=data.get('processed_output', []),
|
||||
model=data.get('model', ''),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
think_level=data.get('think_level', 0),
|
||||
timing=data.get('timing', {}),
|
||||
error=data.get('error'),
|
||||
success=data.get('success', True)
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"读取日志失败: {str(e)}")
|
||||
|
||||
|
||||
# ========== 兼容接口 ==========
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_replier_stats():
|
||||
"""获取回复器统计信息"""
|
||||
overview = await get_replier_overview()
|
||||
|
||||
# 获取最近10条回复的摘要
|
||||
recent_replies = []
|
||||
for chat in overview.chats[:5]: # 从最近5个聊天中获取
|
||||
try:
|
||||
chat_logs = await get_chat_reply_logs(chat.chat_id, page=1, page_size=2)
|
||||
recent_replies.extend(chat_logs.data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 按时间排序取前10
|
||||
recent_replies.sort(key=lambda x: x.timestamp, reverse=True)
|
||||
recent_replies = recent_replies[:10]
|
||||
|
||||
return {
|
||||
"total_chats": overview.total_chats,
|
||||
"total_replies": overview.total_replies,
|
||||
"recent_replies": recent_replies
|
||||
}
|
||||
|
||||
|
||||
@router.get("/chats")
|
||||
async def get_replier_chat_list():
|
||||
"""获取所有聊天ID列表"""
|
||||
overview = await get_replier_overview()
|
||||
return [chat.chat_id for chat in overview.chats]
|
||||
|
|
@ -1,183 +0,0 @@
|
|||
"""
|
||||
WebUI 认证模块
|
||||
提供统一的认证依赖,支持 Cookie 和 Header 两种方式
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, Cookie, Header, Response, Request
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from .token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.auth")
|
||||
|
||||
# Cookie 配置
|
||||
COOKIE_NAME = "maibot_session"
|
||||
COOKIE_MAX_AGE = 7 * 24 * 60 * 60 # 7天
|
||||
|
||||
|
||||
def _is_secure_environment() -> bool:
|
||||
"""
|
||||
检测是否应该启用安全 Cookie(HTTPS)
|
||||
|
||||
Returns:
|
||||
bool: 如果应该使用 secure cookie 则返回 True
|
||||
"""
|
||||
# 从配置读取
|
||||
if global_config.webui.secure_cookie:
|
||||
logger.info("配置中启用了 secure_cookie")
|
||||
return True
|
||||
|
||||
# 检查是否是生产环境
|
||||
if global_config.webui.mode == "production":
|
||||
logger.info("WebUI运行在生产模式,启用 secure cookie")
|
||||
return True
|
||||
|
||||
# 默认:开发环境不启用(因为通常是 HTTP)
|
||||
logger.debug("WebUI运行在开发模式,禁用 secure cookie")
|
||||
return False
|
||||
|
||||
|
||||
def get_current_token(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> str:
|
||||
"""
|
||||
获取当前请求的 token,优先从 Cookie 获取,其次从 Header 获取
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization Header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证通过的 token
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str, request: Optional[Request] = None) -> None:
|
||||
"""
|
||||
设置认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: 要设置的 token
|
||||
request: FastAPI Request 对象(可选,用于检测协议)
|
||||
"""
|
||||
# 根据环境和实际请求协议决定安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
# 如果提供了 request,检测实际使用的协议
|
||||
if request:
|
||||
# 检查 X-Forwarded-Proto header(代理/负载均衡器)
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "").lower()
|
||||
if forwarded_proto:
|
||||
is_https = forwarded_proto == "https"
|
||||
logger.debug(f"检测到 X-Forwarded-Proto: {forwarded_proto}, is_https={is_https}")
|
||||
else:
|
||||
# 检查 request.url.scheme
|
||||
is_https = request.url.scheme == "https"
|
||||
logger.debug(f"检测到 scheme: {request.url.scheme}, is_https={is_https}")
|
||||
|
||||
# 如果是 HTTP 连接,强制禁用 secure 标志
|
||||
if not is_https and is_secure:
|
||||
logger.warning("=" * 80)
|
||||
logger.warning("检测到 HTTP 连接但环境配置要求 HTTPS (secure cookie)")
|
||||
logger.warning("已自动禁用 secure 标志以允许登录,但建议修改配置:")
|
||||
logger.warning("1. 在配置文件中设置: webui.secure_cookie = false")
|
||||
logger.warning("2. 如果使用反向代理,请确保正确配置 X-Forwarded-Proto 头")
|
||||
logger.warning("=" * 80)
|
||||
is_secure = False
|
||||
|
||||
# 设置 Cookie
|
||||
response.set_cookie(
|
||||
key=COOKIE_NAME,
|
||||
value=token,
|
||||
max_age=COOKIE_MAX_AGE,
|
||||
httponly=True, # 防止 JS 读取,阻止 XSS 窃取
|
||||
samesite="lax", # 使用 lax 以兼容更多场景(开发和生产)
|
||||
secure=is_secure, # 根据实际协议决定
|
||||
path="/", # 确保 Cookie 在所有路径下可用
|
||||
)
|
||||
|
||||
logger.info(f"已设置认证 Cookie: {token[:8]}... (secure={is_secure}, samesite=lax, httponly=True, path=/, max_age={COOKIE_MAX_AGE})")
|
||||
logger.debug(f"完整 token 前缀: {token[:20]}...")
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""
|
||||
清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
"""
|
||||
# 保持与 set_auth_cookie 相同的安全设置
|
||||
is_secure = _is_secure_environment()
|
||||
|
||||
response.delete_cookie(
|
||||
key=COOKIE_NAME,
|
||||
httponly=True,
|
||||
samesite="strict" if is_secure else "lax",
|
||||
secure=is_secure,
|
||||
path="/",
|
||||
)
|
||||
logger.debug("已清除认证 Cookie")
|
||||
|
||||
|
||||
def verify_auth_token_from_cookie_or_header(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
验证认证 Token,支持从 Cookie 或 Header 获取
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
验证成功返回 True
|
||||
|
||||
Raises:
|
||||
HTTPException: 认证失败时抛出 401 错误
|
||||
"""
|
||||
token = None
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
# 其次从 Header 获取(兼容旧版本)
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
# 验证 token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效或已过期")
|
||||
|
||||
return True
|
||||
|
|
@ -1,782 +0,0 @@
|
|||
"""本地聊天室路由 - WebUI 与麦麦直接对话
|
||||
|
||||
支持两种模式:
|
||||
1. WebUI 模式:使用 WebUI 平台独立身份聊天
|
||||
2. 虚拟身份模式:使用真实平台用户的身份,在虚拟群聊中与麦麦对话
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.bot import chat_bot
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.chat")
|
||||
|
||||
router = APIRouter(prefix="/api/chat", tags=["LocalChat"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# WebUI 聊天的虚拟群组 ID
|
||||
WEBUI_CHAT_GROUP_ID = "webui_local_chat"
|
||||
WEBUI_CHAT_PLATFORM = "webui"
|
||||
|
||||
# 虚拟身份模式的群 ID 前缀
|
||||
VIRTUAL_GROUP_ID_PREFIX = "webui_virtual_group_"
|
||||
|
||||
# 固定的 WebUI 用户 ID 前缀
|
||||
WEBUI_USER_ID_PREFIX = "webui_user_"
|
||||
|
||||
|
||||
class VirtualIdentityConfig(BaseModel):
|
||||
"""虚拟身份配置"""
|
||||
|
||||
enabled: bool = False # 是否启用虚拟身份模式
|
||||
platform: Optional[str] = None # 目标平台(如 qq, discord 等)
|
||||
person_id: Optional[str] = None # PersonInfo 的 person_id
|
||||
user_id: Optional[str] = None # 原始平台用户 ID
|
||||
user_nickname: Optional[str] = None # 用户昵称
|
||||
group_id: Optional[str] = None # 虚拟群 ID(自动生成或用户指定)
|
||||
group_name: Optional[str] = None # 虚拟群名(用户自定义)
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseModel):
|
||||
"""聊天历史消息"""
|
||||
|
||||
id: str
|
||||
type: str # 'user' | 'bot' | 'system'
|
||||
content: str
|
||||
timestamp: float
|
||||
sender_name: str
|
||||
sender_id: Optional[str] = None
|
||||
is_bot: bool = False
|
||||
|
||||
|
||||
class ChatHistoryManager:
|
||||
"""聊天历史管理器 - 使用 SQLite 数据库存储"""
|
||||
|
||||
def __init__(self, max_messages: int = 200):
|
||||
self.max_messages = max_messages
|
||||
|
||||
def _message_to_dict(self, msg: Messages, group_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""将数据库消息转换为前端格式
|
||||
|
||||
Args:
|
||||
msg: 数据库消息对象
|
||||
group_id: 群 ID,用于判断是否是虚拟群
|
||||
"""
|
||||
# 判断是否是机器人消息
|
||||
user_id = msg.user_id or ""
|
||||
|
||||
# 对于虚拟群,通过比较机器人 QQ 账号来判断
|
||||
# 对于普通 WebUI 群,检查 user_id 是否以 webui_ 开头
|
||||
if group_id and group_id.startswith(VIRTUAL_GROUP_ID_PREFIX):
|
||||
# 虚拟群:user_id 等于机器人 QQ 账号的是机器人消息
|
||||
bot_qq = str(global_config.bot.qq_account)
|
||||
is_bot = user_id == bot_qq
|
||||
else:
|
||||
# 普通 WebUI 群:不以 webui_ 开头的是机器人消息
|
||||
is_bot = not user_id.startswith("webui_") and not user_id.startswith(WEBUI_USER_ID_PREFIX)
|
||||
|
||||
return {
|
||||
"id": msg.message_id,
|
||||
"type": "bot" if is_bot else "user",
|
||||
"content": msg.processed_plain_text or msg.display_message or "",
|
||||
"timestamp": msg.time,
|
||||
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||
"sender_id": "bot" if is_bot else user_id,
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
def get_history(self, limit: int = 50, group_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""从数据库获取最近的历史记录
|
||||
|
||||
Args:
|
||||
limit: 获取的消息数量
|
||||
group_id: 群 ID,默认为 WEBUI_CHAT_GROUP_ID
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
# 查询指定群的消息,按时间排序
|
||||
messages = (
|
||||
Messages.select()
|
||||
.where(Messages.chat_info_group_id == target_group_id)
|
||||
.order_by(Messages.time.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
# 转换为列表并反转(使最旧的消息在前)
|
||||
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
||||
result = [self._message_to_dict(msg, target_group_id) for msg in messages]
|
||||
result.reverse()
|
||||
|
||||
logger.debug(f"从数据库加载了 {len(result)} 条聊天记录 (group_id={target_group_id})")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库加载聊天记录失败: {e}")
|
||||
return []
|
||||
|
||||
def clear_history(self, group_id: Optional[str] = None) -> int:
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
try:
|
||||
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
|
||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||
return deleted
|
||||
except Exception as e:
|
||||
logger.error(f"清空聊天记录失败: {e}")
|
||||
return 0
|
||||
|
||||
|
||||
# 全局聊天历史管理器
|
||||
chat_history = ChatHistoryManager()
|
||||
|
||||
|
||||
# 存储 WebSocket 连接
|
||||
class ChatConnectionManager:
|
||||
"""聊天连接管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_connections: Dict[str, WebSocket] = {}
|
||||
self.user_sessions: Dict[str, str] = {} # user_id -> session_id 映射
|
||||
|
||||
async def connect(self, websocket: WebSocket, session_id: str, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections[session_id] = websocket
|
||||
self.user_sessions[user_id] = session_id
|
||||
logger.info(f"WebUI 聊天会话已连接: session={session_id}, user={user_id}")
|
||||
|
||||
def disconnect(self, session_id: str, user_id: str):
|
||||
if session_id in self.active_connections:
|
||||
del self.active_connections[session_id]
|
||||
if user_id in self.user_sessions and self.user_sessions[user_id] == session_id:
|
||||
del self.user_sessions[user_id]
|
||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||
|
||||
async def send_message(self, session_id: str, message: dict):
|
||||
if session_id in self.active_connections:
|
||||
try:
|
||||
await self.active_connections[session_id].send_json(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
async def broadcast(self, message: dict):
|
||||
"""广播消息给所有连接"""
|
||||
for session_id in list(self.active_connections.keys()):
|
||||
await self.send_message(session_id, message)
|
||||
|
||||
|
||||
chat_manager = ChatConnectionManager()
|
||||
|
||||
|
||||
def create_message_data(
|
||||
content: str,
|
||||
user_id: str,
|
||||
user_name: str,
|
||||
message_id: Optional[str] = None,
|
||||
is_at_bot: bool = True,
|
||||
virtual_config: Optional[VirtualIdentityConfig] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""创建符合麦麦消息格式的消息数据
|
||||
|
||||
Args:
|
||||
content: 消息内容
|
||||
user_id: 用户 ID
|
||||
user_name: 用户昵称
|
||||
message_id: 消息 ID(可选,自动生成)
|
||||
is_at_bot: 是否 @ 机器人
|
||||
virtual_config: 虚拟身份配置(可选,启用后使用真实平台身份)
|
||||
"""
|
||||
if message_id is None:
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# 确定使用的平台、群信息和用户信息
|
||||
if virtual_config and virtual_config.enabled:
|
||||
# 虚拟身份模式:使用真实平台身份
|
||||
platform = virtual_config.platform or WEBUI_CHAT_PLATFORM
|
||||
group_id = virtual_config.group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{uuid.uuid4().hex[:8]}"
|
||||
group_name = virtual_config.group_name or "WebUI虚拟群聊"
|
||||
actual_user_id = virtual_config.user_id or user_id
|
||||
actual_user_name = virtual_config.user_nickname or user_name
|
||||
else:
|
||||
# 标准 WebUI 模式
|
||||
platform = WEBUI_CHAT_PLATFORM
|
||||
group_id = WEBUI_CHAT_GROUP_ID
|
||||
group_name = "WebUI本地聊天室"
|
||||
actual_user_id = user_id
|
||||
actual_user_name = user_name
|
||||
|
||||
return {
|
||||
"message_info": {
|
||||
"platform": platform,
|
||||
"message_id": message_id,
|
||||
"time": time.time(),
|
||||
"group_info": {
|
||||
"group_id": group_id,
|
||||
"group_name": group_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"user_info": {
|
||||
"user_id": actual_user_id,
|
||||
"user_nickname": actual_user_name,
|
||||
"user_cardname": actual_user_name,
|
||||
"platform": platform,
|
||||
},
|
||||
"additional_config": {
|
||||
"at_bot": is_at_bot,
|
||||
},
|
||||
},
|
||||
"message_segment": {
|
||||
"type": "seglist",
|
||||
"data": [
|
||||
{
|
||||
"type": "text",
|
||||
"data": content,
|
||||
},
|
||||
{
|
||||
"type": "mention_bot",
|
||||
"data": "1.0",
|
||||
},
|
||||
],
|
||||
},
|
||||
"raw_message": content,
|
||||
"processed_plain_text": content,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/history")
|
||||
async def get_chat_history(
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
user_id: Optional[str] = Query(default=None), # 保留参数兼容性,但不用于过滤
|
||||
group_id: Optional[str] = Query(default=None), # 可选:指定群 ID 获取历史
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取聊天历史记录
|
||||
|
||||
所有 WebUI 用户共享同一个聊天室,因此返回所有历史记录
|
||||
如果指定了 group_id,则获取该虚拟群的历史记录
|
||||
"""
|
||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||
history = chat_history.get_history(limit, target_group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"messages": history,
|
||||
"total": len(history),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/platforms")
|
||||
async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||
"""获取可用平台列表
|
||||
|
||||
从 PersonInfo 表中获取所有已知的平台
|
||||
"""
|
||||
try:
|
||||
from peewee import fn
|
||||
|
||||
# 查询所有不同的平台
|
||||
platforms = (
|
||||
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
|
||||
.group_by(PersonInfo.platform)
|
||||
.order_by(fn.COUNT(PersonInfo.id).desc())
|
||||
)
|
||||
|
||||
result = []
|
||||
for p in platforms:
|
||||
if p.platform: # 排除空平台
|
||||
result.append({"platform": p.platform, "count": p.count})
|
||||
|
||||
return {"success": True, "platforms": result}
|
||||
except Exception as e:
|
||||
logger.error(f"获取平台列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "platforms": []}
|
||||
|
||||
|
||||
@router.get("/persons")
|
||||
async def get_persons_by_platform(
|
||||
platform: str = Query(..., description="平台名称"),
|
||||
search: Optional[str] = Query(default=None, description="搜索关键词"),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取指定平台的用户列表
|
||||
|
||||
Args:
|
||||
platform: 平台名称(如 qq, discord 等)
|
||||
search: 搜索关键词(匹配昵称、用户名、user_id)
|
||||
limit: 返回数量限制
|
||||
"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = PersonInfo.select().where(PersonInfo.platform == platform)
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 按最后交互时间排序,优先显示活跃用户
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
query = query.limit(limit)
|
||||
|
||||
result = []
|
||||
for person in query:
|
||||
result.append(
|
||||
{
|
||||
"person_id": person.person_id,
|
||||
"user_id": person.user_id,
|
||||
"person_name": person.person_name,
|
||||
"nickname": person.nickname,
|
||||
"is_known": person.is_known,
|
||||
"platform": person.platform,
|
||||
"display_name": person.person_name or person.nickname or person.user_id,
|
||||
}
|
||||
)
|
||||
|
||||
return {"success": True, "persons": result, "total": len(result)}
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
return {"success": False, "error": str(e), "persons": []}
|
||||
|
||||
|
||||
@router.delete("/history")
|
||||
async def clear_chat_history(group_id: Optional[str] = Query(default=None), _auth: bool = Depends(require_auth)):
|
||||
"""清空聊天历史记录
|
||||
|
||||
Args:
|
||||
group_id: 可选,指定要清空的群 ID,默认清空 WebUI 默认聊天室
|
||||
"""
|
||||
deleted = chat_history.clear_history(group_id)
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"已清空 {deleted} 条聊天记录",
|
||||
}
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_chat(
|
||||
websocket: WebSocket,
|
||||
user_id: Optional[str] = Query(default=None),
|
||||
user_name: Optional[str] = Query(default="WebUI用户"),
|
||||
platform: Optional[str] = Query(default=None),
|
||||
person_id: Optional[str] = Query(default=None),
|
||||
group_name: Optional[str] = Query(default=None),
|
||||
group_id: Optional[str] = Query(default=None), # 前端传递的稳定 group_id
|
||||
token: Optional[str] = Query(default=None), # 认证 token
|
||||
):
|
||||
"""WebSocket 聊天端点
|
||||
|
||||
Args:
|
||||
user_id: 用户唯一标识(由前端生成并持久化)
|
||||
user_name: 用户显示昵称(可修改)
|
||||
platform: 虚拟身份模式的平台(可选)
|
||||
person_id: 虚拟身份模式的用户 person_id(可选)
|
||||
group_name: 虚拟身份模式的群名(可选)
|
||||
group_id: 虚拟身份模式的群 ID(可选,由前端生成并持久化)
|
||||
token: 认证 token(可选,也可从 Cookie 获取)
|
||||
|
||||
虚拟身份模式可通过 URL 参数直接配置,或通过消息中的 set_virtual_identity 配置
|
||||
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/api/chat/ws?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("聊天 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("聊天 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
# 生成会话 ID(每次连接都是新的)
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# 如果没有提供 user_id,生成一个新的
|
||||
if not user_id:
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{uuid.uuid4().hex[:16]}"
|
||||
elif not user_id.startswith(WEBUI_USER_ID_PREFIX):
|
||||
# 确保 user_id 有正确的前缀
|
||||
user_id = f"{WEBUI_USER_ID_PREFIX}{user_id}"
|
||||
|
||||
# 当前会话的虚拟身份配置(可通过消息动态更新)
|
||||
current_virtual_config: Optional[VirtualIdentityConfig] = None
|
||||
|
||||
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
||||
if platform and person_id:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=virtual_group_id,
|
||||
group_name=group_name or "WebUI虚拟群聊",
|
||||
)
|
||||
logger.info(
|
||||
f"虚拟身份模式已通过 URL 参数激活: {current_virtual_config.user_nickname} @ {current_virtual_config.platform}, group_id={virtual_group_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"通过 URL 参数配置虚拟身份失败: {e}")
|
||||
|
||||
await chat_manager.connect(websocket, session_id, user_id)
|
||||
|
||||
try:
|
||||
# 构建会话信息
|
||||
session_info_data = {
|
||||
"type": "session_info",
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"user_name": user_name,
|
||||
"bot_name": global_config.bot.nickname,
|
||||
}
|
||||
|
||||
# 如果有虚拟身份配置,添加到会话信息中
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
session_info_data["virtual_mode"] = True
|
||||
session_info_data["group_id"] = current_virtual_config.group_id
|
||||
session_info_data["virtual_identity"] = {
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
}
|
||||
|
||||
# 发送会话信息(包含用户 ID,前端需要保存)
|
||||
await chat_manager.send_message(session_id, session_info_data)
|
||||
|
||||
# 发送历史记录(根据模式选择不同的群)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
else:
|
||||
history = chat_history.get_history(50)
|
||||
if history:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": history,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送欢迎消息(不保存到历史)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
welcome_msg = f"已以 {current_virtual_config.user_nickname} 的身份连接到「{current_virtual_config.group_name}」,开始与 {global_config.bot.nickname} 对话吧!"
|
||||
else:
|
||||
welcome_msg = f"已连接到本地聊天室,可以开始与 {global_config.bot.nickname} 对话了!"
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": welcome_msg,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
data = await websocket.receive_json()
|
||||
|
||||
if data.get("type") == "message":
|
||||
content = data.get("content", "").strip()
|
||||
if not content:
|
||||
continue
|
||||
|
||||
# 用户可以更新昵称
|
||||
current_user_name = data.get("user_name", user_name)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
timestamp = time.time()
|
||||
|
||||
# 确定发送者信息(根据是否使用虚拟身份)
|
||||
if current_virtual_config and current_virtual_config.enabled:
|
||||
sender_name = current_virtual_config.user_nickname or current_user_name
|
||||
sender_user_id = current_virtual_config.user_id or user_id
|
||||
else:
|
||||
sender_name = current_user_name
|
||||
sender_user_id = user_id
|
||||
|
||||
# 广播用户消息给所有连接(包括发送者)
|
||||
# 注意:用户消息会在 chat_bot.message_process 中自动保存到数据库
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "user_message",
|
||||
"content": content,
|
||||
"message_id": message_id,
|
||||
"timestamp": timestamp,
|
||||
"sender": {
|
||||
"name": sender_name,
|
||||
"user_id": sender_user_id,
|
||||
"is_bot": False,
|
||||
},
|
||||
"virtual_mode": current_virtual_config.enabled if current_virtual_config else False,
|
||||
}
|
||||
)
|
||||
|
||||
# 创建麦麦消息格式
|
||||
message_data = create_message_data(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
user_name=current_user_name,
|
||||
message_id=message_id,
|
||||
is_at_bot=True,
|
||||
virtual_config=current_virtual_config,
|
||||
)
|
||||
|
||||
try:
|
||||
# 显示正在输入状态
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": True,
|
||||
}
|
||||
)
|
||||
|
||||
# 调用麦麦的消息处理
|
||||
await chat_bot.message_process(message_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息时出错: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"处理消息时出错: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
await chat_manager.broadcast(
|
||||
{
|
||||
"type": "typing",
|
||||
"is_typing": False,
|
||||
}
|
||||
)
|
||||
|
||||
elif data.get("type") == "ping":
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "pong",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "update_nickname":
|
||||
# 允许用户更新昵称
|
||||
if new_name := data.get("user_name", "").strip():
|
||||
current_user_name = new_name
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "nickname_updated",
|
||||
"user_name": current_user_name,
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
elif data.get("type") == "set_virtual_identity":
|
||||
# 设置或更新虚拟身份配置
|
||||
virtual_data = data.get("config", {})
|
||||
if virtual_data.get("enabled"):
|
||||
# 验证必要字段
|
||||
if not virtual_data.get("platform") or not virtual_data.get("person_id"):
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": "虚拟身份配置缺少必要字段: platform 和 person_id",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 获取用户信息
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
|
||||
if not person:
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"找不到用户: {virtual_data.get('person_id')}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
continue
|
||||
|
||||
# 生成虚拟群 ID
|
||||
custom_group_id = virtual_data.get("group_id")
|
||||
if custom_group_id:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{custom_group_id}"
|
||||
else:
|
||||
group_id = f"{VIRTUAL_GROUP_ID_PREFIX}{session_id[:8]}"
|
||||
|
||||
current_virtual_config = VirtualIdentityConfig(
|
||||
enabled=True,
|
||||
platform=person.platform,
|
||||
person_id=person.person_id,
|
||||
user_id=person.user_id,
|
||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
||||
group_id=group_id,
|
||||
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
||||
)
|
||||
|
||||
# 发送虚拟身份已激活的消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {
|
||||
"enabled": True,
|
||||
"platform": current_virtual_config.platform,
|
||||
"user_id": current_virtual_config.user_id,
|
||||
"user_nickname": current_virtual_config.user_nickname,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
"group_name": current_virtual_config.group_name,
|
||||
},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 加载虚拟群的历史记录
|
||||
virtual_history = chat_history.get_history(50, current_virtual_config.group_id)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": virtual_history,
|
||||
"group_id": current_virtual_config.group_id,
|
||||
},
|
||||
)
|
||||
|
||||
# 发送系统消息
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": f"已切换到虚拟身份模式:以 {current_virtual_config.user_nickname} 的身份在「{current_virtual_config.group_name}」与 {global_config.bot.nickname} 对话",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"设置虚拟身份失败: {e}")
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "error",
|
||||
"content": f"设置虚拟身份失败: {str(e)}",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# 禁用虚拟身份模式
|
||||
current_virtual_config = None
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "virtual_identity_set",
|
||||
"config": {"enabled": False},
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
# 重新加载默认聊天室历史
|
||||
default_history = chat_history.get_history(50, WEBUI_CHAT_GROUP_ID)
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "history",
|
||||
"messages": default_history,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
},
|
||||
)
|
||||
|
||||
await chat_manager.send_message(
|
||||
session_id,
|
||||
{
|
||||
"type": "system",
|
||||
"content": "已切换回 WebUI 独立用户模式",
|
||||
"timestamp": time.time(),
|
||||
},
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket 断开: session={session_id}, user={user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket 错误: {e}")
|
||||
finally:
|
||||
chat_manager.disconnect(session_id, user_id)
|
||||
|
||||
|
||||
@router.get("/info")
|
||||
async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||
"""获取聊天室信息"""
|
||||
return {
|
||||
"bot_name": global_config.bot.nickname,
|
||||
"platform": WEBUI_CHAT_PLATFORM,
|
||||
"group_id": WEBUI_CHAT_GROUP_ID,
|
||||
"active_sessions": len(chat_manager.active_connections),
|
||||
}
|
||||
|
||||
|
||||
def get_webui_chat_broadcaster() -> tuple:
|
||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||
|
||||
Returns:
|
||||
(chat_manager, WEBUI_CHAT_PLATFORM) 元组
|
||||
"""
|
||||
return (chat_manager, WEBUI_CHAT_PLATFORM)
|
||||
|
|
@ -1,597 +0,0 @@
|
|||
"""
|
||||
配置管理API路由
|
||||
"""
|
||||
|
||||
import os
|
||||
import tomlkit
|
||||
from fastapi import APIRouter, HTTPException, Body, Depends, Cookie, Header
|
||||
from typing import Any, Annotated, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
from src.common.toml_utils import save_toml_with_format, _update_toml_doc
|
||||
from src.config.config import Config, APIAdapterConfig, CONFIG_DIR, PROJECT_ROOT
|
||||
from src.config.official_configs import (
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
RelationshipConfig,
|
||||
ChatConfig,
|
||||
MessageReceiveConfig,
|
||||
EmojiConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
TelemetryConfig,
|
||||
ExperimentalConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
ToolConfig,
|
||||
MemoryConfig,
|
||||
DebugConfig,
|
||||
VoiceConfig,
|
||||
)
|
||||
from src.config.api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
)
|
||||
from src.webui.config_schema import ConfigSchemaGenerator
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
# 模块级别的类型别名(解决 B008 ruff 错误)
|
||||
ConfigBody = Annotated[dict[str, Any], Body()]
|
||||
SectionBody = Annotated[Any, Body()]
|
||||
RawContentBody = Annotated[str, Body(embed=True)]
|
||||
PathBody = Annotated[dict[str, str], Body()]
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# ===== 架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/bot")
|
||||
async def get_bot_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置架构"""
|
||||
try:
|
||||
# Config 类包含所有子配置
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(Config)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/schema/model")
|
||||
async def get_model_config_schema(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置架构(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
schema = ConfigSchemaGenerator.generate_config_schema(APIAdapterConfig)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型配置架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型配置架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 子配置架构获取接口 =====
|
||||
|
||||
|
||||
@router.get("/schema/section/{section_name}")
|
||||
async def get_config_section_schema(section_name: str, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取指定配置节的架构
|
||||
|
||||
支持的section_name:
|
||||
- bot: BotConfig
|
||||
- personality: PersonalityConfig
|
||||
- relationship: RelationshipConfig
|
||||
- chat: ChatConfig
|
||||
- message_receive: MessageReceiveConfig
|
||||
- emoji: EmojiConfig
|
||||
- expression: ExpressionConfig
|
||||
- keyword_reaction: KeywordReactionConfig
|
||||
- chinese_typo: ChineseTypoConfig
|
||||
- response_post_process: ResponsePostProcessConfig
|
||||
- response_splitter: ResponseSplitterConfig
|
||||
- telemetry: TelemetryConfig
|
||||
- experimental: ExperimentalConfig
|
||||
- maim_message: MaimMessageConfig
|
||||
- lpmm_knowledge: LPMMKnowledgeConfig
|
||||
- tool: ToolConfig
|
||||
- memory: MemoryConfig
|
||||
- debug: DebugConfig
|
||||
- voice: VoiceConfig
|
||||
- jargon: JargonConfig
|
||||
- model_task_config: ModelTaskConfig
|
||||
- api_provider: APIProvider
|
||||
- model_info: ModelInfo
|
||||
"""
|
||||
section_map = {
|
||||
"bot": BotConfig,
|
||||
"personality": PersonalityConfig,
|
||||
"relationship": RelationshipConfig,
|
||||
"chat": ChatConfig,
|
||||
"message_receive": MessageReceiveConfig,
|
||||
"emoji": EmojiConfig,
|
||||
"expression": ExpressionConfig,
|
||||
"keyword_reaction": KeywordReactionConfig,
|
||||
"chinese_typo": ChineseTypoConfig,
|
||||
"response_post_process": ResponsePostProcessConfig,
|
||||
"response_splitter": ResponseSplitterConfig,
|
||||
"telemetry": TelemetryConfig,
|
||||
"experimental": ExperimentalConfig,
|
||||
"maim_message": MaimMessageConfig,
|
||||
"lpmm_knowledge": LPMMKnowledgeConfig,
|
||||
"tool": ToolConfig,
|
||||
"memory": MemoryConfig,
|
||||
"debug": DebugConfig,
|
||||
"voice": VoiceConfig,
|
||||
"model_task_config": ModelTaskConfig,
|
||||
"api_provider": APIProvider,
|
||||
"model_info": ModelInfo,
|
||||
}
|
||||
|
||||
if section_name not in section_map:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
try:
|
||||
config_class = section_map[section_name]
|
||||
schema = ConfigSchemaGenerator.generate_schema(config_class, include_nested=False)
|
||||
return {"success": True, "schema": schema}
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置节架构失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置节架构失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置读取接口 =====
|
||||
|
||||
|
||||
@router.get("/bot")
|
||||
async def get_bot_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": config_data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/model")
|
||||
async def get_model_config(_auth: bool = Depends(require_auth)):
|
||||
"""获取模型配置(包含提供商和模型任务配置)"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
return {"success": True, "config": config_data}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot")
|
||||
async def update_bot_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("麦麦主程序配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model")
|
||||
async def update_model_config(config_data: ConfigBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新模型配置"""
|
||||
try:
|
||||
# 验证配置数据
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件(自动保留注释和格式)
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info("模型配置已更新")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 配置节更新接口 =====
|
||||
|
||||
|
||||
@router.post("/bot/section/{section_name}")
|
||||
async def update_bot_config_section(section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 更新指定节
|
||||
if section_name not in config_data:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
# 使用递归合并保留注释(对于字典类型)
|
||||
# 对于数组类型(如 platforms, aliases),直接替换
|
||||
if isinstance(section_data, list):
|
||||
# 列表直接替换
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 原始 TOML 文件操作接口 =====
|
||||
|
||||
|
||||
@router.get("/bot/raw")
|
||||
async def get_bot_config_raw(_auth: bool = Depends(require_auth)):
|
||||
"""获取麦麦主程序配置的原始 TOML 内容"""
|
||||
try:
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
raw_content = f.read()
|
||||
|
||||
return {"success": True, "content": raw_content}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/bot/raw")
|
||||
async def update_bot_config_raw(raw_content: RawContentBody, _auth: bool = Depends(require_auth)):
|
||||
"""更新麦麦主程序配置(直接保存原始 TOML 内容,会先验证格式)"""
|
||||
try:
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
config_data = tomlkit.loads(raw_content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 验证配置数据结构
|
||||
try:
|
||||
Config.from_dict(config_data)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置文件
|
||||
config_path = os.path.join(CONFIG_DIR, "bot_config.toml")
|
||||
with open(config_path, "w", encoding="utf-8") as f:
|
||||
f.write(raw_content)
|
||||
|
||||
logger.info("麦麦主程序配置已更新(原始模式)")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置文件失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/model/section/{section_name}")
|
||||
async def update_model_config_section(
|
||||
section_name: str, section_data: SectionBody, _auth: bool = Depends(require_auth)
|
||||
):
|
||||
"""更新模型配置的指定节(保留注释和格式)"""
|
||||
try:
|
||||
# 读取现有配置
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 更新指定节
|
||||
if section_name not in config_data:
|
||||
raise HTTPException(status_code=404, detail=f"配置节 '{section_name}' 不存在")
|
||||
|
||||
# 使用递归合并保留注释(对于字典类型)
|
||||
# 对于数组表(如 [[models]], [[api_providers]]),直接替换
|
||||
if isinstance(section_data, list):
|
||||
# 列表直接替换
|
||||
config_data[section_name] = section_data
|
||||
elif isinstance(section_data, dict) and isinstance(config_data[section_name], dict):
|
||||
# 字典递归合并
|
||||
_update_toml_doc(config_data[section_name], section_data)
|
||||
else:
|
||||
# 其他类型直接替换
|
||||
config_data[section_name] = section_data
|
||||
|
||||
# 验证完整配置
|
||||
try:
|
||||
APIAdapterConfig.from_dict(config_data)
|
||||
except Exception as e:
|
||||
logger.error(f"配置数据验证失败,详细错误: {str(e)}")
|
||||
# 特殊处理:如果是更新 api_providers,检查是否有模型引用了已删除的provider
|
||||
if section_name == "api_providers" and "api_provider" in str(e):
|
||||
provider_names = {p.get("name") for p in section_data if isinstance(p, dict)}
|
||||
models = config_data.get("models", [])
|
||||
orphaned_models = [
|
||||
m.get("name") for m in models if isinstance(m, dict) and m.get("api_provider") not in provider_names
|
||||
]
|
||||
if orphaned_models:
|
||||
error_msg = f"以下模型引用了已删除的提供商: {', '.join(orphaned_models)}。请先在模型管理页面删除这些模型,或重新分配它们的提供商。"
|
||||
raise HTTPException(status_code=400, detail=error_msg) from e
|
||||
raise HTTPException(status_code=400, detail=f"配置数据验证失败: {str(e)}") from e
|
||||
|
||||
# 保存配置(格式化数组为多行,保留注释)
|
||||
save_toml_with_format(config_data, config_path)
|
||||
|
||||
logger.info(f"配置节 '{section_name}' 已更新(保留注释)")
|
||||
return {"success": True, "message": f"配置节 '{section_name}' 已保存"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新配置节失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新配置节失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ===== 适配器配置管理接口 =====
|
||||
|
||||
|
||||
def _normalize_adapter_path(path: str) -> str:
|
||||
"""将路径转换为绝对路径(如果是相对路径,则相对于项目根目录)"""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
# 如果已经是绝对路径,直接返回
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
|
||||
# 相对路径,转换为相对于项目根目录的绝对路径
|
||||
return os.path.normpath(os.path.join(PROJECT_ROOT, path))
|
||||
|
||||
|
||||
def _to_relative_path(path: str) -> str:
|
||||
"""尝试将绝对路径转换为相对于项目根目录的相对路径,如果无法转换则返回原路径"""
|
||||
if not path or not os.path.isabs(path):
|
||||
return path
|
||||
|
||||
try:
|
||||
# 尝试获取相对路径
|
||||
rel_path = os.path.relpath(path, PROJECT_ROOT)
|
||||
# 如果相对路径不是以 .. 开头(说明文件在项目目录内),则返回相对路径
|
||||
if not rel_path.startswith(".."):
|
||||
return rel_path
|
||||
except (ValueError, TypeError):
|
||||
# 在 Windows 上,如果路径在不同驱动器,relpath 会抛出 ValueError
|
||||
pass
|
||||
|
||||
# 无法转换为相对路径,返回绝对路径
|
||||
return path
|
||||
|
||||
|
||||
@router.get("/adapter-config/path")
|
||||
async def get_adapter_config_path(_auth: bool = Depends(require_auth)):
|
||||
"""获取保存的适配器配置文件路径"""
|
||||
try:
|
||||
# 从 data/webui.json 读取路径偏好
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
if not os.path.exists(webui_data_path):
|
||||
return {"success": True, "path": None}
|
||||
|
||||
import json
|
||||
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
|
||||
adapter_config_path = webui_data.get("adapter_config_path")
|
||||
if not adapter_config_path:
|
||||
return {"success": True, "path": None}
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(adapter_config_path)
|
||||
|
||||
# 检查文件是否存在并返回最后修改时间
|
||||
if os.path.exists(abs_path):
|
||||
import datetime
|
||||
|
||||
mtime = os.path.getmtime(abs_path)
|
||||
last_modified = datetime.datetime.fromtimestamp(mtime).isoformat()
|
||||
# 返回相对路径(如果可能)
|
||||
display_path = _to_relative_path(abs_path)
|
||||
return {"success": True, "path": display_path, "lastModified": last_modified}
|
||||
else:
|
||||
# 文件不存在,返回原路径
|
||||
return {"success": True, "path": adapter_config_path, "lastModified": None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取配置路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config/path")
|
||||
async def save_adapter_config_path(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置文件路径偏好"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
|
||||
# 保存到 data/webui.json
|
||||
webui_data_path = os.path.join("data", "webui.json")
|
||||
import json
|
||||
|
||||
# 读取现有数据
|
||||
if os.path.exists(webui_data_path):
|
||||
with open(webui_data_path, "r", encoding="utf-8") as f:
|
||||
webui_data = json.load(f)
|
||||
else:
|
||||
webui_data = {}
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 尝试转换为相对路径保存(如果文件在项目目录内)
|
||||
save_path = _to_relative_path(abs_path)
|
||||
|
||||
# 更新路径
|
||||
webui_data["adapter_config_path"] = save_path
|
||||
|
||||
# 保存
|
||||
os.makedirs("data", exist_ok=True)
|
||||
with open(webui_data_path, "w", encoding="utf-8") as f:
|
||||
json.dump(webui_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"适配器配置路径已保存: {save_path}(绝对路径: {abs_path})")
|
||||
return {"success": True, "message": "路径已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置路径失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存路径失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/adapter-config")
|
||||
async def get_adapter_config(path: str, _auth: bool = Depends(require_auth)):
|
||||
"""从指定路径读取适配器配置文件"""
|
||||
try:
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径参数不能为空")
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(abs_path):
|
||||
raise HTTPException(status_code=404, detail=f"配置文件不存在: {path}")
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 读取文件内容
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
logger.info(f"已读取适配器配置: {path} (绝对路径: {abs_path})")
|
||||
return {"success": True, "content": content}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"读取适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"读取配置失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/adapter-config")
|
||||
async def save_adapter_config(data: PathBody, _auth: bool = Depends(require_auth)):
|
||||
"""保存适配器配置到指定路径"""
|
||||
try:
|
||||
path = data.get("path")
|
||||
content = data.get("content")
|
||||
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="路径不能为空")
|
||||
if content is None:
|
||||
raise HTTPException(status_code=400, detail="配置内容不能为空")
|
||||
|
||||
# 将路径规范化为绝对路径
|
||||
abs_path = _normalize_adapter_path(path)
|
||||
|
||||
# 检查文件扩展名
|
||||
if not abs_path.endswith(".toml"):
|
||||
raise HTTPException(status_code=400, detail="只支持 .toml 格式的配置文件")
|
||||
|
||||
# 验证 TOML 格式
|
||||
try:
|
||||
tomlkit.loads(content)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"TOML 格式错误: {str(e)}") from e
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = os.path.dirname(abs_path)
|
||||
if dir_path:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
# 保存文件
|
||||
with open(abs_path, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
|
||||
logger.info(f"适配器配置已保存: {path} (绝对路径: {abs_path})")
|
||||
return {"success": True, "message": "配置已保存"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"保存适配器配置失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"保存配置失败: {str(e)}") from e
|
||||
|
|
@ -1,335 +0,0 @@
|
|||
"""
|
||||
配置架构生成器 - 自动从配置类生成前端表单架构
|
||||
"""
|
||||
|
||||
import inspect
|
||||
from dataclasses import fields, MISSING
|
||||
from typing import Any, get_origin, get_args, Literal, Optional
|
||||
from enum import Enum
|
||||
|
||||
from src.config.config_base import ConfigBase
|
||||
|
||||
|
||||
class FieldType(str, Enum):
|
||||
"""字段类型枚举"""
|
||||
|
||||
STRING = "string"
|
||||
NUMBER = "number"
|
||||
INTEGER = "integer"
|
||||
BOOLEAN = "boolean"
|
||||
SELECT = "select"
|
||||
ARRAY = "array"
|
||||
OBJECT = "object"
|
||||
TEXTAREA = "textarea"
|
||||
|
||||
|
||||
class FieldSchema:
|
||||
"""字段架构"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
type: FieldType,
|
||||
label: str,
|
||||
description: str = "",
|
||||
default: Any = None,
|
||||
required: bool = True,
|
||||
options: Optional[list[str]] = None,
|
||||
min_value: Optional[float] = None,
|
||||
max_value: Optional[float] = None,
|
||||
items: Optional[dict] = None,
|
||||
properties: Optional[dict] = None,
|
||||
):
|
||||
self.name = name
|
||||
self.type = type
|
||||
self.label = label
|
||||
self.description = description
|
||||
self.default = default
|
||||
self.required = required
|
||||
self.options = options
|
||||
self.min_value = min_value
|
||||
self.max_value = max_value
|
||||
self.items = items
|
||||
self.properties = properties
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"name": self.name,
|
||||
"type": self.type.value,
|
||||
"label": self.label,
|
||||
"description": self.description,
|
||||
"required": self.required,
|
||||
}
|
||||
|
||||
if self.default is not None:
|
||||
result["default"] = self.default
|
||||
|
||||
if self.options is not None:
|
||||
result["options"] = self.options
|
||||
|
||||
if self.min_value is not None:
|
||||
result["minValue"] = self.min_value
|
||||
|
||||
if self.max_value is not None:
|
||||
result["maxValue"] = self.max_value
|
||||
|
||||
if self.items is not None:
|
||||
result["items"] = self.items
|
||||
|
||||
if self.properties is not None:
|
||||
result["properties"] = self.properties
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class ConfigSchemaGenerator:
|
||||
"""配置架构生成器"""
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_description(config_class: type, field_name: str) -> str:
|
||||
"""
|
||||
从类定义中提取字段的文档字符串描述
|
||||
|
||||
Args:
|
||||
config_class: 配置类
|
||||
field_name: 字段名
|
||||
|
||||
Returns:
|
||||
str: 字段描述
|
||||
"""
|
||||
try:
|
||||
# 获取源代码
|
||||
source = inspect.getsource(config_class)
|
||||
lines = source.split("\n")
|
||||
|
||||
# 查找字段定义
|
||||
field_found = False
|
||||
description_lines = []
|
||||
|
||||
for i, line in enumerate(lines):
|
||||
# 匹配字段定义行,例如: platform: str
|
||||
if f"{field_name}:" in line and "=" in line:
|
||||
field_found = True
|
||||
# 查找下一行的文档字符串
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
# 单行文档字符串
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
# 多行文档字符串
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
for j in range(i + 2, len(lines)):
|
||||
if quote in lines[j]:
|
||||
description_lines.append(lines[j].split(quote)[0].strip())
|
||||
break
|
||||
description_lines.append(lines[j].strip())
|
||||
break
|
||||
elif f"{field_name}:" in line and "=" not in line:
|
||||
# 没有默认值的字段
|
||||
field_found = True
|
||||
if i + 1 < len(lines):
|
||||
next_line = lines[i + 1].strip()
|
||||
if next_line.startswith('"""') or next_line.startswith("'''"):
|
||||
if next_line.count('"""') == 2 or next_line.count("'''") == 2:
|
||||
description_lines.append(next_line.replace('"""', "").replace("'''", "").strip())
|
||||
else:
|
||||
quote = '"""' if next_line.startswith('"""') else "'''"
|
||||
description_lines.append(next_line.strip(quote).strip())
|
||||
for j in range(i + 2, len(lines)):
|
||||
if quote in lines[j]:
|
||||
description_lines.append(lines[j].split(quote)[0].strip())
|
||||
break
|
||||
description_lines.append(lines[j].strip())
|
||||
break
|
||||
|
||||
if field_found and description_lines:
|
||||
return " ".join(description_lines)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _get_field_type_and_options(field_type: type) -> tuple[FieldType, Optional[list[str]], Optional[dict]]:
|
||||
"""
|
||||
获取字段类型和选项
|
||||
|
||||
Args:
|
||||
field_type: 字段类型
|
||||
|
||||
Returns:
|
||||
tuple: (FieldType, options, items)
|
||||
"""
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
# 处理 Literal 类型(枚举选项)
|
||||
if origin is Literal:
|
||||
return FieldType.SELECT, [str(arg) for arg in args], None
|
||||
|
||||
# 处理 list 类型
|
||||
if origin is list:
|
||||
item_type = args[0] if args else str
|
||||
if item_type is str:
|
||||
items = {"type": "string"}
|
||||
elif item_type is int:
|
||||
items = {"type": "integer"}
|
||||
elif item_type is float:
|
||||
items = {"type": "number"}
|
||||
elif item_type is bool:
|
||||
items = {"type": "boolean"}
|
||||
elif item_type is dict:
|
||||
items = {"type": "object"}
|
||||
else:
|
||||
items = {"type": "string"}
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理 set 类型(与 list 类似)
|
||||
if origin is set:
|
||||
item_type = args[0] if args else str
|
||||
if item_type is str:
|
||||
items = {"type": "string"}
|
||||
else:
|
||||
items = {"type": "string"}
|
||||
return FieldType.ARRAY, None, items
|
||||
|
||||
# 处理基本类型
|
||||
if field_type is bool:
|
||||
return FieldType.BOOLEAN, None, None
|
||||
elif field_type is int:
|
||||
return FieldType.INTEGER, None, None
|
||||
elif field_type is float:
|
||||
return FieldType.NUMBER, None, None
|
||||
elif field_type is str:
|
||||
return FieldType.STRING, None, None
|
||||
elif field_type is dict or origin is dict:
|
||||
return FieldType.OBJECT, None, None
|
||||
|
||||
# 默认为字符串
|
||||
return FieldType.STRING, None, None
|
||||
|
||||
@staticmethod
|
||||
def _format_field_name(name: str) -> str:
|
||||
"""
|
||||
格式化字段名为可读的标签
|
||||
|
||||
Args:
|
||||
name: 原始字段名
|
||||
|
||||
Returns:
|
||||
str: 格式化后的标签
|
||||
"""
|
||||
# 将下划线替换为空格,并首字母大写
|
||||
return " ".join(word.capitalize() for word in name.split("_"))
|
||||
|
||||
@staticmethod
|
||||
def generate_schema(config_class: type[ConfigBase], include_nested: bool = True) -> dict:
|
||||
"""
|
||||
从配置类生成前端表单架构
|
||||
|
||||
Args:
|
||||
config_class: 配置类(必须继承自 ConfigBase)
|
||||
include_nested: 是否包含嵌套的配置对象
|
||||
|
||||
Returns:
|
||||
dict: 前端表单架构
|
||||
"""
|
||||
if not issubclass(config_class, ConfigBase):
|
||||
raise ValueError(f"{config_class.__name__} 必须继承自 ConfigBase")
|
||||
|
||||
schema_fields = []
|
||||
nested_schemas = {}
|
||||
|
||||
for field in fields(config_class):
|
||||
# 跳过私有字段和内部字段
|
||||
if field.name.startswith("_") or field.name in ["MMC_VERSION"]:
|
||||
continue
|
||||
|
||||
# 提取字段描述
|
||||
description = ConfigSchemaGenerator._extract_field_description(config_class, field.name)
|
||||
|
||||
# 判断是否必填
|
||||
required = field.default is MISSING and field.default_factory is MISSING
|
||||
|
||||
# 获取默认值
|
||||
default_value = None
|
||||
if field.default is not MISSING:
|
||||
default_value = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
try:
|
||||
default_value = field.default_factory()
|
||||
except Exception:
|
||||
default_value = None
|
||||
|
||||
# 检查是否为嵌套的 ConfigBase
|
||||
if isinstance(field.type, type) and issubclass(field.type, ConfigBase):
|
||||
if include_nested:
|
||||
# 递归生成嵌套配置的架构
|
||||
nested_schema = ConfigSchemaGenerator.generate_schema(field.type, include_nested=True)
|
||||
nested_schemas[field.name] = nested_schema
|
||||
|
||||
field_schema = FieldSchema(
|
||||
name=field.name,
|
||||
type=FieldType.OBJECT,
|
||||
label=ConfigSchemaGenerator._format_field_name(field.name),
|
||||
description=description or field.type.__doc__ or "",
|
||||
default=default_value,
|
||||
required=required,
|
||||
properties=nested_schema,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
# 获取字段类型和选项
|
||||
field_type, options, items = ConfigSchemaGenerator._get_field_type_and_options(field.type)
|
||||
|
||||
# 特殊处理:长文本使用 textarea
|
||||
if field_type == FieldType.STRING and field.name in [
|
||||
"personality",
|
||||
"reply_style",
|
||||
"interest",
|
||||
"plan_style",
|
||||
"visual_style",
|
||||
"private_plan_style",
|
||||
"reaction",
|
||||
"filtration_prompt",
|
||||
]:
|
||||
field_type = FieldType.TEXTAREA
|
||||
|
||||
field_schema = FieldSchema(
|
||||
name=field.name,
|
||||
type=field_type,
|
||||
label=ConfigSchemaGenerator._format_field_name(field.name),
|
||||
description=description,
|
||||
default=default_value,
|
||||
required=required,
|
||||
options=options,
|
||||
items=items,
|
||||
)
|
||||
|
||||
schema_fields.append(field_schema.to_dict())
|
||||
|
||||
return {
|
||||
"className": config_class.__name__,
|
||||
"classDoc": config_class.__doc__ or "",
|
||||
"fields": schema_fields,
|
||||
"nested": nested_schemas if nested_schemas else None,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def generate_config_schema(config_class: type[ConfigBase]) -> dict:
|
||||
"""
|
||||
生成完整的配置架构(包含所有嵌套的子配置)
|
||||
|
||||
Args:
|
||||
config_class: 配置类
|
||||
|
||||
Returns:
|
||||
dict: 完整的配置架构
|
||||
"""
|
||||
return ConfigSchemaGenerator.generate_schema(config_class, include_nested=True)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,790 +0,0 @@
|
|||
"""表达方式管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.expression")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/expression", tags=["Expression"])
|
||||
|
||||
|
||||
class ExpressionResponse(BaseModel):
|
||||
"""表达方式响应"""
|
||||
|
||||
id: int
|
||||
situation: str
|
||||
style: str
|
||||
last_active_time: float
|
||||
chat_id: str
|
||||
create_date: Optional[float]
|
||||
checked: bool
|
||||
rejected: bool
|
||||
modified_by: Optional[str] = None # 'ai' 或 'user' 或 None
|
||||
|
||||
|
||||
class ExpressionListResponse(BaseModel):
|
||||
"""表达方式列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[ExpressionResponse]
|
||||
|
||||
|
||||
class ExpressionDetailResponse(BaseModel):
|
||||
"""表达方式详情响应"""
|
||||
|
||||
success: bool
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
class ExpressionCreateRequest(BaseModel):
|
||||
"""表达方式创建请求"""
|
||||
|
||||
situation: str
|
||||
style: str
|
||||
chat_id: str
|
||||
|
||||
|
||||
class ExpressionUpdateRequest(BaseModel):
|
||||
"""表达方式更新请求"""
|
||||
|
||||
situation: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
checked: Optional[bool] = None
|
||||
rejected: Optional[bool] = None
|
||||
require_unchecked: Optional[bool] = False # 用于人工审核时的冲突检测
|
||||
|
||||
|
||||
class ExpressionUpdateResponse(BaseModel):
|
||||
"""表达方式更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[ExpressionResponse] = None
|
||||
|
||||
|
||||
class ExpressionDeleteResponse(BaseModel):
|
||||
"""表达方式删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class ExpressionCreateResponse(BaseModel):
|
||||
"""表达方式创建响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: ExpressionResponse
|
||||
|
||||
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||
"""将 Expression 模型转换为响应对象"""
|
||||
return ExpressionResponse(
|
||||
id=expression.id,
|
||||
situation=expression.situation,
|
||||
style=expression.style,
|
||||
last_active_time=expression.last_active_time,
|
||||
chat_id=expression.chat_id,
|
||||
create_date=expression.create_date,
|
||||
checked=expression.checked,
|
||||
rejected=expression.rejected,
|
||||
modified_by=expression.modified_by,
|
||||
)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""根据 chat_id 获取聊天名称"""
|
||||
try:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream:
|
||||
# 优先使用群聊名称,否则使用用户昵称
|
||||
if chat_stream.group_name:
|
||||
return chat_stream.group_name
|
||||
elif chat_stream.user_nickname:
|
||||
return chat_stream.user_nickname
|
||||
return chat_id # 找不到时返回原始ID
|
||||
except Exception:
|
||||
return chat_id
|
||||
|
||||
|
||||
def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||
"""批量获取聊天名称"""
|
||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||
try:
|
||||
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
|
||||
for cs in chat_streams:
|
||||
if cs.group_name:
|
||||
result[cs.stream_id] = cs.group_name
|
||||
elif cs.user_nickname:
|
||||
result[cs.stream_id] = cs.user_nickname
|
||||
except Exception as e:
|
||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||
return result
|
||||
|
||||
|
||||
class ChatInfo(BaseModel):
|
||||
"""聊天信息"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool
|
||||
data: List[ChatInfo]
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取所有聊天列表(用于下拉选择)
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
聊天列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
chat_list = []
|
||||
for cs in ChatStreams.select():
|
||||
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
|
||||
chat_list.append(
|
||||
ChatInfo(
|
||||
chat_id=cs.stream_id,
|
||||
chat_name=chat_name,
|
||||
platform=cs.platform,
|
||||
is_group=bool(cs.group_id),
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称排序
|
||||
chat_list.sort(key=lambda x: x.chat_name)
|
||||
|
||||
return ChatListResponse(success=True, data=chat_list)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/list", response_model=ExpressionListResponse)
|
||||
async def get_expression_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取表达方式列表
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 situation, style)
|
||||
chat_id: 聊天ID筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = Expression.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search))
|
||||
| (Expression.style.contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
# 转换为响应对象
|
||||
data = [expression_to_response(expr) for expr in expressions]
|
||||
|
||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表达方式列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表达方式列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{expression_id}", response_model=ExpressionDetailResponse)
|
||||
async def get_expression_detail(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式详细信息
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
表达方式详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
return ExpressionDetailResponse(success=True, data=expression_to_response(expression))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取表达方式详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取表达方式详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=ExpressionCreateResponse)
|
||||
async def create_expression(
|
||||
request: ExpressionCreateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
创建新的表达方式
|
||||
|
||||
Args:
|
||||
request: 创建请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
创建结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
current_time = time.time()
|
||||
|
||||
# 创建表达方式
|
||||
expression = Expression.create(
|
||||
situation=request.situation,
|
||||
style=request.style,
|
||||
chat_id=request.chat_id,
|
||||
last_active_time=current_time,
|
||||
create_date=current_time,
|
||||
)
|
||||
|
||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||
|
||||
return ExpressionCreateResponse(
|
||||
success=True, message="表达方式创建成功", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"创建表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{expression_id}", response_model=ExpressionUpdateResponse)
|
||||
async def update_expression(
|
||||
expression_id: int,
|
||||
request: ExpressionUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新表达方式(只更新提供的字段)
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 冲突检测:如果要求未检查状态,但已经被检查了
|
||||
if request.require_unchecked and expression.checked:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表"
|
||||
)
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
# 移除 require_unchecked,它不是数据库字段
|
||||
update_data.pop('require_unchecked', None)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 如果更新了 checked 或 rejected,标记为用户修改
|
||||
if 'checked' in update_data or 'rejected' in update_data:
|
||||
update_data['modified_by'] = 'user'
|
||||
|
||||
# 更新最后活跃时间
|
||||
update_data["last_active_time"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(expression, field, value)
|
||||
|
||||
expression.save()
|
||||
|
||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return ExpressionUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=expression_to_response(expression)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"更新表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{expression_id}", response_model=ExpressionDeleteResponse)
|
||||
async def delete_expression(
|
||||
expression_id: int, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除表达方式
|
||||
|
||||
Args:
|
||||
expression_id: 表达方式ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
||||
|
||||
if not expression:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||
|
||||
# 记录删除信息
|
||||
situation = expression.situation
|
||||
|
||||
# 执行删除
|
||||
expression.delete_instance()
|
||||
|
||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除表达方式: {situation}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int]
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=ExpressionDeleteResponse)
|
||||
async def batch_delete_expressions(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除表达方式
|
||||
|
||||
Args:
|
||||
request: 包含要删除的ID列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||
|
||||
# 查找所有要删除的表达方式
|
||||
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
||||
found_ids = [expr.id for expr in expressions]
|
||||
|
||||
# 检查是否有未找到的ID
|
||||
not_found_ids = set(request.ids) - set(found_ids)
|
||||
if not_found_ids:
|
||||
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||
|
||||
# 执行批量删除
|
||||
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
||||
|
||||
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||
|
||||
return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除表达方式失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_expression_stats(
|
||||
maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取表达方式统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
|
||||
# 按 chat_id 统计
|
||||
chat_stats = {}
|
||||
for expr in Expression.select(Expression.chat_id):
|
||||
chat_id = expr.chat_id
|
||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||
|
||||
# 获取最近创建的记录数(7天内)
|
||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
||||
recent = (
|
||||
Expression.select()
|
||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
||||
.count()
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total": total,
|
||||
"recent_7days": recent,
|
||||
"chat_count": len(chat_stats),
|
||||
"top_chats": dict(sorted(chat_stats.items(), key=lambda x: x[1], reverse=True)[:10]),
|
||||
},
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
# ============ 审核相关接口 ============
|
||||
|
||||
class ReviewStatsResponse(BaseModel):
|
||||
"""审核统计响应"""
|
||||
total: int
|
||||
unchecked: int
|
||||
passed: int
|
||||
rejected: int
|
||||
ai_checked: int
|
||||
user_checked: int
|
||||
|
||||
|
||||
@router.get("/review/stats", response_model=ReviewStatsResponse)
|
||||
async def get_review_stats(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取审核统计数据
|
||||
|
||||
Returns:
|
||||
审核统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = Expression.select().count()
|
||||
unchecked = Expression.select().where(Expression.checked == False).count()
|
||||
passed = Expression.select().where(
|
||||
(Expression.checked == True) & (Expression.rejected == False)
|
||||
).count()
|
||||
rejected = Expression.select().where(
|
||||
(Expression.checked == True) & (Expression.rejected == True)
|
||||
).count()
|
||||
ai_checked = Expression.select().where(Expression.modified_by == 'ai').count()
|
||||
user_checked = Expression.select().where(Expression.modified_by == 'user').count()
|
||||
|
||||
return ReviewStatsResponse(
|
||||
total=total,
|
||||
unchecked=unchecked,
|
||||
passed=passed,
|
||||
rejected=rejected,
|
||||
ai_checked=ai_checked,
|
||||
user_checked=user_checked
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取审核统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取审核统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
class ReviewListResponse(BaseModel):
|
||||
"""审核列表响应"""
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[ExpressionResponse]
|
||||
|
||||
|
||||
@router.get("/review/list", response_model=ReviewListResponse)
|
||||
async def get_review_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
filter_type: str = Query("unchecked", description="筛选类型: unchecked/passed/rejected/all"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="聊天ID筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取待审核/已审核的表达方式列表
|
||||
|
||||
Args:
|
||||
page: 页码
|
||||
page_size: 每页数量
|
||||
filter_type: 筛选类型 (unchecked/passed/rejected/all)
|
||||
search: 搜索关键词
|
||||
chat_id: 聊天ID筛选
|
||||
|
||||
Returns:
|
||||
表达方式列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
query = Expression.select()
|
||||
|
||||
# 根据筛选类型过滤
|
||||
if filter_type == "unchecked":
|
||||
query = query.where(Expression.checked == False)
|
||||
elif filter_type == "passed":
|
||||
query = query.where((Expression.checked == True) & (Expression.rejected == False))
|
||||
elif filter_type == "rejected":
|
||||
query = query.where((Expression.checked == True) & (Expression.rejected == True))
|
||||
# all 不需要额外过滤
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Expression.situation.contains(search)) | (Expression.style.contains(search))
|
||||
)
|
||||
|
||||
# 聊天ID过滤
|
||||
if chat_id:
|
||||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:创建时间倒序
|
||||
from peewee import Case
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.create_date.is_null(), 1)], 0),
|
||||
Expression.create_date.desc()
|
||||
)
|
||||
|
||||
total = query.count()
|
||||
offset = (page - 1) * page_size
|
||||
expressions = query.offset(offset).limit(page_size)
|
||||
|
||||
return ReviewListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=[expression_to_response(expr) for expr in expressions]
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取审核列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取审核列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
class BatchReviewItem(BaseModel):
|
||||
"""批量审核项"""
|
||||
id: int
|
||||
rejected: bool
|
||||
require_unchecked: bool = True # 默认要求未检查状态
|
||||
|
||||
|
||||
class BatchReviewRequest(BaseModel):
|
||||
"""批量审核请求"""
|
||||
items: List[BatchReviewItem]
|
||||
|
||||
|
||||
class BatchReviewResultItem(BaseModel):
|
||||
"""批量审核结果项"""
|
||||
id: int
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchReviewResponse(BaseModel):
|
||||
"""批量审核响应"""
|
||||
success: bool
|
||||
total: int
|
||||
succeeded: int
|
||||
failed: int
|
||||
results: List[BatchReviewResultItem]
|
||||
|
||||
|
||||
@router.post("/review/batch", response_model=BatchReviewResponse)
|
||||
async def batch_review_expressions(
|
||||
request: BatchReviewRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量审核表达方式
|
||||
|
||||
Args:
|
||||
request: 批量审核请求
|
||||
|
||||
Returns:
|
||||
批量审核结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.items:
|
||||
raise HTTPException(status_code=400, detail="未提供要审核的表达方式")
|
||||
|
||||
results = []
|
||||
succeeded = 0
|
||||
failed = 0
|
||||
|
||||
for item in request.items:
|
||||
try:
|
||||
expression = Expression.get_or_none(Expression.id == item.id)
|
||||
|
||||
if not expression:
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=f"未找到 ID 为 {item.id} 的表达方式"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 冲突检测
|
||||
if item.require_unchecked and expression.checked:
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查"
|
||||
))
|
||||
failed += 1
|
||||
continue
|
||||
|
||||
# 更新状态
|
||||
expression.checked = True
|
||||
expression.rejected = item.rejected
|
||||
expression.modified_by = 'user'
|
||||
expression.last_active_time = time.time()
|
||||
expression.save()
|
||||
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=True,
|
||||
message="通过" if not item.rejected else "拒绝"
|
||||
))
|
||||
succeeded += 1
|
||||
|
||||
except Exception as e:
|
||||
results.append(BatchReviewResultItem(
|
||||
id=item.id,
|
||||
success=False,
|
||||
message=str(e)
|
||||
))
|
||||
failed += 1
|
||||
|
||||
logger.info(f"批量审核完成: 成功 {succeeded}, 失败 {failed}")
|
||||
|
||||
return BatchReviewResponse(
|
||||
success=True,
|
||||
total=len(request.items),
|
||||
succeeded=succeeded,
|
||||
failed=failed,
|
||||
results=results
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量审核失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量审核失败: {str(e)}") from e
|
||||
|
|
@ -1,662 +0,0 @@
|
|||
"""Git 镜像源服务 - 支持多镜像源、错误重试、Git 克隆和 Raw 文件获取"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from enum import Enum
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
import subprocess
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.git_mirror")
|
||||
|
||||
# 导入进度更新函数(避免循环导入)
|
||||
_update_progress = None
|
||||
|
||||
|
||||
def set_update_progress_callback(callback):
|
||||
"""设置进度更新回调函数"""
|
||||
global _update_progress
|
||||
_update_progress = callback
|
||||
|
||||
|
||||
class MirrorType(str, Enum):
|
||||
"""镜像源类型"""
|
||||
|
||||
GH_PROXY = "gh-proxy" # gh-proxy 主节点
|
||||
HK_GH_PROXY = "hk-gh-proxy" # gh-proxy 香港节点
|
||||
CDN_GH_PROXY = "cdn-gh-proxy" # gh-proxy CDN 节点
|
||||
EDGEONE_GH_PROXY = "edgeone-gh-proxy" # gh-proxy EdgeOne 节点
|
||||
MEYZH_GITHUB = "meyzh-github" # Meyzh GitHub 镜像
|
||||
GITHUB = "github" # GitHub 官方源(兜底)
|
||||
CUSTOM = "custom" # 自定义镜像源
|
||||
|
||||
|
||||
class GitMirrorConfig:
|
||||
"""Git 镜像源配置管理"""
|
||||
|
||||
# 配置文件路径
|
||||
CONFIG_FILE = Path("data/webui.json")
|
||||
|
||||
# 默认镜像源配置
|
||||
DEFAULT_MIRRORS = [
|
||||
{
|
||||
"id": "gh-proxy",
|
||||
"name": "gh-proxy 镜像",
|
||||
"raw_prefix": "https://gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 1,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "hk-gh-proxy",
|
||||
"name": "gh-proxy 香港节点",
|
||||
"raw_prefix": "https://hk.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://hk.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 2,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "cdn-gh-proxy",
|
||||
"name": "gh-proxy CDN 节点",
|
||||
"raw_prefix": "https://cdn.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://cdn.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 3,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "edgeone-gh-proxy",
|
||||
"name": "gh-proxy EdgeOne 节点",
|
||||
"raw_prefix": "https://edgeone.gh-proxy.org/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://edgeone.gh-proxy.org/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 4,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "meyzh-github",
|
||||
"name": "Meyzh GitHub 镜像",
|
||||
"raw_prefix": "https://meyzh.github.io/https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://meyzh.github.io/https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 5,
|
||||
"created_at": None,
|
||||
},
|
||||
{
|
||||
"id": "github",
|
||||
"name": "GitHub 官方源(兜底)",
|
||||
"raw_prefix": "https://raw.githubusercontent.com",
|
||||
"clone_prefix": "https://github.com",
|
||||
"enabled": True,
|
||||
"priority": 999,
|
||||
"created_at": None,
|
||||
},
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化配置管理器"""
|
||||
self.config_file = self.CONFIG_FILE
|
||||
self.mirrors: List[Dict[str, Any]] = []
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self) -> None:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# 检查是否有镜像源配置
|
||||
if "git_mirrors" not in data or not data["git_mirrors"]:
|
||||
logger.info("配置文件中未找到镜像源配置,使用默认配置")
|
||||
self._init_default_mirrors()
|
||||
else:
|
||||
self.mirrors = data["git_mirrors"]
|
||||
logger.info(f"已加载 {len(self.mirrors)} 个镜像源配置")
|
||||
else:
|
||||
logger.info("配置文件不存在,创建默认配置")
|
||||
self._init_default_mirrors()
|
||||
except Exception as e:
|
||||
logger.error(f"加载配置文件失败: {e}")
|
||||
self._init_default_mirrors()
|
||||
|
||||
def _init_default_mirrors(self) -> None:
|
||||
"""初始化默认镜像源"""
|
||||
current_time = datetime.now().isoformat()
|
||||
self.mirrors = []
|
||||
|
||||
for mirror in self.DEFAULT_MIRRORS:
|
||||
mirror_copy = mirror.copy()
|
||||
mirror_copy["created_at"] = current_time
|
||||
self.mirrors.append(mirror_copy)
|
||||
|
||||
self._save_config()
|
||||
logger.info(f"已初始化 {len(self.mirrors)} 个默认镜像源")
|
||||
|
||||
def _save_config(self) -> None:
|
||||
"""保存配置到文件"""
|
||||
try:
|
||||
# 确保目录存在
|
||||
self.config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取现有配置
|
||||
existing_data = {}
|
||||
if self.config_file.exists():
|
||||
with open(self.config_file, "r", encoding="utf-8") as f:
|
||||
existing_data = json.load(f)
|
||||
|
||||
# 更新镜像源配置
|
||||
existing_data["git_mirrors"] = self.mirrors
|
||||
|
||||
# 写入文件
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing_data, f, indent=2, ensure_ascii=False)
|
||||
|
||||
logger.debug(f"配置已保存到 {self.config_file}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存配置文件失败: {e}")
|
||||
|
||||
def get_all_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有镜像源"""
|
||||
return self.mirrors.copy()
|
||||
|
||||
def get_enabled_mirrors(self) -> List[Dict[str, Any]]:
|
||||
"""获取所有启用的镜像源,按优先级排序"""
|
||||
enabled = [m for m in self.mirrors if m.get("enabled", False)]
|
||||
return sorted(enabled, key=lambda x: x.get("priority", 999))
|
||||
|
||||
def get_mirror_by_id(self, mirror_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""根据 ID 获取镜像源"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
return mirror.copy()
|
||||
return None
|
||||
|
||||
def add_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: str,
|
||||
raw_prefix: str,
|
||||
clone_prefix: str,
|
||||
enabled: bool = True,
|
||||
priority: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
添加新的镜像源
|
||||
|
||||
Returns:
|
||||
添加的镜像源配置
|
||||
|
||||
Raises:
|
||||
ValueError: 如果镜像源 ID 已存在
|
||||
"""
|
||||
# 检查 ID 是否已存在
|
||||
if self.get_mirror_by_id(mirror_id):
|
||||
raise ValueError(f"镜像源 ID 已存在: {mirror_id}")
|
||||
|
||||
# 如果未指定优先级,使用最大优先级 + 1
|
||||
if priority is None:
|
||||
max_priority = max((m.get("priority", 0) for m in self.mirrors), default=0)
|
||||
priority = max_priority + 1
|
||||
|
||||
new_mirror = {
|
||||
"id": mirror_id,
|
||||
"name": name,
|
||||
"raw_prefix": raw_prefix,
|
||||
"clone_prefix": clone_prefix,
|
||||
"enabled": enabled,
|
||||
"priority": priority,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
self.mirrors.append(new_mirror)
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已添加镜像源: {mirror_id} - {name}")
|
||||
return new_mirror.copy()
|
||||
|
||||
def update_mirror(
|
||||
self,
|
||||
mirror_id: str,
|
||||
name: Optional[str] = None,
|
||||
raw_prefix: Optional[str] = None,
|
||||
clone_prefix: Optional[str] = None,
|
||||
enabled: Optional[bool] = None,
|
||||
priority: Optional[int] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
更新镜像源配置
|
||||
|
||||
Returns:
|
||||
更新后的镜像源配置,如果不存在则返回 None
|
||||
"""
|
||||
for mirror in self.mirrors:
|
||||
if mirror.get("id") == mirror_id:
|
||||
if name is not None:
|
||||
mirror["name"] = name
|
||||
if raw_prefix is not None:
|
||||
mirror["raw_prefix"] = raw_prefix
|
||||
if clone_prefix is not None:
|
||||
mirror["clone_prefix"] = clone_prefix
|
||||
if enabled is not None:
|
||||
mirror["enabled"] = enabled
|
||||
if priority is not None:
|
||||
mirror["priority"] = priority
|
||||
|
||||
mirror["updated_at"] = datetime.now().isoformat()
|
||||
self._save_config()
|
||||
|
||||
logger.info(f"已更新镜像源: {mirror_id}")
|
||||
return mirror.copy()
|
||||
|
||||
return None
|
||||
|
||||
def delete_mirror(self, mirror_id: str) -> bool:
|
||||
"""
|
||||
删除镜像源
|
||||
|
||||
Returns:
|
||||
True 如果删除成功,False 如果镜像源不存在
|
||||
"""
|
||||
for i, mirror in enumerate(self.mirrors):
|
||||
if mirror.get("id") == mirror_id:
|
||||
self.mirrors.pop(i)
|
||||
self._save_config()
|
||||
logger.info(f"已删除镜像源: {mirror_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def get_default_priority_list(self) -> List[str]:
|
||||
"""获取默认优先级列表(仅启用的镜像源 ID)"""
|
||||
enabled = self.get_enabled_mirrors()
|
||||
return [m["id"] for m in enabled]
|
||||
|
||||
|
||||
class GitMirrorService:
|
||||
"""Git 镜像源服务"""
|
||||
|
||||
def __init__(self, max_retries: int = 3, timeout: int = 30, config: Optional[GitMirrorConfig] = None):
|
||||
"""
|
||||
初始化 Git 镜像源服务
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
timeout: 请求超时时间(秒)
|
||||
config: 镜像源配置管理器(可选,默认创建新实例)
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
self.config = config or GitMirrorConfig()
|
||||
logger.info(f"Git镜像源服务初始化完成,已加载 {len(self.config.get_enabled_mirrors())} 个启用的镜像源")
|
||||
|
||||
def get_mirror_config(self) -> GitMirrorConfig:
|
||||
"""获取镜像源配置管理器"""
|
||||
return self.config
|
||||
|
||||
@staticmethod
|
||||
def check_git_installed() -> Dict[str, Any]:
|
||||
"""
|
||||
检查本机是否安装了 Git
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- installed: bool - 是否已安装 Git
|
||||
- version: str - Git 版本号(如果已安装)
|
||||
- path: str - Git 可执行文件路径(如果已安装)
|
||||
- error: str - 错误信息(如果未安装或检测失败)
|
||||
"""
|
||||
import subprocess
|
||||
import shutil
|
||||
|
||||
try:
|
||||
# 查找 git 可执行文件路径
|
||||
git_path = shutil.which("git")
|
||||
|
||||
if not git_path:
|
||||
logger.warning("未找到 Git 可执行文件")
|
||||
return {"installed": False, "error": "系统中未找到 Git,请先安装 Git"}
|
||||
|
||||
# 获取 Git 版本
|
||||
result = subprocess.run(["git", "--version"], capture_output=True, text=True, timeout=5)
|
||||
|
||||
if result.returncode == 0:
|
||||
version = result.stdout.strip()
|
||||
logger.info(f"检测到 Git: {version} at {git_path}")
|
||||
return {"installed": True, "version": version, "path": git_path}
|
||||
else:
|
||||
logger.warning(f"Git 命令执行失败: {result.stderr}")
|
||||
return {"installed": False, "error": f"Git 命令执行失败: {result.stderr}"}
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Git 版本检测超时")
|
||||
return {"installed": False, "error": "Git 版本检测超时"}
|
||||
except Exception as e:
|
||||
logger.error(f"检测 Git 时发生错误: {e}")
|
||||
return {"installed": False, "error": f"检测 Git 时发生错误: {str(e)}"}
|
||||
|
||||
async def fetch_raw_file(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
branch: str,
|
||||
file_path: str,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 GitHub 仓库的 Raw 文件内容
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
branch: 分支名称
|
||||
file_path: 文件路径
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义完整 URL(如果提供,将忽略其他参数)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- data: str - 文件内容(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始获取 Raw 文件: {owner}/{repo}/{branch}/{file_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._fetch_with_url(custom_url, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
total_mirrors = len(mirrors_to_try)
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for index, mirror in enumerate(mirrors_to_try, 1):
|
||||
# 推送进度:正在尝试第 N 个镜像源
|
||||
if _update_progress:
|
||||
try:
|
||||
progress = 30 + int((index - 1) / total_mirrors * 40) # 30% - 70%
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=progress,
|
||||
message=f"正在尝试镜像源 {index}/{total_mirrors}: {mirror['name']}",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
result = await self._fetch_raw_from_mirror(owner, repo, branch, file_path, mirror)
|
||||
|
||||
if result["success"]:
|
||||
# 成功,推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=70,
|
||||
message=f"成功从 {mirror['name']} 获取数据",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
return result
|
||||
|
||||
# 失败,记录日志并推送失败信息
|
||||
logger.warning(f"镜像源 {mirror['id']} 失败: {result.get('error')}")
|
||||
|
||||
if _update_progress and index < total_mirrors:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=30 + int(index / total_mirrors * 40),
|
||||
message=f"镜像源 {mirror['name']} 失败,尝试下一个...",
|
||||
total_plugins=0,
|
||||
loaded_plugins=0,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _fetch_raw_from_mirror(
|
||||
self, owner: str, repo: str, branch: str, file_path: str, mirror: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源获取文件"""
|
||||
# 构建 URL
|
||||
raw_prefix = mirror["raw_prefix"]
|
||||
url = f"{raw_prefix}/{owner}/{repo}/{branch}/{file_path}"
|
||||
|
||||
return await self._fetch_with_url(url, mirror["id"])
|
||||
|
||||
async def _fetch_with_url(self, url: str, mirror_type: str) -> Dict[str, Any]:
|
||||
"""使用指定 URL 获取文件,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
try:
|
||||
logger.debug(f"尝试 #{attempt + 1}: {url}")
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.info(f"成功获取文件: {url}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": response.text,
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
}
|
||||
except httpx.HTTPStatusError as e:
|
||||
last_error = f"HTTP {e.response.status_code}: {e}"
|
||||
logger.warning(f"HTTP 错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except httpx.TimeoutException as e:
|
||||
last_error = f"请求超时: {e}"
|
||||
logger.warning(f"超时 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
async def clone_repository(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str] = None,
|
||||
mirror_id: Optional[str] = None,
|
||||
custom_url: Optional[str] = None,
|
||||
depth: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
克隆 GitHub 仓库
|
||||
|
||||
Args:
|
||||
owner: 仓库所有者
|
||||
repo: 仓库名称
|
||||
target_path: 目标路径
|
||||
branch: 分支名称(可选)
|
||||
mirror_id: 指定的镜像源 ID
|
||||
custom_url: 自定义克隆 URL
|
||||
depth: 克隆深度(浅克隆)
|
||||
|
||||
Returns:
|
||||
Dict 包含:
|
||||
- success: bool - 是否成功
|
||||
- path: str - 克隆路径(成功时)
|
||||
- error: str - 错误信息(失败时)
|
||||
- mirror_used: str - 使用的镜像源
|
||||
- attempts: int - 尝试次数
|
||||
"""
|
||||
logger.info(f"开始克隆仓库: {owner}/{repo} 到 {target_path}")
|
||||
|
||||
if custom_url:
|
||||
# 使用自定义 URL
|
||||
return await self._clone_with_url(custom_url, target_path, branch, depth, "custom")
|
||||
|
||||
# 确定要使用的镜像源列表
|
||||
if mirror_id:
|
||||
# 使用指定的镜像源
|
||||
mirror = self.config.get_mirror_by_id(mirror_id)
|
||||
if not mirror:
|
||||
return {"success": False, "error": f"未找到镜像源: {mirror_id}", "mirror_used": None, "attempts": 0}
|
||||
mirrors_to_try = [mirror]
|
||||
else:
|
||||
# 使用所有启用的镜像源
|
||||
mirrors_to_try = self.config.get_enabled_mirrors()
|
||||
|
||||
# 依次尝试每个镜像源
|
||||
for mirror in mirrors_to_try:
|
||||
result = await self._clone_from_mirror(owner, repo, target_path, branch, depth, mirror)
|
||||
if result["success"]:
|
||||
return result
|
||||
logger.warning(f"镜像源 {mirror['id']} 克隆失败: {result.get('error')}")
|
||||
|
||||
# 所有镜像源都失败
|
||||
return {"success": False, "error": "所有镜像源克隆均失败", "mirror_used": None, "attempts": len(mirrors_to_try)}
|
||||
|
||||
async def _clone_from_mirror(
|
||||
self,
|
||||
owner: str,
|
||||
repo: str,
|
||||
target_path: Path,
|
||||
branch: Optional[str],
|
||||
depth: Optional[int],
|
||||
mirror: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""从指定镜像源克隆仓库"""
|
||||
# 构建克隆 URL
|
||||
clone_prefix = mirror["clone_prefix"]
|
||||
url = f"{clone_prefix}/{owner}/{repo}.git"
|
||||
|
||||
return await self._clone_with_url(url, target_path, branch, depth, mirror["id"])
|
||||
|
||||
async def _clone_with_url(
|
||||
self, url: str, target_path: Path, branch: Optional[str], depth: Optional[int], mirror_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""使用指定 URL 克隆仓库,支持重试"""
|
||||
attempts = 0
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
attempts += 1
|
||||
|
||||
try:
|
||||
# 确保目标路径不存在
|
||||
if target_path.exists():
|
||||
logger.warning(f"目标路径已存在,删除: {target_path}")
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
# 构建 git clone 命令
|
||||
cmd = ["git", "clone"]
|
||||
|
||||
# 添加分支参数
|
||||
if branch:
|
||||
cmd.extend(["-b", branch])
|
||||
|
||||
# 添加深度参数(浅克隆)
|
||||
if depth:
|
||||
cmd.extend(["--depth", str(depth)])
|
||||
|
||||
# 添加 URL 和目标路径
|
||||
cmd.extend([url, str(target_path)])
|
||||
|
||||
logger.info(f"尝试克隆 #{attempt + 1}: {' '.join(cmd)}")
|
||||
|
||||
# 推送进度
|
||||
if _update_progress:
|
||||
try:
|
||||
await _update_progress(
|
||||
stage="loading",
|
||||
progress=20 + attempt * 10,
|
||||
message=f"正在克隆仓库 (尝试 {attempt + 1}/{self.max_retries})...",
|
||||
operation="install",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"推送进度失败: {e}")
|
||||
|
||||
# 执行 git clone(在线程池中运行以避免阻塞)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def run_git_clone(clone_cmd=cmd):
|
||||
return subprocess.run(
|
||||
clone_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=300, # 5分钟超时
|
||||
)
|
||||
|
||||
process = await loop.run_in_executor(None, run_git_clone)
|
||||
|
||||
if process.returncode == 0:
|
||||
logger.info(f"成功克隆仓库: {url} -> {target_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"path": str(target_path),
|
||||
"mirror_used": mirror_type,
|
||||
"attempts": attempts,
|
||||
"url": url,
|
||||
"branch": branch or "default",
|
||||
}
|
||||
else:
|
||||
last_error = f"Git 克隆失败: {process.stderr}"
|
||||
logger.warning(f"克隆失败 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
last_error = "克隆超时(超过 5 分钟)"
|
||||
logger.warning(f"克隆超时 (尝试 {attempt + 1}/{self.max_retries})")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
except FileNotFoundError:
|
||||
last_error = "Git 未安装或不在 PATH 中"
|
||||
logger.error(f"Git 未找到: {last_error}")
|
||||
break # Git 不存在,不需要重试
|
||||
|
||||
except Exception as e:
|
||||
last_error = f"未知错误: {e}"
|
||||
logger.error(f"克隆错误 (尝试 {attempt + 1}/{self.max_retries}): {last_error}")
|
||||
|
||||
# 清理可能的部分克隆
|
||||
if target_path.exists():
|
||||
shutil.rmtree(target_path, ignore_errors=True)
|
||||
|
||||
return {"success": False, "error": last_error, "mirror_used": mirror_type, "attempts": attempts, "url": url}
|
||||
|
||||
|
||||
# 全局服务实例
|
||||
_git_mirror_service: Optional[GitMirrorService] = None
|
||||
|
||||
|
||||
def get_git_mirror_service() -> GitMirrorService:
|
||||
"""获取 Git 镜像源服务实例(单例)"""
|
||||
global _git_mirror_service
|
||||
if _git_mirror_service is None:
|
||||
_git_mirror_service = GitMirrorService()
|
||||
return _git_mirror_service
|
||||
|
|
@ -1,532 +0,0 @@
|
|||
"""黑话(俚语)管理路由"""
|
||||
|
||||
import json
|
||||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon, ChatStreams
|
||||
|
||||
logger = get_logger("webui.jargon")
|
||||
|
||||
router = APIRouter(prefix="/jargon", tags=["Jargon"])
|
||||
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
|
||||
def parse_chat_id_to_stream_ids(chat_id_str: str) -> List[str]:
|
||||
"""
|
||||
解析 chat_id 字段,提取所有 stream_id
|
||||
chat_id 格式: [["stream_id", user_id], ...] 或直接是 stream_id 字符串
|
||||
"""
|
||||
if not chat_id_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 尝试解析为 JSON
|
||||
parsed = json.loads(chat_id_str)
|
||||
if isinstance(parsed, list):
|
||||
# 格式: [["stream_id", user_id], ...]
|
||||
stream_ids = []
|
||||
for item in parsed:
|
||||
if isinstance(item, list) and len(item) >= 1:
|
||||
stream_ids.append(str(item[0]))
|
||||
return stream_ids
|
||||
else:
|
||||
# 其他格式,返回原始字符串
|
||||
return [chat_id_str]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 不是有效的 JSON,可能是直接的 stream_id
|
||||
return [chat_id_str]
|
||||
|
||||
|
||||
def get_display_name_for_chat_id(chat_id_str: str) -> str:
|
||||
"""
|
||||
获取 chat_id 的显示名称
|
||||
尝试解析 JSON 并查询 ChatStreams 表获取群聊名称
|
||||
"""
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id_str)
|
||||
|
||||
if not stream_ids:
|
||||
return chat_id_str
|
||||
|
||||
# 查询所有 stream_id 对应的名称
|
||||
names = []
|
||||
for stream_id in stream_ids:
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream and chat_stream.group_name:
|
||||
names.append(chat_stream.group_name)
|
||||
else:
|
||||
# 如果没找到,显示截断的 stream_id
|
||||
names.append(stream_id[:8] + "..." if len(stream_id) > 8 else stream_id)
|
||||
|
||||
return ", ".join(names) if names else chat_id_str
|
||||
|
||||
|
||||
# ==================== 请求/响应模型 ====================
|
||||
|
||||
|
||||
class JargonResponse(BaseModel):
|
||||
"""黑话信息响应"""
|
||||
|
||||
id: int
|
||||
content: str
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: str
|
||||
stream_id: Optional[str] = None # 解析后的 stream_id,用于前端编辑时匹配
|
||||
chat_name: Optional[str] = None # 解析后的聊天名称,用于前端显示
|
||||
is_global: bool = False
|
||||
count: int = 0
|
||||
is_jargon: Optional[bool] = None
|
||||
is_complete: bool = False
|
||||
inference_with_context: Optional[str] = None
|
||||
inference_content_only: Optional[str] = None
|
||||
|
||||
|
||||
class JargonListResponse(BaseModel):
|
||||
"""黑话列表响应"""
|
||||
|
||||
success: bool = True
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[JargonResponse]
|
||||
|
||||
|
||||
class JargonDetailResponse(BaseModel):
|
||||
"""黑话详情响应"""
|
||||
|
||||
success: bool = True
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonCreateRequest(BaseModel):
|
||||
"""黑话创建请求"""
|
||||
|
||||
content: str = Field(..., description="黑话内容")
|
||||
raw_content: Optional[str] = Field(None, description="原始内容")
|
||||
meaning: Optional[str] = Field(None, description="含义")
|
||||
chat_id: str = Field(..., description="聊天ID")
|
||||
is_global: bool = Field(False, description="是否全局")
|
||||
|
||||
|
||||
class JargonUpdateRequest(BaseModel):
|
||||
"""黑话更新请求"""
|
||||
|
||||
content: Optional[str] = None
|
||||
raw_content: Optional[str] = None
|
||||
meaning: Optional[str] = None
|
||||
chat_id: Optional[str] = None
|
||||
is_global: Optional[bool] = None
|
||||
is_jargon: Optional[bool] = None
|
||||
|
||||
|
||||
class JargonCreateResponse(BaseModel):
|
||||
"""黑话创建响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: JargonResponse
|
||||
|
||||
|
||||
class JargonUpdateResponse(BaseModel):
|
||||
"""黑话更新响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
data: Optional[JargonResponse] = None
|
||||
|
||||
|
||||
class JargonDeleteResponse(BaseModel):
|
||||
"""黑话删除响应"""
|
||||
|
||||
success: bool = True
|
||||
message: str
|
||||
deleted_count: int = 0
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
ids: List[int] = Field(..., description="要删除的黑话ID列表")
|
||||
|
||||
|
||||
class JargonStatsResponse(BaseModel):
|
||||
"""黑话统计响应"""
|
||||
|
||||
success: bool = True
|
||||
data: dict
|
||||
|
||||
|
||||
class ChatInfoResponse(BaseModel):
|
||||
"""聊天信息响应"""
|
||||
|
||||
chat_id: str
|
||||
chat_name: str
|
||||
platform: Optional[str] = None
|
||||
is_group: bool = False
|
||||
|
||||
|
||||
class ChatListResponse(BaseModel):
|
||||
"""聊天列表响应"""
|
||||
|
||||
success: bool = True
|
||||
data: List[ChatInfoResponse]
|
||||
|
||||
|
||||
# ==================== 工具函数 ====================
|
||||
|
||||
|
||||
def jargon_to_dict(jargon: Jargon) -> dict:
|
||||
"""将 Jargon ORM 对象转换为字典"""
|
||||
# 解析 chat_id 获取显示名称和 stream_id
|
||||
chat_name = get_display_name_for_chat_id(jargon.chat_id) if jargon.chat_id else None
|
||||
stream_ids = parse_chat_id_to_stream_ids(jargon.chat_id) if jargon.chat_id else []
|
||||
stream_id = stream_ids[0] if stream_ids else None
|
||||
|
||||
return {
|
||||
"id": jargon.id,
|
||||
"content": jargon.content,
|
||||
"raw_content": jargon.raw_content,
|
||||
"meaning": jargon.meaning,
|
||||
"chat_id": jargon.chat_id,
|
||||
"stream_id": stream_id,
|
||||
"chat_name": chat_name,
|
||||
"is_global": jargon.is_global,
|
||||
"count": jargon.count,
|
||||
"is_jargon": jargon.is_jargon,
|
||||
"is_complete": jargon.is_complete,
|
||||
"inference_with_context": jargon.inference_with_context,
|
||||
"inference_content_only": jargon.inference_content_only,
|
||||
}
|
||||
|
||||
|
||||
# ==================== API 端点 ====================
|
||||
|
||||
|
||||
@router.get("/list", response_model=JargonListResponse)
|
||||
async def get_jargon_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
chat_id: Optional[str] = Query(None, description="按聊天ID筛选"),
|
||||
is_jargon: Optional[bool] = Query(None, description="按是否是黑话筛选"),
|
||||
is_global: Optional[bool] = Query(None, description="按是否全局筛选"),
|
||||
):
|
||||
"""获取黑话列表"""
|
||||
try:
|
||||
# 构建查询
|
||||
query = Jargon.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(Jargon.content.contains(search))
|
||||
| (Jargon.meaning.contains(search))
|
||||
| (Jargon.raw_content.contains(search))
|
||||
)
|
||||
|
||||
# 按聊天ID筛选(使用 contains 匹配,因为 chat_id 是 JSON 格式)
|
||||
if chat_id:
|
||||
# 从传入的 chat_id 中解析出 stream_id
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
# 使用第一个 stream_id 进行模糊匹配
|
||||
query = query.where(Jargon.chat_id.contains(stream_ids[0]))
|
||||
else:
|
||||
# 如果无法解析,使用精确匹配
|
||||
query = query.where(Jargon.chat_id == chat_id)
|
||||
|
||||
# 按是否是黑话筛选
|
||||
if is_jargon is not None:
|
||||
query = query.where(Jargon.is_jargon == is_jargon)
|
||||
|
||||
# 按是否全局筛选
|
||||
if is_global is not None:
|
||||
query = query.where(Jargon.is_global == is_global)
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页和排序(按使用次数降序)
|
||||
query = query.order_by(Jargon.count.desc(), Jargon.id.desc())
|
||||
query = query.paginate(page, page_size)
|
||||
|
||||
# 转换为响应格式
|
||||
data = [jargon_to_dict(j) for j in query]
|
||||
|
||||
return JargonListResponse(
|
||||
success=True,
|
||||
total=total,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
data=data,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/chats", response_model=ChatListResponse)
|
||||
async def get_chat_list():
|
||||
"""获取所有有黑话记录的聊天列表"""
|
||||
try:
|
||||
# 获取所有不同的 chat_id
|
||||
chat_ids = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False))
|
||||
|
||||
chat_id_list = [j.chat_id for j in chat_ids if j.chat_id]
|
||||
|
||||
# 用于按 stream_id 去重
|
||||
seen_stream_ids: set[str] = set()
|
||||
|
||||
for chat_id in chat_id_list:
|
||||
stream_ids = parse_chat_id_to_stream_ids(chat_id)
|
||||
if stream_ids:
|
||||
seen_stream_ids.add(stream_ids[0])
|
||||
|
||||
result = []
|
||||
for stream_id in seen_stream_ids:
|
||||
# 尝试从 ChatStreams 表获取聊天名称
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == stream_id)
|
||||
if chat_stream:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id,方便筛选匹配
|
||||
chat_name=chat_stream.group_name or stream_id,
|
||||
platform=chat_stream.platform,
|
||||
is_group=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(
|
||||
ChatInfoResponse(
|
||||
chat_id=stream_id, # 使用 stream_id
|
||||
chat_name=stream_id[:8] + "..." if len(stream_id) > 8 else stream_id,
|
||||
platform=None,
|
||||
is_group=False,
|
||||
)
|
||||
)
|
||||
|
||||
return ChatListResponse(success=True, data=result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取聊天列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary", response_model=JargonStatsResponse)
|
||||
async def get_jargon_stats():
|
||||
"""获取黑话统计数据"""
|
||||
try:
|
||||
# 总数量
|
||||
total = Jargon.select().count()
|
||||
|
||||
# 已确认是黑话的数量
|
||||
confirmed_jargon = Jargon.select().where(Jargon.is_jargon).count()
|
||||
|
||||
# 已确认不是黑话的数量
|
||||
confirmed_not_jargon = Jargon.select().where(~Jargon.is_jargon).count()
|
||||
|
||||
# 未判定的数量
|
||||
pending = Jargon.select().where(Jargon.is_jargon.is_null()).count()
|
||||
|
||||
# 全局黑话数量
|
||||
global_count = Jargon.select().where(Jargon.is_global).count()
|
||||
|
||||
# 已完成推断的数量
|
||||
complete_count = Jargon.select().where(Jargon.is_complete).count()
|
||||
|
||||
# 关联的聊天数量
|
||||
chat_count = Jargon.select(Jargon.chat_id).distinct().where(Jargon.chat_id.is_null(False)).count()
|
||||
|
||||
# 按聊天统计 TOP 5
|
||||
top_chats = (
|
||||
Jargon.select(Jargon.chat_id, fn.COUNT(Jargon.id).alias("count"))
|
||||
.group_by(Jargon.chat_id)
|
||||
.order_by(fn.COUNT(Jargon.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
top_chats_dict = {j.chat_id: j.count for j in top_chats if j.chat_id}
|
||||
|
||||
return JargonStatsResponse(
|
||||
success=True,
|
||||
data={
|
||||
"total": total,
|
||||
"confirmed_jargon": confirmed_jargon,
|
||||
"confirmed_not_jargon": confirmed_not_jargon,
|
||||
"pending": pending,
|
||||
"global_count": global_count,
|
||||
"complete_count": complete_count,
|
||||
"chat_count": chat_count,
|
||||
"top_chats": top_chats_dict,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话统计失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{jargon_id}", response_model=JargonDetailResponse)
|
||||
async def get_jargon_detail(jargon_id: int):
|
||||
"""获取黑话详情"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
return JargonDetailResponse(success=True, data=jargon_to_dict(jargon))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取黑话详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取黑话详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/", response_model=JargonCreateResponse)
|
||||
async def create_jargon(request: JargonCreateRequest):
|
||||
"""创建黑话"""
|
||||
try:
|
||||
# 检查是否已存在相同内容的黑话
|
||||
existing = Jargon.get_or_none((Jargon.content == request.content) & (Jargon.chat_id == request.chat_id))
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="该聊天中已存在相同内容的黑话")
|
||||
|
||||
# 创建黑话
|
||||
jargon = Jargon.create(
|
||||
content=request.content,
|
||||
raw_content=request.raw_content,
|
||||
meaning=request.meaning,
|
||||
chat_id=request.chat_id,
|
||||
is_global=request.is_global,
|
||||
count=0,
|
||||
is_jargon=None,
|
||||
is_complete=False,
|
||||
)
|
||||
|
||||
logger.info(f"创建黑话成功: id={jargon.id}, content={request.content}")
|
||||
|
||||
return JargonCreateResponse(
|
||||
success=True,
|
||||
message="创建成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"创建黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"创建黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{jargon_id}", response_model=JargonUpdateResponse)
|
||||
async def update_jargon(jargon_id: int, request: JargonUpdateRequest):
|
||||
"""更新黑话(增量更新)"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
# 增量更新字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
if update_data:
|
||||
for field, value in update_data.items():
|
||||
if value is not None or field in ["meaning", "raw_content", "is_jargon"]:
|
||||
setattr(jargon, field, value)
|
||||
jargon.save()
|
||||
|
||||
logger.info(f"更新黑话成功: id={jargon_id}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message="更新成功",
|
||||
data=jargon_to_dict(jargon),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"更新黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{jargon_id}", response_model=JargonDeleteResponse)
|
||||
async def delete_jargon(jargon_id: int):
|
||||
"""删除黑话"""
|
||||
try:
|
||||
jargon = Jargon.get_or_none(Jargon.id == jargon_id)
|
||||
if not jargon:
|
||||
raise HTTPException(status_code=404, detail="黑话不存在")
|
||||
|
||||
content = jargon.content
|
||||
jargon.delete_instance()
|
||||
|
||||
logger.info(f"删除黑话成功: id={jargon_id}, content={content}")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message="删除成功",
|
||||
deleted_count=1,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=JargonDeleteResponse)
|
||||
async def batch_delete_jargons(request: BatchDeleteRequest):
|
||||
"""批量删除黑话"""
|
||||
try:
|
||||
if not request.ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
deleted_count = Jargon.delete().where(Jargon.id.in_(request.ids)).execute()
|
||||
|
||||
logger.info(f"批量删除黑话成功: 删除了 {deleted_count} 条记录")
|
||||
|
||||
return JargonDeleteResponse(
|
||||
success=True,
|
||||
message=f"成功删除 {deleted_count} 条黑话",
|
||||
deleted_count=deleted_count,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量删除黑话失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除黑话失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/set-jargon", response_model=JargonUpdateResponse)
|
||||
async def batch_set_jargon_status(
|
||||
ids: Annotated[List[int], Query(description="黑话ID列表")],
|
||||
is_jargon: Annotated[bool, Query(description="是否是黑话")],
|
||||
):
|
||||
"""批量设置黑话状态"""
|
||||
try:
|
||||
if not ids:
|
||||
raise HTTPException(status_code=400, detail="ID列表不能为空")
|
||||
|
||||
updated_count = Jargon.update(is_jargon=is_jargon).where(Jargon.id.in_(ids)).execute()
|
||||
|
||||
logger.info(f"批量更新黑话状态成功: 更新了 {updated_count} 条记录,is_jargon={is_jargon}")
|
||||
|
||||
return JargonUpdateResponse(
|
||||
success=True,
|
||||
message=f"成功更新 {updated_count} 条黑话状态",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"批量更新黑话状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量更新黑话状态失败: {str(e)}") from e
|
||||
|
|
@ -1,298 +0,0 @@
|
|||
"""知识库图谱可视化 API 路由"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/api/webui/knowledge", tags=["knowledge"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class KnowledgeNode(BaseModel):
|
||||
"""知识节点"""
|
||||
|
||||
id: str
|
||||
type: str # 'entity' or 'paragraph'
|
||||
content: str
|
||||
create_time: Optional[float] = None
|
||||
|
||||
|
||||
class KnowledgeEdge(BaseModel):
|
||||
"""知识边"""
|
||||
|
||||
source: str
|
||||
target: str
|
||||
weight: float
|
||||
create_time: Optional[float] = None
|
||||
update_time: Optional[float] = None
|
||||
|
||||
|
||||
class KnowledgeGraph(BaseModel):
|
||||
"""知识图谱"""
|
||||
|
||||
nodes: List[KnowledgeNode]
|
||||
edges: List[KnowledgeEdge]
|
||||
|
||||
|
||||
class KnowledgeStats(BaseModel):
|
||||
"""知识库统计信息"""
|
||||
|
||||
total_nodes: int
|
||||
total_edges: int
|
||||
entity_nodes: int
|
||||
paragraph_nodes: int
|
||||
avg_connections: float
|
||||
|
||||
|
||||
def _load_kg_manager():
|
||||
"""延迟加载 KGManager"""
|
||||
try:
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
|
||||
kg_manager = KGManager()
|
||||
kg_manager.load_from_file()
|
||||
return kg_manager
|
||||
except Exception as e:
|
||||
logger.error(f"加载 KGManager 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _convert_graph_to_json(kg_manager) -> KnowledgeGraph:
|
||||
"""将 DiGraph 转换为 JSON 格式"""
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
graph = kg_manager.graph
|
||||
nodes = []
|
||||
edges = []
|
||||
|
||||
# 转换节点
|
||||
node_list = graph.get_node_list()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
# 节点类型: "ent" -> "entity", "pg" -> "paragraph"
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
# 转换边
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
try:
|
||||
# edge_tuple 是 (source, target) 元组
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 通过 graph[source, target] 获取边的属性数据
|
||||
edge_data = graph[source, target]
|
||||
|
||||
# edge_data 支持 [] 操作符但不支持 .get()
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
return KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
|
||||
|
||||
@router.get("/graph", response_model=KnowledgeGraph)
|
||||
async def get_knowledge_graph(
|
||||
limit: int = Query(100, ge=1, le=10000, description="返回的最大节点数"),
|
||||
node_type: str = Query("all", description="节点类型过滤: all, entity, paragraph"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""获取知识图谱(限制节点数量)
|
||||
|
||||
Args:
|
||||
limit: 返回的最大节点数,默认 100,最大 10000
|
||||
node_type: 节点类型过滤 - all(全部), entity(实体), paragraph(段落)
|
||||
|
||||
Returns:
|
||||
KnowledgeGraph: 包含指定数量节点和相关边的知识图谱
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None:
|
||||
logger.warning("KGManager 未初始化,返回空图谱")
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
graph = kg_manager.graph
|
||||
all_node_list = graph.get_node_list()
|
||||
|
||||
# 按类型过滤节点
|
||||
if node_type == "entity":
|
||||
all_node_list = [
|
||||
n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "ent"
|
||||
]
|
||||
elif node_type == "paragraph":
|
||||
all_node_list = [n for n in all_node_list if n in graph and "type" in graph[n] and graph[n]["type"] == "pg"]
|
||||
|
||||
# 限制节点数量
|
||||
total_nodes = len(all_node_list)
|
||||
if len(all_node_list) > limit:
|
||||
node_list = all_node_list[:limit]
|
||||
else:
|
||||
node_list = all_node_list
|
||||
|
||||
logger.info(f"总节点数: {total_nodes}, 返回节点: {len(node_list)} (limit={limit}, type={node_type})")
|
||||
|
||||
# 转换节点
|
||||
nodes = []
|
||||
node_ids = set()
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type_val = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
|
||||
nodes.append(KnowledgeNode(id=node_id, type=node_type_val, content=content, create_time=create_time))
|
||||
node_ids.add(node_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过节点 {node_id}: {e}")
|
||||
continue
|
||||
|
||||
# 只获取涉及当前节点集的边(保证图的完整性)
|
||||
edges = []
|
||||
edge_list = graph.get_edge_list()
|
||||
for edge_tuple in edge_list:
|
||||
try:
|
||||
source, target = edge_tuple[0], edge_tuple[1]
|
||||
# 只包含两端都在当前节点集中的边
|
||||
if source not in node_ids or target not in node_ids:
|
||||
continue
|
||||
|
||||
edge_data = graph[source, target]
|
||||
weight = edge_data["weight"] if "weight" in edge_data else 1.0
|
||||
create_time = edge_data["create_time"] if "create_time" in edge_data else None
|
||||
update_time = edge_data["update_time"] if "update_time" in edge_data else None
|
||||
|
||||
edges.append(
|
||||
KnowledgeEdge(
|
||||
source=source, target=target, weight=weight, create_time=create_time, update_time=update_time
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过边 {edge_tuple}: {e}")
|
||||
continue
|
||||
|
||||
graph_data = KnowledgeGraph(nodes=nodes, edges=edges)
|
||||
logger.info(f"返回知识图谱: {len(nodes)} 个节点, {len(edges)} 条边")
|
||||
return graph_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取知识图谱失败: {e}", exc_info=True)
|
||||
return KnowledgeGraph(nodes=[], edges=[])
|
||||
|
||||
|
||||
@router.get("/stats", response_model=KnowledgeStats)
|
||||
async def get_knowledge_stats(_auth: bool = Depends(require_auth)):
|
||||
"""获取知识库统计信息
|
||||
|
||||
Returns:
|
||||
KnowledgeStats: 统计信息
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
edge_list = graph.get_edge_list()
|
||||
|
||||
total_nodes = len(node_list)
|
||||
total_edges = len(edge_list)
|
||||
|
||||
# 统计节点类型
|
||||
entity_nodes = 0
|
||||
paragraph_nodes = 0
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
node_type = node_data["type"] if "type" in node_data else "ent"
|
||||
if node_type == "ent":
|
||||
entity_nodes += 1
|
||||
elif node_type == "pg":
|
||||
paragraph_nodes += 1
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 计算平均连接数
|
||||
avg_connections = (total_edges * 2) / total_nodes if total_nodes > 0 else 0.0
|
||||
|
||||
return KnowledgeStats(
|
||||
total_nodes=total_nodes,
|
||||
total_edges=total_edges,
|
||||
entity_nodes=entity_nodes,
|
||||
paragraph_nodes=paragraph_nodes,
|
||||
avg_connections=round(avg_connections, 2),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计信息失败: {e}", exc_info=True)
|
||||
return KnowledgeStats(total_nodes=0, total_edges=0, entity_nodes=0, paragraph_nodes=0, avg_connections=0.0)
|
||||
|
||||
|
||||
@router.get("/search", response_model=List[KnowledgeNode])
|
||||
async def search_knowledge_node(query: str = Query(..., min_length=1), _auth: bool = Depends(require_auth)):
|
||||
"""搜索知识节点
|
||||
|
||||
Args:
|
||||
query: 搜索关键词
|
||||
|
||||
Returns:
|
||||
List[KnowledgeNode]: 匹配的节点列表
|
||||
"""
|
||||
try:
|
||||
kg_manager = _load_kg_manager()
|
||||
if kg_manager is None or kg_manager.graph is None:
|
||||
return []
|
||||
|
||||
graph = kg_manager.graph
|
||||
node_list = graph.get_node_list()
|
||||
results = []
|
||||
query_lower = query.lower()
|
||||
|
||||
# 在节点内容中搜索
|
||||
for node_id in node_list:
|
||||
try:
|
||||
node_data = graph[node_id]
|
||||
content = node_data["content"] if "content" in node_data else node_id
|
||||
node_type = "entity" if ("type" in node_data and node_data["type"] == "ent") else "paragraph"
|
||||
|
||||
if query_lower in content.lower() or query_lower in node_id.lower():
|
||||
create_time = node_data["create_time"] if "create_time" in node_data else None
|
||||
results.append(KnowledgeNode(id=node_id, type=node_type, content=content, create_time=create_time))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
logger.info(f"搜索 '{query}' 找到 {len(results)} 个节点")
|
||||
return results[:50] # 限制返回数量
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"搜索节点失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
|
@ -1,177 +0,0 @@
|
|||
"""WebSocket 日志推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Optional
|
||||
import json
|
||||
from pathlib import Path
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.logs_ws")
|
||||
router = APIRouter()
|
||||
|
||||
# 全局 WebSocket 连接池
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
|
||||
def load_recent_logs(limit: int = 100) -> list[dict]:
|
||||
"""从日志文件中加载最近的日志
|
||||
|
||||
Args:
|
||||
limit: 返回的最大日志条数
|
||||
|
||||
Returns:
|
||||
日志列表
|
||||
"""
|
||||
logs = []
|
||||
log_dir = Path("logs")
|
||||
|
||||
if not log_dir.exists():
|
||||
return logs
|
||||
|
||||
# 获取所有日志文件,按修改时间排序
|
||||
log_files = sorted(log_dir.glob("app_*.log.jsonl"), key=lambda f: f.stat().st_mtime, reverse=True)
|
||||
|
||||
# 用于生成唯一 ID 的计数器
|
||||
log_counter = 0
|
||||
|
||||
# 从最新的文件开始读取
|
||||
for log_file in log_files:
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
|
||||
try:
|
||||
with open(log_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
# 从文件末尾开始读取
|
||||
for line in reversed(lines):
|
||||
if len(logs) >= limit:
|
||||
break
|
||||
try:
|
||||
log_entry = json.loads(line.strip())
|
||||
# 转换为前端期望的格式
|
||||
# 使用时间戳 + 计数器生成唯一 ID
|
||||
timestamp_id = (
|
||||
log_entry.get("timestamp", "0").replace("-", "").replace(" ", "").replace(":", "")
|
||||
)
|
||||
formatted_log = {
|
||||
"id": f"{timestamp_id}_{log_counter}",
|
||||
"timestamp": log_entry.get("timestamp", ""),
|
||||
"level": log_entry.get("level", "INFO").upper(),
|
||||
"module": log_entry.get("logger_name", ""),
|
||||
"message": log_entry.get("event", ""),
|
||||
}
|
||||
logs.append(formatted_log)
|
||||
log_counter += 1
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"读取日志文件失败 {log_file}: {e}")
|
||||
continue
|
||||
|
||||
# 反转列表,使其按时间顺序排列(旧到新)
|
||||
return list(reversed(logs))
|
||||
|
||||
|
||||
@router.websocket("/ws/logs")
|
||||
async def websocket_logs(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 日志推送端点
|
||||
|
||||
客户端连接后会持续接收服务器端的日志消息
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/logs?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
# 连接建立后,立即发送历史日志
|
||||
try:
|
||||
recent_logs = load_recent_logs(limit=100)
|
||||
logger.info(f"发送 {len(recent_logs)} 条历史日志到客户端")
|
||||
|
||||
for log_entry in recent_logs:
|
||||
await websocket.send_text(json.dumps(log_entry, ensure_ascii=False))
|
||||
except Exception as e:
|
||||
logger.error(f"发送历史日志失败: {e}")
|
||||
|
||||
try:
|
||||
# 保持连接,等待客户端消息或断开
|
||||
while True:
|
||||
# 接收客户端消息(用于心跳或控制指令)
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 可以处理客户端的控制消息,例如:
|
||||
# - "ping" -> 心跳检测
|
||||
# - {"filter": "ERROR"} -> 设置日志级别过滤
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
async def broadcast_log(log_data: dict):
|
||||
"""广播日志到所有连接的 WebSocket 客户端
|
||||
|
||||
Args:
|
||||
log_data: 日志数据字典
|
||||
"""
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
# 格式化为 JSON
|
||||
message = json.dumps(log_data, ensure_ascii=False)
|
||||
|
||||
# 记录需要断开的连接
|
||||
disconnected = set()
|
||||
|
||||
# 广播到所有客户端
|
||||
for connection in active_connections:
|
||||
try:
|
||||
await connection.send_text(message)
|
||||
except Exception:
|
||||
# 发送失败,标记为断开
|
||||
disconnected.add(connection)
|
||||
|
||||
# 清理断开的连接
|
||||
if disconnected:
|
||||
active_connections.difference_update(disconnected)
|
||||
logger.debug(f"清理了 {len(disconnected)} 个断开的 WebSocket 连接")
|
||||
|
|
@ -1,383 +0,0 @@
|
|||
"""
|
||||
模型列表获取API路由
|
||||
|
||||
提供从各个 AI 厂商 API 获取可用模型列表的代理接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Query, Depends, Cookie, Header
|
||||
from typing import Optional
|
||||
import tomlkit
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import CONFIG_DIR
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
router = APIRouter(prefix="/models", tags=["models"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
# 模型获取器配置
|
||||
MODEL_FETCHER_CONFIG = {
|
||||
# OpenAI 兼容格式的提供商
|
||||
"openai": {
|
||||
"endpoint": "/models",
|
||||
"parser": "openai",
|
||||
},
|
||||
# Gemini 格式
|
||||
"gemini": {
|
||||
"endpoint": "/models",
|
||||
"parser": "gemini",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _normalize_url(url: str) -> str:
|
||||
"""规范化 URL(去掉尾部斜杠)"""
|
||||
if not url:
|
||||
return ""
|
||||
return url.rstrip("/")
|
||||
|
||||
|
||||
def _parse_openai_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 OpenAI 格式的模型列表响应
|
||||
|
||||
格式: { "data": [{ "id": "gpt-4", "object": "model", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "data" in data and isinstance(data["data"], list):
|
||||
for model in data["data"]:
|
||||
if isinstance(model, dict) and "id" in model:
|
||||
models.append(
|
||||
{
|
||||
"id": model["id"],
|
||||
"name": model.get("name") or model["id"],
|
||||
"owned_by": model.get("owned_by", ""),
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
def _parse_gemini_response(data: dict) -> list[dict]:
|
||||
"""
|
||||
解析 Gemini 格式的模型列表响应
|
||||
|
||||
格式: { "models": [{ "name": "models/gemini-pro", "displayName": "Gemini Pro", ... }] }
|
||||
"""
|
||||
models = []
|
||||
if "models" in data and isinstance(data["models"], list):
|
||||
for model in data["models"]:
|
||||
if isinstance(model, dict) and "name" in model:
|
||||
# Gemini 的 name 格式是 "models/gemini-pro",我们只取后面部分
|
||||
model_id = model["name"]
|
||||
if model_id.startswith("models/"):
|
||||
model_id = model_id[7:] # 去掉 "models/" 前缀
|
||||
models.append(
|
||||
{
|
||||
"id": model_id,
|
||||
"name": model.get("displayName") or model_id,
|
||||
"owned_by": "google",
|
||||
}
|
||||
)
|
||||
return models
|
||||
|
||||
|
||||
async def _fetch_models_from_provider(
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
endpoint: str,
|
||||
parser: str,
|
||||
client_type: str = "openai",
|
||||
) -> list[dict]:
|
||||
"""
|
||||
从提供商 API 获取模型列表
|
||||
|
||||
Args:
|
||||
base_url: 提供商的基础 URL
|
||||
api_key: API 密钥
|
||||
endpoint: 获取模型列表的端点
|
||||
parser: 响应解析器类型 ('openai' | 'gemini')
|
||||
client_type: 客户端类型 ('openai' | 'gemini')
|
||||
|
||||
Returns:
|
||||
模型列表
|
||||
"""
|
||||
url = f"{_normalize_url(base_url)}{endpoint}"
|
||||
|
||||
# 根据客户端类型设置请求头
|
||||
headers = {}
|
||||
params = {}
|
||||
|
||||
if client_type == "gemini":
|
||||
# Gemini 使用 URL 参数传递 API Key
|
||||
params["key"] = api_key
|
||||
else:
|
||||
# OpenAI 兼容格式使用 Authorization 头
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except httpx.TimeoutException as e:
|
||||
raise HTTPException(status_code=504, detail="请求超时,请稍后重试") from e
|
||||
except httpx.HTTPStatusError as e:
|
||||
# 注意:使用 502 Bad Gateway 而不是原始的 401/403,
|
||||
# 因为前端的 fetchWithAuth 会把 401 当作 WebUI 认证失败处理
|
||||
if e.response.status_code == 401:
|
||||
raise HTTPException(status_code=502, detail="API Key 无效或已过期") from e
|
||||
elif e.response.status_code == 403:
|
||||
raise HTTPException(status_code=502, detail="没有权限访问模型列表,请检查 API Key 权限") from e
|
||||
elif e.response.status_code == 404:
|
||||
raise HTTPException(status_code=502, detail="该提供商不支持获取模型列表") from e
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502, detail=f"上游服务请求失败 ({e.response.status_code}): {e.response.text[:200]}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取模型列表失败: {str(e)}") from e
|
||||
|
||||
# 根据解析器类型解析响应
|
||||
if parser == "openai":
|
||||
return _parse_openai_response(data)
|
||||
elif parser == "gemini":
|
||||
return _parse_gemini_response(data)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的解析器类型: {parser}")
|
||||
|
||||
|
||||
def _get_provider_config(provider_name: str) -> Optional[dict]:
|
||||
"""
|
||||
从 model_config.toml 获取指定提供商的配置
|
||||
|
||||
Args:
|
||||
provider_name: 提供商名称
|
||||
|
||||
Returns:
|
||||
提供商配置,如果未找到则返回 None
|
||||
"""
|
||||
config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
providers = config_data.get("api_providers", [])
|
||||
for provider in providers:
|
||||
if provider.get("name") == provider_name:
|
||||
return dict(provider)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"读取提供商配置失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/list")
|
||||
async def get_provider_models(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
获取指定提供商的可用模型列表
|
||||
|
||||
通过提供商名称查找配置,然后请求对应的模型列表端点
|
||||
"""
|
||||
# 获取提供商配置
|
||||
provider_config = _get_provider_config(provider_name)
|
||||
if not provider_config:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
base_url = provider_config.get("base_url")
|
||||
api_key = provider_config.get("api_key")
|
||||
client_type = provider_config.get("client_type", "openai")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 api_key")
|
||||
|
||||
# 获取模型列表
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
"provider": provider_name,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/list-by-url")
|
||||
async def get_models_by_url(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: str = Query(..., description="API Key"),
|
||||
parser: str = Query("openai", description="响应解析器类型 (openai | gemini)"),
|
||||
endpoint: str = Query("/models", description="获取模型列表的端点"),
|
||||
client_type: str = Query("openai", description="客户端类型 (openai | gemini)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过 URL 直接获取模型列表(用于自定义提供商)
|
||||
"""
|
||||
models = await _fetch_models_from_provider(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
parser=parser,
|
||||
client_type=client_type,
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"models": models,
|
||||
"count": len(models),
|
||||
}
|
||||
|
||||
|
||||
@router.get("/test-connection")
|
||||
async def test_provider_connection(
|
||||
base_url: str = Query(..., description="提供商的基础 URL"),
|
||||
api_key: Optional[str] = Query(None, description="API Key(可选,用于验证 Key 有效性)"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
测试提供商连接状态
|
||||
|
||||
分两步测试:
|
||||
1. 网络连通性测试:向 base_url 发送请求,检查是否能连接
|
||||
2. API Key 验证(可选):如果提供了 api_key,尝试获取模型列表验证 Key 是否有效
|
||||
|
||||
返回:
|
||||
- network_ok: 网络是否连通
|
||||
- api_key_valid: API Key 是否有效(仅在提供 api_key 时返回)
|
||||
- latency_ms: 响应延迟(毫秒)
|
||||
- error: 错误信息(如果有)
|
||||
"""
|
||||
import time
|
||||
|
||||
base_url = _normalize_url(base_url)
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="base_url 不能为空")
|
||||
|
||||
result = {
|
||||
"network_ok": False,
|
||||
"api_key_valid": None,
|
||||
"latency_ms": None,
|
||||
"error": None,
|
||||
"http_status": None,
|
||||
}
|
||||
|
||||
# 第一步:测试网络连通性
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=10.0, follow_redirects=True) as client:
|
||||
# 尝试 GET 请求 base_url(不需要 API Key)
|
||||
response = await client.get(base_url)
|
||||
latency = (time.time() - start_time) * 1000
|
||||
|
||||
result["network_ok"] = True
|
||||
result["latency_ms"] = round(latency, 2)
|
||||
result["http_status"] = response.status_code
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
result["error"] = f"连接失败:无法连接到服务器 ({str(e)})"
|
||||
return result
|
||||
except httpx.TimeoutException:
|
||||
result["error"] = "连接超时:服务器响应时间过长"
|
||||
return result
|
||||
except httpx.RequestError as e:
|
||||
result["error"] = f"请求错误:{str(e)}"
|
||||
return result
|
||||
except Exception as e:
|
||||
result["error"] = f"未知错误:{str(e)}"
|
||||
return result
|
||||
|
||||
# 第二步:如果提供了 API Key,验证其有效性
|
||||
if api_key:
|
||||
try:
|
||||
start_time = time.time()
|
||||
async with httpx.AsyncClient(timeout=15.0, follow_redirects=True) as client:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 尝试获取模型列表
|
||||
models_url = f"{base_url}/models"
|
||||
response = await client.get(models_url, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
result["api_key_valid"] = True
|
||||
elif response.status_code in (401, 403):
|
||||
result["api_key_valid"] = False
|
||||
result["error"] = "API Key 无效或已过期"
|
||||
else:
|
||||
# 其他状态码,可能是端点不支持,但 Key 可能是有效的
|
||||
result["api_key_valid"] = None
|
||||
|
||||
except Exception as e:
|
||||
# API Key 验证失败不影响网络连通性结果
|
||||
logger.warning(f"API Key 验证失败: {e}")
|
||||
result["api_key_valid"] = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.post("/test-connection-by-name")
|
||||
async def test_provider_connection_by_name(
|
||||
provider_name: str = Query(..., description="提供商名称"),
|
||||
_auth: bool = Depends(require_auth),
|
||||
):
|
||||
"""
|
||||
通过提供商名称测试连接(从配置文件读取信息)
|
||||
"""
|
||||
# 读取配置文件
|
||||
model_config_path = os.path.join(CONFIG_DIR, "model_config.toml")
|
||||
if not os.path.exists(model_config_path):
|
||||
raise HTTPException(status_code=404, detail="配置文件不存在")
|
||||
|
||||
with open(model_config_path, "r", encoding="utf-8") as f:
|
||||
config = tomlkit.load(f)
|
||||
|
||||
# 查找提供商
|
||||
providers = config.get("api_providers", [])
|
||||
provider = None
|
||||
for p in providers:
|
||||
if p.get("name") == provider_name:
|
||||
provider = p
|
||||
break
|
||||
|
||||
if not provider:
|
||||
raise HTTPException(status_code=404, detail=f"未找到提供商: {provider_name}")
|
||||
|
||||
base_url = provider.get("base_url", "")
|
||||
api_key = provider.get("api_key", "")
|
||||
|
||||
if not base_url:
|
||||
raise HTTPException(status_code=400, detail="提供商配置缺少 base_url")
|
||||
|
||||
# 调用测试接口
|
||||
return await test_provider_connection(base_url=base_url, api_key=api_key if api_key else None)
|
||||
|
|
@ -1,416 +0,0 @@
|
|||
"""人物信息管理 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from .auth import verify_auth_token_from_cookie_or_header
|
||||
import json
|
||||
import time
|
||||
|
||||
logger = get_logger("webui.person")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/person", tags=["Person"])
|
||||
|
||||
|
||||
class PersonInfoResponse(BaseModel):
|
||||
"""人物信息响应"""
|
||||
|
||||
id: int
|
||||
is_known: bool
|
||||
person_id: str
|
||||
person_name: Optional[str]
|
||||
name_reason: Optional[str]
|
||||
platform: str
|
||||
user_id: str
|
||||
nickname: Optional[str]
|
||||
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
|
||||
memory_points: Optional[str]
|
||||
know_times: Optional[float]
|
||||
know_since: Optional[float]
|
||||
last_know: Optional[float]
|
||||
|
||||
|
||||
class PersonListResponse(BaseModel):
|
||||
"""人物列表响应"""
|
||||
|
||||
success: bool
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
data: List[PersonInfoResponse]
|
||||
|
||||
|
||||
class PersonDetailResponse(BaseModel):
|
||||
"""人物详情响应"""
|
||||
|
||||
success: bool
|
||||
data: PersonInfoResponse
|
||||
|
||||
|
||||
class PersonUpdateRequest(BaseModel):
|
||||
"""人物信息更新请求"""
|
||||
|
||||
person_name: Optional[str] = None
|
||||
name_reason: Optional[str] = None
|
||||
nickname: Optional[str] = None
|
||||
memory_points: Optional[str] = None
|
||||
is_known: Optional[bool] = None
|
||||
|
||||
|
||||
class PersonUpdateResponse(BaseModel):
|
||||
"""人物信息更新响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
data: Optional[PersonInfoResponse] = None
|
||||
|
||||
|
||||
class PersonDeleteResponse(BaseModel):
|
||||
"""人物删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class BatchDeleteRequest(BaseModel):
|
||||
"""批量删除请求"""
|
||||
|
||||
person_ids: List[str]
|
||||
|
||||
|
||||
class BatchDeleteResponse(BaseModel):
|
||||
"""批量删除响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
deleted_count: int
|
||||
failed_count: int
|
||||
failed_ids: List[str] = []
|
||||
|
||||
|
||||
def verify_auth_token(
|
||||
maibot_session: Optional[str] = None,
|
||||
authorization: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""验证认证 Token,支持 Cookie 和 Header"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[Dict[str, str]]]:
|
||||
"""解析群昵称 JSON 字符串"""
|
||||
if not group_nick_name_str:
|
||||
return None
|
||||
try:
|
||||
return json.loads(group_nick_name_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
|
||||
"""将 PersonInfo 模型转换为响应对象"""
|
||||
return PersonInfoResponse(
|
||||
id=person.id,
|
||||
is_known=person.is_known,
|
||||
person_id=person.person_id,
|
||||
person_name=person.person_name,
|
||||
name_reason=person.name_reason,
|
||||
platform=person.platform,
|
||||
user_id=person.user_id,
|
||||
nickname=person.nickname,
|
||||
group_nick_name=parse_group_nick_name(person.group_nick_name),
|
||||
memory_points=person.memory_points,
|
||||
know_times=person.know_times,
|
||||
know_since=person.know_since,
|
||||
last_know=person.last_know,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/list", response_model=PersonListResponse)
|
||||
async def get_person_list(
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
page_size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_known: Optional[bool] = Query(None, description="是否已认识筛选"),
|
||||
platform: Optional[str] = Query(None, description="平台筛选"),
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取人物信息列表
|
||||
|
||||
Args:
|
||||
page: 页码 (从 1 开始)
|
||||
page_size: 每页数量 (1-100)
|
||||
search: 搜索关键词 (匹配 person_name, nickname, user_id)
|
||||
is_known: 是否已认识筛选
|
||||
platform: 平台筛选
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
人物信息列表
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
# 构建查询
|
||||
query = PersonInfo.select()
|
||||
|
||||
# 搜索过滤
|
||||
if search:
|
||||
query = query.where(
|
||||
(PersonInfo.person_name.contains(search))
|
||||
| (PersonInfo.nickname.contains(search))
|
||||
| (PersonInfo.user_id.contains(search))
|
||||
)
|
||||
|
||||
# 已认识状态过滤
|
||||
if is_known is not None:
|
||||
query = query.where(PersonInfo.is_known == is_known)
|
||||
|
||||
# 平台过滤
|
||||
if platform:
|
||||
query = query.where(PersonInfo.platform == platform)
|
||||
|
||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
||||
# 分页
|
||||
offset = (page - 1) * page_size
|
||||
persons = query.offset(offset).limit(page_size)
|
||||
|
||||
# 转换为响应对象
|
||||
data = [person_to_response(person) for person in persons]
|
||||
|
||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取人物列表失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取人物列表失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/{person_id}", response_model=PersonDetailResponse)
|
||||
async def get_person_detail(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
获取人物详细信息
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
人物详细信息
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
return PersonDetailResponse(success=True, data=person_to_response(person))
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取人物详情失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取人物详情失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.patch("/{person_id}", response_model=PersonUpdateResponse)
|
||||
async def update_person(
|
||||
person_id: str,
|
||||
request: PersonUpdateRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
增量更新人物信息(只更新提供的字段)
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
request: 更新请求(只包含需要更新的字段)
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
# 只更新提供的字段
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
|
||||
if not update_data:
|
||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||
|
||||
# 更新最后修改时间
|
||||
update_data["last_know"] = time.time()
|
||||
|
||||
# 执行更新
|
||||
for field, value in update_data.items():
|
||||
setattr(person, field, value)
|
||||
|
||||
person.save()
|
||||
|
||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||
|
||||
return PersonUpdateResponse(
|
||||
success=True, message=f"成功更新 {len(update_data)} 个字段", data=person_to_response(person)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"更新人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"更新人物信息失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.delete("/{person_id}", response_model=PersonDeleteResponse)
|
||||
async def delete_person(
|
||||
person_id: str, maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
删除人物信息
|
||||
|
||||
Args:
|
||||
person_id: 人物唯一 ID
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
|
||||
if not person:
|
||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||
|
||||
# 记录删除信息
|
||||
person_name = person.person_name or person.nickname or person.user_id
|
||||
|
||||
# 执行删除
|
||||
person.delete_instance()
|
||||
|
||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||
|
||||
return PersonDeleteResponse(success=True, message=f"成功删除人物信息: {person_name}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"删除人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"删除人物信息失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/stats/summary")
|
||||
async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authorization: Optional[str] = Header(None)):
|
||||
"""
|
||||
获取人物信息统计数据
|
||||
|
||||
Args:
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
统计数据
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
total = PersonInfo.select().count()
|
||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
||||
unknown = total - known
|
||||
|
||||
# 按平台统计
|
||||
platforms = {}
|
||||
for person in PersonInfo.select(PersonInfo.platform):
|
||||
platform = person.platform
|
||||
platforms[platform] = platforms.get(platform, 0) + 1
|
||||
|
||||
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"获取统计数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.post("/batch/delete", response_model=BatchDeleteResponse)
|
||||
async def batch_delete_persons(
|
||||
request: BatchDeleteRequest,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
批量删除人物信息
|
||||
|
||||
Args:
|
||||
request: 包含person_ids列表的请求
|
||||
authorization: Authorization header
|
||||
|
||||
Returns:
|
||||
批量删除结果
|
||||
"""
|
||||
try:
|
||||
verify_auth_token(maibot_session, authorization)
|
||||
|
||||
if not request.person_ids:
|
||||
raise HTTPException(status_code=400, detail="未提供要删除的人物ID")
|
||||
|
||||
deleted_count = 0
|
||||
failed_count = 0
|
||||
failed_ids = []
|
||||
|
||||
for person_id in request.person_ids:
|
||||
try:
|
||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
||||
if person:
|
||||
person.delete_instance()
|
||||
deleted_count += 1
|
||||
logger.info(f"批量删除: {person_id}")
|
||||
else:
|
||||
failed_count += 1
|
||||
failed_ids.append(person_id)
|
||||
except Exception as e:
|
||||
logger.error(f"删除 {person_id} 失败: {e}")
|
||||
failed_count += 1
|
||||
failed_ids.append(person_id)
|
||||
|
||||
message = f"成功删除 {deleted_count} 个人物"
|
||||
if failed_count > 0:
|
||||
message += f",{failed_count} 个失败"
|
||||
|
||||
return BatchDeleteResponse(
|
||||
success=True,
|
||||
message=message,
|
||||
deleted_count=deleted_count,
|
||||
failed_count=failed_count,
|
||||
failed_ids=failed_ids,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"批量删除人物信息失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e
|
||||
|
|
@ -1,164 +0,0 @@
|
|||
"""WebSocket 插件加载进度推送模块"""
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query
|
||||
from typing import Set, Dict, Any, Optional
|
||||
import json
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
from src.webui.ws_auth import verify_ws_token
|
||||
|
||||
logger = get_logger("webui.plugin_progress")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter()
|
||||
|
||||
# 全局 WebSocket 连接池
|
||||
active_connections: Set[WebSocket] = set()
|
||||
|
||||
# 当前加载进度状态
|
||||
current_progress: Dict[str, Any] = {
|
||||
"operation": "idle", # idle, fetch, install, uninstall, update
|
||||
"stage": "idle", # idle, loading, success, error
|
||||
"progress": 0, # 0-100
|
||||
"message": "",
|
||||
"error": None,
|
||||
"plugin_id": None, # 当前操作的插件 ID
|
||||
"total_plugins": 0,
|
||||
"loaded_plugins": 0,
|
||||
}
|
||||
|
||||
|
||||
async def broadcast_progress(progress_data: Dict[str, Any]):
|
||||
"""广播进度更新到所有连接的客户端"""
|
||||
global current_progress
|
||||
current_progress = progress_data.copy()
|
||||
|
||||
if not active_connections:
|
||||
return
|
||||
|
||||
message = json.dumps(progress_data, ensure_ascii=False)
|
||||
disconnected = set()
|
||||
|
||||
for websocket in active_connections:
|
||||
try:
|
||||
await websocket.send_text(message)
|
||||
except Exception as e:
|
||||
logger.error(f"发送进度更新失败: {e}")
|
||||
disconnected.add(websocket)
|
||||
|
||||
# 移除断开的连接
|
||||
for websocket in disconnected:
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
async def update_progress(
|
||||
stage: str,
|
||||
progress: int,
|
||||
message: str,
|
||||
operation: str = "fetch",
|
||||
error: str = None,
|
||||
plugin_id: str = None,
|
||||
total_plugins: int = 0,
|
||||
loaded_plugins: int = 0,
|
||||
):
|
||||
"""更新并广播进度
|
||||
|
||||
Args:
|
||||
stage: 阶段 (idle, loading, success, error)
|
||||
progress: 进度百分比 (0-100)
|
||||
message: 当前消息
|
||||
operation: 操作类型 (fetch, install, uninstall, update)
|
||||
error: 错误信息(可选)
|
||||
plugin_id: 当前操作的插件 ID
|
||||
total_plugins: 总插件数
|
||||
loaded_plugins: 已加载插件数
|
||||
"""
|
||||
progress_data = {
|
||||
"operation": operation,
|
||||
"stage": stage,
|
||||
"progress": progress,
|
||||
"message": message,
|
||||
"error": error,
|
||||
"plugin_id": plugin_id,
|
||||
"total_plugins": total_plugins,
|
||||
"loaded_plugins": loaded_plugins,
|
||||
"timestamp": asyncio.get_event_loop().time(),
|
||||
}
|
||||
|
||||
await broadcast_progress(progress_data)
|
||||
logger.debug(f"进度更新: [{operation}] {stage} - {progress}% - {message}")
|
||||
|
||||
|
||||
@router.websocket("/ws/plugin-progress")
|
||||
async def websocket_plugin_progress(websocket: WebSocket, token: Optional[str] = Query(None)):
|
||||
"""WebSocket 插件加载进度推送端点
|
||||
|
||||
客户端连接后会立即收到当前进度状态
|
||||
支持三种认证方式(按优先级):
|
||||
1. query 参数 token(推荐,通过 /api/webui/ws-token 获取临时 token)
|
||||
2. Cookie 中的 maibot_session
|
||||
3. 直接使用 session token(兼容)
|
||||
|
||||
示例:ws://host/ws/plugin-progress?token=xxx
|
||||
"""
|
||||
is_authenticated = False
|
||||
|
||||
# 方式 1: 尝试验证临时 WebSocket token(推荐方式)
|
||||
if token and verify_ws_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用临时 token 认证成功")
|
||||
|
||||
# 方式 2: 尝试从 Cookie 获取 session token
|
||||
if not is_authenticated:
|
||||
cookie_token = websocket.cookies.get("maibot_session")
|
||||
if cookie_token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(cookie_token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 Cookie 认证成功")
|
||||
|
||||
# 方式 3: 尝试直接验证 query 参数作为 session token(兼容旧方式)
|
||||
if not is_authenticated and token:
|
||||
token_manager = get_token_manager()
|
||||
if token_manager.verify_token(token):
|
||||
is_authenticated = True
|
||||
logger.debug("插件进度 WebSocket 使用 session token 认证成功")
|
||||
|
||||
if not is_authenticated:
|
||||
logger.warning("插件进度 WebSocket 连接被拒绝:认证失败")
|
||||
await websocket.close(code=4001, reason="认证失败,请重新登录")
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
active_connections.add(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已连接(已认证),当前连接数: {len(active_connections)}")
|
||||
|
||||
try:
|
||||
# 发送当前进度状态
|
||||
await websocket.send_text(json.dumps(current_progress, ensure_ascii=False))
|
||||
|
||||
# 保持连接并处理客户端消息
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_text()
|
||||
|
||||
# 处理客户端心跳
|
||||
if data == "ping":
|
||||
await websocket.send_text("pong")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理客户端消息时出错: {e}")
|
||||
break
|
||||
|
||||
except WebSocketDisconnect:
|
||||
active_connections.discard(websocket)
|
||||
logger.info(f"📡 插件进度 WebSocket 客户端已断开,当前连接数: {len(active_connections)}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebSocket 错误: {e}")
|
||||
active_connections.discard(websocket)
|
||||
|
||||
|
||||
def get_progress_router() -> APIRouter:
|
||||
"""获取插件进度 WebSocket 路由器"""
|
||||
return router
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,245 +0,0 @@
|
|||
"""
|
||||
WebUI 请求频率限制模块
|
||||
防止暴力破解和 API 滥用
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, Tuple, Optional
|
||||
from fastapi import Request, HTTPException
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui.rate_limiter")
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""
|
||||
简单的内存请求频率限制器
|
||||
|
||||
使用滑动窗口算法实现
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 存储格式: {key: [(timestamp, count), ...]}
|
||||
self._requests: Dict[str, list] = defaultdict(list)
|
||||
# 被封禁的 IP: {ip: unblock_timestamp}
|
||||
self._blocked: Dict[str, float] = {}
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""获取客户端 IP 地址"""
|
||||
# 检查代理头
|
||||
forwarded = request.headers.get("X-Forwarded-For")
|
||||
if forwarded:
|
||||
# 取第一个 IP(最原始的客户端)
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("X-Real-IP")
|
||||
if real_ip:
|
||||
return real_ip
|
||||
|
||||
# 直接连接的客户端
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return "unknown"
|
||||
|
||||
def _cleanup_old_requests(self, key: str, window_seconds: int):
|
||||
"""清理过期的请求记录"""
|
||||
now = time.time()
|
||||
cutoff = now - window_seconds
|
||||
self._requests[key] = [(ts, count) for ts, count in self._requests[key] if ts > cutoff]
|
||||
|
||||
def _cleanup_expired_blocks(self):
|
||||
"""清理过期的封禁"""
|
||||
now = time.time()
|
||||
expired = [ip for ip, unblock_time in self._blocked.items() if now > unblock_time]
|
||||
for ip in expired:
|
||||
del self._blocked[ip]
|
||||
logger.info(f"🔓 IP {ip} 封禁已解除")
|
||||
|
||||
def is_blocked(self, request: Request) -> Tuple[bool, Optional[int]]:
|
||||
"""
|
||||
检查 IP 是否被封禁
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余封禁秒数)
|
||||
"""
|
||||
self._cleanup_expired_blocks()
|
||||
ip = self._get_client_ip(request)
|
||||
|
||||
if ip in self._blocked:
|
||||
remaining = int(self._blocked[ip] - time.time())
|
||||
return True, max(0, remaining)
|
||||
|
||||
return False, None
|
||||
|
||||
def check_rate_limit(
|
||||
self, request: Request, max_requests: int, window_seconds: int, key_suffix: str = ""
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
检查请求是否超过频率限制
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_requests: 窗口期内允许的最大请求数
|
||||
window_seconds: 窗口时间(秒)
|
||||
key_suffix: 键后缀,用于区分不同的限制规则
|
||||
|
||||
Returns:
|
||||
(是否允许, 剩余请求数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:{key_suffix}" if key_suffix else ip
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前窗口内的请求数
|
||||
current_count = sum(count for _, count in self._requests[key])
|
||||
|
||||
if current_count >= max_requests:
|
||||
return False, 0
|
||||
|
||||
# 记录新请求
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
|
||||
remaining = max_requests - current_count - 1
|
||||
return True, remaining
|
||||
|
||||
def block_ip(self, request: Request, duration_seconds: int):
|
||||
"""
|
||||
封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
duration_seconds: 封禁时长(秒)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
self._blocked[ip] = time.time() + duration_seconds
|
||||
logger.warning(f"🔒 IP {ip} 已被封禁 {duration_seconds} 秒")
|
||||
|
||||
def record_failed_attempt(
|
||||
self, request: Request, max_failures: int = 5, window_seconds: int = 300, block_duration: int = 600
|
||||
) -> Tuple[bool, int]:
|
||||
"""
|
||||
记录失败尝试(如登录失败)
|
||||
|
||||
如果在窗口期内失败次数过多,自动封禁 IP
|
||||
|
||||
Args:
|
||||
request: FastAPI Request 对象
|
||||
max_failures: 允许的最大失败次数
|
||||
window_seconds: 统计窗口(秒)
|
||||
block_duration: 封禁时长(秒)
|
||||
|
||||
Returns:
|
||||
(是否被封禁, 剩余尝试次数)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
|
||||
# 清理过期记录
|
||||
self._cleanup_old_requests(key, window_seconds)
|
||||
|
||||
# 计算当前失败次数
|
||||
current_failures = sum(count for _, count in self._requests[key])
|
||||
|
||||
# 记录本次失败
|
||||
now = time.time()
|
||||
self._requests[key].append((now, 1))
|
||||
current_failures += 1
|
||||
|
||||
remaining = max_failures - current_failures
|
||||
|
||||
# 检查是否需要封禁
|
||||
if current_failures >= max_failures:
|
||||
self.block_ip(request, block_duration)
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败次数过多 ({current_failures}/{max_failures}),已封禁")
|
||||
return True, 0
|
||||
|
||||
if current_failures >= max_failures - 2:
|
||||
logger.warning(f"⚠️ IP {ip} 认证失败 {current_failures}/{max_failures} 次")
|
||||
|
||||
return False, max(0, remaining)
|
||||
|
||||
def reset_failures(self, request: Request):
|
||||
"""
|
||||
重置失败计数(认证成功后调用)
|
||||
"""
|
||||
ip = self._get_client_ip(request)
|
||||
key = f"{ip}:auth_failures"
|
||||
if key in self._requests:
|
||||
del self._requests[key]
|
||||
|
||||
|
||||
# 全局单例
|
||||
_rate_limiter: Optional[RateLimiter] = None
|
||||
|
||||
|
||||
def get_rate_limiter() -> RateLimiter:
|
||||
"""获取 RateLimiter 单例"""
|
||||
global _rate_limiter
|
||||
if _rate_limiter is None:
|
||||
_rate_limiter = RateLimiter()
|
||||
return _rate_limiter
|
||||
|
||||
|
||||
async def check_auth_rate_limit(request: Request):
|
||||
"""
|
||||
认证接口的频率限制依赖
|
||||
|
||||
规则:
|
||||
- 每个 IP 每分钟最多 10 次认证请求
|
||||
- 连续失败 5 次后封禁 10 分钟
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, remaining = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=10, # 每分钟 10 次
|
||||
window_seconds=60,
|
||||
key_suffix="auth",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="认证请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
||||
|
||||
async def check_api_rate_limit(request: Request):
|
||||
"""
|
||||
普通 API 的频率限制依赖
|
||||
|
||||
规则:每个 IP 每分钟最多 100 次请求
|
||||
"""
|
||||
limiter = get_rate_limiter()
|
||||
|
||||
# 检查是否被封禁
|
||||
blocked, remaining_block = limiter.is_blocked(request)
|
||||
if blocked:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"请求过于频繁,请在 {remaining_block} 秒后重试",
|
||||
headers={"Retry-After": str(remaining_block)},
|
||||
)
|
||||
|
||||
# 检查频率限制
|
||||
allowed, _ = limiter.check_rate_limit(
|
||||
request,
|
||||
max_requests=100, # 每分钟 100 次
|
||||
window_seconds=60,
|
||||
key_suffix="api",
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试", headers={"Retry-After": "60"})
|
||||
|
|
@ -1,113 +0,0 @@
|
|||
"""
|
||||
系统控制路由
|
||||
|
||||
提供系统重启、状态查询等功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
from src.config.config import MMC_VERSION
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
router = APIRouter(prefix="/system", tags=["system"])
|
||||
logger = get_logger("webui_system")
|
||||
|
||||
# 记录启动时间
|
||||
_start_time = time.time()
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class RestartResponse(BaseModel):
|
||||
"""重启响应"""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class StatusResponse(BaseModel):
|
||||
"""状态响应"""
|
||||
|
||||
running: bool
|
||||
uptime: float
|
||||
version: str
|
||||
start_time: str
|
||||
|
||||
|
||||
@router.post("/restart", response_model=RestartResponse)
|
||||
async def restart_maibot(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
重启麦麦主程序
|
||||
|
||||
请求重启当前进程,配置更改将在重启后生效。
|
||||
注意:此操作会使麦麦暂时离线。
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
# 记录重启操作
|
||||
logger.info("WebUI 触发重启操作")
|
||||
|
||||
# 定义延迟重启的异步任务
|
||||
async def delayed_restart():
|
||||
await asyncio.sleep(0.5) # 延迟0.5秒,确保响应已发送
|
||||
# 使用 os._exit(42) 退出当前进程,配合外部 runner 脚本进行重启
|
||||
# 42 是约定的重启状态码
|
||||
logger.info("WebUI 请求重启,退出代码 42")
|
||||
os._exit(42)
|
||||
|
||||
# 创建后台任务执行重启
|
||||
asyncio.create_task(delayed_restart())
|
||||
|
||||
# 立即返回成功响应
|
||||
return RestartResponse(success=True, message="麦麦正在重启中...")
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"重启失败: {str(e)}") from e
|
||||
|
||||
|
||||
@router.get("/status", response_model=StatusResponse)
|
||||
async def get_maibot_status(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取麦麦运行状态
|
||||
|
||||
返回麦麦的运行状态、运行时长和版本信息。
|
||||
"""
|
||||
try:
|
||||
uptime = time.time() - _start_time
|
||||
|
||||
# 尝试获取版本信息(需要根据实际情况调整)
|
||||
version = MMC_VERSION # 可以从配置或常量中读取
|
||||
|
||||
return StatusResponse(
|
||||
running=True, uptime=uptime, version=version, start_time=datetime.fromtimestamp(_start_time).isoformat()
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"获取状态失败: {str(e)}") from e
|
||||
|
||||
|
||||
# 可选:添加更多系统控制功能
|
||||
|
||||
|
||||
@router.post("/reload-config")
|
||||
async def reload_config(_auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
热重载配置(不重启进程)
|
||||
|
||||
仅重新加载配置文件,某些配置可能需要重启才能生效。
|
||||
此功能需要在主程序中实现配置热重载逻辑。
|
||||
"""
|
||||
# 这里需要调用主程序的配置重载函数
|
||||
# 示例:await app_instance.reload_config()
|
||||
|
||||
return {"success": True, "message": "配置重载功能待实现"}
|
||||
|
|
@ -1,456 +0,0 @@
|
|||
"""WebUI API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Header, Response, Request, Cookie, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional
|
||||
from src.common.logger import get_logger
|
||||
from .token_manager import get_token_manager
|
||||
from .auth import set_auth_cookie, clear_auth_cookie
|
||||
from .rate_limiter import get_rate_limiter, check_auth_rate_limit
|
||||
from .config_routes import router as config_router
|
||||
from .statistics_routes import router as statistics_router
|
||||
from .person_routes import router as person_router
|
||||
from .expression_routes import router as expression_router
|
||||
from .jargon_routes import router as jargon_router
|
||||
from .emoji_routes import router as emoji_router
|
||||
from .plugin_routes import router as plugin_router
|
||||
from .plugin_progress_ws import get_progress_router
|
||||
from .routers.system import router as system_router
|
||||
from .model_routes import router as model_router
|
||||
from .ws_auth import router as ws_auth_router
|
||||
from .annual_report_routes import router as annual_report_router
|
||||
|
||||
logger = get_logger("webui.api")
|
||||
|
||||
# 创建路由器
|
||||
router = APIRouter(prefix="/api/webui", tags=["WebUI"])
|
||||
|
||||
# 注册配置管理路由
|
||||
router.include_router(config_router)
|
||||
# 注册统计数据路由
|
||||
router.include_router(statistics_router)
|
||||
# 注册人物信息管理路由
|
||||
router.include_router(person_router)
|
||||
# 注册表达方式管理路由
|
||||
router.include_router(expression_router)
|
||||
# 注册黑话管理路由
|
||||
router.include_router(jargon_router)
|
||||
# 注册表情包管理路由
|
||||
router.include_router(emoji_router)
|
||||
# 注册插件管理路由
|
||||
router.include_router(plugin_router)
|
||||
# 注册插件进度 WebSocket 路由
|
||||
router.include_router(get_progress_router())
|
||||
# 注册系统控制路由
|
||||
router.include_router(system_router)
|
||||
# 注册模型列表获取路由
|
||||
router.include_router(model_router)
|
||||
# 注册 WebSocket 认证路由
|
||||
router.include_router(ws_auth_router)
|
||||
# 注册年度报告路由
|
||||
router.include_router(annual_report_router)
|
||||
|
||||
|
||||
class TokenVerifyRequest(BaseModel):
|
||||
"""Token 验证请求"""
|
||||
|
||||
token: str = Field(..., description="访问令牌")
|
||||
|
||||
|
||||
class TokenVerifyResponse(BaseModel):
|
||||
"""Token 验证响应"""
|
||||
|
||||
valid: bool = Field(..., description="Token 是否有效")
|
||||
message: str = Field(..., description="验证结果消息")
|
||||
is_first_setup: bool = Field(False, description="是否为首次设置")
|
||||
|
||||
|
||||
class TokenUpdateRequest(BaseModel):
|
||||
"""Token 更新请求"""
|
||||
|
||||
new_token: str = Field(..., description="新的访问令牌", min_length=10)
|
||||
|
||||
|
||||
class TokenUpdateResponse(BaseModel):
|
||||
"""Token 更新响应"""
|
||||
|
||||
success: bool = Field(..., description="是否更新成功")
|
||||
message: str = Field(..., description="更新结果消息")
|
||||
|
||||
|
||||
class TokenRegenerateResponse(BaseModel):
|
||||
"""Token 重新生成响应"""
|
||||
|
||||
success: bool = Field(..., description="是否生成成功")
|
||||
token: str = Field(..., description="新生成的令牌")
|
||||
message: str = Field(..., description="生成结果消息")
|
||||
|
||||
|
||||
class FirstSetupStatusResponse(BaseModel):
|
||||
"""首次配置状态响应"""
|
||||
|
||||
is_first_setup: bool = Field(..., description="是否为首次配置")
|
||||
message: str = Field(..., description="状态消息")
|
||||
|
||||
|
||||
class CompleteSetupResponse(BaseModel):
|
||||
"""完成配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
class ResetSetupResponse(BaseModel):
|
||||
"""重置配置响应"""
|
||||
|
||||
success: bool = Field(..., description="是否成功")
|
||||
message: str = Field(..., description="结果消息")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""健康检查"""
|
||||
return {"status": "healthy", "service": "MaiBot WebUI"}
|
||||
|
||||
|
||||
@router.post("/auth/verify", response_model=TokenVerifyResponse)
|
||||
async def verify_token(
|
||||
request_body: TokenVerifyRequest,
|
||||
request: Request,
|
||||
response: Response,
|
||||
_rate_limit: None = Depends(check_auth_rate_limit),
|
||||
):
|
||||
"""
|
||||
验证访问令牌,验证成功后设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
request_body: 包含 token 的验证请求
|
||||
request: FastAPI Request 对象(用于获取客户端 IP)
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
验证结果(包含首次配置状态)
|
||||
"""
|
||||
try:
|
||||
token_manager = get_token_manager()
|
||||
rate_limiter = get_rate_limiter()
|
||||
|
||||
is_valid = token_manager.verify_token(request_body.token)
|
||||
|
||||
if is_valid:
|
||||
# 认证成功,重置失败计数
|
||||
rate_limiter.reset_failures(request)
|
||||
# 设置 HttpOnly Cookie(传入 request 以检测协议)
|
||||
set_auth_cookie(response, request_body.token, request)
|
||||
# 同时返回首次配置状态,避免额外请求
|
||||
is_first_setup = token_manager.is_first_setup()
|
||||
return TokenVerifyResponse(valid=True, message="Token 验证成功", is_first_setup=is_first_setup)
|
||||
else:
|
||||
# 记录失败尝试
|
||||
blocked, remaining = rate_limiter.record_failed_attempt(
|
||||
request,
|
||||
max_failures=5, # 5 次失败
|
||||
window_seconds=300, # 5 分钟窗口
|
||||
block_duration=600, # 封禁 10 分钟
|
||||
)
|
||||
|
||||
if blocked:
|
||||
raise HTTPException(status_code=429, detail="认证失败次数过多,您的 IP 已被临时封禁 10 分钟")
|
||||
|
||||
message = "Token 无效或已过期"
|
||||
if remaining <= 2:
|
||||
message += f"(剩余 {remaining} 次尝试机会)"
|
||||
|
||||
return TokenVerifyResponse(valid=False, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 验证失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 验证失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/logout")
|
||||
async def logout(response: Response):
|
||||
"""
|
||||
登出并清除认证 Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
|
||||
Returns:
|
||||
登出结果
|
||||
"""
|
||||
clear_auth_cookie(response)
|
||||
return {"success": True, "message": "已成功登出"}
|
||||
|
||||
|
||||
@router.get("/auth/check")
|
||||
async def check_auth_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
检查当前认证状态(用于前端判断是否已登录)
|
||||
|
||||
Returns:
|
||||
认证状态
|
||||
"""
|
||||
try:
|
||||
token = None
|
||||
|
||||
# 记录请求信息用于调试
|
||||
logger.debug(f"检查认证状态 - Cookie: {maibot_session[:20] if maibot_session else 'None'}..., Authorization: {'Present' if authorization else 'None'}")
|
||||
|
||||
# 优先从 Cookie 获取
|
||||
if maibot_session:
|
||||
token = maibot_session
|
||||
logger.debug("使用 Cookie 中的 token")
|
||||
# 其次从 Header 获取
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
token = authorization.replace("Bearer ", "")
|
||||
logger.debug("使用 Header 中的 token")
|
||||
|
||||
if not token:
|
||||
logger.debug("未找到 token,返回未认证")
|
||||
return {"authenticated": False}
|
||||
|
||||
token_manager = get_token_manager()
|
||||
is_valid = token_manager.verify_token(token)
|
||||
logger.debug(f"Token 验证结果: {is_valid}")
|
||||
|
||||
if is_valid:
|
||||
return {"authenticated": True}
|
||||
else:
|
||||
return {"authenticated": False}
|
||||
except Exception as e:
|
||||
logger.error(f"认证检查失败: {e}", exc_info=True)
|
||||
return {"authenticated": False}
|
||||
|
||||
|
||||
@router.post("/auth/update", response_model=TokenUpdateResponse)
|
||||
async def update_token(
|
||||
request: TokenUpdateRequest,
|
||||
response: Response,
|
||||
req: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
更新访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
request: 包含新 token 的更新请求
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
更新结果
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 更新 token
|
||||
success, message = token_manager.update_token(request.new_token)
|
||||
|
||||
# 如果更新成功,清除 Cookie,要求用户重新登录
|
||||
if success:
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenUpdateResponse(success=success, message=message)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 更新失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 更新失败") from e
|
||||
|
||||
|
||||
@router.post("/auth/regenerate", response_model=TokenRegenerateResponse)
|
||||
async def regenerate_token(
|
||||
response: Response,
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重新生成访问令牌(需要当前有效的 token)
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
新生成的 token
|
||||
"""
|
||||
try:
|
||||
# 验证当前 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="当前 Token 无效")
|
||||
|
||||
# 重新生成 token
|
||||
new_token = token_manager.regenerate_token()
|
||||
|
||||
# 清除 Cookie,要求用户重新登录
|
||||
clear_auth_cookie(response)
|
||||
|
||||
return TokenRegenerateResponse(success=True, token=new_token, message="Token 已重新生成")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Token 重新生成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="Token 重新生成失败") from e
|
||||
|
||||
|
||||
@router.get("/setup/status", response_model=FirstSetupStatusResponse)
|
||||
async def get_setup_status(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取首次配置状态
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
首次配置状态
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 检查是否为首次配置
|
||||
is_first = token_manager.is_first_setup()
|
||||
|
||||
return FirstSetupStatusResponse(is_first_setup=is_first, message="首次配置" if is_first else "已完成配置")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取配置状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="获取配置状态失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/complete", response_model=CompleteSetupResponse)
|
||||
async def complete_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
标记首次配置完成
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
完成结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 标记配置完成
|
||||
success = token_manager.mark_setup_completed()
|
||||
|
||||
return CompleteSetupResponse(success=success, message="配置已完成" if success else "标记失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"标记配置完成失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="标记配置完成失败") from e
|
||||
|
||||
|
||||
@router.post("/setup/reset", response_model=ResetSetupResponse)
|
||||
async def reset_setup(
|
||||
request: Request,
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Args:
|
||||
maibot_session: Cookie 中的 token
|
||||
authorization: Authorization header (Bearer token)
|
||||
|
||||
Returns:
|
||||
重置结果
|
||||
"""
|
||||
try:
|
||||
# 验证 token(优先 Cookie,其次 Header)
|
||||
current_token = None
|
||||
if maibot_session:
|
||||
current_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
current_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not current_token:
|
||||
raise HTTPException(status_code=401, detail="未提供有效的认证信息")
|
||||
|
||||
token_manager = get_token_manager()
|
||||
|
||||
if not token_manager.verify_token(current_token):
|
||||
raise HTTPException(status_code=401, detail="Token 无效")
|
||||
|
||||
# 重置配置状态
|
||||
success = token_manager.reset_setup_status()
|
||||
|
||||
return ResetSetupResponse(success=success, message="配置状态已重置" if success else "重置失败")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"重置配置状态失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="重置配置状态失败") from e
|
||||
|
|
@ -1,319 +0,0 @@
|
|||
"""统计数据 API 路由"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from peewee import fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||
from src.webui.auth import verify_auth_token_from_cookie_or_header
|
||||
|
||||
logger = get_logger("webui.statistics")
|
||||
|
||||
router = APIRouter(prefix="/statistics", tags=["statistics"])
|
||||
|
||||
|
||||
def require_auth(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
) -> bool:
|
||||
"""认证依赖:验证用户是否已登录"""
|
||||
return verify_auth_token_from_cookie_or_header(maibot_session, authorization)
|
||||
|
||||
|
||||
class StatisticsSummary(BaseModel):
|
||||
"""统计数据摘要"""
|
||||
|
||||
total_requests: int = Field(0, description="总请求数")
|
||||
total_cost: float = Field(0.0, description="总花费")
|
||||
total_tokens: int = Field(0, description="总token数")
|
||||
online_time: float = Field(0.0, description="在线时间(秒)")
|
||||
total_messages: int = Field(0, description="总消息数")
|
||||
total_replies: int = Field(0, description="总回复数")
|
||||
avg_response_time: float = Field(0.0, description="平均响应时间")
|
||||
cost_per_hour: float = Field(0.0, description="每小时花费")
|
||||
tokens_per_hour: float = Field(0.0, description="每小时token数")
|
||||
|
||||
|
||||
class ModelStatistics(BaseModel):
|
||||
"""模型统计"""
|
||||
|
||||
model_name: str
|
||||
request_count: int
|
||||
total_cost: float
|
||||
total_tokens: int
|
||||
avg_response_time: float
|
||||
|
||||
|
||||
class TimeSeriesData(BaseModel):
|
||||
"""时间序列数据"""
|
||||
|
||||
timestamp: str
|
||||
requests: int = 0
|
||||
cost: float = 0.0
|
||||
tokens: int = 0
|
||||
|
||||
|
||||
class DashboardData(BaseModel):
|
||||
"""仪表盘数据"""
|
||||
|
||||
summary: StatisticsSummary
|
||||
model_stats: List[ModelStatistics]
|
||||
hourly_data: List[TimeSeriesData]
|
||||
daily_data: List[TimeSeriesData]
|
||||
recent_activity: List[Dict[str, Any]]
|
||||
|
||||
|
||||
@router.get("/dashboard", response_model=DashboardData)
|
||||
async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取仪表盘统计数据
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时),默认24小时
|
||||
|
||||
Returns:
|
||||
仪表盘数据
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
|
||||
# 获取摘要数据
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
|
||||
# 获取模型统计
|
||||
model_stats = await _get_model_statistics(start_time)
|
||||
|
||||
# 获取小时级时间序列数据
|
||||
hourly_data = await _get_hourly_statistics(start_time, now)
|
||||
|
||||
# 获取日级时间序列数据(最近7天)
|
||||
daily_start = now - timedelta(days=7)
|
||||
daily_data = await _get_daily_statistics(daily_start, now)
|
||||
|
||||
# 获取最近活动
|
||||
recent_activity = await _get_recent_activity(limit=10)
|
||||
|
||||
return DashboardData(
|
||||
summary=summary,
|
||||
model_stats=model_stats,
|
||||
hourly_data=hourly_data,
|
||||
daily_data=daily_data,
|
||||
recent_activity=recent_activity,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"获取仪表盘数据失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e
|
||||
|
||||
|
||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||
summary = StatisticsSummary()
|
||||
|
||||
# 使用聚合查询替代全量加载
|
||||
query = LLMUsage.select(
|
||||
fn.COUNT(LLMUsage.id).alias("total_requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
|
||||
result = query.dicts().get()
|
||||
summary.total_requests = result["total_requests"]
|
||||
summary.total_cost = result["total_cost"]
|
||||
summary.total_tokens = result["total_tokens"]
|
||||
summary.avg_response_time = result["avg_response_time"] or 0.0
|
||||
|
||||
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||
online_records = list(
|
||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
||||
)
|
||||
|
||||
for record in online_records:
|
||||
start = max(record.start_timestamp, start_time)
|
||||
end = min(record.end_timestamp, end_time)
|
||||
if end > start:
|
||||
summary.online_time += (end - start).total_seconds()
|
||||
|
||||
# 查询消息数量 - 使用聚合优化
|
||||
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
|
||||
)
|
||||
summary.total_messages = messages_query.scalar() or 0
|
||||
|
||||
# 统计回复数量
|
||||
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
||||
(Messages.time >= start_time.timestamp())
|
||||
& (Messages.time <= end_time.timestamp())
|
||||
& (Messages.reply_to.is_null(False))
|
||||
)
|
||||
summary.total_replies = replies_query.scalar() or 0
|
||||
|
||||
# 计算派生指标
|
||||
if summary.online_time > 0:
|
||||
online_hours = summary.online_time / 3600.0
|
||||
summary.cost_per_hour = summary.total_cost / online_hours
|
||||
summary.tokens_per_hour = summary.total_tokens / online_hours
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||
# 使用GROUP BY聚合,避免全量加载
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
|
||||
fn.COUNT(LLMUsage.id).alias("request_count"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
||||
)
|
||||
.where(LLMUsage.timestamp >= start_time)
|
||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
|
||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(10) # 只取前10个
|
||||
)
|
||||
|
||||
result = []
|
||||
for row in query.dicts():
|
||||
result.append(
|
||||
ModelStatistics(
|
||||
model_name=row["model_name"],
|
||||
request_count=row["request_count"],
|
||||
total_cost=row["total_cost"],
|
||||
total_tokens=row["total_tokens"],
|
||||
avg_response_time=row["avg_response_time"] or 0.0,
|
||||
)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||
# SQLite的日期时间函数进行小时分组
|
||||
# 使用strftime将timestamp格式化为小时级别
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
|
||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
|
||||
)
|
||||
|
||||
# 转换为字典以快速查找
|
||||
data_dict = {row["hour"]: row for row in query.dicts()}
|
||||
|
||||
# 填充所有小时(包括没有数据的)
|
||||
result = []
|
||||
current = start_time.replace(minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
hour_str = current.strftime("%Y-%m-%dT%H:00:00")
|
||||
if hour_str in data_dict:
|
||||
row = data_dict[hour_str]
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(hours=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||
# 使用strftime按日期分组
|
||||
query = (
|
||||
LLMUsage.select(
|
||||
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
|
||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
||||
)
|
||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
||||
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
|
||||
)
|
||||
|
||||
# 转换为字典
|
||||
data_dict = {row["day"]: row for row in query.dicts()}
|
||||
|
||||
# 填充所有天
|
||||
result = []
|
||||
current = start_time.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
while current <= end_time:
|
||||
day_str = current.strftime("%Y-%m-%dT00:00:00")
|
||||
if day_str in data_dict:
|
||||
row = data_dict[day_str]
|
||||
result.append(
|
||||
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
||||
)
|
||||
else:
|
||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||
current += timedelta(days=1)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""获取最近活动"""
|
||||
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
||||
|
||||
activities = []
|
||||
for record in records:
|
||||
activities.append(
|
||||
{
|
||||
"timestamp": record.timestamp.isoformat(),
|
||||
"model": record.model_assign_name or record.model_name,
|
||||
"request_type": record.request_type,
|
||||
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
||||
"cost": record.cost or 0.0,
|
||||
"time_cost": record.time_cost or 0.0,
|
||||
"status": record.status,
|
||||
}
|
||||
)
|
||||
|
||||
return activities
|
||||
|
||||
|
||||
@router.get("/summary")
|
||||
async def get_summary(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取统计摘要
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
summary = await _get_summary_statistics(start_time, now)
|
||||
return summary
|
||||
except Exception as e:
|
||||
logger.error(f"获取统计摘要失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.get("/models")
|
||||
async def get_model_stats(hours: int = 24, _auth: bool = Depends(require_auth)):
|
||||
"""
|
||||
获取模型统计
|
||||
|
||||
Args:
|
||||
hours: 统计时间范围(小时)
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
start_time = now - timedelta(hours=hours)
|
||||
stats = await _get_model_statistics(start_time)
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"获取模型统计失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
|
@ -1,309 +0,0 @@
|
|||
"""
|
||||
WebUI Token 管理模块
|
||||
负责生成、保存、验证和更新访问令牌
|
||||
"""
|
||||
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui")
|
||||
|
||||
|
||||
class TokenManager:
|
||||
"""Token 管理器"""
|
||||
|
||||
def __init__(self, config_path: Optional[Path] = None):
|
||||
"""
|
||||
初始化 Token 管理器
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径,默认为项目根目录的 data/webui.json
|
||||
"""
|
||||
if config_path is None:
|
||||
# 获取项目根目录 (src/webui -> src -> 根目录)
|
||||
project_root = Path(__file__).parent.parent.parent
|
||||
config_path = project_root / "data" / "webui.json"
|
||||
|
||||
self.config_path = config_path
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 确保配置文件存在并包含有效的 token
|
||||
self._ensure_config()
|
||||
|
||||
def _ensure_config(self):
|
||||
"""确保配置文件存在且包含有效的 token"""
|
||||
if not self.config_path.exists():
|
||||
logger.info(f"WebUI 配置文件不存在,正在创建: {self.config_path}")
|
||||
self._create_new_token()
|
||||
else:
|
||||
# 验证配置文件格式
|
||||
try:
|
||||
config = self._load_config()
|
||||
if not config.get("access_token"):
|
||||
logger.warning("WebUI 配置文件中缺少 access_token,正在重新生成")
|
||||
self._create_new_token()
|
||||
else:
|
||||
logger.info(f"WebUI Token 已加载: {config['access_token'][:8]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"读取 WebUI 配置文件失败: {e},正在重新创建")
|
||||
self._create_new_token()
|
||||
|
||||
def _load_config(self) -> dict:
|
||||
"""加载配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"加载 WebUI 配置失败: {e}")
|
||||
return {}
|
||||
|
||||
def _save_config(self, config: dict):
|
||||
"""保存配置文件"""
|
||||
try:
|
||||
with open(self.config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
logger.info(f"WebUI 配置已保存到: {self.config_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存 WebUI 配置失败: {e}")
|
||||
raise
|
||||
|
||||
def _create_new_token(self) -> str:
|
||||
"""生成新的 64 位随机 token"""
|
||||
# 生成 64 位十六进制字符串 (32 字节 = 64 hex 字符)
|
||||
token = secrets.token_hex(32)
|
||||
|
||||
config = {
|
||||
"access_token": token,
|
||||
"created_at": self._get_current_timestamp(),
|
||||
"updated_at": self._get_current_timestamp(),
|
||||
"first_setup_completed": False, # 标记首次配置未完成
|
||||
}
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"新的 WebUI Token 已生成: {token[:8]}...")
|
||||
|
||||
return token
|
||||
|
||||
def _get_current_timestamp(self) -> str:
|
||||
"""获取当前时间戳字符串"""
|
||||
from datetime import datetime
|
||||
|
||||
return datetime.now().isoformat()
|
||||
|
||||
def get_token(self) -> str:
|
||||
"""获取当前有效的 token"""
|
||||
config = self._load_config()
|
||||
return config.get("access_token", "")
|
||||
|
||||
def verify_token(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 是否有效
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: token 是否有效
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
current_token = self.get_token()
|
||||
if not current_token:
|
||||
logger.error("系统中没有有效的 token")
|
||||
return False
|
||||
|
||||
# 使用 secrets.compare_digest 防止时序攻击
|
||||
is_valid = secrets.compare_digest(token, current_token)
|
||||
|
||||
if is_valid:
|
||||
logger.debug("Token 验证成功")
|
||||
else:
|
||||
logger.warning("Token 验证失败")
|
||||
|
||||
return is_valid
|
||||
|
||||
def update_token(self, new_token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
更新 token
|
||||
|
||||
Args:
|
||||
new_token: 新的 token (最少 10 位,必须包含大小写字母和特殊符号)
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否更新成功, 错误消息)
|
||||
"""
|
||||
# 验证新 token 格式
|
||||
is_valid, error_msg = self._validate_custom_token(new_token)
|
||||
if not is_valid:
|
||||
logger.error(f"Token 格式无效: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
try:
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8]
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"Token 已更新: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return True, "Token 更新成功"
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Token 失败: {e}")
|
||||
return False, f"更新失败: {str(e)}"
|
||||
|
||||
def regenerate_token(self) -> str:
|
||||
"""
|
||||
重新生成 token(保留 first_setup_completed 状态)
|
||||
|
||||
Returns:
|
||||
str: 新生成的 token
|
||||
"""
|
||||
logger.info("正在重新生成 WebUI Token...")
|
||||
|
||||
# 生成新的 64 位十六进制字符串
|
||||
new_token = secrets.token_hex(32)
|
||||
|
||||
# 加载现有配置,保留 first_setup_completed 状态
|
||||
config = self._load_config()
|
||||
old_token = config.get("access_token", "")[:8] if config.get("access_token") else "无"
|
||||
first_setup_completed = config.get("first_setup_completed", True) # 默认为 True,表示已完成配置
|
||||
|
||||
config["access_token"] = new_token
|
||||
config["updated_at"] = self._get_current_timestamp()
|
||||
config["first_setup_completed"] = first_setup_completed # 保留原来的状态
|
||||
|
||||
self._save_config(config)
|
||||
logger.info(f"WebUI Token 已重新生成: {old_token}... -> {new_token[:8]}...")
|
||||
|
||||
return new_token
|
||||
|
||||
def _validate_token_format(self, token: str) -> bool:
|
||||
"""
|
||||
验证 token 格式是否正确(旧的 64 位十六进制验证,保留用于系统生成的 token)
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
bool: 格式是否正确
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False
|
||||
|
||||
# 必须是 64 位十六进制字符串
|
||||
if len(token) != 64:
|
||||
return False
|
||||
|
||||
# 验证是否为有效的十六进制字符串
|
||||
try:
|
||||
int(token, 16)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _validate_custom_token(self, token: str) -> tuple[bool, str]:
|
||||
"""
|
||||
验证自定义 token 格式
|
||||
|
||||
要求:
|
||||
- 最少 10 位
|
||||
- 包含大写字母
|
||||
- 包含小写字母
|
||||
- 包含特殊符号
|
||||
|
||||
Args:
|
||||
token: 待验证的 token
|
||||
|
||||
Returns:
|
||||
tuple[bool, str]: (是否有效, 错误消息)
|
||||
"""
|
||||
if not token or not isinstance(token, str):
|
||||
return False, "Token 不能为空"
|
||||
|
||||
# 检查长度
|
||||
if len(token) < 10:
|
||||
return False, "Token 长度至少为 10 位"
|
||||
|
||||
# 检查是否包含大写字母
|
||||
has_upper = any(c.isupper() for c in token)
|
||||
if not has_upper:
|
||||
return False, "Token 必须包含大写字母"
|
||||
|
||||
# 检查是否包含小写字母
|
||||
has_lower = any(c.islower() for c in token)
|
||||
if not has_lower:
|
||||
return False, "Token 必须包含小写字母"
|
||||
|
||||
# 检查是否包含特殊符号
|
||||
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/"
|
||||
has_special = any(c in special_chars for c in token)
|
||||
if not has_special:
|
||||
return False, f"Token 必须包含特殊符号 ({special_chars})"
|
||||
|
||||
return True, "Token 格式正确"
|
||||
|
||||
def is_first_setup(self) -> bool:
|
||||
"""
|
||||
检查是否为首次配置
|
||||
|
||||
Returns:
|
||||
bool: 是否为首次配置
|
||||
"""
|
||||
config = self._load_config()
|
||||
return not config.get("first_setup_completed", False)
|
||||
|
||||
def mark_setup_completed(self) -> bool:
|
||||
"""
|
||||
标记首次配置已完成
|
||||
|
||||
Returns:
|
||||
bool: 是否标记成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = True
|
||||
config["setup_completed_at"] = self._get_current_timestamp()
|
||||
self._save_config(config)
|
||||
logger.info("首次配置已标记为完成")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"标记首次配置完成失败: {e}")
|
||||
return False
|
||||
|
||||
def reset_setup_status(self) -> bool:
|
||||
"""
|
||||
重置首次配置状态,允许重新进入配置向导
|
||||
|
||||
Returns:
|
||||
bool: 是否重置成功
|
||||
"""
|
||||
try:
|
||||
config = self._load_config()
|
||||
config["first_setup_completed"] = False
|
||||
if "setup_completed_at" in config:
|
||||
del config["setup_completed_at"]
|
||||
self._save_config(config)
|
||||
logger.info("首次配置状态已重置")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"重置首次配置状态失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局单例
|
||||
_token_manager_instance: Optional[TokenManager] = None
|
||||
|
||||
|
||||
def get_token_manager() -> TokenManager:
|
||||
"""获取 TokenManager 单例"""
|
||||
global _token_manager_instance
|
||||
if _token_manager_instance is None:
|
||||
_token_manager_instance = TokenManager()
|
||||
return _token_manager_instance
|
||||
|
|
@ -1,295 +0,0 @@
|
|||
"""独立的 WebUI 服务器 - 运行在 0.0.0.0:8001"""
|
||||
|
||||
import asyncio
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from uvicorn import Config, Server as UvicornServer
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("webui_server")
|
||||
|
||||
|
||||
class WebUIServer:
|
||||
"""独立的 WebUI 服务器"""
|
||||
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8001):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.app = FastAPI(title="MaiBot WebUI")
|
||||
self._server = None
|
||||
|
||||
# 配置防爬虫中间件(需要在CORS之前注册)
|
||||
self._setup_anti_crawler()
|
||||
|
||||
# 配置 CORS(支持开发环境跨域请求)
|
||||
self._setup_cors()
|
||||
|
||||
# 显示 Access Token
|
||||
self._show_access_token()
|
||||
|
||||
# 重要:先注册 API 路由,再设置静态文件
|
||||
self._register_api_routes()
|
||||
self._setup_static_files()
|
||||
|
||||
# 注册robots.txt路由
|
||||
self._setup_robots_txt()
|
||||
|
||||
def _setup_cors(self):
|
||||
"""配置 CORS 中间件"""
|
||||
# 开发环境需要允许前端开发服务器的跨域请求
|
||||
self.app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[
|
||||
"http://localhost:5173", # Vite 开发服务器
|
||||
"http://127.0.0.1:5173",
|
||||
"http://localhost:7999", # 前端开发服务器备用端口
|
||||
"http://127.0.0.1:7999",
|
||||
"http://localhost:8001", # 生产环境
|
||||
"http://127.0.0.1:8001",
|
||||
],
|
||||
allow_credentials=True, # 允许携带 Cookie
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], # 明确指定允许的方法
|
||||
allow_headers=[
|
||||
"Content-Type",
|
||||
"Authorization",
|
||||
"Accept",
|
||||
"Origin",
|
||||
"X-Requested-With",
|
||||
], # 明确指定允许的头
|
||||
expose_headers=["Content-Length", "Content-Type"], # 允许前端读取的响应头
|
||||
)
|
||||
logger.debug("✅ CORS 中间件已配置")
|
||||
|
||||
def _show_access_token(self):
|
||||
"""显示 WebUI Access Token"""
|
||||
try:
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
token_manager = get_token_manager()
|
||||
current_token = token_manager.get_token()
|
||||
logger.info(f"🔑 WebUI Access Token: {current_token}")
|
||||
logger.info("💡 请使用此 Token 登录 WebUI")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 获取 Access Token 失败: {e}")
|
||||
|
||||
def _setup_static_files(self):
|
||||
"""设置静态文件服务"""
|
||||
# 确保正确的 MIME 类型映射
|
||||
mimetypes.init()
|
||||
mimetypes.add_type("application/javascript", ".js")
|
||||
mimetypes.add_type("application/javascript", ".mjs")
|
||||
mimetypes.add_type("text/css", ".css")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
static_path = base_dir / "webui" / "dist"
|
||||
|
||||
if not static_path.exists():
|
||||
logger.warning(f"❌ WebUI 静态文件目录不存在: {static_path}")
|
||||
logger.warning("💡 请先构建前端: cd webui && npm run build")
|
||||
return
|
||||
|
||||
if not (static_path / "index.html").exists():
|
||||
logger.warning(f"❌ 未找到 index.html: {static_path / 'index.html'}")
|
||||
logger.warning("💡 请确认前端已正确构建")
|
||||
return
|
||||
|
||||
# 处理 SPA 路由 - 注意:这个路由优先级最低
|
||||
@self.app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_spa(full_path: str):
|
||||
"""服务单页应用 - 只处理非 API 请求"""
|
||||
# 如果是根路径,直接返回 index.html
|
||||
if not full_path or full_path == "/":
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 检查是否是静态文件
|
||||
file_path = static_path / full_path
|
||||
if file_path.is_file() and file_path.exists():
|
||||
# 自动检测 MIME 类型
|
||||
media_type = mimetypes.guess_type(str(file_path))[0]
|
||||
response = FileResponse(file_path, media_type=media_type)
|
||||
# HTML 文件添加防索引头
|
||||
if str(file_path).endswith(".html"):
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
# 其他路径返回 index.html(SPA 路由)
|
||||
response = FileResponse(static_path / "index.html", media_type="text/html")
|
||||
response.headers["X-Robots-Tag"] = "noindex, nofollow, noarchive"
|
||||
return response
|
||||
|
||||
logger.info(f"✅ WebUI 静态文件服务已配置: {static_path}")
|
||||
|
||||
def _setup_anti_crawler(self):
|
||||
"""配置防爬虫中间件"""
|
||||
try:
|
||||
from src.webui.anti_crawler import AntiCrawlerMiddleware
|
||||
from src.config.config import global_config
|
||||
|
||||
# 从配置读取防爬虫模式
|
||||
anti_crawler_mode = global_config.webui.anti_crawler_mode
|
||||
|
||||
# 注意:中间件按注册顺序反向执行,所以先注册的中间件后执行
|
||||
# 我们需要在CORS之前注册,这样防爬虫检查会在CORS之前执行
|
||||
self.app.add_middleware(AntiCrawlerMiddleware, mode=anti_crawler_mode)
|
||||
|
||||
mode_descriptions = {"false": "已禁用", "strict": "严格模式", "loose": "宽松模式", "basic": "基础模式"}
|
||||
mode_desc = mode_descriptions.get(anti_crawler_mode, "基础模式")
|
||||
logger.info(f"🛡️ 防爬虫中间件已配置: {mode_desc}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 配置防爬虫中间件失败: {e}", exc_info=True)
|
||||
|
||||
def _setup_robots_txt(self):
|
||||
"""设置robots.txt路由"""
|
||||
try:
|
||||
from src.webui.anti_crawler import create_robots_txt_response
|
||||
|
||||
@self.app.get("/robots.txt", include_in_schema=False)
|
||||
async def robots_txt():
|
||||
"""返回robots.txt,禁止所有爬虫"""
|
||||
return create_robots_txt_response()
|
||||
|
||||
logger.debug("✅ robots.txt 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册robots.txt路由失败: {e}", exc_info=True)
|
||||
|
||||
def _register_api_routes(self):
|
||||
"""注册所有 WebUI API 路由"""
|
||||
try:
|
||||
# 导入所有 WebUI 路由
|
||||
from src.webui.routes import router as webui_router
|
||||
from src.webui.logs_ws import router as logs_router
|
||||
from src.webui.knowledge_routes import router as knowledge_router
|
||||
|
||||
# 导入本地聊天室路由
|
||||
from src.webui.chat_routes import router as chat_router
|
||||
|
||||
# 导入规划器监控路由
|
||||
from src.webui.api.planner import router as planner_router
|
||||
|
||||
# 导入回复器监控路由
|
||||
from src.webui.api.replier import router as replier_router
|
||||
|
||||
# 注册路由
|
||||
self.app.include_router(webui_router)
|
||||
self.app.include_router(logs_router)
|
||||
self.app.include_router(knowledge_router)
|
||||
self.app.include_router(chat_router)
|
||||
self.app.include_router(planner_router)
|
||||
self.app.include_router(replier_router)
|
||||
|
||||
logger.info("✅ WebUI API 路由已注册")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 注册 WebUI API 路由失败: {e}", exc_info=True)
|
||||
|
||||
async def start(self):
|
||||
"""启动服务器"""
|
||||
# 预先检查端口是否可用
|
||||
if not self._check_port_available():
|
||||
error_msg = f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用"
|
||||
logger.error(error_msg)
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
||||
logger.error(f"💡 Windows 用户可以运行: netstat -ano | findstr :{self.port}")
|
||||
logger.error(f"💡 Linux/Mac 用户可以运行: lsof -i :{self.port}")
|
||||
raise OSError(f"端口 {self.port} 已被占用,无法启动 WebUI 服务器")
|
||||
|
||||
config = Config(
|
||||
app=self.app,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
log_config=None,
|
||||
access_log=False,
|
||||
)
|
||||
self._server = UvicornServer(config=config)
|
||||
|
||||
logger.info("🌐 WebUI 服务器启动中...")
|
||||
|
||||
# 根据地址类型显示正确的访问地址
|
||||
if ':' in self.host:
|
||||
# IPv6 地址需要用方括号包裹
|
||||
logger.info(f"🌐 访问地址: http://[{self.host}]:{self.port}")
|
||||
if self.host == "::":
|
||||
logger.info(f"💡 IPv6 本机访问: http://[::1]:{self.port}")
|
||||
logger.info(f"💡 IPv4 本机访问: http://127.0.0.1:{self.port}")
|
||||
elif self.host == "::1":
|
||||
logger.info("💡 仅支持 IPv6 本地访问")
|
||||
else:
|
||||
# IPv4 地址
|
||||
logger.info(f"🌐 访问地址: http://{self.host}:{self.port}")
|
||||
if self.host == "0.0.0.0":
|
||||
logger.info(f"💡 本机访问: http://localhost:{self.port} 或 http://127.0.0.1:{self.port}")
|
||||
|
||||
try:
|
||||
await self._server.serve()
|
||||
except OSError as e:
|
||||
# 处理端口绑定相关的错误
|
||||
if "address already in use" in str(e).lower() or e.errno in (98, 10048): # 98: Linux, 10048: Windows
|
||||
logger.error(f"❌ WebUI 服务器启动失败: 端口 {self.port} 已被占用")
|
||||
logger.error(f"💡 请检查是否有其他程序正在使用端口 {self.port}")
|
||||
logger.error("💡 可以在 .env 文件中修改 WEBUI_PORT 来更改 WebUI 端口")
|
||||
else:
|
||||
logger.error(f"❌ WebUI 服务器启动失败 (网络错误): {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器运行错误: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def _check_port_available(self) -> bool:
|
||||
"""检查端口是否可用(支持 IPv4 和 IPv6)"""
|
||||
import socket
|
||||
|
||||
# 判断使用 IPv4 还是 IPv6
|
||||
if ':' in self.host:
|
||||
# IPv6 地址
|
||||
family = socket.AF_INET6
|
||||
test_host = self.host if self.host != "::" else "::1"
|
||||
else:
|
||||
# IPv4 地址
|
||||
family = socket.AF_INET
|
||||
test_host = self.host if self.host != "0.0.0.0" else "127.0.0.1"
|
||||
|
||||
try:
|
||||
with socket.socket(family, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(1)
|
||||
# 尝试绑定端口
|
||||
s.bind((test_host, self.port))
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
async def shutdown(self):
|
||||
"""关闭服务器"""
|
||||
if self._server:
|
||||
logger.info("正在关闭 WebUI 服务器...")
|
||||
self._server.should_exit = True
|
||||
try:
|
||||
await asyncio.wait_for(self._server.shutdown(), timeout=3.0)
|
||||
logger.info("✅ WebUI 服务器已关闭")
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("⚠️ WebUI 服务器关闭超时")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WebUI 服务器关闭失败: {e}")
|
||||
finally:
|
||||
self._server = None
|
||||
|
||||
|
||||
# 全局 WebUI 服务器实例
|
||||
_webui_server = None
|
||||
|
||||
|
||||
def get_webui_server() -> WebUIServer:
|
||||
"""获取全局 WebUI 服务器实例"""
|
||||
global _webui_server
|
||||
if _webui_server is None:
|
||||
# 从环境变量读取
|
||||
import os
|
||||
host = os.getenv("WEBUI_HOST", "127.0.0.1")
|
||||
port = int(os.getenv("WEBUI_PORT", "8001"))
|
||||
_webui_server = WebUIServer(host=host, port=port)
|
||||
return _webui_server
|
||||
|
|
@ -1,114 +0,0 @@
|
|||
"""WebSocket 认证模块
|
||||
|
||||
提供所有 WebSocket 端点统一使用的临时 token 认证机制。
|
||||
临时 token 有效期 60 秒,且只能使用一次,用于解决 WebSocket 握手时 Cookie 不可用的问题。
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Cookie, Header
|
||||
from typing import Optional
|
||||
import secrets
|
||||
import time
|
||||
from src.common.logger import get_logger
|
||||
from src.webui.token_manager import get_token_manager
|
||||
|
||||
logger = get_logger("webui.ws_auth")
|
||||
router = APIRouter()
|
||||
|
||||
# WebSocket 临时 token 存储 {token: (expire_time, session_token)}
|
||||
# 临时 token 有效期 60 秒,仅用于 WebSocket 握手
|
||||
_ws_temp_tokens: dict[str, tuple[float, str]] = {}
|
||||
_WS_TOKEN_EXPIRE_SECONDS = 60
|
||||
|
||||
|
||||
def _cleanup_expired_ws_tokens():
|
||||
"""清理过期的临时 token"""
|
||||
now = time.time()
|
||||
expired = [t for t, (exp, _) in _ws_temp_tokens.items() if now > exp]
|
||||
for t in expired:
|
||||
del _ws_temp_tokens[t]
|
||||
|
||||
|
||||
def generate_ws_token(session_token: str) -> str:
|
||||
"""生成 WebSocket 临时 token
|
||||
|
||||
Args:
|
||||
session_token: 原始的 session token
|
||||
|
||||
Returns:
|
||||
临时 token 字符串
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
temp_token = secrets.token_urlsafe(32)
|
||||
_ws_temp_tokens[temp_token] = (time.time() + _WS_TOKEN_EXPIRE_SECONDS, session_token)
|
||||
logger.debug(f"生成 WS 临时 token: {temp_token[:8]}... 有效期 {_WS_TOKEN_EXPIRE_SECONDS}s")
|
||||
return temp_token
|
||||
|
||||
|
||||
def verify_ws_token(temp_token: str) -> bool:
|
||||
"""验证并消费 WebSocket 临时 token(一次性使用)
|
||||
|
||||
Args:
|
||||
temp_token: 临时 token
|
||||
|
||||
Returns:
|
||||
验证是否通过
|
||||
"""
|
||||
_cleanup_expired_ws_tokens()
|
||||
if temp_token not in _ws_temp_tokens:
|
||||
logger.warning(f"WS token 不存在: {temp_token[:8]}...")
|
||||
return False
|
||||
expire_time, session_token = _ws_temp_tokens[temp_token]
|
||||
if time.time() > expire_time:
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 已过期: {temp_token[:8]}...")
|
||||
return False
|
||||
# 验证原始 session token 仍然有效
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.warning(f"WS token 关联的 session 已失效: {temp_token[:8]}...")
|
||||
return False
|
||||
# 消费 token(一次性使用)
|
||||
del _ws_temp_tokens[temp_token]
|
||||
logger.debug(f"WS token 验证成功: {temp_token[:8]}...")
|
||||
return True
|
||||
|
||||
|
||||
@router.get("/ws-token")
|
||||
async def get_ws_token(
|
||||
maibot_session: Optional[str] = Cookie(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
获取 WebSocket 连接用的临时 token
|
||||
|
||||
此端点验证当前会话的 Cookie 或 Authorization header,
|
||||
然后返回一个临时 token 用于 WebSocket 握手认证。
|
||||
临时 token 有效期 60 秒,且只能使用一次。
|
||||
|
||||
注意:在未认证时返回 200 状态码但 success=False,避免前端因 401 刷新页面。
|
||||
"""
|
||||
# 获取当前 session token
|
||||
session_token = None
|
||||
if maibot_session:
|
||||
session_token = maibot_session
|
||||
elif authorization and authorization.startswith("Bearer "):
|
||||
session_token = authorization.replace("Bearer ", "")
|
||||
|
||||
if not session_token:
|
||||
# 返回 200 但 success=False,避免前端因 401 刷新页面
|
||||
# 这在登录页面是正常情况,不应该触发错误处理
|
||||
logger.debug("ws-token 请求:未提供认证信息(可能在登录页面)")
|
||||
return {"success": False, "message": "未提供认证信息,请先登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 验证 session token
|
||||
token_manager = get_token_manager()
|
||||
if not token_manager.verify_token(session_token):
|
||||
# 同样返回 200 但 success=False,避免前端刷新
|
||||
logger.debug("ws-token 请求:认证已过期")
|
||||
return {"success": False, "message": "认证已过期,请重新登录", "token": None, "expires_in": 0}
|
||||
|
||||
# 生成临时 WebSocket token
|
||||
ws_token = generate_ws_token(session_token)
|
||||
|
||||
return {"success": True, "token": ws_token, "expires_in": _WS_TOKEN_EXPIRE_SECONDS}
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue