添加对 peewee 的旧数据库的兼容层,初步重构插件的 database API

pull/1496/head
DrSmoothl 2026-02-09 22:44:56 +08:00
parent bf46d540f1
commit 60f76e4d4e
No known key found for this signature in database
14 changed files with 226 additions and 424 deletions

View File

@ -17,7 +17,6 @@ dependencies = [
"numpy>=2.2.6",
"openai>=1.95.0",
"pandas>=2.3.1",
"peewee>=3.18.2",
"pillow>=11.3.0",
"pyarrow>=20.0.0",
"pydantic>=2.11.7",
@ -29,6 +28,8 @@ dependencies = [
"rich>=14.0.0",
"ruff>=0.12.2",
"setuptools>=80.9.0",
"sqlalchemy>=2.0.40",
"sqlmodel>=0.0.24",
"structlog>=25.4.0",
"toml>=0.10.2",
"tomlkit>=0.13.3",

View File

@ -11,7 +11,6 @@ matplotlib>=3.10.3
numpy>=2.2.6
openai>=1.95.0
pandas>=2.3.1
peewee>=3.18.2
pillow>=11.3.0
pyarrow>=20.0.0
pydantic>=2.11.7
@ -23,8 +22,10 @@ quick-algo>=0.1.3
rich>=14.0.0
ruff>=0.12.2
setuptools>=80.9.0
sqlalchemy>=2.0.40
sqlmodel>=0.0.24
structlog>=25.4.0
toml>=0.10.2
tomlkit>=0.13.3
urllib3>=2.5.0
uvicorn>=0.35.0
uvicorn>=0.35.0

View File

@ -4,7 +4,7 @@ import random
from collections import OrderedDict
from typing import List, Dict, Optional, Callable
from json_repair import repair_json
from peewee import fn
from sqlalchemy import func as fn
from src.common.logger import get_logger
from src.common.database.database_model import Jargon

View File

@ -1,8 +1,9 @@
from rich.traceback import install
from pathlib import Path
from contextlib import contextmanager
from sqlalchemy import create_engine, event
from sqlalchemy import create_engine, event, text
from sqlalchemy.engine import Engine
from sqlalchemy import inspect as sqlalchemy_inspect
from sqlalchemy.orm import Session, sessionmaker
from typing import TYPE_CHECKING, Generator
@ -131,3 +132,59 @@ def get_db() -> Generator[Session, None, None]:
yield session
finally:
session.close()
class _AtomicContext:
def __init__(self) -> None:
self._session: Session | None = None
def __enter__(self) -> Session:
self._session = SessionLocal()
self._session.begin()
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
if self._session is None:
return
try:
if exc_type is None:
self._session.commit()
else:
self._session.rollback()
finally:
self._session.close()
class DatabaseCompat:
"""兼容旧 db 调用接口Peewee 风格),底层使用 SQLAlchemy。"""
def connect(self, reuse_if_open: bool = True) -> None:
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
_ = reuse_if_open
def create_tables(self, models: list[type], safe: bool = True) -> None:
_ = safe
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
if not tables:
return
from sqlmodel import SQLModel
SQLModel.metadata.create_all(engine, tables=tables)
def atomic(self) -> _AtomicContext:
return _AtomicContext()
def execute_sql(self, sql: str):
with engine.connect() as conn:
result = conn.execute(text(sql))
conn.commit()
return result
def table_exists(self, model: type) -> bool:
if not hasattr(model, "__tablename__"):
return False
inspector = sqlalchemy_inspect(engine)
return inspector.has_table(model.__tablename__)
db = DatabaseCompat()

View File

@ -304,7 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
"websockets",
"httpcore",
"requests",
"peewee",
"sqlalchemy",
"openai",
"uvicorn",
"jieba",
@ -876,19 +876,19 @@ def initialize_logging(verbose: bool = True):
"""手动初始化日志系统确保所有logger都使用正确的配置
在应用程序的早期调用此函数确保所有模块都使用统一的日志配置
Args:
verbose: 是否输出详细的初始化信息默认为 True
Runner 进程中可以设置为 False 以避免重复的初始化日志
"""
global LOG_CONFIG, _logging_initialized
# 防止重复初始化(在同一进程内)
if _logging_initialized:
return
_logging_initialized = True
LOG_CONFIG = load_log_config()
# print(LOG_CONFIG)
configure_third_party_loggers()
@ -941,16 +941,16 @@ def cleanup_old_logs():
def start_log_cleanup_task(verbose: bool = True):
"""启动日志清理任务
Args:
verbose: 是否输出启动信息默认为 True
"""
global _cleanup_task_started
# 防止重复启动清理任务
if _cleanup_task_started:
return
_cleanup_task_started = True
def cleanup_task():

View File

@ -1,7 +1,6 @@
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages
@ -11,11 +10,15 @@ from src.common.logger import get_logger
logger = get_logger(__name__)
def _model_to_instance(model_instance: Model) -> DatabaseMessages:
def _model_to_instance(model_instance: Any) -> DatabaseMessages:
"""
Peewee 模型实例转换为字典
"""
return DatabaseMessages(**model_instance.__data__)
if isinstance(model_instance, dict):
return DatabaseMessages(**model_instance)
if hasattr(model_instance, "model_dump"):
return DatabaseMessages(**model_instance.model_dump())
return DatabaseMessages(**model_instance.__dict__)
def find_messages(
@ -92,14 +95,17 @@ def find_messages(
if limit > 0:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
query = query.order_by("time").limit(limit)
peewee_results = list(query)
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
query = query.order_by("-time").limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
peewee_results = sorted(
latest_results_peewee,
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
@ -108,9 +114,9 @@ def find_messages(
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
peewee_sort_terms.append(field_name)
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
peewee_sort_terms.append(f"-{field_name}")
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:

View File

@ -3,7 +3,7 @@ import random
import time
from typing import Any, Dict, List, Optional, Tuple
from peewee import fn
from sqlalchemy import func as fn
from src.common.logger import get_logger
from src.config.config import global_config, model_config

View File

@ -1,228 +1,97 @@
"""数据库API模块
提供数据库操作相关功能采用标准Python包设计模式
使用方式
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
提供数据库操作相关功能统一使用 SQLModel/SQLAlchemy 兼容接口
"""
import traceback
import time
import json
from typing import Dict, List, Any, Union, Type, Optional
import time
import traceback
from typing import Any, Optional
from src.common.logger import get_logger
from peewee import Model, DoesNotExist
logger = get_logger("database_api")
# =============================================================================
# 通用数据库查询API函数
# =============================================================================
def _to_dict(record: Any) -> dict[str, Any]:
if record is None:
return {}
if isinstance(record, dict):
return record
if hasattr(record, "model_dump"):
return record.model_dump()
if hasattr(record, "__dict__"):
return dict(record.__dict__)
return {}
async def db_query(
model_class: Type[Model],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
model_class,
data: Optional[dict[str, Any]] = None,
query_type: str = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作包括查询创建更新和删除记录
Args:
model_class: Peewee 模型类例如 ActionRecords, Messages
data: 用于创建或更新的数据字典
query_type: 查询类型可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典键为字段名值为要匹配的值
limit: 限制结果数量
order_by: 排序字段前缀'-'表示降序例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
"""
"""
示例:
# 查询最近10条消息
messages = await database_api.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
data={"action_done": True},
query_type="update",
filters={"action_id": "123"},
)
# 删除记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
order_by: Optional[list[str]] = None,
single_result: bool = False,
):
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件
if query_type == "get":
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by:
for field in order_by:
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
query = query.order_by(*order_by)
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
elif query_type == "create":
if query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data)
# 返回创建的记录
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
return _to_dict(record)
elif query_type == "update":
query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
# 更新记录
return query.update(**data).execute()
if query_type == "delete":
return model_class.delete().where(*query.stmt._where_criteria).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
return None if query_type == "get" and single_result else []
return query.count()
except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
async def db_save(
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await database_api.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if existing_records := list(
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
):
# 更新现有记录
existing_record = existing_records[0]
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
if record is not None:
for field, value in data.items():
setattr(existing_record, field, value)
existing_record.save()
setattr(record, field, value)
record.save()
return _to_dict(record)
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
return created_record
return _to_dict(new_record)
except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
traceback.print_exc()
@ -230,71 +99,25 @@ async def db_save(
async def db_get(
model_class: Type[Model],
filters: Optional[Dict[str, Any]] = None,
model_class,
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作
Args:
model_class: Peewee模型类
filters: 过滤条件字段名和值的字典
limit: 结果数量限制
order_by: 排序字段前缀'-'表示降序例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
如果single_result为True返回单个记录字典或None
否则返回记录字典列表或空列表
示例:
# 获取单个记录
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time",
)
"""
single_result: bool = False,
):
try:
# 构建查询
query = model_class.select()
# 应用过滤条件
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by:
if order_by.startswith("-"):
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
query = query.order_by(order_by)
if limit:
query = query.limit(limit)
# 执行查询
results = list(query.dicts())
# 返回结果
if single_result:
return results[0] if results else None
return results
except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc()
@ -310,41 +133,12 @@ async def store_action_info(
action_data: Optional[dict] = None,
action_name: str = "",
action_reasoning: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建
Args:
chat_stream: 聊天流对象包含聊天相关信息
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
action_reasoning: 动作执行理由
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
示例:
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
)
"""
):
try:
from src.common.database.database_model import ActionRecords
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
@ -354,7 +148,6 @@ async def store_action_info(
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update(
{
@ -364,27 +157,16 @@ async def store_action_info(
}
)
else:
# 如果没有chat_stream设置默认值
record_data.update(
{
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
}
)
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
# 使用已有的db_save函数保存记录
saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
traceback.print_exc()

View File

@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime
from peewee import fn
from sqlalchemy import func as fn
from src.common.logger import get_logger
from src.common.database.database_model import (
@ -151,9 +151,7 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
try:
# 1. 年度在线时长
online_records = list(
OnlineTime.select().where(
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
)
OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt))
)
total_seconds = 0
for record in online_records:
@ -235,10 +233,10 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
bot_qq = str(global_config.bot.qq_account or "")
@ -254,9 +252,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
fn.COUNT(Messages.id).alias("count"),
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.chat_info_group_id.is_null(False))
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False))
)
.group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc())
@ -302,22 +298,14 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
# 4. 被@次数
data.at_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_at == True)
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True))
.count()
)
# 5. 被提及次数
data.mentioned_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_mentioned == True)
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True))
.count()
)
@ -329,8 +317,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
)
.where(
(ChatStreams.user_id.is_null(False))
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
(ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
)
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1)
@ -451,20 +438,21 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5)
)
data.top_reply_models = [
{"model": row["model"], "count": row["count"]}
for row in reply_model_query.dicts()
]
data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = ActionRecords.select().where(
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
).count()
no_reply_count = ActionRecords.select().where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply")
).count()
total_actions = (
ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count()
)
no_reply_count = (
ActionRecords.select()
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply")
)
.count()
)
data.total_actions = total_actions
data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
@ -473,11 +461,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
interest_query = Messages.select(
fn.AVG(Messages.interest_value).alias("avg_interest"),
fn.MAX(Messages.interest_value).alias("max_interest"),
).where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value.is_null(False))
)
).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False)))
interest_result = interest_query.dicts().get()
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
@ -494,19 +478,14 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
.first()
)
if max_interest_msg:
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
"%Y-%m-%d %H:%M:%S"
)
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S")
# 7. 思考深度 (基于 action_reasoning 长度)
reasoning_records = (
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "")
)
reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "")
)
reasoning_lengths = []
max_len = 0
@ -518,7 +497,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
if length > max_len:
max_len = length
max_len_time = record.time
if reasoning_lengths:
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
data.max_reasoning_length = max_len
@ -537,10 +516,10 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
bot_qq = str(global_config.bot.qq_account or "")
@ -578,17 +557,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
Expression.style,
fn.SUM(Expression.count).alias("total_count"),
)
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
.group_by(Expression.style)
.order_by(fn.SUM(Expression.count).desc())
.limit(5)
)
data.top_expressions = [
{"style": row["style"], "count": row["total_count"]}
for row in expression_query.dicts()
{"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts()
]
# 3. 被拒绝的表达
@ -616,18 +591,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
# 5. 表达总数
data.total_expressions = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
.count()
)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [
"reply", "no_reply", "no_reply_until_call", "make_question",
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
"reply",
"no_reply",
"no_reply_until_call",
"make_question",
"no_action",
"wait",
"complete_talk",
"listening",
"block_and_ignore",
]
action_query = (
ActionRecords.select(
@ -643,38 +622,31 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
.order_by(fn.COUNT(ActionRecords.id).desc())
.limit(10)
)
data.action_types = [
{"action": row["action_name"], "count": row["count"]}
for row in action_query.dicts()
]
data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()]
# 7. 处理的图片数量
data.image_processed_count = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_picid == True)
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True))
.count()
)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
import re
def clean_message_content(content: str) -> str:
"""清理消息内容,移除回复引用等标记"""
if not content:
return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r'\[回复<[^>]+>\s*的消息[:][^\]]*\]', '', content)
content = re.sub(r"\[回复<[^>]+>\s*的消息[:][^\]]*\]", "", content)
# 移除 [图片] [表情] 等标记
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content)
# 移除多余的空白
content = re.sub(r'\s+', ' ', content).strip()
content = re.sub(r"\s+", " ", content).strip()
return content
# 使用 user_id 判断是否是 bot 发送的消息
late_night_messages = list(
Messages.select(
@ -683,9 +655,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
Messages.display_message,
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id == bot_qq) # bot 发送的消息
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息
)
.order_by(Messages.time.desc())
)
@ -699,16 +669,18 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append({
"time": msg.time,
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
})
late_night_filtered.append(
{
"time": msg.time,
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"),
}
)
if len(late_night_filtered) >= 10:
break
if late_night_filtered:
selected = random.choice(late_night_filtered)
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
@ -720,18 +692,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
from collections import Counter
import json as json_lib
reply_records = (
ActionRecords.select(ActionRecords.action_data)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
reply_records = ActionRecords.select(ActionRecords.action_data).where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
reply_contents = []
for record in reply_records:
try:
@ -748,11 +717,12 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
content = parsed
except (json_lib.JSONDecodeError, TypeError):
pass
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
# 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None:
import ast
try:
parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict):
@ -762,13 +732,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
except (ValueError, SyntaxError):
# 无法解析,使用原始字符串
content = action_data
# 只统计有意义的回复长度大于2
if content and len(content) > 2:
reply_contents.append(content)
except Exception:
continue
if reply_contents:
content_counter = Counter(reply_contents)
most_common = content_counter.most_common(1)
@ -817,20 +787,12 @@ async def get_achievements(year: int = 2025) -> AchievementData:
]
# 3. 总消息数
data.total_messages = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.count()
)
data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count()
# 4. 总回复数 (有 reply_to 的消息)
data.total_replies = (
Messages.select()
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.reply_to.is_null(False))
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False)))
.count()
)
@ -856,9 +818,9 @@ async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require
"""
try:
from src.config.config import global_config
logger.info(f"开始生成 {year} 年度报告...")
# 获取 bot 名称
bot_name = global_config.bot.nickname or "麦麦"

View File

@ -10,6 +10,7 @@ import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from sqlalchemy import case, func as fn
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
@ -290,8 +291,6 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)):
PersonInfo 表中获取所有已知的平台
"""
try:
from peewee import fn
# 查询所有不同的平台
platforms = (
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
@ -337,9 +336,7 @@ async def get_persons_by_platform(
)
# 按最后交互时间排序,优先显示活跃用户
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
query = query.limit(limit)
result = []

View File

@ -3,6 +3,7 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from sqlalchemy import case
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from src.webui.core import verify_auth_token_from_cookie_or_header
@ -231,10 +232,8 @@ async def get_expression_list(
query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
)
# 获取总数
@ -641,9 +640,7 @@ async def get_review_list(
query = query.where(Expression.chat_id == chat_id)
# 排序:创建时间倒序
from peewee import Case
query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc())
query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
total = query.count()
offset = (page - 1) * page_size

View File

@ -4,7 +4,7 @@ import json
from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field
from peewee import fn
from sqlalchemy import func as fn
from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams

View File

@ -3,6 +3,7 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from sqlalchemy import case
from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo
from src.webui.core import verify_auth_token_from_cookie_or_header
@ -176,9 +177,7 @@ async def get_person_list(
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
# 获取总数
total = query.count()

View File

@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from peewee import fn
from sqlalchemy import func as fn
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages