mirror of https://github.com/Mai-with-u/MaiBot.git
添加对 peewee 的旧数据库的兼容层,初步重构插件的 database API
parent
bf46d540f1
commit
60f76e4d4e
|
|
@ -17,7 +17,6 @@ dependencies = [
|
|||
"numpy>=2.2.6",
|
||||
"openai>=1.95.0",
|
||||
"pandas>=2.3.1",
|
||||
"peewee>=3.18.2",
|
||||
"pillow>=11.3.0",
|
||||
"pyarrow>=20.0.0",
|
||||
"pydantic>=2.11.7",
|
||||
|
|
@ -29,6 +28,8 @@ dependencies = [
|
|||
"rich>=14.0.0",
|
||||
"ruff>=0.12.2",
|
||||
"setuptools>=80.9.0",
|
||||
"sqlalchemy>=2.0.40",
|
||||
"sqlmodel>=0.0.24",
|
||||
"structlog>=25.4.0",
|
||||
"toml>=0.10.2",
|
||||
"tomlkit>=0.13.3",
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ matplotlib>=3.10.3
|
|||
numpy>=2.2.6
|
||||
openai>=1.95.0
|
||||
pandas>=2.3.1
|
||||
peewee>=3.18.2
|
||||
pillow>=11.3.0
|
||||
pyarrow>=20.0.0
|
||||
pydantic>=2.11.7
|
||||
|
|
@ -23,8 +22,10 @@ quick-algo>=0.1.3
|
|||
rich>=14.0.0
|
||||
ruff>=0.12.2
|
||||
setuptools>=80.9.0
|
||||
sqlalchemy>=2.0.40
|
||||
sqlmodel>=0.0.24
|
||||
structlog>=25.4.0
|
||||
toml>=0.10.2
|
||||
tomlkit>=0.13.3
|
||||
urllib3>=2.5.0
|
||||
uvicorn>=0.35.0
|
||||
uvicorn>=0.35.0
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import random
|
|||
from collections import OrderedDict
|
||||
from typing import List, Dict, Optional, Callable
|
||||
from json_repair import repair_json
|
||||
from peewee import fn
|
||||
from sqlalchemy import func as fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
from rich.traceback import install
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing import TYPE_CHECKING, Generator
|
||||
|
||||
|
|
@ -131,3 +132,59 @@ def get_db() -> Generator[Session, None, None]:
|
|||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _AtomicContext:
|
||||
def __init__(self) -> None:
|
||||
self._session: Session | None = None
|
||||
|
||||
def __enter__(self) -> Session:
|
||||
self._session = SessionLocal()
|
||||
self._session.begin()
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if exc_type is None:
|
||||
self._session.commit()
|
||||
else:
|
||||
self._session.rollback()
|
||||
finally:
|
||||
self._session.close()
|
||||
|
||||
|
||||
class DatabaseCompat:
|
||||
"""兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。"""
|
||||
|
||||
def connect(self, reuse_if_open: bool = True) -> None:
|
||||
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
|
||||
_ = reuse_if_open
|
||||
|
||||
def create_tables(self, models: list[type], safe: bool = True) -> None:
|
||||
_ = safe
|
||||
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
|
||||
if not tables:
|
||||
return
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
SQLModel.metadata.create_all(engine, tables=tables)
|
||||
|
||||
def atomic(self) -> _AtomicContext:
|
||||
return _AtomicContext()
|
||||
|
||||
def execute_sql(self, sql: str):
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
conn.commit()
|
||||
return result
|
||||
|
||||
def table_exists(self, model: type) -> bool:
|
||||
if not hasattr(model, "__tablename__"):
|
||||
return False
|
||||
inspector = sqlalchemy_inspect(engine)
|
||||
return inspector.has_table(model.__tablename__)
|
||||
|
||||
|
||||
db = DatabaseCompat()
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
|||
"websockets",
|
||||
"httpcore",
|
||||
"requests",
|
||||
"peewee",
|
||||
"sqlalchemy",
|
||||
"openai",
|
||||
"uvicorn",
|
||||
"jieba",
|
||||
|
|
@ -876,19 +876,19 @@ def initialize_logging(verbose: bool = True):
|
|||
"""手动初始化日志系统,确保所有logger都使用正确的配置
|
||||
|
||||
在应用程序的早期调用此函数,确保所有模块都使用统一的日志配置
|
||||
|
||||
|
||||
Args:
|
||||
verbose: 是否输出详细的初始化信息。默认为 True。
|
||||
在 Runner 进程中可以设置为 False 以避免重复的初始化日志。
|
||||
"""
|
||||
global LOG_CONFIG, _logging_initialized
|
||||
|
||||
|
||||
# 防止重复初始化(在同一进程内)
|
||||
if _logging_initialized:
|
||||
return
|
||||
|
||||
|
||||
_logging_initialized = True
|
||||
|
||||
|
||||
LOG_CONFIG = load_log_config()
|
||||
# print(LOG_CONFIG)
|
||||
configure_third_party_loggers()
|
||||
|
|
@ -941,16 +941,16 @@ def cleanup_old_logs():
|
|||
|
||||
def start_log_cleanup_task(verbose: bool = True):
|
||||
"""启动日志清理任务
|
||||
|
||||
|
||||
Args:
|
||||
verbose: 是否输出启动信息。默认为 True。
|
||||
"""
|
||||
global _cleanup_task_started
|
||||
|
||||
|
||||
# 防止重复启动清理任务
|
||||
if _cleanup_task_started:
|
||||
return
|
||||
|
||||
|
||||
_cleanup_task_started = True
|
||||
|
||||
def cleanup_task():
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
|
@ -11,11 +10,15 @@ from src.common.logger import get_logger
|
|||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _model_to_instance(model_instance: Model) -> DatabaseMessages:
|
||||
def _model_to_instance(model_instance: Any) -> DatabaseMessages:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
"""
|
||||
return DatabaseMessages(**model_instance.__data__)
|
||||
if isinstance(model_instance, dict):
|
||||
return DatabaseMessages(**model_instance)
|
||||
if hasattr(model_instance, "model_dump"):
|
||||
return DatabaseMessages(**model_instance.model_dump())
|
||||
return DatabaseMessages(**model_instance.__dict__)
|
||||
|
||||
|
||||
def find_messages(
|
||||
|
|
@ -92,14 +95,17 @@ def find_messages(
|
|||
if limit > 0:
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
query = query.order_by("time").limit(limit)
|
||||
peewee_results = list(query)
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
query = query.order_by("-time").limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||
peewee_results = sorted(
|
||||
latest_results_peewee,
|
||||
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
|
||||
)
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
|
|
@ -108,9 +114,9 @@ def find_messages(
|
|||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
peewee_sort_terms.append(field.asc())
|
||||
peewee_sort_terms.append(field_name)
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
peewee_sort_terms.append(f"-{field_name}")
|
||||
else:
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import random
|
|||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from peewee import fn
|
||||
from sqlalchemy import func as fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
|
|
|
|||
|
|
@ -1,228 +1,97 @@
|
|||
"""数据库API模块
|
||||
|
||||
提供数据库操作相关功能,采用标准Python包设计模式
|
||||
使用方式:
|
||||
from src.plugin_system.apis import database_api
|
||||
records = await database_api.db_query(ActionRecords, query_type="get")
|
||||
record = await database_api.db_save(ActionRecords, data={"action_id": "123"})
|
||||
提供数据库操作相关功能,统一使用 SQLModel/SQLAlchemy 兼容接口。
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from peewee import Model, DoesNotExist
|
||||
|
||||
logger = get_logger("database_api")
|
||||
|
||||
# =============================================================================
|
||||
# 通用数据库查询API函数
|
||||
# =============================================================================
|
||||
|
||||
def _to_dict(record: Any) -> dict[str, Any]:
|
||||
if record is None:
|
||||
return {}
|
||||
if isinstance(record, dict):
|
||||
return record
|
||||
if hasattr(record, "model_dump"):
|
||||
return record.model_dump()
|
||||
if hasattr(record, "__dict__"):
|
||||
return dict(record.__dict__)
|
||||
return {}
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Model],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
model_class,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
query_type: str = "get",
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
这个方法提供了一个通用接口来执行数据库操作,包括查询、创建、更新和删除记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee 模型类,例如 ActionRecords, Messages 等
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型,可选值: "get", "create", "update", "delete", "count"
|
||||
filters: 过滤条件字典,键为字段名,值为要匹配的值
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回不同的结果:
|
||||
- "get": 返回查询结果列表或单个结果(如果 single_result=True)
|
||||
- "create": 返回创建的记录
|
||||
- "update": 返回受影响的行数
|
||||
- "delete": 返回受影响的行数
|
||||
- "count": 返回记录数量
|
||||
"""
|
||||
"""
|
||||
示例:
|
||||
# 查询最近10条消息
|
||||
messages = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="get",
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by=["-time"]
|
||||
)
|
||||
|
||||
# 创建一条记录
|
||||
new_record = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_id": "123", "time": time.time(), "action_name": "TestAction"},
|
||||
query_type="create",
|
||||
)
|
||||
|
||||
# 更新记录
|
||||
updated_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
data={"action_done": True},
|
||||
query_type="update",
|
||||
filters={"action_id": "123"},
|
||||
)
|
||||
|
||||
# 删除记录
|
||||
deleted_count = await database_api.db_query(
|
||||
ActionRecords,
|
||||
query_type="delete",
|
||||
filters={"action_id": "123"}
|
||||
)
|
||||
|
||||
# 计数
|
||||
count = await database_api.db_query(
|
||||
Messages,
|
||||
query_type="count",
|
||||
filters={"chat_id": chat_stream.stream_id}
|
||||
)
|
||||
"""
|
||||
order_by: Optional[list[str]] = None,
|
||||
single_result: bool = False,
|
||||
):
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get' or 'create' or 'update' or 'delete' or 'count'")
|
||||
# 构建基本查询
|
||||
if query_type in ["get", "update", "delete", "count"]:
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if query_type == "get":
|
||||
query = model_class.select()
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 执行查询
|
||||
if query_type == "get":
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field in order_by:
|
||||
if field.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, field[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, field))
|
||||
|
||||
# 应用限制
|
||||
query = query.order_by(*order_by)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
elif query_type == "create":
|
||||
if query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建记录
|
||||
record = model_class.create(**data)
|
||||
# 返回创建的记录
|
||||
return model_class.select().where(model_class.id == record.id).dicts().get() # type: ignore
|
||||
return _to_dict(record)
|
||||
|
||||
elif query_type == "update":
|
||||
query = model_class.select()
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
if query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
return query.model_class.update(**data).where(*query.stmt._where_criteria).execute()
|
||||
|
||||
# 更新记录
|
||||
return query.update(**data).execute()
|
||||
if query_type == "delete":
|
||||
return model_class.delete().where(*query.stmt._where_criteria).execute()
|
||||
|
||||
elif query_type == "delete":
|
||||
# 删除记录
|
||||
return query.delete().execute()
|
||||
|
||||
elif query_type == "count":
|
||||
# 计数
|
||||
return query.count()
|
||||
|
||||
else:
|
||||
raise ValueError(f"不支持的查询类型: {query_type}")
|
||||
|
||||
except DoesNotExist:
|
||||
# 记录不存在
|
||||
return None if query_type == "get" and single_result else []
|
||||
return query.count()
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Model], data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: inline-immediately-returned-variable
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
如果提供了key_field和key_value,会先尝试查找匹配的记录进行更新;
|
||||
如果没有找到匹配记录,或未提供key_field和key_value,则创建新记录。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类,如ActionRecords, Messages等
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名,例如"action_id"
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 保存后的记录数据
|
||||
None: 如果操作失败
|
||||
|
||||
示例:
|
||||
# 创建或更新一条记录
|
||||
record = await database_api.db_save(
|
||||
ActionRecords,
|
||||
{
|
||||
"action_id": "123",
|
||||
"time": time.time(),
|
||||
"action_name": "TestAction",
|
||||
"action_done": True
|
||||
},
|
||||
key_field="action_id",
|
||||
key_value="123"
|
||||
)
|
||||
"""
|
||||
async def db_save(model_class, data: dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None):
|
||||
try:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if existing_records := list(
|
||||
model_class.select().where(getattr(model_class, key_field) == key_value).limit(1)
|
||||
):
|
||||
# 更新现有记录
|
||||
existing_record = existing_records[0]
|
||||
record = model_class.get_or_none(getattr(model_class, key_field) == key_value)
|
||||
if record is not None:
|
||||
for field, value in data.items():
|
||||
setattr(existing_record, field, value)
|
||||
existing_record.save()
|
||||
setattr(record, field, value)
|
||||
record.save()
|
||||
return _to_dict(record)
|
||||
|
||||
# 返回更新后的记录
|
||||
updated_record = model_class.select().where(model_class.id == existing_record.id).dicts().get() # type: ignore
|
||||
return updated_record
|
||||
|
||||
# 如果没有找到现有记录或未提供key_field和key_value,创建新记录
|
||||
new_record = model_class.create(**data)
|
||||
|
||||
# 返回创建的记录
|
||||
created_record = model_class.select().where(model_class.id == new_record.id).dicts().get() # type: ignore
|
||||
return created_record
|
||||
|
||||
return _to_dict(new_record)
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
@ -230,71 +99,25 @@ async def db_save(
|
|||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Model],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
model_class,
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
这是db_query方法的简化版本,专注于数据检索操作。
|
||||
|
||||
Args:
|
||||
model_class: Peewee模型类
|
||||
filters: 过滤条件,字段名和值的字典
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序,例如'-time'表示按时间字段(即time字段)降序
|
||||
single_result: 是否只返回单个结果,如果为True,则返回单个记录字典或None;否则返回记录字典列表或空列表
|
||||
|
||||
Returns:
|
||||
如果single_result为True,返回单个记录字典或None;
|
||||
否则返回记录字典列表或空列表。
|
||||
|
||||
示例:
|
||||
# 获取单个记录
|
||||
record = await database_api.db_get(
|
||||
ActionRecords,
|
||||
filters={"action_id": "123"},
|
||||
limit=1
|
||||
)
|
||||
|
||||
# 获取最近10条记录
|
||||
records = await database_api.db_get(
|
||||
Messages,
|
||||
filters={"chat_id": chat_stream.stream_id},
|
||||
limit=10,
|
||||
order_by="-time",
|
||||
)
|
||||
"""
|
||||
single_result: bool = False,
|
||||
):
|
||||
try:
|
||||
# 构建查询
|
||||
query = model_class.select()
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field, value in filters.items():
|
||||
query = query.where(getattr(model_class, field) == value)
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
if order_by.startswith("-"):
|
||||
query = query.order_by(getattr(model_class, order_by[1:]).desc())
|
||||
else:
|
||||
query = query.order_by(getattr(model_class, order_by))
|
||||
|
||||
# 应用限制
|
||||
query = query.order_by(order_by)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = list(query.dicts())
|
||||
|
||||
# 返回结果
|
||||
if single_result:
|
||||
return results[0] if results else None
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 获取数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
@ -310,41 +133,12 @@ async def store_action_info(
|
|||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
action_reasoning: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
将Action执行的相关信息保存到ActionRecords表中,用于后续的记忆和上下文构建。
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象,包含聊天相关信息
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
action_reasoning: 动作执行理由
|
||||
Returns:
|
||||
Dict[str, Any]: 保存的记录数据
|
||||
None: 如果保存失败
|
||||
|
||||
示例:
|
||||
record = await database_api.store_action_info(
|
||||
chat_stream=chat_stream,
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display="执行了回复动作",
|
||||
action_done=True,
|
||||
thinking_id="thinking_123",
|
||||
action_data={"content": "Hello"},
|
||||
action_name="reply_action"
|
||||
)
|
||||
"""
|
||||
):
|
||||
try:
|
||||
from src.common.database.database_model import ActionRecords
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)), # 使用thinking_id或生成唯一ID
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
|
|
@ -354,7 +148,6 @@ async def store_action_info(
|
|||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update(
|
||||
{
|
||||
|
|
@ -364,27 +157,16 @@ async def store_action_info(
|
|||
}
|
||||
)
|
||||
else:
|
||||
# 如果没有chat_stream,设置默认值
|
||||
record_data.update(
|
||||
{
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
}
|
||||
)
|
||||
record_data.update({"chat_id": "", "chat_info_stream_id": "", "chat_info_platform": ""})
|
||||
|
||||
# 使用已有的db_save函数保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[DatabaseAPI] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[DatabaseAPI] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from peewee import fn
|
||||
from sqlalchemy import func as fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import (
|
||||
|
|
@ -151,9 +151,7 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
|||
try:
|
||||
# 1. 年度在线时长
|
||||
online_records = list(
|
||||
OnlineTime.select().where(
|
||||
(OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)
|
||||
)
|
||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt))
|
||||
)
|
||||
total_seconds = 0
|
||||
for record in online_records:
|
||||
|
|
@ -235,10 +233,10 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
|||
async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||
"""获取社交网络数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
data = SocialNetworkData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于过滤
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
|
|
@ -254,9 +252,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
|||
fn.COUNT(Messages.id).alias("count"),
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.chat_info_group_id.is_null(False))
|
||||
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False))
|
||||
)
|
||||
.group_by(Messages.chat_info_group_id)
|
||||
.order_by(fn.COUNT(Messages.id).desc())
|
||||
|
|
@ -302,22 +298,14 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
|||
# 4. 被@次数
|
||||
data.at_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_at == True)
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 5. 被提及次数
|
||||
data.mentioned_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_mentioned == True)
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True))
|
||||
.count()
|
||||
)
|
||||
|
||||
|
|
@ -329,8 +317,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
|||
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
|
||||
)
|
||||
.where(
|
||||
(ChatStreams.user_id.is_null(False))
|
||||
& (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
|
||||
(ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
|
||||
)
|
||||
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
|
||||
.limit(1)
|
||||
|
|
@ -451,20 +438,21 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
|||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_reply_models = [
|
||||
{"model": row["model"], "count": row["count"]}
|
||||
for row in reply_model_query.dicts()
|
||||
]
|
||||
data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()]
|
||||
|
||||
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
|
||||
total_actions = ActionRecords.select().where(
|
||||
(ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)
|
||||
).count()
|
||||
no_reply_count = ActionRecords.select().where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "no_reply")
|
||||
).count()
|
||||
total_actions = (
|
||||
ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count()
|
||||
)
|
||||
no_reply_count = (
|
||||
ActionRecords.select()
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "no_reply")
|
||||
)
|
||||
.count()
|
||||
)
|
||||
data.total_actions = total_actions
|
||||
data.no_reply_count = no_reply_count
|
||||
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
|
||||
|
|
@ -473,11 +461,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
|||
interest_query = Messages.select(
|
||||
fn.AVG(Messages.interest_value).alias("avg_interest"),
|
||||
fn.MAX(Messages.interest_value).alias("max_interest"),
|
||||
).where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.interest_value.is_null(False))
|
||||
)
|
||||
).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False)))
|
||||
interest_result = interest_query.dicts().get()
|
||||
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
|
||||
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
|
||||
|
|
@ -494,19 +478,14 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
|||
.first()
|
||||
)
|
||||
if max_interest_msg:
|
||||
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# 7. 思考深度 (基于 action_reasoning 长度)
|
||||
reasoning_records = (
|
||||
ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_reasoning.is_null(False))
|
||||
& (ActionRecords.action_reasoning != "")
|
||||
)
|
||||
reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_reasoning.is_null(False))
|
||||
& (ActionRecords.action_reasoning != "")
|
||||
)
|
||||
reasoning_lengths = []
|
||||
max_len = 0
|
||||
|
|
@ -518,7 +497,7 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
|||
if length > max_len:
|
||||
max_len = length
|
||||
max_len_time = record.time
|
||||
|
||||
|
||||
if reasoning_lengths:
|
||||
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
|
||||
data.max_reasoning_length = max_len
|
||||
|
|
@ -537,10 +516,10 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
|||
async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||
"""获取个性与表达数据"""
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
data = ExpressionVibeData()
|
||||
start_ts, end_ts = get_year_time_range(year)
|
||||
|
||||
|
||||
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
|
||||
bot_qq = str(global_config.bot.qq_account or "")
|
||||
|
||||
|
|
@ -578,17 +557,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
Expression.style,
|
||||
fn.SUM(Expression.count).alias("total_count"),
|
||||
)
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
)
|
||||
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
|
||||
.group_by(Expression.style)
|
||||
.order_by(fn.SUM(Expression.count).desc())
|
||||
.limit(5)
|
||||
)
|
||||
data.top_expressions = [
|
||||
{"style": row["style"], "count": row["total_count"]}
|
||||
for row in expression_query.dicts()
|
||||
{"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts()
|
||||
]
|
||||
|
||||
# 3. 被拒绝的表达
|
||||
|
|
@ -616,18 +591,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
# 5. 表达总数
|
||||
data.total_expressions = (
|
||||
Expression.select()
|
||||
.where(
|
||||
(Expression.last_active_time >= start_ts)
|
||||
& (Expression.last_active_time <= end_ts)
|
||||
)
|
||||
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 6. 动作类型分布 (过滤无意义的动作)
|
||||
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
|
||||
excluded_actions = [
|
||||
"reply", "no_reply", "no_reply_until_call", "make_question",
|
||||
"no_action", "wait", "complete_talk", "listening", "block_and_ignore"
|
||||
"reply",
|
||||
"no_reply",
|
||||
"no_reply_until_call",
|
||||
"make_question",
|
||||
"no_action",
|
||||
"wait",
|
||||
"complete_talk",
|
||||
"listening",
|
||||
"block_and_ignore",
|
||||
]
|
||||
action_query = (
|
||||
ActionRecords.select(
|
||||
|
|
@ -643,38 +622,31 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
.order_by(fn.COUNT(ActionRecords.id).desc())
|
||||
.limit(10)
|
||||
)
|
||||
data.action_types = [
|
||||
{"action": row["action_name"], "count": row["count"]}
|
||||
for row in action_query.dicts()
|
||||
]
|
||||
data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()]
|
||||
|
||||
# 7. 处理的图片数量
|
||||
data.image_processed_count = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.is_picid == True)
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True))
|
||||
.count()
|
||||
)
|
||||
|
||||
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
|
||||
import random
|
||||
import re
|
||||
|
||||
|
||||
def clean_message_content(content: str) -> str:
|
||||
"""清理消息内容,移除回复引用等标记"""
|
||||
if not content:
|
||||
return ""
|
||||
# 移除 [回复<xxx:xxx> 的消息:...] 格式的引用
|
||||
content = re.sub(r'\[回复<[^>]+>\s*的消息[::][^\]]*\]', '', content)
|
||||
content = re.sub(r"\[回复<[^>]+>\s*的消息[::][^\]]*\]", "", content)
|
||||
# 移除 [图片] [表情] 等标记
|
||||
content = re.sub(r'\[(图片|表情|语音|视频|文件)\]', '', content)
|
||||
content = re.sub(r"\[(图片|表情|语音|视频|文件)\]", "", content)
|
||||
# 移除多余的空白
|
||||
content = re.sub(r'\s+', ' ', content).strip()
|
||||
content = re.sub(r"\s+", " ", content).strip()
|
||||
return content
|
||||
|
||||
|
||||
# 使用 user_id 判断是否是 bot 发送的消息
|
||||
late_night_messages = list(
|
||||
Messages.select(
|
||||
|
|
@ -683,9 +655,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
Messages.display_message,
|
||||
)
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.user_id == bot_qq) # bot 发送的消息
|
||||
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息
|
||||
)
|
||||
.order_by(Messages.time.desc())
|
||||
)
|
||||
|
|
@ -699,16 +669,18 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
cleaned_content = clean_message_content(raw_content)
|
||||
# 只保留有意义的内容
|
||||
if cleaned_content and len(cleaned_content) > 2:
|
||||
late_night_filtered.append({
|
||||
"time": msg.time,
|
||||
"hour": hour,
|
||||
"minute": msg_dt.minute,
|
||||
"content": cleaned_content,
|
||||
"datetime_str": msg_dt.strftime("%H:%M"),
|
||||
})
|
||||
late_night_filtered.append(
|
||||
{
|
||||
"time": msg.time,
|
||||
"hour": hour,
|
||||
"minute": msg_dt.minute,
|
||||
"content": cleaned_content,
|
||||
"datetime_str": msg_dt.strftime("%H:%M"),
|
||||
}
|
||||
)
|
||||
if len(late_night_filtered) >= 10:
|
||||
break
|
||||
|
||||
|
||||
if late_night_filtered:
|
||||
selected = random.choice(late_night_filtered)
|
||||
content = selected["content"][:50] + "..." if len(selected["content"]) > 50 else selected["content"]
|
||||
|
|
@ -720,18 +692,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
# 9. 最喜欢的回复(按 action_data 统计回复内容出现次数)
|
||||
from collections import Counter
|
||||
import json as json_lib
|
||||
|
||||
reply_records = (
|
||||
ActionRecords.select(ActionRecords.action_data)
|
||||
.where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "reply")
|
||||
& (ActionRecords.action_data.is_null(False))
|
||||
& (ActionRecords.action_data != "")
|
||||
)
|
||||
|
||||
reply_records = ActionRecords.select(ActionRecords.action_data).where(
|
||||
(ActionRecords.time >= start_ts)
|
||||
& (ActionRecords.time <= end_ts)
|
||||
& (ActionRecords.action_name == "reply")
|
||||
& (ActionRecords.action_data.is_null(False))
|
||||
& (ActionRecords.action_data != "")
|
||||
)
|
||||
|
||||
|
||||
reply_contents = []
|
||||
for record in reply_records:
|
||||
try:
|
||||
|
|
@ -748,11 +717,12 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
content = parsed
|
||||
except (json_lib.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
# 如果 JSON 解析失败,尝试解析 Python 字典字符串格式
|
||||
# 例如: "{'reply_text': '墨白灵不知道哦'}"
|
||||
if content is None:
|
||||
import ast
|
||||
|
||||
try:
|
||||
parsed = ast.literal_eval(action_data)
|
||||
if isinstance(parsed, dict):
|
||||
|
|
@ -762,13 +732,13 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
|||
except (ValueError, SyntaxError):
|
||||
# 无法解析,使用原始字符串
|
||||
content = action_data
|
||||
|
||||
|
||||
# 只统计有意义的回复(长度大于2)
|
||||
if content and len(content) > 2:
|
||||
reply_contents.append(content)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
|
||||
if reply_contents:
|
||||
content_counter = Counter(reply_contents)
|
||||
most_common = content_counter.most_common(1)
|
||||
|
|
@ -817,20 +787,12 @@ async def get_achievements(year: int = 2025) -> AchievementData:
|
|||
]
|
||||
|
||||
# 3. 总消息数
|
||||
data.total_messages = (
|
||||
Messages.select()
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
||||
.count()
|
||||
)
|
||||
data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count()
|
||||
|
||||
# 4. 总回复数 (有 reply_to 的消息)
|
||||
data.total_replies = (
|
||||
Messages.select()
|
||||
.where(
|
||||
(Messages.time >= start_ts)
|
||||
& (Messages.time <= end_ts)
|
||||
& (Messages.reply_to.is_null(False))
|
||||
)
|
||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False)))
|
||||
.count()
|
||||
)
|
||||
|
||||
|
|
@ -856,9 +818,9 @@ async def get_full_annual_report(year: int = 2025, _auth: bool = Depends(require
|
|||
"""
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
logger.info(f"开始生成 {year} 年度报告...")
|
||||
|
||||
|
||||
# 获取 bot 名称
|
||||
bot_name = global_config.bot.nickname or "麦麦"
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import uuid
|
|||
from typing import Dict, Any, Optional, List
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import case, func as fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Messages, PersonInfo
|
||||
|
|
@ -290,8 +291,6 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
|||
从 PersonInfo 表中获取所有已知的平台
|
||||
"""
|
||||
try:
|
||||
from peewee import fn
|
||||
|
||||
# 查询所有不同的平台
|
||||
platforms = (
|
||||
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
|
||||
|
|
@ -337,9 +336,7 @@ async def get_persons_by_platform(
|
|||
)
|
||||
|
||||
# 按最后交互时间排序,优先显示活跃用户
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
|
||||
query = query.limit(limit)
|
||||
|
||||
result = []
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from sqlalchemy import case
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Expression, ChatStreams
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
|
|
@ -231,10 +232,8 @@ async def get_expression_list(
|
|||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(
|
||||
Case(None, [(Expression.last_active_time.is_null(), 1)], 0), Expression.last_active_time.desc()
|
||||
case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
|
||||
)
|
||||
|
||||
# 获取总数
|
||||
|
|
@ -641,9 +640,7 @@ async def get_review_list(
|
|||
query = query.where(Expression.chat_id == chat_id)
|
||||
|
||||
# 排序:创建时间倒序
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(Expression.create_date.is_null(), 1)], 0), Expression.create_date.desc())
|
||||
query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
|
||||
|
||||
total = query.count()
|
||||
offset = (page - 1) * page_size
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import json
|
|||
from typing import Optional, List, Annotated
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from peewee import fn
|
||||
from sqlalchemy import func as fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import Jargon, ChatStreams
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Dict
|
||||
from sqlalchemy import case
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import PersonInfo
|
||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||
|
|
@ -176,9 +177,7 @@ async def get_person_list(
|
|||
|
||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||
from peewee import Case
|
||||
|
||||
query = query.order_by(Case(None, [(PersonInfo.last_know.is_null(), 1)], 0), PersonInfo.last_know.desc())
|
||||
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
|
||||
|
||||
# 获取总数
|
||||
total = query.count()
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from peewee import fn
|
||||
from sqlalchemy import func as fn
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
||||
|
|
|
|||
Loading…
Reference in New Issue