diff --git a/pyproject.toml b/pyproject.toml index 1d432bd7..72fe4984 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,6 @@ dependencies = [ "numpy>=2.2.6", "openai>=1.95.0", "pandas>=2.3.1", - "peewee>=3.18.2", "pillow>=11.3.0", "pyarrow>=20.0.0", "pydantic>=2.11.7", @@ -29,6 +28,8 @@ dependencies = [ "rich>=14.0.0", "ruff>=0.12.2", "setuptools>=80.9.0", + "sqlalchemy>=2.0.40", + "sqlmodel>=0.0.24", "structlog>=25.4.0", "toml>=0.10.2", "tomlkit>=0.13.3", diff --git a/requirements.txt b/requirements.txt index 4cc63bc8..6bd487cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,6 @@ matplotlib>=3.10.3 numpy>=2.2.6 openai>=1.95.0 pandas>=2.3.1 -peewee>=3.18.2 pillow>=11.3.0 pyarrow>=20.0.0 pydantic>=2.11.7 @@ -23,8 +22,10 @@ quick-algo>=0.1.3 rich>=14.0.0 ruff>=0.12.2 setuptools>=80.9.0 +sqlalchemy>=2.0.40 +sqlmodel>=0.0.24 structlog>=25.4.0 toml>=0.10.2 tomlkit>=0.13.3 urllib3>=2.5.0 -uvicorn>=0.35.0 \ No newline at end of file +uvicorn>=0.35.0 diff --git a/src/bw_learner/jargon_miner.py b/src/bw_learner/jargon_miner.py index 449ebe61..0d1622a9 100644 --- a/src/bw_learner/jargon_miner.py +++ b/src/bw_learner/jargon_miner.py @@ -4,7 +4,7 @@ import random from collections import OrderedDict from typing import List, Dict, Optional, Callable from json_repair import repair_json -from peewee import fn +from sqlalchemy import func as fn from src.common.logger import get_logger from src.common.database.database_model import Jargon diff --git a/src/common/database/database.py b/src/common/database/database.py index 2c0598ba..979da54d 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,8 +1,9 @@ from rich.traceback import install from pathlib import Path from contextlib import contextmanager -from sqlalchemy import create_engine, event +from sqlalchemy import create_engine, event, text from sqlalchemy.engine import Engine +from sqlalchemy import inspect as sqlalchemy_inspect from sqlalchemy.orm import Session, sessionmaker from typing import TYPE_CHECKING, Generator @@ -131,3 +132,59 @@ def get_db() -> Generator[Session, None, None]: yield session finally: session.close() + + +class _AtomicContext: + def __init__(self) -> None: + self._session: Session | None = None + + def __enter__(self) -> Session: + self._session = SessionLocal() + self._session.begin() + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: + if self._session is None: + return + try: + if exc_type is None: + self._session.commit() + else: + self._session.rollback() + finally: + self._session.close() + + +class DatabaseCompat: + """兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。""" + + def connect(self, reuse_if_open: bool = True) -> None: + # SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。 + _ = reuse_if_open + + def create_tables(self, models: list[type], safe: bool = True) -> None: + _ = safe + tables = [model.__table__ for model in models if hasattr(model, "__table__")] + if not tables: + return + from sqlmodel import SQLModel + + SQLModel.metadata.create_all(engine, tables=tables) + + def atomic(self) -> _AtomicContext: + return _AtomicContext() + + def execute_sql(self, sql: str): + with engine.connect() as conn: + result = conn.execute(text(sql)) + conn.commit() + return result + + def table_exists(self, model: type) -> bool: + if not hasattr(model, "__tablename__"): + return False + inspector = sqlalchemy_inspect(engine) + return inspector.has_table(model.__tablename__) + + +db = DatabaseCompat() diff --git a/src/common/logger.py b/src/common/logger.py index b57c9dd3..92306f6e 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -304,7 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress "websockets", "httpcore", "requests", - "peewee", + "sqlalchemy", "openai", "uvicorn", "jieba", @@ -876,19 +876,19 @@ def initialize_logging(verbose: bool = True): """手动初始化日志系统,确保所有logger都使用正确的配置 在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置 - + Args: verbose: 是否输出详细的初始化信息。默认为 True。 在 Runner 进程中可以设置为 False 以避免重复的初始化日志。 """ global LOG_CONFIG, _logging_initialized - + # 防止重复初始化(在同一进程内) if _logging_initialized: return - + _logging_initialized = True - + LOG_CONFIG = load_log_config() # print(LOG_CONFIG) configure_third_party_loggers() @@ -941,16 +941,16 @@ def cleanup_old_logs(): def start_log_cleanup_task(verbose: bool = True): """启动日志清理任务 - + Args: verbose: 是否输出启动信息。默认为 True。 """ global _cleanup_task_started - + # 防止重复启动清理任务 if _cleanup_task_started: return - + _cleanup_task_started = True def cleanup_task(): diff --git a/src/common/message_repository.py b/src/common/message_repository.py index fa0126d4..19cb0544 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,7 +1,6 @@ import traceback from typing import List, Any, Optional -from peewee import Model # 添加 Peewee Model 导入 from src.config.config import global_config from src.common.data_models.database_data_model import DatabaseMessages @@ -11,11 +10,15 @@ from src.common.logger import get_logger logger = get_logger(__name__) -def _model_to_instance(model_instance: Model) -> DatabaseMessages: +def _model_to_instance(model_instance: Any) -> DatabaseMessages: """ 将 Peewee 模型实例转换为字典。 """ - return DatabaseMessages(**model_instance.__data__) + if isinstance(model_instance, dict): + return DatabaseMessages(**model_instance) + if hasattr(model_instance, "model_dump"): + return DatabaseMessages(**model_instance.model_dump()) + return DatabaseMessages(**model_instance.__dict__) def find_messages( @@ -92,14 +95,17 @@ def find_messages( if limit > 0: if limit_mode == "earliest": # 获取时间最早的 limit 条记录,已经是正序 - query = query.order_by(Messages.time.asc()).limit(limit) + query = query.order_by("time").limit(limit) peewee_results = list(query) else: # 默认为 'latest' # 获取时间最晚的 limit 条记录 - query = query.order_by(Messages.time.desc()).limit(limit) + query = query.order_by("-time").limit(limit) latest_results_peewee = list(query) # 将结果按时间正序排列 - peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time) + peewee_results = sorted( + latest_results_peewee, + key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0), + ) else: # limit 为 0 时,应用传入的 sort 参数 if sort: @@ -108,9 +114,9 @@ def find_messages( if hasattr(Messages, field_name): field = getattr(Messages, field_name) if direction == 1: # ASC - peewee_sort_terms.append(field.asc()) + peewee_sort_terms.append(field_name) elif direction == -1: # DESC - peewee_sort_terms.append(field.desc()) + peewee_sort_terms.append(f"-{field_name}") else: logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") else: diff --git a/src/dream/dream_agent.py b/src/dream/dream_agent.py index 30176200..ece9e65a 100644 --- a/src/dream/dream_agent.py +++ b/src/dream/dream_agent.py @@ -3,7 +3,7 @@ import random import time from typing import Any, Dict, List, Optional, Tuple -from peewee import fn +from sqlalchemy import func as fn from src.common.logger import get_logger from src.config.config import global_config, model_config diff --git a/src/plugin_system/apis/database_api.py b/src/plugin_system/apis/database_api.py index bb96aeb3..9af4e078 100644 --- a/src/plugin_system/apis/database_api.py +++ b/src/plugin_system/apis/database_api.py @@ -1,228 +1,97 @@ """数据库API模块 -提供数据库操作相关功能,采用标准Python包设计模式 -使用方式: - from src.plugin_system.apis import database_api - records = await database_api.db_query(ActionRecords, query_type="get") - record = await database_api.db_save(ActionRecords, data={"action_id": "123"}) +提供数据库操作相关功能,统一使用 SQLModel/SQLAlchemy 兼容接口。 """ -import traceback -import time import json -from typing import Dict, List, Any, Union, Type, Optional +import time +import traceback +from typing import Any, Optional + from src.common.logger import get_logger -from peewee import Model, DoesNotExist logger = get_logger("database_api") -# ============================================================================= -# 通用数据库查询API函数 -# ============================================================================= + +def _to_dict(record: Any) -> dict[str, Any]: + if record is None: + return {} + if isinstance(record, dict): + return record + if hasattr(record, "model_dump"): + return record.model_dump() + if hasattr(record, "__dict__"): + return dict(record.__dict__) + return {} async def db_query( - model_class: Type[Model], - data: Optional[Dict[str, Any]] = None, - query_type: Optional[str] = "get", - filters: Optional[Dict[str, Any]] = None, + model_class, + data: Optional[dict[str, Any]] = None, + query_type: str = "get", + filters: Optional[dict[str, Any]] = None, limit: Optional[int] = None, - order_by: Optional[List[str]] = None, - single_result: Optional[bool] = False, -) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """执行数据库查询操作 - - 这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。 - - Args: - model_class: Peewee 模型类,例如 ActionRecords, Messages 等 - data: 用于创建或更新的数据字典 - query_type: 查询类型,可选值: "get", "create", "update", "delete", "count" - filters: 过滤条件字典,键为字段名,值为要匹配的值 - limit: 限制结果数量 - order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序 - single_result: 是否只返回单个结果 - - Returns: - 根据查询类型返回不同的结果: - - "get": 返回查询结果列表或单个结果(如果 single_result=True) - - "create": 返回创建的记录 - - "update": 返回受影响的行数 - - "delete": 返回受影响的行数 - - "count": 返回记录数量 - """ - """ - 示例: - # 查询最近10条消息 - messages = await database_api.db_query( - Messages, - query_type="get", - filters={"chat_id": chat_stream.stream_id}, - limit=10, - order_by=["-time"] - ) - - # 创建一条记录 - new_record = await database_api.db_query( - ActionRecords, - data={"action_id": "123", "time": time.time(), "action_name": "TestAction"}, - query_type="create", - ) - - # 更新记录 - updated_count = await database_api.db_query( - ActionRecords, - data={"action_done": True}, - query_type="update", - filters={"action_id": "123"}, - ) - - # 删除记录 - deleted_count = await database_api.db_query( - ActionRecords, - query_type="delete", - filters={"action_id": "123"} - ) - - # 计数 - count = await database_api.db_query( - Messages, - query_type="count", - filters={"chat_id": chat_stream.stream_id} - ) - """ + order_by: Optional[list[str]] = None, + single_result: bool = False, +): try: if query_type not in ["get", "create", "update", "delete", "count"]: raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'") - # 构建基本查询 - if query_type in ["get", "update", "delete", "count"]: - query = model_class.select() - # 应用过滤条件 + if query_type == "get": + query = model_class.select() if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) - - # 执行查询 - if query_type == "get": - # 应用排序 if order_by: - for field in order_by: - if field.startswith("-"): - query = query.order_by(getattr(model_class, field[1:]).desc()) - else: - query = query.order_by(getattr(model_class, field)) - - # 应用限制 + query = query.order_by(*order_by) if limit: query = query.limit(limit) - - # 执行查询 results = list(query.dicts()) - - # 返回结果 if single_result: return results[0] if results else None return results - elif query_type == "create": + if query_type == "create": if not data: raise ValueError("创建记录需要提供data参数") - - # 创建记录 record = model_class.create(**data) - # 返回创建的记录 - return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore + return _to_dict(record) - elif query_type == "update": + query = model_class.select() + if filters: + for field, value in filters.items(): + query = query.where(getattr(model_class, field) == value) + + if query_type == "update": if not data: raise ValueError("更新记录需要提供data参数") + return query.model_class.update(**data).where(*query.stmt._where_criteria).execute() - # 更新记录 - return query.update(**data).execute() + if query_type == "delete": + return model_class.delete().where(*query.stmt._where_criteria).execute() - elif query_type == "delete": - # 删除记录 - return query.delete().execute() - - elif query_type == "count": - # 计数 - return query.count() - - else: - raise ValueError(f"不支持的查询类型: {query_type}") - - except DoesNotExist: - # 记录不存在 - return None if query_type == "get" and single_result else [] + return query.count() except Exception as e: logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") traceback.print_exc() - - # 根据查询类型返回合适的默认值 if query_type == "get": return None if single_result else [] - elif query_type in ["create", "update", "delete", "count"]: - return None return None -async def db_save( - model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None -) -> Optional[Dict[str, Any]]: - # sourcery skip: inline-immediately-returned-variable - """保存数据到数据库(创建或更新) - - 如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新; - 如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。 - - Args: - model_class: Peewee模型类,如ActionRecords, Messages等 - data: 要保存的数据字典 - key_field: 用于查找现有记录的字段名,例如"action_id" - key_value: 用于查找现有记录的字段值 - - Returns: - Dict[str, Any]: 保存后的记录数据 - None: 如果操作失败 - - 示例: - # 创建或更新一条记录 - record = await database_api.db_save( - ActionRecords, - { - "action_id": "123", - "time": time.time(), - "action_name": "TestAction", - "action_done": True - }, - key_field="action_id", - key_value="123" - ) - """ +async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None): try: - # 如果提供了key_field和key_value,尝试更新现有记录 if key_field and key_value is not None: - if existing_records := list( - model_class.select().where(getattr(model_class, key_field) == key_value).limit(1) - ): - # 更新现有记录 - existing_record = existing_records[0] + record = model_class.get_or_none(getattr(model_class, key_field) == key_value) + if record is not None: for field, value in data.items(): - setattr(existing_record, field, value) - existing_record.save() + setattr(record, field, value) + record.save() + return _to_dict(record) - # 返回更新后的记录 - updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore - return updated_record - - # 如果没有找到现有记录或未提供key_field和key_value,创建新记录 new_record = model_class.create(**data) - - # 返回创建的记录 - created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore - return created_record - + return _to_dict(new_record) except Exception as e: logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}") traceback.print_exc() @@ -230,71 +99,25 @@ async def db_save( async def db_get( - model_class: Type[Model], - filters: Optional[Dict[str, Any]] = None, + model_class, + filters: Optional[dict[str, Any]] = None, limit: Optional[int] = None, order_by: Optional[str] = None, - single_result: Optional[bool] = False, -) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: - """从数据库获取记录 - - 这是db_query方法的简化版本,专注于数据检索操作。 - - Args: - model_class: Peewee模型类 - filters: 过滤条件,字段名和值的字典 - limit: 结果数量限制 - order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序 - single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表 - - Returns: - 如果single_result为True,返回单个记录字典或None; - 否则返回记录字典列表或空列表。 - - 示例: - # 获取单个记录 - record = await database_api.db_get( - ActionRecords, - filters={"action_id": "123"}, - limit=1 - ) - - # 获取最近10条记录 - records = await database_api.db_get( - Messages, - filters={"chat_id": chat_stream.stream_id}, - limit=10, - order_by="-time", - ) - """ + single_result: bool = False, +): try: - # 构建查询 query = model_class.select() - - # 应用过滤条件 if filters: for field, value in filters.items(): query = query.where(getattr(model_class, field) == value) - - # 应用排序 if order_by: - if order_by.startswith("-"): - query = query.order_by(getattr(model_class, order_by[1:]).desc()) - else: - query = query.order_by(getattr(model_class, order_by)) - - # 应用限制 + query = query.order_by(order_by) if limit: query = query.limit(limit) - - # 执行查询 results = list(query.dicts()) - - # 返回结果 if single_result: return results[0] if results else None return results - except Exception as e: logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}") traceback.print_exc() @@ -310,41 +133,12 @@ async def store_action_info( action_data: Optional[dict] = None, action_name: str = "", action_reasoning: str = "", -) -> Optional[Dict[str, Any]]: - """存储动作信息到数据库 - - 将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。 - - Args: - chat_stream: 聊天流对象,包含聊天相关信息 - action_build_into_prompt: 是否将此动作构建到提示中 - action_prompt_display: 动作的提示显示文本 - action_done: 动作是否完成 - thinking_id: 关联的思考ID - action_data: 动作数据字典 - action_name: 动作名称 - action_reasoning: 动作执行理由 - Returns: - Dict[str, Any]: 保存的记录数据 - None: 如果保存失败 - - 示例: - record = await database_api.store_action_info( - chat_stream=chat_stream, - action_build_into_prompt=True, - action_prompt_display="执行了回复动作", - action_done=True, - thinking_id="thinking_123", - action_data={"content": "Hello"}, - action_name="reply_action" - ) - """ +): try: from src.common.database.database_model import ActionRecords - # 构建动作记录数据 record_data = { - "action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID + "action_id": thinking_id or str(int(time.time() * 1000000)), "time": time.time(), "action_name": action_name, "action_data": json.dumps(action_data or {}, ensure_ascii=False), @@ -354,7 +148,6 @@ async def store_action_info( "action_prompt_display": action_prompt_display, } - # 从chat_stream获取聊天信息 if chat_stream: record_data.update( { @@ -364,27 +157,16 @@ async def store_action_info( } ) else: - # 如果没有chat_stream,设置默认值 - record_data.update( - { - "chat_id": "", - "chat_info_stream_id": "", - "chat_info_platform": "", - } - ) + record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""}) - # 使用已有的db_save函数保存记录 saved_record = await db_save( ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] ) - if saved_record: logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") else: logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}") - return saved_record - except Exception as e: logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}") traceback.print_exc() diff --git a/src/webui/routers/annual_report.py b/src/webui/routers/annual_report.py index 68e1f4b9..a0f676f2 100644 --- a/src/webui/routers/annual_report.py +++ b/src/webui/routers/annual_report.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header from pydantic import BaseModel, Field from typing import Dict, Any, List, Optional from datetime import datetime -from peewee import fn +from sqlalchemy import func as fn from src.common.logger import get_logger from src.common.database.database_model import ( @@ -151,9 +151,7 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData: try: # 1. 年度在线时长 online_records = list( - OnlineTime.select().where( - (OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt) - ) + OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)) ) total_seconds = 0 for record in online_records: @@ -235,10 +233,10 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData: async def get_social_network(year: int = 2025) -> SocialNetworkData: """获取社交网络数据""" from src.config.config import global_config - + data = SocialNetworkData() start_ts, end_ts = get_year_time_range(year) - + # 获取 bot 自身的 QQ 账号,用于过滤 bot_qq = str(global_config.bot.qq_account or "") @@ -254,9 +252,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: fn.COUNT(Messages.id).alias("count"), ) .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.chat_info_group_id.is_null(False)) + (Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False)) ) .group_by(Messages.chat_info_group_id) .order_by(fn.COUNT(Messages.id).desc()) @@ -302,22 +298,14 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: # 4. 被@次数 data.at_count = ( Messages.select() - .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.is_at == True) - ) + .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True)) .count() ) # 5. 被提及次数 data.mentioned_count = ( Messages.select() - .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.is_mentioned == True) - ) + .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True)) .count() ) @@ -329,8 +317,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: (ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"), ) .where( - (ChatStreams.user_id.is_null(False)) - & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身 + (ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身 ) .order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc()) .limit(1) @@ -451,20 +438,21 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData: .order_by(fn.COUNT(LLMUsage.id).desc()) .limit(5) ) - data.top_reply_models = [ - {"model": row["model"], "count": row["count"]} - for row in reply_model_query.dicts() - ] + data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()] # 6. 高冷指数 (沉默率) - 基于 ActionRecords - total_actions = ActionRecords.select().where( - (ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts) - ).count() - no_reply_count = ActionRecords.select().where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_name == "no_reply") - ).count() + total_actions = ( + ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count() + ) + no_reply_count = ( + ActionRecords.select() + .where( + (ActionRecords.time >= start_ts) + & (ActionRecords.time <= end_ts) + & (ActionRecords.action_name == "no_reply") + ) + .count() + ) data.total_actions = total_actions data.no_reply_count = no_reply_count data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0 @@ -473,11 +461,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData: interest_query = Messages.select( fn.AVG(Messages.interest_value).alias("avg_interest"), fn.MAX(Messages.interest_value).alias("max_interest"), - ).where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.interest_value.is_null(False)) - ) + ).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False))) interest_result = interest_query.dicts().get() data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2) data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2) @@ -494,19 +478,14 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData: .first() ) if max_interest_msg: - data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime( - "%Y-%m-%d %H:%M:%S" - ) + data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S") # 7. 思考深度 (基于 action_reasoning 长度) - reasoning_records = ( - ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time) - .where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_reasoning.is_null(False)) - & (ActionRecords.action_reasoning != "") - ) + reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where( + (ActionRecords.time >= start_ts) + & (ActionRecords.time <= end_ts) + & (ActionRecords.action_reasoning.is_null(False)) + & (ActionRecords.action_reasoning != "") ) reasoning_lengths = [] max_len = 0 @@ -518,7 +497,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData: if length > max_len: max_len = length max_len_time = record.time - + if reasoning_lengths: data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1) data.max_reasoning_length = max_len @@ -537,10 +516,10 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData: async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: """获取个性与表达数据""" from src.config.config import global_config - + data = ExpressionVibeData() start_ts, end_ts = get_year_time_range(year) - + # 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息 bot_qq = str(global_config.bot.qq_account or "") @@ -578,17 +557,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: Expression.style, fn.SUM(Expression.count).alias("total_count"), ) - .where( - (Expression.last_active_time >= start_ts) - & (Expression.last_active_time <= end_ts) - ) + .where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts)) .group_by(Expression.style) .order_by(fn.SUM(Expression.count).desc()) .limit(5) ) data.top_expressions = [ - {"style": row["style"], "count": row["total_count"]} - for row in expression_query.dicts() + {"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts() ] # 3. 被拒绝的表达 @@ -616,18 +591,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: # 5. 表达总数 data.total_expressions = ( Expression.select() - .where( - (Expression.last_active_time >= start_ts) - & (Expression.last_active_time <= end_ts) - ) + .where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts)) .count() ) # 6. 动作类型分布 (过滤无意义的动作) # 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore excluded_actions = [ - "reply", "no_reply", "no_reply_until_call", "make_question", - "no_action", "wait", "complete_talk", "listening", "block_and_ignore" + "reply", + "no_reply", + "no_reply_until_call", + "make_question", + "no_action", + "wait", + "complete_talk", + "listening", + "block_and_ignore", ] action_query = ( ActionRecords.select( @@ -643,38 +622,31 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: .order_by(fn.COUNT(ActionRecords.id).desc()) .limit(10) ) - data.action_types = [ - {"action": row["action_name"], "count": row["count"]} - for row in action_query.dicts() - ] + data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()] # 7. 处理的图片数量 data.image_processed_count = ( Messages.select() - .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.is_picid == True) - ) + .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True)) .count() ) # 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条) import random import re - + def clean_message_content(content: str) -> str: """清理消息内容,移除回复引用等标记""" if not content: return "" # 移除 [回复 的消息:...] 格式的引用 - content = re.sub(r'\[回复<[^>]+>\s*的消息[::][^\]]*\]', '', content) + content = re.sub(r"\[回复<[^>]+>\s*的消息[::][^\]]*\]", "", content) # 移除 [图片] [表情] 等标记 - content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content) + content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content) # 移除多余的空白 - content = re.sub(r'\s+', ' ', content).strip() + content = re.sub(r"\s+", " ", content).strip() return content - + # 使用 user_id 判断是否是 bot 发送的消息 late_night_messages = list( Messages.select( @@ -683,9 +655,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: Messages.display_message, ) .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.user_id == bot_qq) # bot 发送的消息 + (Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息 ) .order_by(Messages.time.desc()) ) @@ -699,16 +669,18 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: cleaned_content = clean_message_content(raw_content) # 只保留有意义的内容 if cleaned_content and len(cleaned_content) > 2: - late_night_filtered.append({ - "time": msg.time, - "hour": hour, - "minute": msg_dt.minute, - "content": cleaned_content, - "datetime_str": msg_dt.strftime("%H:%M"), - }) + late_night_filtered.append( + { + "time": msg.time, + "hour": hour, + "minute": msg_dt.minute, + "content": cleaned_content, + "datetime_str": msg_dt.strftime("%H:%M"), + } + ) if len(late_night_filtered) >= 10: break - + if late_night_filtered: selected = random.choice(late_night_filtered) content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"] @@ -720,18 +692,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: # 9. 最喜欢的回复(按 action_data 统计回复内容出现次数) from collections import Counter import json as json_lib - - reply_records = ( - ActionRecords.select(ActionRecords.action_data) - .where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_name == "reply") - & (ActionRecords.action_data.is_null(False)) - & (ActionRecords.action_data != "") - ) + + reply_records = ActionRecords.select(ActionRecords.action_data).where( + (ActionRecords.time >= start_ts) + & (ActionRecords.time <= end_ts) + & (ActionRecords.action_name == "reply") + & (ActionRecords.action_data.is_null(False)) + & (ActionRecords.action_data != "") ) - + reply_contents = [] for record in reply_records: try: @@ -748,11 +717,12 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: content = parsed except (json_lib.JSONDecodeError, TypeError): pass - + # 如果 JSON 解析失败,尝试解析 Python 字典字符串格式 # 例如: "{'reply_text': '墨白灵不知道哦'}" if content is None: import ast + try: parsed = ast.literal_eval(action_data) if isinstance(parsed, dict): @@ -762,13 +732,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: except (ValueError, SyntaxError): # 无法解析,使用原始字符串 content = action_data - + # 只统计有意义的回复(长度大于2) if content and len(content) > 2: reply_contents.append(content) except Exception: continue - + if reply_contents: content_counter = Counter(reply_contents) most_common = content_counter.most_common(1) @@ -817,20 +787,12 @@ async def get_achievements(year: int = 2025) -> AchievementData: ] # 3. 总消息数 - data.total_messages = ( - Messages.select() - .where((Messages.time >= start_ts) & (Messages.time <= end_ts)) - .count() - ) + data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count() # 4. 总回复数 (有 reply_to 的消息) data.total_replies = ( Messages.select() - .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.reply_to.is_null(False)) - ) + .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False))) .count() ) @@ -856,9 +818,9 @@ async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require """ try: from src.config.config import global_config - + logger.info(f"开始生成 {year} 年度报告...") - + # 获取 bot 名称 bot_name = global_config.bot.nickname or "麦麦" diff --git a/src/webui/routers/chat.py b/src/webui/routers/chat.py index c7f847ea..f666e85c 100644 --- a/src/webui/routers/chat.py +++ b/src/webui/routers/chat.py @@ -10,6 +10,7 @@ import uuid from typing import Dict, Any, Optional, List from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header from pydantic import BaseModel +from sqlalchemy import case, func as fn from src.common.logger import get_logger from src.common.database.database_model import Messages, PersonInfo @@ -290,8 +291,6 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)): 从 PersonInfo 表中获取所有已知的平台 """ try: - from peewee import fn - # 查询所有不同的平台 platforms = ( PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count")) @@ -337,9 +336,7 @@ async def get_persons_by_platform( ) # 按最后交互时间排序,优先显示活跃用户 - from peewee import Case - - query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc()) + query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc()) query = query.limit(limit) result = [] diff --git a/src/webui/routers/expression.py b/src/webui/routers/expression.py index 1e78d982..67495263 100644 --- a/src/webui/routers/expression.py +++ b/src/webui/routers/expression.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Header, Query, Cookie from pydantic import BaseModel from typing import Optional, List, Dict +from sqlalchemy import case from src.common.logger import get_logger from src.common.database.database_model import Expression, ChatStreams from src.webui.core import verify_auth_token_from_cookie_or_header @@ -231,10 +232,8 @@ async def get_expression_list( query = query.where(Expression.chat_id == chat_id) # 排序:最后活跃时间倒序(NULL 值放在最后) - from peewee import Case - query = query.order_by( - Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc() + case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc() ) # 获取总数 @@ -641,9 +640,7 @@ async def get_review_list( query = query.where(Expression.chat_id == chat_id) # 排序:创建时间倒序 - from peewee import Case - - query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc()) + query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc()) total = query.count() offset = (page - 1) * page_size diff --git a/src/webui/routers/jargon.py b/src/webui/routers/jargon.py index 8d372688..cca15c3b 100644 --- a/src/webui/routers/jargon.py +++ b/src/webui/routers/jargon.py @@ -4,7 +4,7 @@ import json from typing import Optional, List, Annotated from fastapi import APIRouter, HTTPException, Query from pydantic import BaseModel, Field -from peewee import fn +from sqlalchemy import func as fn from src.common.logger import get_logger from src.common.database.database_model import Jargon, ChatStreams diff --git a/src/webui/routers/person.py b/src/webui/routers/person.py index 1368c2a4..de3b8587 100644 --- a/src/webui/routers/person.py +++ b/src/webui/routers/person.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, HTTPException, Header, Query, Cookie from pydantic import BaseModel from typing import Optional, List, Dict +from sqlalchemy import case from src.common.logger import get_logger from src.common.database.database_model import PersonInfo from src.webui.core import verify_auth_token_from_cookie_or_header @@ -176,9 +177,7 @@ async def get_person_list( # 排序:最后更新时间倒序(NULL 值放在最后) # Peewee 不支持 nulls_last,使用 CASE WHEN 来实现 - from peewee import Case - - query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc()) + query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc()) # 获取总数 total = query.count() diff --git a/src/webui/routers/statistics.py b/src/webui/routers/statistics.py index 40770bd6..da49883d 100644 --- a/src/webui/routers/statistics.py +++ b/src/webui/routers/statistics.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header from pydantic import BaseModel, Field from typing import Dict, Any, List, Optional from datetime import datetime, timedelta -from peewee import fn +from sqlalchemy import func as fn from src.common.logger import get_logger from src.common.database.database_model import LLMUsage, OnlineTime, Messages