添加对 peewee 的旧数据库的兼容层,初步重构插件的 database API

pull/1496/head
DrSmoothl 2026-02-09 22:44:56 +08:00
parent bf46d540f1
commit 60f76e4d4e
No known key found for this signature in database
14 changed files with 226 additions and 424 deletions

View File

@ -17,7 +17,6 @@ dependencies = [
"numpy>=2.2.6", "numpy>=2.2.6",
"openai>=1.95.0", "openai>=1.95.0",
"pandas>=2.3.1", "pandas>=2.3.1",
"peewee>=3.18.2",
"pillow>=11.3.0", "pillow>=11.3.0",
"pyarrow>=20.0.0", "pyarrow>=20.0.0",
"pydantic>=2.11.7", "pydantic>=2.11.7",
@ -29,6 +28,8 @@ dependencies = [
"rich>=14.0.0", "rich>=14.0.0",
"ruff>=0.12.2", "ruff>=0.12.2",
"setuptools>=80.9.0", "setuptools>=80.9.0",
"sqlalchemy>=2.0.40",
"sqlmodel>=0.0.24",
"structlog>=25.4.0", "structlog>=25.4.0",
"toml>=0.10.2", "toml>=0.10.2",
"tomlkit>=0.13.3", "tomlkit>=0.13.3",

View File

@ -11,7 +11,6 @@ matplotlib>=3.10.3
numpy>=2.2.6 numpy>=2.2.6
openai>=1.95.0 openai>=1.95.0
pandas>=2.3.1 pandas>=2.3.1
peewee>=3.18.2
pillow>=11.3.0 pillow>=11.3.0
pyarrow>=20.0.0 pyarrow>=20.0.0
pydantic>=2.11.7 pydantic>=2.11.7
@ -23,6 +22,8 @@ quick-algo>=0.1.3
rich>=14.0.0 rich>=14.0.0
ruff>=0.12.2 ruff>=0.12.2
setuptools>=80.9.0 setuptools>=80.9.0
sqlalchemy>=2.0.40
sqlmodel>=0.0.24
structlog>=25.4.0 structlog>=25.4.0
toml>=0.10.2 toml>=0.10.2
tomlkit>=0.13.3 tomlkit>=0.13.3

View File

@ -4,7 +4,7 @@ import random
from collections import OrderedDict from collections import OrderedDict
from typing import List, Dict, Optional, Callable from typing import List, Dict, Optional, Callable
from json_repair import repair_json from json_repair import repair_json
from peewee import fn from sqlalchemy import func as fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Jargon from src.common.database.database_model import Jargon

View File

@ -1,8 +1,9 @@
from rich.traceback import install from rich.traceback import install
from pathlib import Path from pathlib import Path
from contextlib import contextmanager from contextlib import contextmanager
from sqlalchemy import create_engine, event from sqlalchemy import create_engine, event, text
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy import inspect as sqlalchemy_inspect
from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.orm import Session, sessionmaker
from typing import TYPE_CHECKING, Generator from typing import TYPE_CHECKING, Generator
@ -131,3 +132,59 @@ def get_db() -> Generator[Session, None, None]:
yield session yield session
finally: finally:
session.close() session.close()
class _AtomicContext:
def __init__(self) -> None:
self._session: Session | None = None
def __enter__(self) -> Session:
self._session = SessionLocal()
self._session.begin()
return self._session
def __exit__(self, exc_type, exc, tb) -> None:
if self._session is None:
return
try:
if exc_type is None:
self._session.commit()
else:
self._session.rollback()
finally:
self._session.close()
class DatabaseCompat:
"""兼容旧 db 调用接口Peewee 风格),底层使用 SQLAlchemy。"""
def connect(self, reuse_if_open: bool = True) -> None:
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
_ = reuse_if_open
def create_tables(self, models: list[type], safe: bool = True) -> None:
_ = safe
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
if not tables:
return
from sqlmodel import SQLModel
SQLModel.metadata.create_all(engine, tables=tables)
def atomic(self) -> _AtomicContext:
return _AtomicContext()
def execute_sql(self, sql: str):
with engine.connect() as conn:
result = conn.execute(text(sql))
conn.commit()
return result
def table_exists(self, model: type) -> bool:
if not hasattr(model, "__tablename__"):
return False
inspector = sqlalchemy_inspect(engine)
return inspector.has_table(model.__tablename__)
db = DatabaseCompat()

View File

@ -304,7 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
"websockets", "websockets",
"httpcore", "httpcore",
"requests", "requests",
"peewee", "sqlalchemy",
"openai", "openai",
"uvicorn", "uvicorn",
"jieba", "jieba",

View File

@ -1,7 +1,6 @@
import traceback import traceback
from typing import List, Any, Optional from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from src.config.config import global_config from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
@ -11,11 +10,15 @@ from src.common.logger import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def _model_to_instance(model_instance: Model) -> DatabaseMessages: def _model_to_instance(model_instance: Any) -> DatabaseMessages:
""" """
Peewee 模型实例转换为字典 Peewee 模型实例转换为字典
""" """
return DatabaseMessages(**model_instance.__data__) if isinstance(model_instance, dict):
return DatabaseMessages(**model_instance)
if hasattr(model_instance, "model_dump"):
return DatabaseMessages(**model_instance.model_dump())
return DatabaseMessages(**model_instance.__dict__)
def find_messages( def find_messages(
@ -92,14 +95,17 @@ def find_messages(
if limit > 0: if limit > 0:
if limit_mode == "earliest": if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序 # 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit) query = query.order_by("time").limit(limit)
peewee_results = list(query) peewee_results = list(query)
else: # 默认为 'latest' else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录 # 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit) query = query.order_by("-time").limit(limit)
latest_results_peewee = list(query) latest_results_peewee = list(query)
# 将结果按时间正序排列 # 将结果按时间正序排列
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time) peewee_results = sorted(
latest_results_peewee,
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
)
else: else:
# limit 为 0 时,应用传入的 sort 参数 # limit 为 0 时,应用传入的 sort 参数
if sort: if sort:
@ -108,9 +114,9 @@ def find_messages(
if hasattr(Messages, field_name): if hasattr(Messages, field_name):
field = getattr(Messages, field_name) field = getattr(Messages, field_name)
if direction == 1: # ASC if direction == 1: # ASC
peewee_sort_terms.append(field.asc()) peewee_sort_terms.append(field_name)
elif direction == -1: # DESC elif direction == -1: # DESC
peewee_sort_terms.append(field.desc()) peewee_sort_terms.append(f"-{field_name}")
else: else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else: else:

View File

@ -3,7 +3,7 @@ import random
import time import time
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from peewee import fn from sqlalchemy import func as fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config

View File

@ -1,228 +1,97 @@
"""数据库API模块 """数据库API模块
提供数据库操作相关功能采用标准Python包设计模式 提供数据库操作相关功能统一使用 SQLModel/SQLAlchemy 兼容接口
使用方式
from src.plugin_system.apis import database_api
records = await database_api.db_query(ActionRecords, query_type="get")
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
""" """
import traceback
import time
import json import json
from typing import Dict, List, Any, Union, Type, Optional import time
import traceback
from typing import Any, Optional
from src.common.logger import get_logger from src.common.logger import get_logger
from peewee import Model, DoesNotExist
logger = get_logger("database_api") logger = get_logger("database_api")
# =============================================================================
# 通用数据库查询API函数 def _to_dict(record: Any) -> dict[str, Any]:
# ============================================================================= if record is None:
return {}
if isinstance(record, dict):
return record
if hasattr(record, "model_dump"):
return record.model_dump()
if hasattr(record, "__dict__"):
return dict(record.__dict__)
return {}
async def db_query( async def db_query(
model_class: Type[Model], model_class,
data: Optional[Dict[str, Any]] = None, data: Optional[dict[str, Any]] = None,
query_type: Optional[str] = "get", query_type: str = "get",
filters: Optional[Dict[str, Any]] = None, filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
order_by: Optional[List[str]] = None, order_by: Optional[list[str]] = None,
single_result: Optional[bool] = False, single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: ):
"""执行数据库查询操作
这个方法提供了一个通用接口来执行数据库操作包括查询创建更新和删除记录
Args:
model_class: Peewee 模型类例如 ActionRecords, Messages
data: 用于创建或更新的数据字典
query_type: 查询类型可选值: "get", "create", "update", "delete", "count"
filters: 过滤条件字典键为字段名值为要匹配的值
limit: 限制结果数量
order_by: 排序字段前缀'-'表示降序例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回不同的结果:
- "get": 返回查询结果列表或单个结果如果 single_result=True
- "create": 返回创建的记录
- "update": 返回受影响的行数
- "delete": 返回受影响的行数
- "count": 返回记录数量
"""
"""
示例:
# 查询最近10条消息
messages = await database_api.db_query(
Messages,
query_type="get",
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by=["-time"]
)
# 创建一条记录
new_record = await database_api.db_query(
ActionRecords,
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
query_type="create",
)
# 更新记录
updated_count = await database_api.db_query(
ActionRecords,
data={"action_done": True},
query_type="update",
filters={"action_id": "123"},
)
# 删除记录
deleted_count = await database_api.db_query(
ActionRecords,
query_type="delete",
filters={"action_id": "123"}
)
# 计数
count = await database_api.db_query(
Messages,
query_type="count",
filters={"chat_id": chat_stream.stream_id}
)
"""
try: try:
if query_type not in ["get", "create", "update", "delete", "count"]: if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'") raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
# 构建基本查询
if query_type in ["get", "update", "delete", "count"]:
query = model_class.select()
# 应用过滤条件 if query_type == "get":
query = model_class.select()
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
query = query.where(getattr(model_class, field) == value) query = query.where(getattr(model_class, field) == value)
# 执行查询
if query_type == "get":
# 应用排序
if order_by: if order_by:
for field in order_by: query = query.order_by(*order_by)
if field.startswith("-"):
query = query.order_by(getattr(model_class, field[1:]).desc())
else:
query = query.order_by(getattr(model_class, field))
# 应用限制
if limit: if limit:
query = query.limit(limit) query = query.limit(limit)
# 执行查询
results = list(query.dicts()) results = list(query.dicts())
# 返回结果
if single_result: if single_result:
return results[0] if results else None return results[0] if results else None
return results return results
elif query_type == "create": if query_type == "create":
if not data: if not data:
raise ValueError("创建记录需要提供data参数") raise ValueError("创建记录需要提供data参数")
# 创建记录
record = model_class.create(**data) record = model_class.create(**data)
# 返回创建的记录 return _to_dict(record)
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
elif query_type == "update": query = model_class.select()
if filters:
for field, value in filters.items():
query = query.where(getattr(model_class, field) == value)
if query_type == "update":
if not data: if not data:
raise ValueError("更新记录需要提供data参数") raise ValueError("更新记录需要提供data参数")
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
# 更新记录 if query_type == "delete":
return query.update(**data).execute() return model_class.delete().where(*query.stmt._where_criteria).execute()
elif query_type == "delete":
# 删除记录
return query.delete().execute()
elif query_type == "count":
# 计数
return query.count() return query.count()
else:
raise ValueError(f"不支持的查询类型: {query_type}")
except DoesNotExist:
# 记录不存在
return None if query_type == "get" and single_result else []
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}") logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
traceback.print_exc() traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get": if query_type == "get":
return None if single_result else [] return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None return None
async def db_save( async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
# sourcery skip: inline-immediately-returned-variable
"""保存数据到数据库(创建或更新)
如果提供了key_field和key_value会先尝试查找匹配的记录进行更新
如果没有找到匹配记录或未提供key_field和key_value则创建新记录
Args:
model_class: Peewee模型类如ActionRecords, Messages等
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名例如"action_id"
key_value: 用于查找现有记录的字段值
Returns:
Dict[str, Any]: 保存后的记录数据
None: 如果操作失败
示例:
# 创建或更新一条记录
record = await database_api.db_save(
ActionRecords,
{
"action_id": "123",
"time": time.time(),
"action_name": "TestAction",
"action_done": True
},
key_field="action_id",
key_value="123"
)
"""
try: try:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None: if key_field and key_value is not None:
if existing_records := list( record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1) if record is not None:
):
# 更新现有记录
existing_record = existing_records[0]
for field, value in data.items(): for field, value in data.items():
setattr(existing_record, field, value) setattr(record, field, value)
existing_record.save() record.save()
return _to_dict(record)
# 返回更新后的记录
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
return updated_record
# 如果没有找到现有记录或未提供key_field和key_value创建新记录
new_record = model_class.create(**data) new_record = model_class.create(**data)
return _to_dict(new_record)
# 返回创建的记录
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
return created_record
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}") logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
traceback.print_exc() traceback.print_exc()
@ -230,71 +99,25 @@ async def db_save(
async def db_get( async def db_get(
model_class: Type[Model], model_class,
filters: Optional[Dict[str, Any]] = None, filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
order_by: Optional[str] = None, order_by: Optional[str] = None,
single_result: Optional[bool] = False, single_result: bool = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]: ):
"""从数据库获取记录
这是db_query方法的简化版本专注于数据检索操作
Args:
model_class: Peewee模型类
filters: 过滤条件字段名和值的字典
limit: 结果数量限制
order_by: 排序字段前缀'-'表示降序例如'-time'表示按时间字段即time字段降序
single_result: 是否只返回单个结果如果为True则返回单个记录字典或None否则返回记录字典列表或空列表
Returns:
如果single_result为True返回单个记录字典或None
否则返回记录字典列表或空列表
示例:
# 获取单个记录
record = await database_api.db_get(
ActionRecords,
filters={"action_id": "123"},
limit=1
)
# 获取最近10条记录
records = await database_api.db_get(
Messages,
filters={"chat_id": chat_stream.stream_id},
limit=10,
order_by="-time",
)
"""
try: try:
# 构建查询
query = model_class.select() query = model_class.select()
# 应用过滤条件
if filters: if filters:
for field, value in filters.items(): for field, value in filters.items():
query = query.where(getattr(model_class, field) == value) query = query.where(getattr(model_class, field) == value)
# 应用排序
if order_by: if order_by:
if order_by.startswith("-"): query = query.order_by(order_by)
query = query.order_by(getattr(model_class, order_by[1:]).desc())
else:
query = query.order_by(getattr(model_class, order_by))
# 应用限制
if limit: if limit:
query = query.limit(limit) query = query.limit(limit)
# 执行查询
results = list(query.dicts()) results = list(query.dicts())
# 返回结果
if single_result: if single_result:
return results[0] if results else None return results[0] if results else None
return results return results
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}") logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
traceback.print_exc() traceback.print_exc()
@ -310,41 +133,12 @@ async def store_action_info(
action_data: Optional[dict] = None, action_data: Optional[dict] = None,
action_name: str = "", action_name: str = "",
action_reasoning: str = "", action_reasoning: str = "",
) -> Optional[Dict[str, Any]]: ):
"""存储动作信息到数据库
将Action执行的相关信息保存到ActionRecords表中用于后续的记忆和上下文构建
Args:
chat_stream: 聊天流对象包含聊天相关信息
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
action_reasoning: 动作执行理由
Returns:
Dict[str, Any]: 保存的记录数据
None: 如果保存失败
示例:
record = await database_api.store_action_info(
chat_stream=chat_stream,
action_build_into_prompt=True,
action_prompt_display="执行了回复动作",
action_done=True,
thinking_id="thinking_123",
action_data={"content": "Hello"},
action_name="reply_action"
)
"""
try: try:
from src.common.database.database_model import ActionRecords from src.common.database.database_model import ActionRecords
# 构建动作记录数据
record_data = { record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID "action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(), "time": time.time(),
"action_name": action_name, "action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False), "action_data": json.dumps(action_data or {}, ensure_ascii=False),
@ -354,7 +148,6 @@ async def store_action_info(
"action_prompt_display": action_prompt_display, "action_prompt_display": action_prompt_display,
} }
# 从chat_stream获取聊天信息
if chat_stream: if chat_stream:
record_data.update( record_data.update(
{ {
@ -364,27 +157,16 @@ async def store_action_info(
} }
) )
else: else:
# 如果没有chat_stream设置默认值 record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
record_data.update(
{
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
}
)
# 使用已有的db_save函数保存记录
saved_record = await db_save( saved_record = await db_save(
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
) )
if saved_record: if saved_record:
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else: else:
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}") logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
return saved_record return saved_record
except Exception as e: except Exception as e:
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}") logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
traceback.print_exc() traceback.print_exc()

View File

@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime from datetime import datetime
from peewee import fn from sqlalchemy import func as fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import ( from src.common.database.database_model import (
@ -151,9 +151,7 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
try: try:
# 1. 年度在线时长 # 1. 年度在线时长
online_records = list( online_records = list(
OnlineTime.select().where( OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt))
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
)
) )
total_seconds = 0 total_seconds = 0
for record in online_records: for record in online_records:
@ -254,9 +252,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
fn.COUNT(Messages.id).alias("count"), fn.COUNT(Messages.id).alias("count"),
) )
.where( .where(
(Messages.time >= start_ts) (Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False))
& (Messages.time <= end_ts)
& (Messages.chat_info_group_id.is_null(False))
) )
.group_by(Messages.chat_info_group_id) .group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc()) .order_by(fn.COUNT(Messages.id).desc())
@ -302,22 +298,14 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
# 4. 被@次数 # 4. 被@次数
data.at_count = ( data.at_count = (
Messages.select() Messages.select()
.where( .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True))
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_at == True)
)
.count() .count()
) )
# 5. 被提及次数 # 5. 被提及次数
data.mentioned_count = ( data.mentioned_count = (
Messages.select() Messages.select()
.where( .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True))
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_mentioned == True)
)
.count() .count()
) )
@ -329,8 +317,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"), (ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
) )
.where( .where(
(ChatStreams.user_id.is_null(False)) (ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
) )
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc()) .order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1) .limit(1)
@ -451,20 +438,21 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
.order_by(fn.COUNT(LLMUsage.id).desc()) .order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5) .limit(5)
) )
data.top_reply_models = [ data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()]
{"model": row["model"], "count": row["count"]}
for row in reply_model_query.dicts()
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords # 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = ActionRecords.select().where( total_actions = (
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts) ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count()
).count() )
no_reply_count = ActionRecords.select().where( no_reply_count = (
ActionRecords.select()
.where(
(ActionRecords.time >= start_ts) (ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts) & (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply") & (ActionRecords.action_name == "no_reply")
).count() )
.count()
)
data.total_actions = total_actions data.total_actions = total_actions
data.no_reply_count = no_reply_count data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0 data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
@ -473,11 +461,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
interest_query = Messages.select( interest_query = Messages.select(
fn.AVG(Messages.interest_value).alias("avg_interest"), fn.AVG(Messages.interest_value).alias("avg_interest"),
fn.MAX(Messages.interest_value).alias("max_interest"), fn.MAX(Messages.interest_value).alias("max_interest"),
).where( ).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False)))
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value.is_null(False))
)
interest_result = interest_query.dicts().get() interest_result = interest_query.dicts().get()
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2) data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2) data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
@ -494,20 +478,15 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
.first() .first()
) )
if max_interest_msg: if max_interest_msg:
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime( data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S")
"%Y-%m-%d %H:%M:%S"
)
# 7. 思考深度 (基于 action_reasoning 长度) # 7. 思考深度 (基于 action_reasoning 长度)
reasoning_records = ( reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where(
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
.where(
(ActionRecords.time >= start_ts) (ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts) & (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False)) & (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "") & (ActionRecords.action_reasoning != "")
) )
)
reasoning_lengths = [] reasoning_lengths = []
max_len = 0 max_len = 0
max_len_time = None max_len_time = None
@ -578,17 +557,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
Expression.style, Expression.style,
fn.SUM(Expression.count).alias("total_count"), fn.SUM(Expression.count).alias("total_count"),
) )
.where( .where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.group_by(Expression.style) .group_by(Expression.style)
.order_by(fn.SUM(Expression.count).desc()) .order_by(fn.SUM(Expression.count).desc())
.limit(5) .limit(5)
) )
data.top_expressions = [ data.top_expressions = [
{"style": row["style"], "count": row["total_count"]} {"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts()
for row in expression_query.dicts()
] ]
# 3. 被拒绝的表达 # 3. 被拒绝的表达
@ -616,18 +591,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
# 5. 表达总数 # 5. 表达总数
data.total_expressions = ( data.total_expressions = (
Expression.select() Expression.select()
.where( .where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
)
.count() .count()
) )
# 6. 动作类型分布 (过滤无意义的动作) # 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore # 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
excluded_actions = [ excluded_actions = [
"reply", "no_reply", "no_reply_until_call", "make_question", "reply",
"no_action", "wait", "complete_talk", "listening", "block_and_ignore" "no_reply",
"no_reply_until_call",
"make_question",
"no_action",
"wait",
"complete_talk",
"listening",
"block_and_ignore",
] ]
action_query = ( action_query = (
ActionRecords.select( ActionRecords.select(
@ -643,19 +622,12 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
.order_by(fn.COUNT(ActionRecords.id).desc()) .order_by(fn.COUNT(ActionRecords.id).desc())
.limit(10) .limit(10)
) )
data.action_types = [ data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()]
{"action": row["action_name"], "count": row["count"]}
for row in action_query.dicts()
]
# 7. 处理的图片数量 # 7. 处理的图片数量
data.image_processed_count = ( data.image_processed_count = (
Messages.select() Messages.select()
.where( .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True))
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.is_picid == True)
)
.count() .count()
) )
@ -668,11 +640,11 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
if not content: if not content:
return "" return ""
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用 # 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
content = re.sub(r'\[回复<[^>]+>\s*的消息[:][^\]]*\]', '', content) content = re.sub(r"\[回复<[^>]+>\s*的消息[:][^\]]*\]", "", content)
# 移除 [图片] [表情] 等标记 # 移除 [图片] [表情] 等标记
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content) content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content)
# 移除多余的空白 # 移除多余的空白
content = re.sub(r'\s+', ' ', content).strip() content = re.sub(r"\s+", " ", content).strip()
return content return content
# 使用 user_id 判断是否是 bot 发送的消息 # 使用 user_id 判断是否是 bot 发送的消息
@ -683,9 +655,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
Messages.display_message, Messages.display_message,
) )
.where( .where(
(Messages.time >= start_ts) (Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息
& (Messages.time <= end_ts)
& (Messages.user_id == bot_qq) # bot 发送的消息
) )
.order_by(Messages.time.desc()) .order_by(Messages.time.desc())
) )
@ -699,13 +669,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
cleaned_content = clean_message_content(raw_content) cleaned_content = clean_message_content(raw_content)
# 只保留有意义的内容 # 只保留有意义的内容
if cleaned_content and len(cleaned_content) > 2: if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append({ late_night_filtered.append(
{
"time": msg.time, "time": msg.time,
"hour": hour, "hour": hour,
"minute": msg_dt.minute, "minute": msg_dt.minute,
"content": cleaned_content, "content": cleaned_content,
"datetime_str": msg_dt.strftime("%H:%M"), "datetime_str": msg_dt.strftime("%H:%M"),
}) }
)
if len(late_night_filtered) >= 10: if len(late_night_filtered) >= 10:
break break
@ -721,16 +693,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
from collections import Counter from collections import Counter
import json as json_lib import json as json_lib
reply_records = ( reply_records = ActionRecords.select(ActionRecords.action_data).where(
ActionRecords.select(ActionRecords.action_data)
.where(
(ActionRecords.time >= start_ts) (ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts) & (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply") & (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False)) & (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "") & (ActionRecords.action_data != "")
) )
)
reply_contents = [] reply_contents = []
for record in reply_records: for record in reply_records:
@ -753,6 +722,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
# 例如: "{'reply_text': '墨白灵不知道哦'}" # 例如: "{'reply_text': '墨白灵不知道哦'}"
if content is None: if content is None:
import ast import ast
try: try:
parsed = ast.literal_eval(action_data) parsed = ast.literal_eval(action_data)
if isinstance(parsed, dict): if isinstance(parsed, dict):
@ -817,20 +787,12 @@ async def get_achievements(year: int = 2025) -> AchievementData:
] ]
# 3. 总消息数 # 3. 总消息数
data.total_messages = ( data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count()
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.count()
)
# 4. 总回复数 (有 reply_to 的消息) # 4. 总回复数 (有 reply_to 的消息)
data.total_replies = ( data.total_replies = (
Messages.select() Messages.select()
.where( .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False)))
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.reply_to.is_null(False))
)
.count() .count()
) )

