diff --git a/src/webui/emoji_routes.py b/src/webui/emoji_routes.py index 96899bf3..2deaa6e6 100644 --- a/src/webui/emoji_routes.py +++ b/src/webui/emoji_routes.py @@ -76,6 +76,22 @@ class EmojiDeleteResponse(BaseModel): message: str +class BatchDeleteRequest(BaseModel): + """批量删除请求""" + + emoji_ids: List[int] + + +class BatchDeleteResponse(BaseModel): + """批量删除响应""" + + success: bool + message: str + deleted_count: int + failed_count: int + failed_ids: List[int] = [] + + def verify_auth_token(authorization: Optional[str]) -> bool: """验证认证 Token""" if not authorization or not authorization.startswith("Bearer "): @@ -503,3 +519,59 @@ async def get_emoji_thumbnail( except Exception as e: logger.exception(f"获取表情包缩略图失败: {e}") raise HTTPException(status_code=500, detail=f"获取表情包缩略图失败: {str(e)}") from e + + +@router.post("/batch/delete", response_model=BatchDeleteResponse) +async def batch_delete_emojis(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)): + """ + 批量删除表情包 + + Args: + request: 包含emoji_ids列表的请求 + authorization: Authorization header + + Returns: + 批量删除结果 + """ + try: + verify_auth_token(authorization) + + if not request.emoji_ids: + raise HTTPException(status_code=400, detail="未提供要删除的表情包ID") + + deleted_count = 0 + failed_count = 0 + failed_ids = [] + + for emoji_id in request.emoji_ids: + try: + emoji = Emoji.get_or_none(Emoji.id == emoji_id) + if emoji: + emoji.delete_instance() + deleted_count += 1 + logger.info(f"批量删除表情包: {emoji_id}") + else: + failed_count += 1 + failed_ids.append(emoji_id) + except Exception as e: + logger.error(f"删除表情包 {emoji_id} 失败: {e}") + failed_count += 1 + failed_ids.append(emoji_id) + + message = f"成功删除 {deleted_count} 个表情包" + if failed_count > 0: + message += f",{failed_count} 个失败" + + return BatchDeleteResponse( + success=True, + message=message, + deleted_count=deleted_count, + failed_count=failed_count, + failed_ids=failed_ids, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception(f"批量删除表情包失败: {e}") + raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e diff --git a/src/webui/expression_routes.py b/src/webui/expression_routes.py index aa9261d2..983918cf 100644 --- a/src/webui/expression_routes.py +++ b/src/webui/expression_routes.py @@ -338,6 +338,53 @@ async def delete_expression(expression_id: int, authorization: Optional[str] = H raise HTTPException(status_code=500, detail=f"删除表达方式失败: {str(e)}") from e +class BatchDeleteRequest(BaseModel): + """批量删除请求""" + + ids: List[int] + + +@router.post("/batch/delete", response_model=ExpressionDeleteResponse) +async def batch_delete_expressions(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)): + """ + 批量删除表达方式 + + Args: + request: 包含要删除的ID列表的请求 + authorization: Authorization header + + Returns: + 删除结果 + """ + try: + verify_auth_token(authorization) + + if not request.ids: + raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID") + + # 查找所有要删除的表达方式 + expressions = Expression.select().where(Expression.id.in_(request.ids)) + found_ids = [expr.id for expr in expressions] + + # 检查是否有未找到的ID + not_found_ids = set(request.ids) - set(found_ids) + if not_found_ids: + logger.warning(f"部分表达方式未找到: {not_found_ids}") + + # 执行批量删除 + deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute() + + logger.info(f"批量删除了 {deleted_count} 个表达方式") + + return ExpressionDeleteResponse(success=True, message=f"成功删除 {deleted_count} 个表达方式") + + except HTTPException: + raise + except Exception as e: + logger.exception(f"批量删除表达方式失败: {e}") + raise HTTPException(status_code=500, detail=f"批量删除表达方式失败: {str(e)}") from e + + @router.get("/stats/summary") async def get_expression_stats(authorization: Optional[str] = Header(None)): """ diff --git a/src/webui/person_routes.py b/src/webui/person_routes.py index 24855aba..5935a2fa 100644 --- a/src/webui/person_routes.py +++ b/src/webui/person_routes.py @@ -75,6 +75,22 @@ class PersonDeleteResponse(BaseModel): message: str +class BatchDeleteRequest(BaseModel): + """批量删除请求""" + + person_ids: List[str] + + +class BatchDeleteResponse(BaseModel): + """批量删除响应""" + + success: bool + message: str + deleted_count: int + failed_count: int + failed_ids: List[str] = [] + + def verify_auth_token(authorization: Optional[str]) -> bool: """验证认证 Token""" if not authorization or not authorization.startswith("Bearer "): @@ -334,3 +350,59 @@ async def get_person_stats(authorization: Optional[str] = Header(None)): except Exception as e: logger.exception(f"获取统计数据失败: {e}") raise HTTPException(status_code=500, detail=f"获取统计数据失败: {str(e)}") from e + + +@router.post("/batch/delete", response_model=BatchDeleteResponse) +async def batch_delete_persons(request: BatchDeleteRequest, authorization: Optional[str] = Header(None)): + """ + 批量删除人物信息 + + Args: + request: 包含person_ids列表的请求 + authorization: Authorization header + + Returns: + 批量删除结果 + """ + try: + verify_auth_token(authorization) + + if not request.person_ids: + raise HTTPException(status_code=400, detail="未提供要删除的人物ID") + + deleted_count = 0 + failed_count = 0 + failed_ids = [] + + for person_id in request.person_ids: + try: + person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + if person: + person.delete_instance() + deleted_count += 1 + logger.info(f"批量删除: {person_id}") + else: + failed_count += 1 + failed_ids.append(person_id) + except Exception as e: + logger.error(f"删除 {person_id} 失败: {e}") + failed_count += 1 + failed_ids.append(person_id) + + message = f"成功删除 {deleted_count} 个人物" + if failed_count > 0: + message += f",{failed_count} 个失败" + + return BatchDeleteResponse( + success=True, + message=message, + deleted_count=deleted_count, + failed_count=failed_count, + failed_ids=failed_ids, + ) + + except HTTPException: + raise + except Exception as e: + logger.exception(f"批量删除人物信息失败: {e}") + raise HTTPException(status_code=500, detail=f"批量删除失败: {str(e)}") from e diff --git a/src/webui/statistics_routes.py b/src/webui/statistics_routes.py index 45855475..b0a3664c 100644 --- a/src/webui/statistics_routes.py +++ b/src/webui/statistics_routes.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException from pydantic import BaseModel, Field from typing import Dict, Any, List from datetime import datetime, timedelta -from collections import defaultdict +from peewee import fn from src.common.logger import get_logger from src.common.database.database_model import LLMUsage, OnlineTime, Messages @@ -101,29 +101,24 @@ async def get_dashboard_data(hours: int = 24): async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary: - """获取摘要统计数据""" + """获取摘要统计数据(优化:使用数据库聚合)""" summary = StatisticsSummary() - # 查询 LLM 使用记录 - llm_records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time)) + # 使用聚合查询替代全量加载 + query = LLMUsage.select( + fn.COUNT(LLMUsage.id).alias("total_requests"), + fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"), + fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"), + fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"), + ).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time)) - total_time_cost = 0.0 - time_cost_count = 0 + result = query.dicts().get() + summary.total_requests = result["total_requests"] + summary.total_cost = result["total_cost"] + summary.total_tokens = result["total_tokens"] + summary.avg_response_time = result["avg_response_time"] or 0.0 - for record in llm_records: - summary.total_requests += 1 - summary.total_cost += record.cost or 0.0 - summary.total_tokens += (record.prompt_tokens or 0) + (record.completion_tokens or 0) - - if record.time_cost and record.time_cost > 0: - total_time_cost += record.time_cost - time_cost_count += 1 - - # 计算平均响应时间 - if time_cost_count > 0: - summary.avg_response_time = total_time_cost / time_cost_count - - # 查询在线时间 + # 查询在线时间 - 这个数据量通常不大,保留原逻辑 online_records = list( OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time)) ) @@ -134,14 +129,19 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S if end > start: summary.online_time += (end - start).total_seconds() - # 查询消息数量 - messages = list( - Messages.select().where(Messages.time >= start_time.timestamp()).where(Messages.time <= end_time.timestamp()) + # 查询消息数量 - 使用聚合优化 + messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where( + (Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp()) ) + summary.total_messages = messages_query.scalar() or 0 - summary.total_messages = len(messages) - # 简单统计:如果 reply_to 不为空,则认为是回复 - summary.total_replies = len([m for m in messages if m.reply_to]) + # 统计回复数量 + replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where( + (Messages.time >= start_time.timestamp()) + & (Messages.time <= end_time.timestamp()) + & (Messages.reply_to.is_null(False)) + ) + summary.total_replies = replies_query.scalar() or 0 # 计算派生指标 if summary.online_time > 0: @@ -153,93 +153,101 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]: - """获取模型统计数据""" - model_data = defaultdict(lambda: {"request_count": 0, "total_cost": 0.0, "total_tokens": 0, "time_costs": []}) + """获取模型统计数据(优化:使用数据库聚合和分组)""" + # 使用GROUP BY聚合,避免全量加载 + query = ( + LLMUsage.select( + fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"), + fn.COUNT(LLMUsage.id).alias("request_count"), + fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"), + fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"), + fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"), + ) + .where(LLMUsage.timestamp >= start_time) + .group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown")) + .order_by(fn.COUNT(LLMUsage.id).desc()) + .limit(10) # 只取前10个 + ) - records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time)) - - for record in records: - model_name = record.model_assign_name or record.model_name or "unknown" - model_data[model_name]["request_count"] += 1 - model_data[model_name]["total_cost"] += record.cost or 0.0 - model_data[model_name]["total_tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) - - if record.time_cost and record.time_cost > 0: - model_data[model_name]["time_costs"].append(record.time_cost) - - # 转换为列表并排序 result = [] - for model_name, data in model_data.items(): - avg_time = sum(data["time_costs"]) / len(data["time_costs"]) if data["time_costs"] else 0.0 + for row in query.dicts(): result.append( ModelStatistics( - model_name=model_name, - request_count=data["request_count"], - total_cost=data["total_cost"], - total_tokens=data["total_tokens"], - avg_response_time=avg_time, + model_name=row["model_name"], + request_count=row["request_count"], + total_cost=row["total_cost"], + total_tokens=row["total_tokens"], + avg_response_time=row["avg_response_time"] or 0.0, ) ) - # 按请求数排序 - result.sort(key=lambda x: x.request_count, reverse=True) - return result[:10] # 返回前10个 + return result async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: - """获取小时级统计数据""" - # 创建小时桶 - hourly_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0}) + """获取小时级统计数据(优化:使用数据库聚合)""" + # SQLite的日期时间函数进行小时分组 + # 使用strftime将timestamp格式化为小时级别 + query = ( + LLMUsage.select( + fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"), + fn.COUNT(LLMUsage.id).alias("requests"), + fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"), + fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"), + ) + .where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time)) + .group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp)) + ) - records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time)) - - for record in records: - # 获取小时键(去掉分钟和秒) - hour_key = record.timestamp.replace(minute=0, second=0, microsecond=0) - hour_str = hour_key.isoformat() - - hourly_buckets[hour_str]["requests"] += 1 - hourly_buckets[hour_str]["cost"] += record.cost or 0.0 - hourly_buckets[hour_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) + # 转换为字典以快速查找 + data_dict = {row["hour"]: row for row in query.dicts()} # 填充所有小时(包括没有数据的) result = [] current = start_time.replace(minute=0, second=0, microsecond=0) while current <= end_time: - hour_str = current.isoformat() - data = hourly_buckets.get(hour_str, {"requests": 0, "cost": 0.0, "tokens": 0}) - result.append( - TimeSeriesData(timestamp=hour_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"]) - ) + hour_str = current.strftime("%Y-%m-%dT%H:00:00") + if hour_str in data_dict: + row = data_dict[hour_str] + result.append( + TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"]) + ) + else: + result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0)) current += timedelta(hours=1) return result async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: - """获取日级统计数据""" - daily_buckets = defaultdict(lambda: {"requests": 0, "cost": 0.0, "tokens": 0}) + """获取日级统计数据(优化:使用数据库聚合)""" + # 使用strftime按日期分组 + query = ( + LLMUsage.select( + fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"), + fn.COUNT(LLMUsage.id).alias("requests"), + fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"), + fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"), + ) + .where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time)) + .group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp)) + ) - records = list(LLMUsage.select().where(LLMUsage.timestamp >= start_time).where(LLMUsage.timestamp <= end_time)) - - for record in records: - # 获取日期键 - day_key = record.timestamp.replace(hour=0, minute=0, second=0, microsecond=0) - day_str = day_key.isoformat() - - daily_buckets[day_str]["requests"] += 1 - daily_buckets[day_str]["cost"] += record.cost or 0.0 - daily_buckets[day_str]["tokens"] += (record.prompt_tokens or 0) + (record.completion_tokens or 0) + # 转换为字典 + data_dict = {row["day"]: row for row in query.dicts()} # 填充所有天 result = [] current = start_time.replace(hour=0, minute=0, second=0, microsecond=0) while current <= end_time: - day_str = current.isoformat() - data = daily_buckets.get(day_str, {"requests": 0, "cost": 0.0, "tokens": 0}) - result.append( - TimeSeriesData(timestamp=day_str, requests=data["requests"], cost=data["cost"], tokens=data["tokens"]) - ) + day_str = current.strftime("%Y-%m-%dT00:00:00") + if day_str in data_dict: + row = data_dict[day_str] + result.append( + TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"]) + ) + else: + result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0)) current += timedelta(days=1) return result