View File

@ -10,6 +10,7 @@ import uuid
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy import case, func as fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo from src.common.database.database_model import Messages, PersonInfo
@ -290,8 +291,6 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)):
PersonInfo 表中获取所有已知的平台 PersonInfo 表中获取所有已知的平台
""" """
try: try:
from peewee import fn
# 查询所有不同的平台 # 查询所有不同的平台
platforms = ( platforms = (
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count")) PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
@ -337,9 +336,7 @@ async def get_persons_by_platform(
) )
# 按最后交互时间排序,优先显示活跃用户 # 按最后交互时间排序,优先显示活跃用户
from peewee import Case query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
query = query.limit(limit) query = query.limit(limit)
result = [] result = []

View File

@ -3,6 +3,7 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List, Dict from typing import Optional, List, Dict
from sqlalchemy import case
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams from src.common.database.database_model import Expression, ChatStreams
from src.webui.core import verify_auth_token_from_cookie_or_header from src.webui.core import verify_auth_token_from_cookie_or_header
@ -231,10 +232,8 @@ async def get_expression_list(
query = query.where(Expression.chat_id == chat_id) query = query.where(Expression.chat_id == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后) # 排序最后活跃时间倒序NULL 值放在最后)
from peewee import Case
query = query.order_by( query = query.order_by(
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc() case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
) )
# 获取总数 # 获取总数
@ -641,9 +640,7 @@ async def get_review_list(
query = query.where(Expression.chat_id == chat_id) query = query.where(Expression.chat_id == chat_id)
# 排序:创建时间倒序 # 排序:创建时间倒序
from peewee import Case query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc())
total = query.count() total = query.count()
offset = (page - 1) * page_size offset = (page - 1) * page_size

View File

@ -4,7 +4,7 @@ import json
from typing import Optional, List, Annotated from typing import Optional, List, Annotated
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from peewee import fn from sqlalchemy import func as fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import Jargon, ChatStreams from src.common.database.database_model import Jargon, ChatStreams

View File

@ -3,6 +3,7 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List, Dict from typing import Optional, List, Dict
from sqlalchemy import case
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import PersonInfo from src.common.database.database_model import PersonInfo
from src.webui.core import verify_auth_token_from_cookie_or_header from src.webui.core import verify_auth_token_from_cookie_or_header
@ -176,9 +177,7 @@ async def get_person_list(
# 排序最后更新时间倒序NULL 值放在最后) # 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现 # Peewee 不支持 nulls_last使用 CASE WHEN 来实现
from peewee import Case query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
# 获取总数 # 获取总数
total = query.count() total = query.count()

View File

@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from peewee import fn from sqlalchemy import func as fn
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages from src.common.database.database_model import LLMUsage, OnlineTime, Messages