重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject

pull/1496/head
DrSmoothl 2026-02-13 20:39:11 +08:00
parent c14736ffca
commit 16b16d2ca6
No known key found for this signature in database
29 changed files with 2459 additions and 1737 deletions

2
bot.py
View File

@ -1,4 +1,4 @@
raise RuntimeError("System Not Ready")
# raise RuntimeError("System Not Ready")
import asyncio
import hashlib
import os

View File

@ -1,6 +1,6 @@
[project]
name = "MaiBot"
version = "0.11.6"
version = "1.0.0"
description = "MaiCore 是一个基于大语言模型的可交互智能体"
requires-python = ">=3.10"
dependencies = [
@ -35,6 +35,7 @@ dependencies = [
"tomlkit>=0.13.3",
"urllib3>=2.5.0",
"uvicorn>=0.35.0",
"msgpack>=1.1.2",
]

View File

@ -29,3 +29,4 @@ toml>=0.10.2
tomlkit>=0.13.3
urllib3>=2.5.0
uvicorn>=0.35.0
msgpack>=1.1.2

View File

@ -8,11 +8,15 @@
4. 未通过评估的rejected=1, checked=1
"""
from typing import List
import asyncio
import json
import random
from typing import List
from sqlmodel import select
from src.bw_learner.expression_review_store import get_review_state, set_review_state
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.common.logger import get_logger
from src.config.config import global_config
@ -26,11 +30,11 @@ logger = get_logger("expression_auto_check_task")
def create_evaluation_prompt(situation: str, style: str) -> str:
"""
创建评估提示词
Args:
situation: 情境
style: 风格
Returns:
评估提示词
"""
@ -39,20 +43,20 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
"允许部分语法错误或口头化或缺省出现",
"表达方式不能太过特指,需要具有泛用性",
"一般不涉及具体的人名或名称"
"一般不涉及具体的人名或名称",
]
# 从配置中获取额外的自定义标准
custom_criteria = global_config.expression.expression_auto_check_custom_criteria
# 合并所有评估标准
all_criteria = base_criteria.copy()
if custom_criteria:
all_criteria.extend(custom_criteria)
# 构建评估标准列表字符串
criteria_list = "\n".join([f"{i+1}. {criterion}" for i, criterion in enumerate(all_criteria)])
criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)])
prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适:
使用条件或使用情景{situation}
表达方式或言语风格{style}
@ -68,54 +72,52 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
}}
如果合适suitable设为true如果不合适suitable设为false并在reason中说明原因
请严格按照JSON格式输出不要包含其他内容"""
return prompt
judge_llm = LLMRequest(
model_set=model_config.model_task_config.tool_use,
request_type="expression_check"
)
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str]:
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
"""
执行单次LLM评估
Args:
situation: 情境
style: 风格
Returns:
(suitable, reason, error) 元组如果出错则 suitable Falseerror 包含错误信息
"""
try:
prompt = create_evaluation_prompt(situation, style)
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
prompt=prompt,
temperature=0.6,
max_tokens=1024
prompt=prompt, temperature=0.6, max_tokens=1024
)
logger.debug(f"LLM响应: {response}")
# 解析JSON响应
try:
evaluation = json.loads(response)
except json.JSONDecodeError as e:
import re
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
if json_match:
evaluation = json.loads(json_match.group())
else:
raise ValueError("无法从响应中提取JSON格式的评估结果") from e
suitable = evaluation.get("suitable", False)
reason = evaluation.get("reason", "未提供理由")
logger.debug(f"评估结果: {'通过' if suitable else '不通过'}")
return suitable, reason, None
except Exception as e:
logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}")
return False, f"评估过程出错: {str(e)}", str(e)
@ -130,36 +132,37 @@ class ExpressionAutoCheckTask(AsyncTask):
super().__init__(
task_name="Expression Auto Check Task",
wait_before_start=60, # 启动后等待60秒再开始第一次检查
run_interval=check_interval
run_interval=check_interval,
)
async def _select_expressions(self, count: int) -> List[Expression]:
"""
随机选择指定数量的未检查表达方式
Args:
count: 需要选择的数量
Returns:
选中的表达方式列表
"""
try:
# 查询所有未检查的表达方式checked=False
unevaluated_expressions = list(
Expression.select().where(~Expression.checked)
)
with get_db_session() as session:
statement = select(Expression)
all_expressions = session.exec(statement).all()
unevaluated_expressions = [expr for expr in all_expressions if not get_review_state(expr.id)["checked"]]
if not unevaluated_expressions:
logger.info("没有未检查的表达方式")
return []
# 随机选择指定数量
selected_count = min(count, len(unevaluated_expressions))
selected = random.sample(unevaluated_expressions, selected_count)
logger.info(f"{len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count}")
return selected
except Exception as e:
logger.error(f"选择表达方式时出错: {e}")
return []
@ -167,26 +170,23 @@ class ExpressionAutoCheckTask(AsyncTask):
async def _evaluate_expression(self, expression: Expression) -> bool:
"""
评估单个表达方式
Args:
expression: 要评估的表达方式
Returns:
True表示通过False表示不通过
"""
suitable, reason, error = await single_expression_check(
expression.situation,
expression.style,
)
# 更新数据库
try:
expression.checked = True
expression.rejected = not suitable # 通过则rejected=0不通过则rejected=1
expression.modified_by = 'ai' # 标记为AI检查
expression.save()
set_review_state(expression.id, True, not suitable, "ai")
status = "通过" if suitable else "不通过"
logger.info(
f"表达方式评估完成 [ID: {expression.id}] - {status} | "
@ -194,12 +194,12 @@ class ExpressionAutoCheckTask(AsyncTask):
f"Style: {expression.style}... | "
f"Reason: {reason[:50]}..."
)
if error:
logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}")
return suitable
except Exception as e:
logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}")
return False
@ -211,42 +211,39 @@ class ExpressionAutoCheckTask(AsyncTask):
if not global_config.expression.expression_self_reflect:
logger.debug("表达方式自动检查未启用,跳过本次执行")
return
check_count = global_config.expression.expression_auto_check_count
if check_count <= 0:
logger.warning(f"检查数量配置无效: {check_count},跳过本次执行")
return
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count}")
# 选择要检查的表达方式
expressions = await self._select_expressions(check_count)
if not expressions:
logger.info("没有需要检查的表达方式")
return
# 逐个评估
passed_count = 0
failed_count = 0
for i, expression in enumerate(expressions, 1):
logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}")
if await self._evaluate_expression(expression):
passed_count += 1
else:
failed_count += 1
# 避免请求过快
await asyncio.sleep(0.3)
logger.info(
f"表达方式自动检查完成: 总计 {len(expressions)} 条,"
f"通过 {passed_count} 条,不通过 {failed_count}"
f"表达方式自动检查完成: 总计 {len(expressions)} 条,通过 {passed_count} 条,不通过 {failed_count}"
)
except Exception as e:
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)

View File

@ -0,0 +1,35 @@
from typing import Any, Dict, Optional
from src.manager.local_store_manager import local_storage
def _review_key(expression_id: int) -> str:
return f"expression_review:{expression_id}"
def get_review_state(expression_id: Optional[int]) -> Dict[str, Any]:
if expression_id is None:
return {"checked": False, "rejected": False, "modified_by": None}
value = local_storage[_review_key(expression_id)]
if isinstance(value, dict):
return {
"checked": bool(value.get("checked", False)),
"rejected": bool(value.get("rejected", False)),
"modified_by": value.get("modified_by"),
}
return {"checked": False, "rejected": False, "modified_by": None}
def set_review_state(
expression_id: Optional[int],
checked: bool,
rejected: bool,
modified_by: Optional[str],
) -> None:
if expression_id is None:
return
local_storage[_review_key(expression_id)] = {
"checked": checked,
"rejected": rejected,
"modified_by": modified_by,
}

View File

@ -3,11 +3,11 @@ MaiBot模块系统
包含聊天情绪记忆日程等功能模块
"""
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.emoji_system.emoji_manager import get_emoji_manager
# 导出主要组件供外部使用
__all__ = [
"get_chat_manager",
"get_emoji_manager",
"emoji_manager",
]

View File

@ -1,8 +1,11 @@
import time
import asyncio
import traceback
from datetime import datetime
from typing import Optional, Dict, Any, List
from src.common.logger import get_logger
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
from maim_message import UserInfo
from src.config.config import global_config
@ -16,17 +19,18 @@ logger = get_logger("chat_observer")
def _message_to_dict(message: Messages) -> Dict[str, Any]:
"""Convert Peewee Message model to dict for PFC compatibility
Args:
message: Peewee Messages model instance
Returns:
Dict[str, Any]: Message dictionary
"""
message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp
return {
"message_id": message.message_id,
"time": message.time,
"chat_id": message.chat_id,
"time": message_timestamp,
"chat_id": message.session_id,
"user_id": message.user_id,
"user_nickname": message.user_nickname,
"processed_plain_text": message.processed_plain_text,
@ -37,7 +41,7 @@ def _message_to_dict(message: Messages) -> Dict[str, Any]:
"user_info": {
"user_id": message.user_id,
"user_nickname": message.user_nickname,
}
},
}
@ -109,10 +113,13 @@ class ChatObserver:
"""
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
new_message_exists = Messages.select().where(
(Messages.chat_id == self.stream_id) &
(Messages.time > self.last_check_time)
).exists()
last_check_time = self.last_check_time or 0.0
last_check_dt = datetime.fromtimestamp(last_check_time)
with get_db_session() as session:
statement = select(Messages).where(
(col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt)
)
new_message_exists = session.exec(statement).first() is not None
if new_message_exists:
logger.debug(f"[私聊][{self.private_name}]发现新消息")
@ -183,20 +190,21 @@ class ChatObserver:
)
return has_new
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
"""获取新消息
Returns:
List[Dict[str, Any]]: 新消息列表
"""
query = Messages.select().where(
(Messages.chat_id == self.stream_id) &
(Messages.time > self.last_message_time)
).order_by(Messages.time.asc())
new_messages = [_message_to_dict(msg) for msg in query]
last_message_time = self.last_message_time or 0.0
last_message_dt = datetime.fromtimestamp(last_message_time)
with get_db_session() as session:
statement = (
select(Messages)
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt))
.order_by(col(Messages.timestamp))
)
new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()]
if new_messages:
self.last_message_read = new_messages[-1]
@ -215,13 +223,16 @@ class ChatObserver:
Returns:
List[Dict[str, Any]]: 最多5条消息
"""
query = Messages.select().where(
(Messages.chat_id == self.stream_id) &
(Messages.time < time_point)
).order_by(Messages.time.desc()).limit(5)
messages = list(query)
messages.reverse() # 需要按时间正序排列
time_point_dt = datetime.fromtimestamp(time_point)
with get_db_session() as session:
statement = (
select(Messages)
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt))
.order_by(col(Messages.timestamp))
.limit(5)
)
messages = list(session.exec(statement).all())
messages.reverse()
new_messages = [_message_to_dict(msg) for msg in messages]
if new_messages:

View File

@ -10,7 +10,9 @@ from src.chat.utils.utils import is_mentioned_bot_in_message
from src.chat.utils.chat_message_builder import replace_user_references
from src.common.logger import get_logger
from src.person_info.person_info import Person
from src.common.database.database_model import Images
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
if TYPE_CHECKING:
pass
@ -47,6 +49,12 @@ class HeartFCMessageReceiver:
# 1. 消息解析与初始化
userinfo = message.message_info.user_info
chat = message.chat_stream
if userinfo is None or message.message_info.platform is None:
raise ValueError("message userinfo or platform is missing")
if userinfo.user_id is None or userinfo.user_nickname is None:
raise ValueError("message userinfo id or nickname is missing")
user_id = userinfo.user_id
nickname = userinfo.user_nickname
# 2. 计算at信息
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
@ -70,7 +78,15 @@ class HeartFCMessageReceiver:
processed_text = message.processed_plain_text
if picid_list:
for picid in picid_list:
image = Images.get_or_none(Images.image_id == picid)
with get_db_session() as session:
statement = (
select(Images).where(
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
)
if picid.isdigit()
else None
)
image = session.exec(statement).first() if statement is not None else None
if image and image.description:
# 将[picid:xxxx]替换成图片描述
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
@ -80,26 +96,24 @@ class HeartFCMessageReceiver:
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
processed_plain_text = replace_user_references(
processed_text,
message.message_info.platform, # type: ignore
replace_bot_name=True,
processed_text, message.message_info.platform, replace_bot_name=True
)
# if not processed_plain_text:
# print(message)
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
# 如果是群聊,获取群号和群昵称
group_id = None
group_nick_name = None
if chat.group_info:
group_id = chat.group_info.group_id # type: ignore
group_nick_name = userinfo.user_cardname # type: ignore
group_id = chat.group_info.group_id
group_nick_name = userinfo.user_cardname
_ = Person.register_person(
platform=message.message_info.platform, # type: ignore
user_id=message.message_info.user_info.user_id, # type: ignore
nickname=userinfo.user_nickname, # type: ignore
platform=message.message_info.platform,
user_id=user_id,
nickname=nickname,
group_id=group_id,
group_nick_name=group_nick_name,
)

View File

@ -1,10 +1,10 @@
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage
__all__ = [
"get_emoji_manager",
"get_chat_manager",
"MessageStorage",
"emoji_manager",
]

View File

@ -2,13 +2,15 @@ import asyncio
import hashlib
import time
import copy
from datetime import datetime
from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
from src.common.database.database import get_db_session
from src.common.database.database_model import ChatSession
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
@ -76,7 +78,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
self.context: Optional[ChatMessageContext] = None
def to_dict(self) -> dict:
"""转换为字典格式"""
@ -95,10 +97,13 @@ class ChatStream:
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
if user_info is None:
raise ValueError("user_info is required to build ChatStream")
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info, # type: ignore
user_info=user_info,
group_info=group_info,
data=data,
)
@ -128,12 +133,7 @@ class ChatManager:
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
try:
db.connect(reuse_if_open=True)
# 确保 ChatStreams 表存在
db.create_tables([ChatStreams], safe=True)
except Exception as e:
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
get_db_session()
self._initialized = True
# 在事件循环中启动初始化
@ -161,8 +161,13 @@ class ChatManager:
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
platform = message.message_info.platform or ""
if not platform:
raise ValueError("platform is required for ChatStream")
if message.message_info.user_info is None and message.message_info.group_info is None:
raise ValueError("user_info or group_info is required for ChatStream")
stream_id = self._generate_stream_id(
message.message_info.platform, # type: ignore
platform,
message.message_info.user_info,
message.message_info.group_info,
)
@ -176,12 +181,18 @@ class ChatManager:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
if group_info is None and user_info is None:
raise ValueError("用户信息或群组信息必须提供")
if group_info:
# 组合关键信息
components = [platform, str(group_info.group_id)]
else:
components = [platform, str(user_info.user_id), "private"] # type: ignore
if user_info is None:
raise ValueError("用户信息或群组信息必须提供")
if user_info.user_id is None:
raise ValueError("user_id is required for private stream")
components = [platform, str(user_info.user_id), "private"]
# 使用MD5生成唯一ID
key = "_".join(components)
@ -231,33 +242,35 @@ class ChatManager:
# 检查数据库中是否存在
def _db_find_stream_sync(s_id: str):
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
with get_db_session() as session:
statement = select(ChatSession).where(col(ChatSession.session_id) == s_id)
return session.exec(statement).first()
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
if model_instance:
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
user_info_data = {
"platform": model_instance.user_platform,
"platform": model_instance.platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
"user_nickname": "",
"user_cardname": "",
}
group_info_data = None
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"platform": model_instance.platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
"group_name": "",
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"stream_id": model_instance.session_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
"create_time": model_instance.created_timestamp.timestamp(),
"last_active_time": model_instance.last_active_timestamp.timestamp(),
}
stream = ChatStream.from_dict(data_for_from_dict)
# 更新用户信息和群组信息
@ -329,20 +342,26 @@ class ChatManager:
user_info_d = s_data_dict.get("user_info")
group_info_d = s_data_dict.get("group_info")
fields_to_save = {
"platform": s_data_dict["platform"],
"create_time": s_data_dict["create_time"],
"last_active_time": s_data_dict["last_active_time"],
"user_platform": user_info_d["platform"] if user_info_d else "",
"user_id": user_info_d["user_id"] if user_info_d else "",
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
"group_platform": group_info_d["platform"] if group_info_d else "",
"group_id": group_info_d["group_id"] if group_info_d else "",
"group_name": group_info_d["group_name"] if group_info_d else "",
}
with get_db_session() as session:
statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"])
record = session.exec(statement).first()
if record is None:
record = ChatSession(
session_id=s_data_dict["stream_id"],
platform=s_data_dict["platform"],
user_id=user_info_d["user_id"] if user_info_d else None,
group_id=group_info_d["group_id"] if group_info_d else None,
created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]),
last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]),
)
session.add(record)
return
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
record.platform = s_data_dict["platform"]
record.user_id = user_info_d["user_id"] if user_info_d else None
record.group_id = group_info_d["group_id"] if group_info_d else None
record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"])
record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"])
try:
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
@ -361,30 +380,32 @@ class ChatManager:
def _db_load_all_streams_sync():
loaded_streams_data = []
for model_instance in ChatStreams.select():
user_info_data = {
"platform": model_instance.user_platform,
"user_id": model_instance.user_id,
"user_nickname": model_instance.user_nickname,
"user_cardname": model_instance.user_cardname or "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.group_platform,
"group_id": model_instance.group_id,
"group_name": model_instance.group_name,
with get_db_session() as session:
statement = select(ChatSession)
for model_instance in session.exec(statement).all():
user_info_data = {
"platform": model_instance.platform,
"user_id": model_instance.user_id or "",
"user_nickname": "",
"user_cardname": "",
}
group_info_data = None
if model_instance.group_id:
group_info_data = {
"platform": model_instance.platform,
"group_id": model_instance.group_id,
"group_name": "",
}
data_for_from_dict = {
"stream_id": model_instance.stream_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.create_time,
"last_active_time": model_instance.last_active_time,
}
loaded_streams_data.append(data_for_from_dict)
data_for_from_dict = {
"stream_id": model_instance.session_id,
"platform": model_instance.platform,
"user_info": user_info_data,
"group_info": group_info_data,
"create_time": model_instance.created_timestamp.timestamp(),
"last_active_time": model_instance.last_active_timestamp.timestamp(),
}
loaded_streams_data.append(data_for_from_dict)
return loaded_streams_data
try:

View File

@ -1,36 +1,74 @@
import re
import json
import traceback
from typing import Union
from datetime import datetime
from collections.abc import Mapping
from typing import cast
from src.common.database.database_model import Messages, Images
import json
import re
import traceback
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType, Messages
from src.common.logger import get_logger
from src.common.data_models.message_component_model import MessageSequence, TextComponent
from src.common.utils.utils_message import MessageUtils
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from .message import MessageRecv, MessageSending
logger = get_logger("message_storage")
class MessageStorage:
@staticmethod
def _serialize_keywords(keywords) -> str:
def _coerce_str_list(value: object) -> list[str]:
if isinstance(value, list):
return [str(item) for item in value]
if isinstance(value, tuple):
return [str(item) for item in value]
if isinstance(value, set):
return [str(item) for item in value]
if isinstance(value, str):
return [value]
return []
@staticmethod
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
value = mapping.get(key)
if value is None:
return default
return str(value)
@staticmethod
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
value = mapping.get(key)
if value is None:
return None
return str(value)
@staticmethod
def _serialize_keywords(keywords: list[str] | None) -> str:
"""将关键词列表序列化为JSON字符串"""
if isinstance(keywords, list):
return json.dumps(keywords, ensure_ascii=False)
return "[]"
@staticmethod
def _deserialize_keywords(keywords_str: str) -> list:
def _deserialize_keywords(keywords_str: str) -> list[str]:
"""将JSON字符串反序列化为关键词列表"""
if not keywords_str:
return []
try:
return json.loads(keywords_str)
parsed = cast(object, json.loads(keywords_str))
except (json.JSONDecodeError, TypeError):
return []
if isinstance(parsed, list):
return [str(item) for item in parsed]
if isinstance(parsed, str):
return [parsed]
return []
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 通知消息不存储
@ -66,7 +104,7 @@ class MessageStorage:
priority_mode = ""
priority_info = {}
is_emoji = False
is_picid = False
is_picture = False
is_notify = False
is_command = False
key_words = ""
@ -83,66 +121,73 @@ class MessageStorage:
priority_mode = message.priority_mode
priority_info = message.priority_info
is_emoji = message.is_emoji
is_picid = message.is_picid
is_picture = message.is_picid
is_notify = message.is_notify
is_command = message.is_command
intercept_message_level = getattr(message, "intercept_message_level", 0)
# 序列化关键词列表为JSON字符串
key_words = MessageStorage._serialize_keywords(message.key_words)
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
key_words_lite = MessageStorage._serialize_keywords(
MessageStorage._coerce_str_list(message.key_words_lite)
)
selected_expressions = ""
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
if message.message_info.user_info is None:
raise ValueError("message.user_info is required")
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
msg_id = message.message_info.message_id or ""
# 安全地获取 group_info, 如果为 None 则视为空字典
group_info_from_chat = chat_info_dict.get("group_info") or {}
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
user_info_from_chat = chat_info_dict.get("user_info") or {}
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
Messages.create(
message_id=msg_id,
time=float(message.message_info.time), # type: ignore
chat_id=chat_stream.stream_id,
# Flattened chat_info
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
additional_config.update(
{
"interest_value": interest_value,
"priority_mode": priority_mode,
"priority_info": priority_info,
"reply_probability_boost": reply_probability_boost,
"intercept_message_level": intercept_message_level,
"key_words": key_words,
"key_words_lite": key_words_lite,
"selected_expressions": selected_expressions,
"is_picid": is_picture,
}
)
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
timestamp_value = message.message_info.time
if timestamp_value is None:
raise ValueError("message.message_info.time is required")
db_message = Messages(
message_id=str(msg_id),
timestamp=datetime.fromtimestamp(float(timestamp_value)),
platform=MessageStorage._get_str(chat_info_dict, "platform"),
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
is_mentioned=bool(is_mentioned),
is_at=bool(is_at),
session_id=chat_stream.stream_id,
reply_to=reply_to,
is_mentioned=is_mentioned,
is_at=is_at,
reply_probability_boost=reply_probability_boost,
chat_info_stream_id=chat_info_dict.get("stream_id"),
chat_info_platform=chat_info_dict.get("platform"),
chat_info_user_platform=user_info_from_chat.get("platform"),
chat_info_user_id=user_info_from_chat.get("user_id"),
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
chat_info_group_platform=group_info_from_chat.get("platform"),
chat_info_group_id=group_info_from_chat.get("group_id"),
chat_info_group_name=group_info_from_chat.get("group_name"),
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
# Flattened user_info (message sender)
user_platform=user_info_dict.get("platform"),
user_id=user_info_dict.get("user_id"),
user_nickname=user_info_dict.get("user_nickname"),
user_cardname=user_info_dict.get("user_cardname"),
# Text content
is_emoji=is_emoji,
is_picture=is_picture,
is_command=is_command,
is_notify=is_notify,
raw_content=raw_content,
processed_plain_text=filtered_processed_plain_text,
display_message=filtered_display_message,
interest_value=interest_value,
priority_mode=priority_mode,
priority_info=priority_info,
is_emoji=is_emoji,
is_picid=is_picid,
is_notify=is_notify,
is_command=is_command,
intercept_message_level=intercept_message_level,
key_words=key_words,
key_words_lite=key_words_lite,
selected_expressions=selected_expressions,
additional_config=json.dumps(additional_config, ensure_ascii=False),
)
with get_db_session() as session:
session.add(db_message)
except Exception:
logger.exception("存储消息失败")
logger.error(f"消息:{message}")
@ -156,16 +201,21 @@ class MessageStorage:
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return False
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
return True
else:
logger.debug("未找到匹配的消息")
return False
with get_db_session() as session:
statement = (
select(Messages)
.where(col(Messages.message_id) == mmc_message_id)
.order_by(col(Messages.timestamp).desc())
.limit(1)
)
matched_message = session.exec(statement).first()
if matched_message:
matched_message.message_id = qq_message_id
session.add(matched_message)
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
return True
logger.debug("未找到匹配的消息")
return False
except Exception as e:
logger.error(f"更新消息ID失败: {e}")
@ -182,13 +232,18 @@ class MessageStorage:
logger.debug("文本中没有图片标记,直接返回原文本")
return text
def replace_match(match):
def replace_match(match: re.Match[str]) -> str:
description = match.group(1).strip()
try:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
with get_db_session() as session:
statement = (
select(Images)
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
.order_by(col(Images.record_time).desc())
.limit(1)
)
image_record = session.exec(statement).first()
return f"[picid:{image_record.id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@ -1,17 +1,17 @@
import time
import random
import re
from datetime import datetime
from typing import List, Dict, Any, Tuple, Optional, Callable
from rich.traceback import install
from sqlmodel import select, col
from src.config.config import global_config
from src.common.logger import get_logger
from src.common.message_repository import find_messages, count_messages
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
from src.common.data_models.message_data_model import MessageAndActionModel
from src.common.database.database_model import ActionRecords
from src.common.database.database_model import Images
from src.common.database.database import get_db_session
from src.common.database.database_model import ActionRecord, Images
from src.person_info.person_info import Person, get_person_id
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat(
limit_mode: str = "latest",
) -> List[DatabaseActionRecords]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time > timestamp_start) # type: ignore
& (ActionRecords.time < timestamp_end) # type: ignore
)
with get_db_session() as session:
statement = (
select(ActionRecord)
.where((col(ActionRecord.session_id) == chat_id))
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end))
)
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
actions.reverse()
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
else:
query = query.order_by(ActionRecords.time.asc())
actions = list(query)
if limit > 0:
if limit_mode == "latest":
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
actions = list(session.exec(statement).all())
actions = list(reversed(actions))
else:
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
actions = list(session.exec(statement).all())
else:
statement = statement.order_by(col(ActionRecord.timestamp))
actions = session.exec(statement).all()
return [
DatabaseActionRecords(
action_id=action.action_id,
time=action.time,
time=action.timestamp.timestamp(),
action_name=action.action_name,
action_data=action.action_data,
action_done=action.action_done,
action_build_into_prompt=action.action_build_into_prompt,
action_prompt_display=action.action_prompt_display,
chat_id=action.chat_id,
chat_info_stream_id=action.chat_info_stream_id,
chat_info_platform=action.chat_info_platform,
action_reasoning=action.action_reasoning,
action_data=action.action_data or "{}",
action_done=True,
action_build_into_prompt=bool(action.action_display_prompt),
action_prompt_display=action.action_display_prompt or "",
chat_id=action.session_id,
chat_info_stream_id=action.session_id,
chat_info_platform=global_config.bot.platform,
action_reasoning=action.action_reasoning or "",
)
for action in actions
]
@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive(
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
) -> List[Dict[str, Any]]:
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
query = ActionRecords.select().where(
(ActionRecords.chat_id == chat_id)
& (ActionRecords.time >= timestamp_start) # type: ignore
& (ActionRecords.time <= timestamp_end) # type: ignore
)
with get_db_session() as session:
statement = (
select(ActionRecord)
.where((col(ActionRecord.session_id) == chat_id))
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
)
if limit > 0:
if limit_mode == "latest":
query = query.order_by(ActionRecords.time.desc()).limit(limit)
# 获取后需要反转列表,以保持最终输出为时间升序
actions = list(query)
return [action.__data__ for action in reversed(actions)]
else: # earliest
query = query.order_by(ActionRecords.time.asc()).limit(limit)
else:
query = query.order_by(ActionRecords.time.asc())
if limit > 0:
if limit_mode == "latest":
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
actions = list(session.exec(statement).all())
actions = list(reversed(actions))
else:
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
actions = list(session.exec(statement).all())
else:
statement = statement.order_by(col(ActionRecord.timestamp))
actions = session.exec(statement).all()
actions = list(query)
return [action.__data__ for action in actions]
return [action.model_dump() for action in actions]
def get_raw_msg_by_timestamp_random(
@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random(
def get_raw_msg_by_timestamp_with_users(
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
) -> List[DatabaseMessages]:
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat(
def get_raw_msg_before_timestamp_with_users(
timestamp: float, person_ids: list, limit: int = 0
timestamp: float, person_ids: List[str], limit: int = 0
) -> List[DatabaseMessages]:
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
limit: 限制返回的消息数量0为不限制
@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
def num_new_messages_since_with_users(
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
) -> int:
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
if not person_ids: # 保持空列表检查
@ -358,7 +361,7 @@ def num_new_messages_since_with_users(
def _build_readable_messages_internal(
messages: List[MessageAndActionModel],
messages: List[DatabaseMessages],
replace_bot_name: bool = True,
timestamp_mode: str = "relative",
truncate: bool = False,
@ -413,7 +416,7 @@ def _build_readable_messages_internal(
# 匹配 [picid:xxxxx] 格式
pic_pattern = r"\[picid:([^\]]+)\]"
def replace_pic_id(match: re.Match) -> str:
def replace_pic_id(match: re.Match[str]) -> str:
nonlocal current_pic_counter
nonlocal pic_counter
pic_id = match.group(1)
@ -421,7 +424,8 @@ def _build_readable_messages_internal(
if pic_id not in pic_description_cache:
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
@ -438,16 +442,11 @@ def _build_readable_messages_internal(
# 1: 获取发送者信息并提取消息组件
for message in messages:
if message.is_action_record:
# 对于动作记录也处理图片ID
content = process_pic_ids(message.display_message)
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
continue
platform = message.user_platform
user_id = message.user_id
user_nickname = message.user_nickname
user_cardname = message.user_cardname
user_info = message.user_info
platform = user_info.platform
user_id = user_info.user_id
user_nickname = user_info.user_nickname
user_cardname = user_info.user_cardname
timestamp = message.time
content = message.display_message or message.processed_plain_text or ""
@ -525,12 +524,12 @@ def _build_readable_messages_internal(
if long_time_notice and prev_timestamp is not None:
time_diff = timestamp - prev_timestamp
time_diff_hours = time_diff / 3600
# 检查是否跨天
prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp))
current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp))
is_cross_day = prev_date != current_date
# 如果间隔大于8小时或跨天插入提示
if time_diff_hours > 8 or is_cross_day:
# 格式化日期为中文格式xxxx年xx月xx日去掉前导零
@ -542,20 +541,15 @@ def _build_readable_messages_internal(
hours_str = f"{int(time_diff_hours)}h"
notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n"
output_lines.append(notice)
readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode)
# 查找消息id如果有并构建id_prefix
message_id = timestamp_to_id_mapping.get(timestamp, "")
id_prefix = f"[{message_id}]" if message_id else ""
if is_action:
# 对于动作记录,使用特殊格式
output_lines.append(f"{id_prefix}{readable_time}, {content}")
else:
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
output_lines.append("\n")
prev_timestamp = timestamp
formatted_string = "".join(output_lines).strip()
@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
# 从数据库中获取图片描述
description = "内容正在阅读,请稍等"
try:
image = Images.get_or_none(Images.image_id == pic_id)
with get_db_session() as session:
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
if image and image.description:
description = image.description
except Exception:
@ -663,7 +658,7 @@ async def build_readable_messages_with_list(
允许通过参数控制格式化行为
"""
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
messages,
replace_bot_name,
timestamp_mode,
truncate,
@ -754,7 +749,7 @@ def build_readable_messages(
filtered_messages = []
for msg in messages:
# 获取消息内容
content = msg.processed_plain_text
content = msg.processed_plain_text or ""
# 移除表情包
emoji_pattern = r"\[表情包:[^\]]+\]"
content = re.sub(emoji_pattern, "", content)
@ -765,17 +760,14 @@ def build_readable_messages(
messages = filtered_messages
copy_messages: List[MessageAndActionModel] = []
copy_messages: List[DatabaseMessages] = []
for msg in messages:
if remove_emoji_stickers:
# 创建 MessageAndActionModel 但移除表情包
model = MessageAndActionModel.from_DatabaseMessages(msg)
# 移除表情包
if model.processed_plain_text:
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
copy_messages.append(model)
msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "")
copy_messages.append(msg)
else:
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
copy_messages.append(msg)
if show_actions and copy_messages:
# 获取所有消息的时间范围
@ -786,40 +778,45 @@ def build_readable_messages(
chat_id = messages[0].chat_id if messages else None
# 获取这个时间范围内的动作记录并匹配chat_id
actions_in_range = (
ActionRecords.select()
.where(
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
)
.order_by(ActionRecords.time)
)
with get_db_session() as session:
actions_in_range = session.exec(
select(ActionRecord)
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
.where(col(ActionRecord.session_id) == chat_id)
.order_by(col(ActionRecord.timestamp))
).all()
# 获取最新消息之后的第一个动作记录
action_after_latest = (
ActionRecords.select()
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
.order_by(ActionRecords.time)
.limit(1)
)
with get_db_session() as session:
action_after_latest = session.exec(
select(ActionRecord)
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
.where(col(ActionRecord.session_id) == chat_id)
.order_by(col(ActionRecord.timestamp))
.limit(1)
).all()
# 合并两部分动作记录
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest)
# 将动作记录转换为消息格式
for action in actions:
# 只有当build_into_prompt为True时才添加动作记录
if action.action_build_into_prompt:
action_msg = MessageAndActionModel(
time=float(action.time), # type: ignore
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
user_platform=global_config.bot.platform, # 使用机器人的平台
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
user_cardname="", # 机器人没有群名片
processed_plain_text=f"{action.action_prompt_display}",
display_message=f"{action.action_prompt_display}",
chat_info_platform=str(action.chat_info_platform),
is_action_record=True, # 添加标识字段
action_name=str(action.action_name), # 保存动作名称
action_display_prompt = action.action_display_prompt or ""
if action_display_prompt:
action_msg = DatabaseMessages(
message_id=f"action_{action.action_id}",
time=float(action.timestamp.timestamp()),
chat_id=chat_id or "",
processed_plain_text=action_display_prompt,
display_message=action_display_prompt,
user_platform=global_config.bot.platform,
user_id=str(global_config.bot.qq_account),
user_nickname=global_config.bot.nickname,
user_cardname="",
chat_info_platform=str(global_config.bot.platform),
chat_info_stream_id=chat_id or "",
)
copy_messages.append(action_msg)
@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
person_ids_set = set() # 使用集合来自动去重
for msg in messages:
platform: str = msg.get("user_platform") # type: ignore
user_id: str = msg.get("user_id") # type: ignore
platform = msg.get("user_platform") or ""
user_id = msg.get("user_id") or ""
# 检查必要信息是否存在 且 不是机器人自己
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
continue
# 添加空值检查,防止 platform 为 None 时出错
if platform is None:
platform = "unknown"
if person_id := get_person_id(platform, user_id):
person_ids_set.add(person_id)

File diff suppressed because it is too large Load Diff

View File

@ -1,18 +1,21 @@
import base64
from datetime import datetime
from typing import Optional, Tuple
import hashlib
import io
import os
import time
import hashlib
import uuid
import io
import numpy as np
from typing import Optional, Tuple
import numpy as np
from PIL import Image
from rich.traceback import install
from sqlmodel import select, col
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
@ -38,11 +41,7 @@ class ImageManager:
self._initialized = True
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
try:
db.connect(reuse_if_open=True)
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
except Exception as e:
logger.error(f"数据库连接或表创建失败: {e}")
get_db_session()
try:
self._cleanup_invalid_descriptions()
@ -72,10 +71,12 @@ class ImageManager:
Optional[str]: 描述文本如果不存在则返回None
"""
try:
record = ImageDescriptions.get_or_none(
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
)
return record.description if record else None
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
return record.description if record else None
except Exception as e:
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
return None
@ -90,15 +91,27 @@ class ImageManager:
description_type: 描述类型 ('emoji' 'image')
"""
try:
current_timestamp = time.time()
defaults = {"description": description, "timestamp": current_timestamp}
desc_obj, created = ImageDescriptions.get_or_create(
image_description_hash=image_hash, type=description_type, defaults=defaults
)
if not created: # 如果记录已存在,则更新
desc_obj.description = description
desc_obj.timestamp = current_timestamp
desc_obj.save()
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
)
record = session.exec(statement).first()
if record:
record.description = description
session.add(record)
return
new_record = Images(
image_hash=image_hash,
description=description,
full_path="",
image_type=ImageType(description_type),
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(new_record)
except Exception as e:
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
@ -107,20 +120,18 @@ class ImageManager:
"""清理数据库中 description 为空或为 'None' 的记录"""
invalid_values = ["", "None"]
# 清理 Images 表
deleted_images = (
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
)
with get_db_session() as session:
statement = (
select(Images)
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
.limit(1000)
)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理 ImageDescriptions 表
deleted_descriptions = (
ImageDescriptions.delete()
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
.execute()
)
if deleted_images or deleted_descriptions:
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions}")
if records:
logger.info(f"[清理完成] 删除 Images: {len(records)}")
else:
logger.info("[清理完成] 未发现无效描述记录")
@ -128,19 +139,15 @@ class ImageManager:
def _cleanup_emoji_from_image_descriptions():
"""清理Images和ImageDescriptions表中type为emoji的记录已迁移到EmojiDescriptionCache"""
try:
# 清理Images表中type为emoji的记录
deleted_images = Images.delete().where(Images.type == "emoji").execute()
with get_db_session() as session:
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
records = session.exec(statement).all()
for record in records:
session.delete(record)
# 清理ImageDescriptions表中type为emoji的记录
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
total_deleted = deleted_images + deleted_descriptions
total_deleted = len(records)
if total_deleted > 0:
logger.info(
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
f"共删除 {total_deleted} 条记录"
)
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
else:
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
except Exception as e:
@ -148,14 +155,14 @@ class ImageManager:
raise
async def get_emoji_tag(self, image_base64: str) -> str:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
emoji_manager = emoji_manager_instance
if isinstance(image_base64, str):
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji = emoji_manager.get_emoji_by_hash(image_hash)
if not emoji:
return "[表情包:未知]"
emotion_list = emoji.emotion
@ -175,14 +182,14 @@ class ImageManager:
try:
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
# 确保目录存在
os.makedirs(EMOJI_DIR, exist_ok=True)
# 检查是否已存在该表情包(通过哈希值)
emoji_manager = get_emoji_manager()
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
emoji_manager = emoji_manager_instance
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
if existing_emoji:
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
return
@ -212,14 +219,15 @@ class ImageManager:
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 优先使用EmojiManager查询已注册表情包的描述
try:
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
emoji_manager = get_emoji_manager()
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
emoji_manager = emoji_manager_instance
emoji = emoji_manager.get_emoji_by_hash(image_hash)
tags = emoji.emotion if emoji else None
if tags:
tag_str = ",".join(tags)
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
@ -227,29 +235,26 @@ class ImageManager:
except Exception as e:
logger.debug(f"查询EmojiManager时出错: {e}")
# 查询EmojiDescriptionCache表的缓存包含描述和情感标签
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
# 优先使用情感标签,如果没有则使用详细描述
result_text = ""
if cache_record.emotion_tags:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
)
result_text = f"[表情包:{cache_record.emotion_tags}]"
if cache_record.emotion:
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
result_text = f"[表情包:{cache_record.emotion}]"
elif cache_record.description:
logger.info(
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
)
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
result_text = f"[表情包:{cache_record.description}]"
# 即使缓存命中如果启用了steal_emoji也检查是否需要保存文件
if result_text:
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
return result_text
except Exception as e:
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"查询Images缓存时出错: {e}")
# === 二步走识别流程 ===
@ -309,33 +314,42 @@ class ImageManager:
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
# 再次检查缓存(防止并发情况下其他线程已经保存)
try:
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
if cache_record and cache_record.emotion_tags:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
return f"[表情包:{cache_record.emotion_tags}]"
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record and cache_record.emotion:
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
return f"[表情包:{cache_record.emotion}]"
except Exception as e:
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
logger.debug(f"再次查询Images缓存时出错: {e}")
# 保存识别出的详细描述和情感标签到 emoji_description_cache
try:
current_timestamp = time.time()
cache_record, created = EmojiDescriptionCache.get_or_create(
emoji_hash=image_hash,
defaults={
"description": detailed_description,
"emotion_tags": final_emotion,
"timestamp": current_timestamp,
},
)
if not created:
# 更新已有记录
cache_record.description = detailed_description
cache_record.emotion_tags = final_emotion
cache_record.timestamp = current_timestamp
cache_record.save()
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
)
cache_record = session.exec(statement).first()
if cache_record:
cache_record.description = detailed_description
cache_record.emotion = final_emotion
session.add(cache_record)
else:
cache_record = Images(
image_hash=image_hash,
description=detailed_description,
full_path="",
image_type=ImageType.EMOJI,
emotion=final_emotion,
query_count=0,
is_registered=False,
is_banned=False,
vlm_processed=True,
)
session.add(cache_record)
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
@ -358,14 +372,13 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 优先检查Images表中是否已有完整的描述
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
with get_db_session() as session:
statement = select(Images).where(col(Images.image_hash) == image_hash)
existing_image = session.exec(statement).first()
if existing_image:
# 更新计数
if hasattr(existing_image, "count") and existing_image.count is not None:
existing_image.count += 1
else:
existing_image.count = 1
existing_image.save()
existing_image.query_count += 1
with get_db_session() as session:
session.add(existing_image)
# 如果已有描述,直接返回
if existing_image.description:
@ -377,7 +390,7 @@ class ImageManager:
return f"[图片:{cached_description}]"
# 调用AI获取描述
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
prompt = global_config.personality.visual_style
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
description, _ = await self.vlm.generate_response_for_image(
@ -402,26 +415,27 @@ class ImageManager:
# 保存到数据库,补充缺失字段
if existing_image:
existing_image.path = file_path
existing_image.full_path = file_path
existing_image.description = description
existing_image.timestamp = current_timestamp
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
existing_image.vlm_processed = True
existing_image.save()
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
existing_image.vlm_processed = True
with get_db_session() as session:
session.add(existing_image)
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
else:
Images.create(
image_id=str(uuid.uuid4()),
emoji_hash=image_hash,
path=file_path,
type="image",
description=description,
timestamp=current_timestamp,
vlm_processed=True,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description=description,
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=True,
)
session.add(new_record)
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
except Exception as e:
logger.error(f"保存图片文件或元数据失败: {str(e)}")
@ -575,30 +589,17 @@ class ImageManager:
image_bytes = base64.b64decode(image_base64)
image_hash = hashlib.md5(image_bytes).hexdigest()
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
# 检查是否缺少必要字段,如果缺少则创建新记录
if (
not hasattr(existing_image, "image_id")
or not existing_image.image_id
or not hasattr(existing_image, "count")
or existing_image.count is None
or not hasattr(existing_image, "vlm_processed")
or existing_image.vlm_processed is None
):
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
if not existing_image.image_id:
existing_image.image_id = str(uuid.uuid4())
if existing_image.count is None:
existing_image.count = 0
if existing_image.vlm_processed is None:
existing_image.vlm_processed = False
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
)
existing_image = session.exec(statement).first()
if existing_image:
existing_image.query_count += 1
session.add(existing_image)
return str(existing_image.id), f"[picid:{existing_image.id}]"
existing_image.count += 1
existing_image.save()
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
else:
# print(f"图片不存在: {image_hash}")
image_id = str(uuid.uuid4())
image_id = str(uuid.uuid4())
# 保存新图片
current_timestamp = time.time()
@ -612,15 +613,19 @@ class ImageManager:
f.write(image_bytes)
# 保存到数据库
Images.create(
image_id=image_id,
emoji_hash=image_hash,
path=file_path,
type="image",
timestamp=current_timestamp,
vlm_processed=False,
count=1,
)
with get_db_session() as session:
new_record = Images(
image_hash=image_hash,
description="",
full_path=file_path,
image_type=ImageType.IMAGE,
query_count=1,
is_registered=False,
is_banned=False,
record_time=datetime.fromtimestamp(current_timestamp),
vlm_processed=False,
)
session.add(new_record)
# 启动异步VLM处理
await self._process_image_with_vlm(image_id, image_base64)
@ -647,17 +652,26 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest()
# 获取当前图片记录
image = Images.get(Images.image_id == image_id)
with get_db_session() as session:
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
if image is None:
logger.warning(f"未找到图片记录: {image_id}")
return
# 优先检查是否已有其他相同哈希的图片记录包含描述
existing_with_description = Images.get_or_none(
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
)
with get_db_session() as session:
statement = select(Images).where(
(col(Images.image_hash) == image_hash)
& (col(Images.description).is_not(None))
& (col(Images.description) != "")
)
existing_with_description = session.exec(statement).first()
if existing_with_description and existing_with_description.id != image.id:
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
image.description = existing_with_description.description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 同时保存到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, existing_with_description.description, "image")
return
@ -667,11 +681,12 @@ class ImageManager:
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
image.description = cached_description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
return
# 获取图片格式
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
# 构建prompt
prompt = global_config.personality.visual_style
@ -692,7 +707,8 @@ class ImageManager:
# 更新数据库
image.description = description
image.vlm_processed = True
image.save()
with get_db_session() as session:
session.add(image)
# 保存描述到ImageDescriptions表作为备用缓存
self._save_description_to_db(image_hash, description, "image")

View File

@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
from typing import Self, TypeVar, Generic, TYPE_CHECKING
from dataclasses import is_dataclass
from typing import Any, Dict, Self, TypeVar, Generic, TYPE_CHECKING
import copy
@ -15,9 +16,23 @@ class BaseDataModel:
return copy.deepcopy(self)
def transform_class_to_dict(obj: Any) -> Dict[str, Any]:
if obj is None:
return {}
if is_dataclass(obj):
return obj.__dict__
if hasattr(obj, "dict"):
return obj.dict()
if hasattr(obj, "model_dump"):
return obj.model_dump()
if hasattr(obj, "__dict__"):
return obj.__dict__
return {"value": obj}
class BaseDatabaseDataModel(ABC, Generic[T]):
@abstractmethod
@classmethod
@abstractmethod
def from_db_instance(cls, db_record: T) -> Self:
"""从数据库实例创建数据模型对象"""
raise NotImplementedError

View File

@ -0,0 +1,79 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Iterable, List, Optional, Tuple, Union
from . import BaseDataModel
class ReplyContentType(Enum):
TEXT = "text"
IMAGE = "image"
EMOJI = "emoji"
COMMAND = "command"
VOICE = "voice"
HYBRID = "hybrid"
FORWARD = "forward"
def __str__(self) -> str:
return self.value
@dataclass
class ReplyContent:
content_type: ReplyContentType | str
content: Any
@dataclass
class ForwardNode:
user_id: Optional[str] = None
user_nickname: Optional[str] = None
content: Union[str, List[ReplyContent], None] = None
@classmethod
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
return cls(content=message_id)
@classmethod
def construct_as_created_node(
cls,
user_id: str,
user_nickname: str,
content: List[ReplyContent],
) -> "ForwardNode":
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
class ReplySetModel(BaseDataModel):
def __init__(self) -> None:
self.reply_data: List[ReplyContent] = []
def __len__(self) -> int:
return len(self.reply_data)
def add_text_content(self, text: str) -> None:
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
def add_voice_content(self, voice_base64: str) -> None:
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
def add_hybrid_content_by_raw(self, message_tuple_list: Iterable[Tuple[ReplyContentType | str, str]]) -> None:
hybrid_contents: List[ReplyContent] = []
for content_type, content in message_tuple_list:
hybrid_contents.append(
ReplyContent(content_type=self._normalize_content_type(content_type), content=content)
)
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_contents))
def add_forward_content(self, forward_nodes: List[ForwardNode]) -> None:
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_nodes))
@staticmethod
def _normalize_content_type(content_type: ReplyContentType | str) -> ReplyContentType | str:
if isinstance(content_type, ReplyContentType):
return content_type
if isinstance(content_type, str):
for item in ReplyContentType:
if item.value == content_type:
return item
return content_type

View File

@ -1,11 +1,12 @@
from rich.traceback import install
from pathlib import Path
from contextlib import contextmanager
from sqlalchemy.orm import sessionmaker
from pathlib import Path
from typing import Generator, TYPE_CHECKING
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlmodel import create_engine, Session
from typing import TYPE_CHECKING, Generator
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, Session, create_engine
if TYPE_CHECKING:
from sqlite3 import Connection as SQLite3Connection
@ -53,6 +54,19 @@ SessionLocal = sessionmaker(
class_=Session,
)
_db_initialized = False
def initialize_database() -> None:
global _db_initialized
if _db_initialized:
return
_DB_DIR.mkdir(parents=True, exist_ok=True)
import src.common.database.database_model # noqa: F401
SQLModel.metadata.create_all(engine)
_db_initialized = True
@contextmanager
def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
@ -87,6 +101,7 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
- auto_commit=True ,成功执行完会自动提交
- auto_commit=False ,需要手动调用 session.commit()
"""
initialize_database()
session = SessionLocal()
try:
yield session
@ -120,6 +135,7 @@ def get_db() -> Generator[Session, None, None]:
Yields:
Session: SQLAlchemy 数据库会话
"""
initialize_database()
session = SessionLocal()
try:
yield session

View File

@ -1,35 +1,153 @@
import traceback
from datetime import datetime
from typing import Any
from typing import List, Any, Optional
import json
from src.config.config import global_config
from src.common.data_models.database_data_model import DatabaseMessages
from sqlalchemy import func
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
from src.config.config import global_config
logger = get_logger(__name__)
def _model_to_instance(model_instance: Any) -> DatabaseMessages:
"""
Peewee 模型实例转换为字典
"""
if isinstance(model_instance, dict):
return DatabaseMessages(**model_instance)
if hasattr(model_instance, "model_dump"):
return DatabaseMessages(**model_instance.model_dump())
return DatabaseMessages(**model_instance.__dict__)
FIELD_MAP: dict[str, Any] = {
"time": Messages.timestamp,
"timestamp": Messages.timestamp,
"chat_id": Messages.session_id,
"session_id": Messages.session_id,
"user_id": Messages.user_id,
"message_id": Messages.message_id,
"group_id": Messages.group_id,
"platform": Messages.platform,
"is_command": Messages.is_command,
"is_mentioned": Messages.is_mentioned,
"is_at": Messages.is_at,
"is_emoji": Messages.is_emoji,
"is_picid": Messages.is_picture,
"is_picture": Messages.is_picture,
"reply_to": Messages.reply_to,
}
def _parse_additional_config(message: Messages) -> dict[str, Any]:
if not message.additional_config:
return {}
try:
parsed = json.loads(message.additional_config)
except (json.JSONDecodeError, TypeError):
return {}
if isinstance(parsed, dict):
return parsed
return {}
def _normalize_optional_str(value: object) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, ensure_ascii=False)
except (TypeError, ValueError):
return str(value)
def _message_to_instance(message: Messages) -> DatabaseMessages:
config = _parse_additional_config(message)
timestamp_value = message.timestamp
if isinstance(timestamp_value, datetime):
time_value = timestamp_value.timestamp()
else:
time_value = float(timestamp_value)
selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
priority_info = _normalize_optional_str(config.get("priority_info"))
return DatabaseMessages(
message_id=message.message_id,
time=time_value,
chat_id=message.session_id,
reply_to=message.reply_to,
interest_value=config.get("interest_value"),
key_words=_normalize_optional_str(config.get("key_words")),
key_words_lite=_normalize_optional_str(config.get("key_words_lite")),
is_mentioned=message.is_mentioned,
is_at=message.is_at,
reply_probability_boost=config.get("reply_probability_boost"),
processed_plain_text=message.processed_plain_text,
display_message=message.display_message,
priority_mode=_normalize_optional_str(config.get("priority_mode")),
priority_info=priority_info,
additional_config=message.additional_config,
is_emoji=message.is_emoji,
is_picid=message.is_picture,
is_command=message.is_command,
intercept_message_level=config.get("intercept_message_level", 0),
is_notify=message.is_notify,
selected_expressions=selected_expressions,
user_id=message.user_id,
user_nickname=message.user_nickname,
user_cardname=message.user_cardname,
user_platform=message.platform,
chat_info_group_id=message.group_id,
chat_info_group_name=message.group_name,
chat_info_group_platform=message.platform,
chat_info_user_id=message.user_id,
chat_info_user_nickname=message.user_nickname,
chat_info_user_cardname=message.user_cardname,
chat_info_user_platform=message.platform,
chat_info_stream_id=message.session_id,
chat_info_platform=message.platform,
chat_info_create_time=0.0,
chat_info_last_active_time=0.0,
)
def _coerce_datetime(value: Any) -> Any:
if isinstance(value, (int, float)):
return datetime.fromtimestamp(value)
return value
def _cast_value_for_field(field: Any, value: Any) -> Any:
if field is Messages.timestamp:
return _coerce_datetime(value)
return value
def _ensure_list(value: Any) -> list[Any]:
if value is None:
return []
if isinstance(value, list):
return value
if isinstance(value, tuple):
return list(value)
if isinstance(value, set):
return list(value)
return [value]
def _resolve_field(field_name: str) -> Any | None:
if field_name in FIELD_MAP:
return FIELD_MAP[field_name]
if hasattr(Messages, field_name):
return getattr(Messages, field_name)
return None
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
sort: list[tuple[str, int]] | None = None,
limit: int = 0,
limit_mode: str = "latest",
filter_bot=False,
filter_command=False,
filter_intercept_message_level: Optional[int] = None,
) -> List[DatabaseMessages]:
filter_bot: bool = False,
filter_command: bool = False,
filter_intercept_message_level: int | None = None,
) -> list[DatabaseMessages]:
"""
根据提供的过滤器排序和限制条件查找消息
@ -43,92 +161,79 @@ def find_messages(
消息字典列表如果出错则返回空列表
"""
try:
query = Messages.select()
# 应用过滤器
conditions: list[Any] = []
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
field = _resolve_field(key)
if field is None:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
# 排除 id 为 "notice" 的消息
query = query.where(Messages.message_id != "notice")
continue
if isinstance(value, dict):
for op, op_value in value.items():
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
if op == "$gt":
conditions.append(field > coerced_value)
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
conditions.append(Messages.message_id != "notice")
if filter_bot:
query = query.where(Messages.user_id != global_config.bot.qq_account)
conditions.append(Messages.user_id != global_config.bot.qq_account)
if filter_command:
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
query = query.where(~Messages.is_command)
if filter_intercept_message_level is not None:
# 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息
query = query.where(Messages.intercept_message_level <= filter_intercept_message_level)
conditions.append(Messages.is_command == False) # noqa: E712
statement = select(Messages).where(*conditions)
if limit > 0:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by("time").limit(limit)
peewee_results = list(query)
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by("-time").limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(
latest_results_peewee,
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
)
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
with get_db_session() as session:
results = list(session.exec(statement).all())
else:
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
with get_db_session() as session:
results = list(session.exec(statement).all())
results = list(reversed(results))
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
peewee_sort_terms = []
order_terms: list[Any] = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field_name)
elif direction == -1: # DESC
peewee_sort_terms.append(f"-{field_name}")
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
sort_field = _resolve_field(field_name)
if sort_field is None:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
continue
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
if order_terms:
statement = statement.order_by(*order_terms)
with get_db_session() as session:
results = list(session.exec(statement).all())
return [_model_to_instance(msg) for msg in peewee_results]
if filter_intercept_message_level is not None:
filtered_results = []
for msg in results:
config = _parse_additional_config(msg)
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
filtered_results.append(msg)
results = filtered_results
return [_message_to_instance(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@ -146,54 +251,42 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量如果出错则返回 0
"""
try:
query = Messages.select()
# 应用过滤器
conditions: list[Any] = []
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
field = _resolve_field(key)
if field is None:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
continue
if isinstance(value, dict):
for op, op_value in value.items():
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
if op == "$gt":
conditions.append(field > coerced_value)
elif op == "$lt":
conditions.append(field < coerced_value)
elif op == "$gte":
conditions.append(field >= coerced_value)
elif op == "$lte":
conditions.append(field <= coerced_value)
elif op == "$ne":
conditions.append(field != coerced_value)
elif op == "$in":
conditions.append(field.in_(_ensure_list(coerced_value)))
elif op == "$nin":
conditions.append(field.not_in(_ensure_list(coerced_value)))
else:
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
conditions.append(field == coerced_value)
# 排除 id 为 "notice" 的消息
query = query.where(Messages.message_id != "notice")
count = query.count()
return count
conditions.append(Messages.message_id != "notice")
statement = select(func.count()).select_from(Messages).where(*conditions)
with get_db_session() as session:
result = session.exec(statement).one()
return int(result or 0)
except Exception as e:
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。

View File

@ -5,8 +5,8 @@ from PIL import Image
from datetime import datetime
from src.common.logger import get_logger
from src.common.database.database import db # 确保 db 被导入用于 create_tables
from src.common.database.database_model import LLMUsage
from src.common.database.database import get_db_session
from src.common.database.database_model import ModelUsage, ModelUser
from src.config.model_configs import ModelInfo
from .payload_content.message import Message, MessageBuilder
from .model_client.base_client import UsageRecord
@ -158,12 +158,7 @@ class LLMUsageRecorder:
"""
def __init__(self):
try:
# 使用 Peewee 创建表safe=True 表示如果表已存在则不会抛出错误
db.create_tables([LLMUsage], safe=True)
# logger.debug("LLMUsage 表已初始化/确保存在。")
except Exception as e:
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
pass
def record_usage_to_database(
self,
@ -178,22 +173,22 @@ class LLMUsageRecorder:
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
total_cost = round(input_cost + output_cost, 6)
try:
# 使用 Peewee 模型创建记录
LLMUsage.create(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
model_api_provider=model_info.api_provider,
user_id=user_id,
request_type=request_type,
endpoint=endpoint,
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
time_cost=round(time_cost or 0.0, 3),
status="success",
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
)
with get_db_session() as session:
record = ModelUsage(
model_name=model_info.model_identifier,
model_assign_name=model_info.name,
model_api_provider_name=model_info.api_provider,
endpoint=endpoint,
user_type=ModelUser.SYSTEM,
request_type=request_type,
time_cost=round(time_cost or 0.0, 3),
timestamp=datetime.now(),
prompt_tokens=model_usage.prompt_tokens or 0,
completion_tokens=model_usage.completion_tokens or 0,
total_tokens=model_usage.total_tokens or 0,
cost=total_cost or 0.0,
)
session.add(record)
logger.debug(
f"Token使用情况 - 模型: {model_usage.model_name}, "
f"用户: {user_id}, 类型: {request_type}, "

View File

@ -7,7 +7,7 @@ from src.manager.async_task_manager import async_task_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
# from src.chat.utils.token_statistics import TokenStatisticsTask
from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
@ -107,7 +107,7 @@ class MainSystem:
plugin_manager.load_all_plugins()
# 初始化表情管理器
get_emoji_manager().initialize()
emoji_manager.load_emojis_from_db()
logger.info("表情包管理器初始化成功")
# 初始化聊天管理器
@ -121,7 +121,7 @@ class MainSystem:
# 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中
self.app.register_message_handler(chat_bot.message_process)
self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process)
prompt_manager.load_prompts()
# 触发 ON_START 事件
@ -141,7 +141,7 @@ class MainSystem:
"""调度定时任务"""
try:
tasks = [
get_emoji_manager().start_periodic_check_register(),
emoji_manager.periodic_emoji_maintenance(),
start_dream_scheduler(),
self.app.run(),
self.server.run(),

View File

@ -1,12 +1,15 @@
import time
import json
import asyncio
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.prompt.prompt_manager import prompt_manager
from src.plugin_system.apis import llm_api
from src.common.database.database_model import ThinkingBack
from sqlmodel import select, col
from src.common.database.database import get_db_session
from src.common.database.database_model import ThinkingQuestion
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
from src.chat.message_receive.chat_stream import get_chat_manager
@ -29,13 +32,16 @@ def _cleanup_stale_not_found_thinking_back() -> None:
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
try:
deleted_rows = (
ThinkingBack.delete()
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
.execute()
)
if deleted_rows:
logger.info(f"清理过期的未找到答案thinking_back记录 {deleted_rows}")
with get_db_session() as session:
statement = select(ThinkingQuestion).where(
(ThinkingQuestion.found_answer == False)
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
)
records = session.exec(statement).all()
for record in records:
session.delete(record)
if records:
logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)}")
_last_not_found_cleanup_ts = now
except Exception as e:
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
@ -249,12 +255,12 @@ async def _react_agent_solve_question(
# 后续迭代都复用第一次构建的head_prompt
head_prompt = first_head_prompt
def message_factory(
def _build_messages(
_client,
*,
_head_prompt: str = head_prompt,
_conversation_messages: List[Message] = conversation_messages,
) -> List[Message]:
):
messages: List[Message] = []
system_builder = MessageBuilder()
@ -266,6 +272,7 @@ async def _react_agent_solve_question(
return messages
message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues]
(
success,
response,
@ -273,7 +280,7 @@ async def _react_agent_solve_question(
model_name,
tool_calls,
) = await llm_api.generate_with_model_with_tools_by_message_factory(
message_factory,
message_factory_fn, # type: ignore[arg-type]
model_config=model_config.model_task_config.tool_use,
tool_options=tool_definitions,
request_type="memory.react",
@ -304,7 +311,12 @@ async def _react_agent_solve_question(
assistant_message = assistant_builder.build()
# 记录思考步骤
step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
step: Dict[str, Any] = {
"iteration": iteration + 1,
"thought": response,
"actions": [],
"observations": [],
}
if assistant_message:
conversation_messages.append(assistant_message)
@ -417,20 +429,21 @@ async def _react_agent_solve_question(
"action_params": {"information": parsed_information or ""},
}
)
if parsed_information and parsed_information.strip():
parsed_info_text = parsed_information if isinstance(parsed_information, str) else ""
if parsed_info_text.strip():
step["observations"] = [f"检测到return_information{format_type}调用,返回信息"]
thinking_steps.append(step)
logger.info(
f"{react_log_prefix}{iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_information[:100]}..."
f"{react_log_prefix}{iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_info_text[:100]}..."
)
_log_conversation_messages(
conversation_messages,
head_prompt=first_head_prompt,
final_status=f"返回信息:{parsed_information}",
final_status=f"返回信息:{parsed_info_text}",
)
return True, parsed_information, thinking_steps, False
return True, parsed_info_text, thinking_steps, False
else:
# 信息为空,直接退出查询
step["observations"] = [f"检测到return_information{format_type}调用,信息为空"]
@ -776,15 +789,16 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
current_time = time.time()
start_time = current_time - time_window_seconds
# 查询最近时间窗口内的记录,按更新时间倒序
records = (
ThinkingBack.select()
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
.order_by(ThinkingBack.update_time.desc())
.limit(5) # 最多返回5条最近的记录
)
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(5)
)
records = session.exec(statement).all()
if not records.exists():
if not records:
return ""
history_lines = []
@ -828,20 +842,19 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
start_time = current_time - time_window_seconds
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
records = (
ThinkingBack.select()
.where(
(ThinkingBack.chat_id == chat_id)
& (ThinkingBack.update_time >= start_time)
& (ThinkingBack.found_answer == 1)
& (ThinkingBack.answer.is_null(False))
& (ThinkingBack.answer != "")
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.where(col(ThinkingQuestion.found_answer) == True)
.where(col(ThinkingQuestion.answer).is_not(None))
.where(col(ThinkingQuestion.answer) != "")
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(3)
)
.order_by(ThinkingBack.update_time.desc())
.limit(3) # 最多返回5条最近的记录
)
records = session.exec(statement).all()
if not records.exists():
if not records:
return []
found_answers = []
@ -873,36 +886,35 @@ def _store_thinking_back(
now = time.time()
# 先查询是否已存在相同chat_id和问题的记录
existing = (
ThinkingBack.select()
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
.order_by(ThinkingBack.update_time.desc())
.limit(1)
)
with get_db_session() as session:
statement = (
select(ThinkingQuestion)
.where(col(ThinkingQuestion.context) == chat_id)
.where(col(ThinkingQuestion.question) == question)
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
.limit(1)
)
record = session.exec(statement).first()
if record:
record.context = context
record.found_answer = found_answer
record.answer = answer
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
record.updated_timestamp = datetime.fromtimestamp(now)
session.add(record)
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
return
if existing.exists():
# 更新现有记录
record = existing.get()
record.context = context
record.found_answer = found_answer
record.answer = answer
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
record.update_time = now
record.save()
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
else:
# 创建新记录
ThinkingBack.create(
chat_id=chat_id,
new_record = ThinkingQuestion(
question=question,
context=context,
context=chat_id,
found_answer=found_answer,
answer=answer,
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
create_time=now,
update_time=now,
created_timestamp=datetime.fromtimestamp(now),
updated_timestamp=datetime.fromtimestamp(now),
)
# logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
session.add(new_record)
except Exception as e:
logger.error(f"存储思考过程失败: {e}")

View File

@ -6,10 +6,13 @@ import random
import math
from json_repair import repair_json
from typing import Union, Optional
from typing import Union, Optional, Dict
from datetime import datetime
from sqlmodel import col, select
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config, model_config
@ -35,24 +38,37 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
def get_person_id_by_person_name(person_name: str) -> str:
"""根据用户名获取用户ID"""
try:
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
record = session.exec(statement).first()
return record.person_id if record else ""
except Exception as e:
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
return ""
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
def is_person_known(
person_id: Optional[str] = None,
user_id: Optional[str] = None,
platform: Optional[str] = None,
person_name: Optional[str] = None,
) -> bool:
if person_id:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
elif user_id and platform:
person_id = get_person_id(platform, user_id)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
elif person_name:
person_id = get_person_id_by_person_name(person_name)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
return person.is_known if person else False
else:
return False
@ -442,17 +458,18 @@ class Person:
def load_from_database(self):
"""从数据库加载个人信息数据"""
try:
# 查询数据库中的记录
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
record = session.exec(statement).first()
if record:
self.user_id = record.user_id or ""
self.platform = record.platform or ""
self.is_known = record.is_known or False
self.nickname = record.nickname or ""
self.nickname = record.user_nickname or ""
self.person_name = record.person_name or self.nickname
self.name_reason = record.name_reason or None
self.know_times = record.know_times or 0
self.know_times = record.know_counts or 0
# 处理points字段JSON格式的列表
if record.memory_points:
@ -470,16 +487,16 @@ class Person:
self.memory_points = []
# 处理group_nick_name字段JSON格式的列表
if record.group_nick_name:
if record.group_nickname:
try:
loaded_group_nick_names = json.loads(record.group_nick_name)
loaded_group_nick_names = json.loads(record.group_nickname)
# 确保是列表格式
if isinstance(loaded_group_nick_names, list):
self.group_nick_name = loaded_group_nick_names
else:
self.group_nick_name = []
except (json.JSONDecodeError, TypeError):
logger.warning(f"解析用户 {self.person_id} 的group_nick_name字段失败使用默认值")
logger.warning(f"解析用户 {self.person_id} 的group_nickname字段失败使用默认值")
self.group_nick_name = []
else:
self.group_nick_name = []
@ -498,42 +515,55 @@ class Person:
if not self.is_known:
return
try:
# 准备数据
data = {
"person_id": self.person_id,
"is_known": self.is_known,
"platform": self.platform,
"user_id": self.user_id,
"nickname": self.nickname,
"person_name": self.person_name,
"name_reason": self.name_reason,
"know_times": self.know_times,
"know_since": self.know_since,
"last_know": self.last_know,
"memory_points": json.dumps(
[point for point in self.memory_points if point is not None], ensure_ascii=False
)
memory_points_value = (
json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False)
if self.memory_points
else json.dumps([], ensure_ascii=False),
"group_nick_name": json.dumps(self.group_nick_name, ensure_ascii=False)
else json.dumps([], ensure_ascii=False)
)
group_nickname_value = (
json.dumps(self.group_nick_name, ensure_ascii=False)
if self.group_nick_name
else json.dumps([], ensure_ascii=False),
}
else json.dumps([], ensure_ascii=False)
)
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
# 检查记录是否存在
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
record = session.exec(statement).first()
if record:
# 更新现有记录
for field, value in data.items():
if hasattr(record, field):
setattr(record, field, value)
record.save()
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
# 创建新记录
PersonInfo.create(**data)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
if record:
record.person_id = self.person_id
record.is_known = self.is_known
record.platform = self.platform
record.user_id = self.user_id
record.user_nickname = self.nickname
record.person_name = self.person_name
record.name_reason = self.name_reason
record.know_counts = self.know_times
record.first_known_time = first_known_time
record.last_known_time = last_known_time
record.memory_points = memory_points_value
record.group_nickname = group_nickname_value
session.add(record)
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
else:
record = PersonInfo(
person_id=self.person_id,
is_known=self.is_known,
platform=self.platform,
user_id=self.user_id,
user_nickname=self.nickname,
person_name=self.person_name,
name_reason=self.name_reason,
know_counts=self.know_times,
first_known_time=first_known_time,
last_known_time=last_known_time,
memory_points=memory_points_value,
group_nickname=group_nickname_value,
)
session.add(record)
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
except Exception as e:
logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}")
@ -621,30 +651,26 @@ class PersonInfoManager:
self.person_name_list = {}
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
try:
db.connect(reuse_if_open=True)
# 设置连接池参数
if hasattr(db, "execute_sql"):
# 设置SQLite优化参数
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
db.create_tables([PersonInfo], safe=True)
with get_db_session() as _:
pass
except Exception as e:
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
# 初始化时读取所有person_name
try:
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
PersonInfo.person_name.is_null(False)
):
if record.person_name:
self.person_name_list[record.person_id] = record.person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
with get_db_session() as session:
statement = select(PersonInfo.person_id, PersonInfo.person_name).where(
col(PersonInfo.person_name).is_not(None)
)
for person_id, person_name in session.exec(statement).all():
if person_name:
self.person_name_list[person_id] = person_name
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
except Exception as e:
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
logger.error(f"加载 person_name_list 失败: {e}")
@staticmethod
def _extract_json_from_text(text: str) -> dict:
def _extract_json_from_text(text: str) -> Dict[str, str]:
"""从文本中提取JSON数据的高容错方法"""
try:
fixed_json = repair_json(text)
@ -744,7 +770,9 @@ class PersonInfoManager:
else:
def _db_check_name_exists_sync(name_to_check):
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
with get_db_session() as session:
statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == name_to_check)
return session.exec(statement).first() is not None
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
is_duplicate = True
@ -804,7 +832,7 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
if not person_id:
# 如果通过person_name找不到尝试从chat_stream获取user_info
if chat_stream.user_info:
if platform and chat_stream.user_info and chat_stream.user_info.user_id:
user_id = chat_stream.user_info.user_id
person_id = get_person_id(platform, user_id)
else:

View File

@ -15,8 +15,9 @@ import uuid
from typing import Optional, Tuple, List, Dict, Any
from src.common.logger import get_logger
from src.chat.emoji_system.emoji_manager import get_emoji_manager, EMOJI_DIR
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image
from src.config.config import global_config
logger = get_logger("emoji_api")
@ -46,14 +47,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
try:
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
emoji_manager = get_emoji_manager()
emoji_result = await emoji_manager.get_emoji_for_text(description)
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
if not emoji_result:
if not emoji_obj:
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
return None
emoji_path, emoji_description, matched_emotion = emoji_result
emoji_path = str(emoji_obj.full_path)
emoji_description = emoji_obj.description
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
emoji_base64 = image_path_to_base64(emoji_path)
if not emoji_base64:
@ -90,8 +92,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
return []
try:
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
@ -114,7 +115,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
results = []
for selected_emoji in selected_emojis:
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
emoji_base64 = image_path_to_base64(str(selected_emoji.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
@ -123,7 +124,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
emoji_manager.update_emoji_usage(selected_emoji)
results.append((emoji_base64, selected_emoji.description, matched_emotion))
if not results and count > 0:
@ -158,8 +159,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
try:
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
all_emojis = emoji_manager.emojis
# 筛选匹配情感的表情包
matching_emojis = []
@ -181,7 +181,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
return None
# 记录使用次数
emoji_manager.record_usage(selected_emoji.hash)
emoji_manager.update_emoji_usage(selected_emoji)
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
return emoji_base64, selected_emoji.description, emotion
@ -203,8 +203,7 @@ def get_count() -> int:
int: 当前可用的表情包数量
"""
try:
emoji_manager = get_emoji_manager()
return emoji_manager.emoji_num
return len(emoji_manager.emojis)
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
return 0
@ -217,11 +216,10 @@ def get_info():
dict: 包含表情包数量最大数量可用数量信息
"""
try:
emoji_manager = get_emoji_manager()
return {
"current_count": emoji_manager.emoji_num,
"max_count": emoji_manager.emoji_num_max,
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
"current_count": len(emoji_manager.emojis),
"max_count": global_config.emoji.max_reg_num,
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
}
except Exception as e:
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
@ -235,10 +233,9 @@ def get_emotions() -> List[str]:
list: 所有表情包的情感标签列表去重
"""
try:
emoji_manager = get_emoji_manager()
emotions = set()
for emoji_obj in emoji_manager.emoji_objects:
for emoji_obj in emoji_manager.emojis:
if not emoji_obj.is_deleted and emoji_obj.emotion:
emotions.update(emoji_obj.emotion)
@ -255,8 +252,7 @@ async def get_all() -> List[Tuple[str, str, str]]:
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
"""
try:
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
all_emojis = emoji_manager.emojis
if not all_emojis:
logger.warning("[EmojiAPI] 没有可用的表情包")
@ -267,7 +263,7 @@ async def get_all() -> List[Tuple[str, str, str]]:
if emoji_obj.is_deleted:
continue
emoji_base64 = image_path_to_base64(emoji_obj.full_path)
emoji_base64 = image_path_to_base64(str(emoji_obj.full_path))
if not emoji_base64:
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
@ -291,12 +287,11 @@ def get_descriptions() -> List[str]:
list: 所有可用表情包的描述列表
"""
try:
emoji_manager = get_emoji_manager()
descriptions = []
descriptions.extend(
emoji_obj.description
for emoji_obj in emoji_manager.emoji_objects
for emoji_obj in emoji_manager.emojis
if not emoji_obj.is_deleted and emoji_obj.description
)
return descriptions
@ -341,14 +336,11 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
# 1. 获取emoji管理器并检查容量
emoji_manager = get_emoji_manager()
count_before = emoji_manager.emoji_num
max_count = emoji_manager.emoji_num_max
count_before = len(emoji_manager.emojis)
max_count = global_config.emoji.max_reg_num
# 2. 检查是否可以注册(未达到上限或启用替换)
can_register = count_before < max_count or (
count_before >= max_count and emoji_manager.emoji_num_max_reach_deletion
)
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
if not can_register:
return {
@ -474,7 +466,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
# 8. 构建返回结果
if register_success:
count_after = emoji_manager.emoji_num
count_after = len(emoji_manager.emojis)
replaced = count_after <= count_before # 如果数量没增加,说明是替换
# 尝试获取新注册的表情包信息
@ -483,10 +475,10 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
# 获取最新的表情包信息
try:
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
for emoji_obj in reversed(emoji_manager.emoji_objects):
for emoji_obj in reversed(emoji_manager.emojis):
if not emoji_obj.is_deleted and (
emoji_obj.filename == filename # 直接匹配
or (hasattr(emoji_obj, "full_path") and filename in emoji_obj.full_path) # 路径包含匹配
emoji_obj.file_name == filename
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
):
new_emoji_info = emoji_obj
break
@ -495,7 +487,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
description = new_emoji_info.description if new_emoji_info else None
emotions = new_emoji_info.emotion if new_emoji_info else None
emoji_hash = new_emoji_info.hash if new_emoji_info else None
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
return {
"success": True,
@ -560,12 +552,14 @@ async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
# 1. 获取emoji管理器和删除前的数量
emoji_manager = get_emoji_manager()
count_before = emoji_manager.emoji_num
count_before = len(emoji_manager.emojis)
# 2. 获取被删除表情包的信息(用于返回结果)
deleted_emoji = None
try:
deleted_emoji = await emoji_manager.get_emoji_from_manager(emoji_hash)
deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db(
emoji_hash
)
description = deleted_emoji.description if deleted_emoji else None
emotions = deleted_emoji.emotion if deleted_emoji else None
except Exception as info_error:
@ -574,10 +568,12 @@ async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
emotions = None
# 3. 执行删除操作
delete_success = await emoji_manager.delete_emoji(emoji_hash)
delete_success = False
if deleted_emoji:
delete_success = emoji_manager.delete_emoji(deleted_emoji)
# 4. 获取删除后的数量
count_after = emoji_manager.emoji_num
count_after = len(emoji_manager.emojis)
# 5. 构建返回结果
if delete_success:
@ -638,8 +634,7 @@ async def delete_emoji_by_description(description: str, exact_match: bool = Fals
try:
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
emoji_manager = get_emoji_manager()
all_emojis = emoji_manager.emoji_objects
all_emojis = emoji_manager.emojis
# 筛选匹配的表情包
matching_emojis = []
@ -669,12 +664,12 @@ async def delete_emoji_by_description(description: str, exact_match: bool = Fals
deleted_hashes = []
for emoji_obj in matching_emojis:
try:
delete_success = await emoji_manager.delete_emoji(emoji_obj.hash)
delete_success = emoji_manager.delete_emoji(emoji_obj)
if delete_success:
deleted_count += 1
deleted_hashes.append(emoji_obj.hash)
deleted_hashes.append(emoji_obj.emoji_hash)
except Exception as delete_error:
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.hash}): {delete_error}")
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}")
# 构建返回结果
if deleted_count > 0:

View File

@ -10,8 +10,10 @@
import time
from typing import List, Dict, Any, Tuple, Optional
from sqlmodel import col, select
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.database.database_model import Images
from src.common.database.database import get_db_session
from src.common.database.database_model import Images, ImageType
from src.chat.utils.utils import is_bot_self
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
@ -516,7 +518,13 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
def translate_pid_to_description(pid: str) -> str:
image = Images.get_or_none(Images.image_id == pid)
with get_db_session() as session:
statement = (
select(Images).where((col(Images.id) == int(pid)) & (col(Images.image_type) == ImageType.IMAGE))
if pid.isdigit()
else None
)
image = session.exec(statement).first() if statement is not None else None
description = ""
if image and image.description and image.description.strip():
description = image.description.strip()

View File

@ -1,23 +1,25 @@
"""麦麦 2025 年度总结 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime
from sqlalchemy import func as fn
from typing import Any, Optional
from src.common.logger import get_logger
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import desc, func
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import (
LLMUsage,
OnlineTime,
Messages,
ChatStreams,
PersonInfo,
Emoji,
ActionRecord,
Expression,
ActionRecords,
Images,
Jargon,
Messages,
ModelUsage,
OnlineTime,
PersonInfo,
)
from src.common.logger import get_logger
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.annual_report")
@ -45,7 +47,7 @@ class TimeFootprintData(BaseModel):
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
hourly_distribution: list[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
is_night_owl: bool = Field(False, description="是否是夜猫子")
@ -54,8 +56,8 @@ class SocialNetworkData(BaseModel):
"""社交网络数据"""
total_groups: int = Field(0, description="加入的群组总数")
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
top_groups: list[dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
top_users: list[dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
at_count: int = Field(0, description="被@次数")
mentioned_count: int = Field(0, description="被提及次数")
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
@ -69,11 +71,11 @@ class BrainPowerData(BaseModel):
total_cost: float = Field(0.0, description="年度总花费")
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
model_distribution: list[dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
top_reply_models: list[dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
top_token_consumers: list[dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
total_actions: int = Field(0, description="总动作数")
no_reply_count: int = Field(0, description="选择沉默的次数")
@ -88,23 +90,23 @@ class BrainPowerData(BaseModel):
class ExpressionVibeData(BaseModel):
"""个性与表达数据"""
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
top_emoji: Optional[dict[str, Any]] = Field(None, description="表情包之王")
top_emojis: list[dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
top_expressions: list[dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
checked_expression_count: int = Field(0, description="已检查的表达次数")
total_expressions: int = Field(0, description="表达总数")
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
action_types: list[dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
image_processed_count: int = Field(0, description="处理的图片数量")
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
late_night_reply: Optional[dict[str, Any]] = Field(None, description="深夜还在回复")
favorite_reply: Optional[dict[str, Any]] = Field(None, description="最喜欢的回复")
class AchievementData(BaseModel):
"""趣味成就数据"""
new_jargon_count: int = Field(0, description="新学到的黑话数量")
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
sample_jargons: list[dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
total_messages: int = Field(0, description="总消息数")
total_replies: int = Field(0, description="总回复数")
@ -115,11 +117,11 @@ class AnnualReportData(BaseModel):
year: int = Field(2025, description="报告年份")
bot_name: str = Field("麦麦", description="Bot名称")
generated_at: str = Field(..., description="报告生成时间")
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
achievements: AchievementData = Field(default_factory=AchievementData)
time_footprint: TimeFootprintData = Field(default_factory=lambda: TimeFootprintData.model_construct())
social_network: SocialNetworkData = Field(default_factory=lambda: SocialNetworkData.model_construct())
brain_power: BrainPowerData = Field(default_factory=lambda: BrainPowerData.model_construct())
expression_vibe: ExpressionVibeData = Field(default_factory=lambda: ExpressionVibeData.model_construct())
achievements: AchievementData = Field(default_factory=lambda: AchievementData.model_construct())
# ==================== 辅助函数 ====================
@ -144,15 +146,18 @@ def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
"""获取时光足迹数据"""
data = TimeFootprintData()
data = TimeFootprintData.model_construct()
start_ts, end_ts = get_year_time_range(year)
start_dt, end_dt = get_year_datetime_range(year)
try:
# 1. 年度在线时长
online_records = list(
OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt))
)
with get_db_session() as session:
statement = select(OnlineTime).where(
col(OnlineTime.start_timestamp) >= start_dt,
col(OnlineTime.end_timestamp) <= end_dt,
)
online_records = session.exec(statement).all()
total_seconds = 0
for record in online_records:
try:
@ -165,50 +170,66 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
data.total_online_hours = round(total_seconds / 3600, 2)
# 2. 初次相遇 - 年度第一条消息
first_msg = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.order_by(Messages.time.asc())
.first()
)
with get_db_session() as session:
statement = (
select(Messages)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.order_by(col(Messages.timestamp).asc())
.limit(1)
)
first_msg = session.exec(statement).first()
if first_msg:
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
data.first_message_time = first_msg.timestamp.strftime("%Y-%m-%d %H:%M:%S")
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
content = first_msg.processed_plain_text or first_msg.display_message or ""
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
# 3. 最忙碌的一天
# 使用 SQLite 的 date 函数按日期分组
busiest_query = (
Messages.select(
fn.date(Messages.time, "unixepoch").alias("day"),
fn.COUNT(Messages.id).alias("count"),
day_expr = func.date(col(Messages.timestamp))
with get_db_session() as session:
statement = (
select(
day_expr.label("day"),
func.count().label("count"),
)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(day_expr)
.order_by(func.count().desc())
.limit(1)
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.date(Messages.time, "unixepoch"))
.order_by(fn.COUNT(Messages.id).desc())
.limit(1)
)
busiest_result = list(busiest_query.dicts())
busiest_result = session.exec(statement).all()
if busiest_result:
data.busiest_day = busiest_result[0].get("day")
data.busiest_day_count = busiest_result[0].get("count", 0)
data.busiest_day = busiest_result[0][0]
data.busiest_day_count = busiest_result[0][1] or 0
# 4. 昼夜节律 - 24小时活跃分布
hourly_query = (
Messages.select(
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
fn.COUNT(Messages.id).alias("count"),
hour_expr = func.strftime("%H", col(Messages.timestamp))
with get_db_session() as session:
statement = (
select(
hour_expr.label("hour"),
func.count().label("count"),
)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(hour_expr)
)
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
)
hourly_rows = session.exec(statement).all()
hourly_distribution = [0] * 24
for row in hourly_query.dicts():
for row in hourly_rows:
try:
hour = int(row.get("hour", 0))
hour = int(row[0] or 0)
if 0 <= hour < 24:
hourly_distribution[hour] = row.get("count", 0)
hourly_distribution[hour] = row[1] or 0
except (ValueError, TypeError):
continue
data.hourly_distribution = hourly_distribution
@ -234,7 +255,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
"""获取社交网络数据"""
from src.config.config import global_config
data = SocialNetworkData()
data = SocialNetworkData.model_construct()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于过滤
@ -242,91 +263,110 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
try:
# 1. 加入的群组总数
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
with get_db_session() as session:
statement = select(func.count(func.distinct(col(Messages.group_id)))).where(
col(Messages.group_id).is_not(None),
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
data.total_groups = int(session.exec(statement).first() or 0)
# 2. 话痨群组 TOP3
top_groups_query = (
Messages.select(
Messages.chat_info_group_id,
Messages.chat_info_group_name,
fn.COUNT(Messages.id).alias("count"),
with get_db_session() as session:
statement = (
select(
col(Messages.group_id),
func.max(col(Messages.group_name)).label("group_name"),
func.count().label("count"),
)
.where(
col(Messages.group_id).is_not(None),
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(col(Messages.group_id))
.order_by(func.count().desc())
.limit(5)
)
.where(
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False))
)
.group_by(Messages.chat_info_group_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
top_groups_rows = session.exec(statement).all()
data.top_groups = [
{
"group_id": row["chat_info_group_id"],
"group_name": row["chat_info_group_name"] or "未知群组",
"message_count": row["count"],
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
"group_id": row[0],
"group_name": row[1] or "未知群组",
"message_count": row[2] or 0,
"is_webui": str(row[0]).startswith("webui_"),
}
for row in top_groups_query.dicts()
for row in top_groups_rows
]
# 3. 互动最多的用户 TOP5过滤 bot 自身)
top_users_query = (
Messages.select(
Messages.user_id,
Messages.user_nickname,
fn.COUNT(Messages.id).alias("count"),
with get_db_session() as session:
statement = (
select(
col(Messages.user_id),
func.max(col(Messages.user_nickname)).label("user_nickname"),
func.count().label("count"),
)
.where(
col(Messages.user_id).is_not(None),
col(Messages.user_id) != bot_qq,
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
.group_by(col(Messages.user_id))
.order_by(func.count().desc())
.limit(5)
)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.user_id.is_null(False))
& (Messages.user_id != bot_qq) # 过滤 bot 自身
)
.group_by(Messages.user_id)
.order_by(fn.COUNT(Messages.id).desc())
.limit(5)
)
top_users_rows = session.exec(statement).all()
data.top_users = [
{
"user_id": row["user_id"],
"user_nickname": row["user_nickname"] or "未知用户",
"message_count": row["count"],
"is_webui": str(row["user_id"]).startswith("webui_"),
"user_id": row[0],
"user_nickname": row[1] or "未知用户",
"message_count": row[2] or 0,
"is_webui": str(row[0]).startswith("webui_"),
}
for row in top_users_query.dicts()
for row in top_users_rows
]
# 4. 被@次数
data.at_count = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True))
.count()
)
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_at) == True,
)
data.at_count = int(session.exec(statement).first() or 0)
# 5. 被提及次数
data.mentioned_count = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True))
.count()
)
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_mentioned) == True,
)
data.mentioned_count = int(session.exec(statement).first() or 0)
# 6. 最长情陪伴的用户(过滤 bot 自身)
companion_query = (
ChatStreams.select(
ChatStreams.user_id,
ChatStreams.user_nickname,
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
with get_db_session() as session:
statement = select(PersonInfo).where(
col(PersonInfo.user_id) != bot_qq,
col(PersonInfo.first_known_time).is_not(None),
col(PersonInfo.last_known_time).is_not(None),
)
.where(
(ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
)
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
.limit(1)
)
companion_result = list(companion_query.dicts())
if companion_result:
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
duration = companion_result[0].get("duration", 0) or 0
data.longest_companion_days = int(duration / 86400) # 转换为天
persons = session.exec(statement).all()
if persons:
def _companion_days(person: PersonInfo) -> float:
if not person.first_known_time or not person.last_known_time:
return 0.0
return (person.last_known_time - person.first_known_time).total_seconds()
longest = max(persons, key=_companion_days)
data.longest_companion_user = longest.person_name or longest.user_nickname or longest.user_id
data.longest_companion_days = int(_companion_days(longest) / 86400)
else:
data.longest_companion_user = None
data.longest_companion_days = 0
except Exception as e:
logger.error(f"获取社交网络数据失败: {e}")
@ -339,154 +379,139 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
async def get_brain_power(year: int = 2025) -> BrainPowerData:
"""获取最强大脑数据"""
data = BrainPowerData()
data = BrainPowerData.model_construct()
start_dt, end_dt = get_year_datetime_range(year)
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 年度消耗 Token 总量和总花费
token_query = LLMUsage.select(
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
result = token_query.dicts().get()
data.total_tokens = int(result.get("total_tokens", 0) or 0)
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
with get_db_session() as session:
statement = select(
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
func.sum(col(ModelUsage.cost)).label("total_cost"),
).where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
result = session.exec(statement).first()
if result:
data.total_tokens = int(result[0] or 0)
data.total_cost = round(float(result[1] or 0), 4)
# 2. 最爱用的模型
model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
with get_db_session() as session:
statement = (
select(ModelUsage)
.where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
.order_by(desc(col(ModelUsage.timestamp)))
)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10)
)
model_results = list(model_query.dicts())
records = session.exec(statement).all()
model_agg: dict[str, dict[str, float | int]] = {}
for record in records:
model_name = record.model_assign_name or record.model_name or "unknown"
if model_name not in model_agg:
model_agg[model_name] = {"count": 0, "tokens": 0, "cost": 0.0}
bucket = model_agg[model_name]
bucket["count"] = int(bucket["count"]) + 1
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
model_results = sorted(
model_agg.items(),
key=lambda item: float(item[1]["count"]),
reverse=True,
)[:10]
if model_results:
data.favorite_model = model_results[0].get("model")
data.favorite_model_count = model_results[0].get("count", 0)
data.favorite_model = model_results[0][0]
data.favorite_model_count = int(model_results[0][1]["count"])
data.model_distribution = [
{
"model": row["model"],
"count": row["count"],
"tokens": row["tokens"],
"cost": round(row["cost"], 4),
"model": model_name,
"count": int(bucket["count"]),
"tokens": int(bucket["tokens"]),
"cost": round(float(bucket["cost"]), 4),
}
for row in model_results
for model_name, bucket in model_results
]
# 3. 最昂贵的一次思考
expensive_query = (
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
.order_by(LLMUsage.cost.desc())
.limit(1)
)
expensive_result = expensive_query.first()
if expensive_result:
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
if records:
expensive_record = max(records, key=lambda record: record.cost or 0.0)
data.most_expensive_cost = round(expensive_record.cost or 0.0, 4)
data.most_expensive_time = expensive_record.timestamp.strftime("%Y-%m-%d %H:%M:%S")
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
consumer_query = (
LLMUsage.select(
LLMUsage.user_id,
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (LLMUsage.user_id != "system") # 过滤 system 用户
& (LLMUsage.user_id.is_null(False))
)
.group_by(LLMUsage.user_id)
.order_by(fn.SUM(LLMUsage.cost).desc())
.limit(3)
)
consumer_agg: dict[str, dict[str, float | int]] = {}
for record in records:
user_id = record.model_api_provider_name
if not user_id or user_id == "system":
continue
if user_id not in consumer_agg:
consumer_agg[user_id] = {"cost": 0.0, "tokens": 0}
bucket = consumer_agg[user_id]
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
data.top_token_consumers = [
{
"user_id": row["user_id"],
"cost": round(row["cost"], 4),
"tokens": row["tokens"],
"user_id": user_id,
"cost": round(float(bucket["cost"]), 4),
"tokens": int(bucket["tokens"]),
}
for row in consumer_query.dicts()
for user_id, bucket in sorted(
consumer_agg.items(),
key=lambda item: float(item[1]["cost"]),
reverse=True,
)[:3]
]
# 5. 最喜欢的回复模型 TOP5按模型的回复次数统计只统计 replyer 调用)
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
reply_model_query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
fn.COUNT(LLMUsage.id).alias("count"),
)
.where(
(LLMUsage.timestamp >= start_dt)
& (LLMUsage.timestamp <= end_dt)
& (
LLMUsage.model_assign_name.contains("replyer")
| LLMUsage.model_assign_name.contains("回复")
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
)
)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(5)
)
data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()]
reply_model_agg: dict[str, int] = {}
for record in records:
model_assign_name = record.model_assign_name or ""
if "replyer" not in model_assign_name and "回复" not in model_assign_name:
continue
model_name = model_assign_name or record.model_name or "unknown"
reply_model_agg[model_name] = reply_model_agg.get(model_name, 0) + 1
data.top_reply_models = [
{"model": model_name, "count": count}
for model_name, count in sorted(reply_model_agg.items(), key=lambda item: item[1], reverse=True)[:5]
]
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
total_actions = (
ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count()
)
no_reply_count = (
ActionRecords.select()
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "no_reply")
with get_db_session() as session:
statement = select(func.count()).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
)
.count()
)
total_actions = int(session.exec(statement).first() or 0)
with get_db_session() as session:
statement = select(func.count()).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_name) == "no_reply",
)
no_reply_count = int(session.exec(statement).first() or 0)
data.total_actions = total_actions
data.no_reply_count = no_reply_count
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
# 6. 情绪波动 (兴趣值)
interest_query = Messages.select(
fn.AVG(Messages.interest_value).alias("avg_interest"),
fn.MAX(Messages.interest_value).alias("max_interest"),
).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False)))
interest_result = interest_query.dicts().get()
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
data.avg_interest_value = 0.0
data.max_interest_value = 0.0
# 找到最高兴趣值的时间
if data.max_interest_value > 0:
max_interest_msg = (
Messages.select(Messages.time)
.where(
(Messages.time >= start_ts)
& (Messages.time <= end_ts)
& (Messages.interest_value == data.max_interest_value)
)
.first()
)
if max_interest_msg:
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S")
data.max_interest_time = None
# 7. 思考深度 (基于 action_reasoning 长度)
reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_reasoning.is_null(False))
& (ActionRecords.action_reasoning != "")
)
with get_db_session() as session:
statement = select(ActionRecord).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_reasoning).is_not(None),
col(ActionRecord.action_reasoning) != "",
)
reasoning_records = session.exec(statement).all()
reasoning_lengths = []
max_len = 0
max_len_time = None
@ -496,13 +521,13 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
reasoning_lengths.append(length)
if length > max_len:
max_len = length
max_len_time = record.time
max_len_time = record.timestamp
if reasoning_lengths:
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
data.max_reasoning_length = max_len
if max_len_time:
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
data.max_reasoning_time = max_len_time.strftime("%Y-%m-%d %H:%M:%S")
except Exception as e:
logger.error(f"获取最强大脑数据失败: {e}")
@ -517,7 +542,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"""获取个性与表达数据"""
from src.config.config import global_config
data = ExpressionVibeData()
data = ExpressionVibeData.model_construct()
start_ts, end_ts = get_year_time_range(year)
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
@ -525,75 +550,58 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
try:
# 1. 表情包之王 - 使用次数最多的表情包
top_emoji_query = (
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
.where(Emoji.is_registered == True)
.order_by(Emoji.usage_count.desc())
.limit(5)
)
top_emojis = list(top_emoji_query.dicts())
with get_db_session() as session:
statement = (
select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5)
)
top_emojis = session.exec(statement).all()
if top_emojis:
data.top_emoji = {
"id": top_emojis[0].get("id"),
"path": top_emojis[0].get("full_path"),
"description": top_emojis[0].get("description"),
"usage_count": top_emojis[0].get("usage_count", 0),
"hash": top_emojis[0].get("emoji_hash"),
"id": top_emojis[0].id,
"path": top_emojis[0].full_path,
"description": top_emojis[0].description,
"usage_count": top_emojis[0].query_count,
"hash": top_emojis[0].image_hash,
}
data.top_emojis = [
{
"id": e.get("id"),
"path": e.get("full_path"),
"description": e.get("description"),
"usage_count": e.get("usage_count", 0),
"hash": e.get("emoji_hash"),
"id": e.id,
"path": e.full_path,
"description": e.description,
"usage_count": e.query_count,
"hash": e.image_hash,
}
for e in top_emojis
]
# 2. 百变麦麦 - 最常用的表达风格
expression_query = (
Expression.select(
Expression.style,
fn.SUM(Expression.count).alias("total_count"),
with get_db_session() as session:
statement = (
select(Expression.style, func.sum(col(Expression.count)).label("total_count"))
.where(
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
)
.group_by(Expression.style)
.order_by(func.sum(col(Expression.count)).desc())
.limit(5)
)
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
.group_by(Expression.style)
.order_by(fn.SUM(Expression.count).desc())
.limit(5)
)
data.top_expressions = [
{"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts()
]
expression_rows = session.exec(statement).all()
data.top_expressions = [{"style": row[0], "count": row[1] or 0} for row in expression_rows]
# 3. 被拒绝的表达
data.rejected_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.rejected == True)
)
.count()
)
data.rejected_expression_count = 0
# 4. 已检查的表达
data.checked_expression_count = (
Expression.select()
.where(
(Expression.last_active_time >= start_ts)
& (Expression.last_active_time <= end_ts)
& (Expression.checked == True)
)
.count()
)
data.checked_expression_count = 0
# 5. 表达总数
data.total_expressions = (
Expression.select()
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
.count()
)
with get_db_session() as session:
statement = select(func.count()).where(
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
)
data.total_expressions = int(session.exec(statement).first() or 0)
# 6. 动作类型分布 (过滤无意义的动作)
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
@ -608,28 +616,29 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
"listening",
"block_and_ignore",
]
action_query = (
ActionRecords.select(
ActionRecords.action_name,
fn.COUNT(ActionRecords.id).alias("count"),
with get_db_session() as session:
statement = (
select(ActionRecord.action_name, func.count().label("count"))
.where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_name).not_in(excluded_actions),
)
.group_by(ActionRecord.action_name)
.order_by(func.count().desc())
.limit(10)
)
.where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name.not_in(excluded_actions))
)
.group_by(ActionRecords.action_name)
.order_by(fn.COUNT(ActionRecords.id).desc())
.limit(10)
)
data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()]
action_rows = session.exec(statement).all()
data.action_types = [{"action": row[0], "count": row[1]} for row in action_rows]
# 7. 处理的图片数量
data.image_processed_count = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True))
.count()
)
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.is_picture) == True,
)
data.image_processed_count = int(session.exec(statement).first() or 0)
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
import random
@ -648,21 +657,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
return content
# 使用 user_id 判断是否是 bot 发送的消息
late_night_messages = list(
Messages.select(
Messages.time,
Messages.processed_plain_text,
Messages.display_message,
with get_db_session() as session:
statement = (
select(Messages)
.where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.user_id) == bot_qq,
)
.order_by(desc(col(Messages.timestamp)))
.limit(200)
)
.where(
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息
)
.order_by(Messages.time.desc())
)
late_night_messages = session.exec(statement).all()
# 筛选出0-6点的消息
late_night_filtered = []
for msg in late_night_messages:
msg_dt = datetime.fromtimestamp(msg.time)
msg_dt = msg.timestamp
hour = msg_dt.hour
if 0 <= hour < 6: # 0点到6点
raw_content = msg.processed_plain_text or msg.display_message or ""
@ -671,7 +681,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
if cleaned_content and len(cleaned_content) > 2:
late_night_filtered.append(
{
"time": msg.time,
"time": msg_dt.timestamp(),
"hour": hour,
"minute": msg_dt.minute,
"content": cleaned_content,
@ -693,13 +703,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
from collections import Counter
import json as json_lib
reply_records = ActionRecords.select(ActionRecords.action_data).where(
(ActionRecords.time >= start_ts)
& (ActionRecords.time <= end_ts)
& (ActionRecords.action_name == "reply")
& (ActionRecords.action_data.is_null(False))
& (ActionRecords.action_data != "")
)
with get_db_session() as session:
statement = select(ActionRecord).where(
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
col(ActionRecord.action_name) == "reply",
col(ActionRecord.action_data).is_not(None),
col(ActionRecord.action_data) != "",
)
reply_records = session.exec(statement).all()
reply_contents = []
for record in reply_records:
@ -762,21 +774,20 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
async def get_achievements(year: int = 2025) -> AchievementData:
"""获取趣味成就数据"""
data = AchievementData()
data = AchievementData.model_construct()
start_ts, end_ts = get_year_time_range(year)
try:
# 1. 新学到的黑话数量
# Jargon 表没有时间字段,统计全部已确认的黑话
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
with get_db_session() as session:
statement = select(func.count()).where(col(Jargon.is_jargon) == True)
data.new_jargon_count = int(session.exec(statement).first() or 0)
# 2. 代表性黑话示例
jargon_samples = (
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
.where(Jargon.is_jargon == True)
.order_by(Jargon.count.desc())
.limit(5)
)
with get_db_session() as session:
statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5)
jargon_samples = session.exec(statement).all()
data.sample_jargons = [
{
"content": j.content,
@ -787,14 +798,21 @@ async def get_achievements(year: int = 2025) -> AchievementData:
]
# 3. 总消息数
data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count()
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
)
data.total_messages = int(session.exec(statement).first() or 0)
# 4. 总回复数 (有 reply_to 的消息)
data.total_replies = (
Messages.select()
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False)))
.count()
)
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
col(Messages.reply_to).is_not(None),
)
data.total_replies = int(session.exec(statement).first() or 0)
except Exception as e:
logger.error(f"获取趣味成就数据失败: {e}")

View File

@ -7,16 +7,19 @@
import time
import uuid
from typing import Dict, Any, Optional, List
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
from pydantic import BaseModel
from sqlalchemy import case, func as fn
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Cookie, Depends, Header, Query, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from sqlalchemy import case, desc, func
from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database_model import Messages, PersonInfo
from src.config.config import global_config
from src.chat.message_receive.bot import chat_bot
from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, PersonInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
from src.webui.routers.websocket.auth import verify_ws_token
logger = get_logger("webui.chat")
@ -97,7 +100,7 @@ class ChatHistoryManager:
"id": msg.message_id,
"type": "bot" if is_bot else "user",
"content": msg.processed_plain_text or msg.display_message or "",
"timestamp": msg.time,
"timestamp": msg.timestamp.timestamp(),
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
"sender_id": "bot" if is_bot else user_id,
"is_bot": is_bot,
@ -113,12 +116,14 @@ class ChatHistoryManager:
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
# 查询指定群的消息,按时间排序
messages = (
Messages.select()
.where(Messages.chat_info_group_id == target_group_id)
.order_by(Messages.time.desc())
.limit(limit)
)
with get_db_session() as session:
statement = (
select(Messages)
.where(col(Messages.group_id) == target_group_id)
.order_by(desc(col(Messages.timestamp)))
.limit(limit)
)
messages = session.exec(statement).all()
# 转换为列表并反转(使最旧的消息在前)
# 传递 group_id 以便正确判断虚拟群中的机器人消息
@ -139,7 +144,10 @@ class ChatHistoryManager:
"""
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
try:
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
with get_db_session() as session:
statement = delete(Messages).where(col(Messages.group_id) == target_group_id)
result = session.exec(statement)
deleted = result.rowcount or 0
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
return deleted
except Exception as e:
@ -172,14 +180,14 @@ class ChatConnectionManager:
del self.user_sessions[user_id]
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
async def send_message(self, session_id: str, message: dict):
async def send_message(self, session_id: str, message: dict[str, Any]):
if session_id in self.active_connections:
try:
await self.active_connections[session_id].send_json(message)
except Exception as e:
logger.error(f"发送消息失败: {e}")
async def broadcast(self, message: dict):
async def broadcast(self, message: dict[str, Any]):
"""广播消息给所有连接"""
for session_id in list(self.active_connections.keys()):
await self.send_message(session_id, message)
@ -292,16 +300,18 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)):
"""
try:
# 查询所有不同的平台
platforms = (
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
.group_by(PersonInfo.platform)
.order_by(fn.COUNT(PersonInfo.id).desc())
)
with get_db_session() as session:
statement = (
select(PersonInfo.platform, func.count().label("count"))
.group_by(PersonInfo.platform)
.order_by(func.count().desc())
)
platforms = session.exec(statement).all()
result = []
for p in platforms:
if p.platform: # 排除空平台
result.append({"platform": p.platform, "count": p.count})
for platform, count in platforms:
if platform:
result.append({"platform": platform, "count": count})
return {"success": True, "platforms": result}
except Exception as e:
@ -325,31 +335,36 @@ async def get_persons_by_platform(
"""
try:
# 构建查询
query = PersonInfo.select().where(PersonInfo.platform == platform)
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
statement = statement.where(
(col(PersonInfo.person_name).contains(search))
| (col(PersonInfo.user_nickname).contains(search))
| (col(PersonInfo.user_id).contains(search))
)
# 按最后交互时间排序,优先显示活跃用户
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
query = query.limit(limit)
statement = statement.order_by(
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
col(PersonInfo.last_known_time).desc(),
).limit(limit)
with get_db_session() as session:
persons = session.exec(statement).all()
result = []
for person in query:
for person in persons:
result.append(
{
"person_id": person.person_id,
"user_id": person.user_id,
"person_name": person.person_name,
"nickname": person.nickname,
"nickname": person.user_nickname,
"is_known": person.is_known,
"platform": person.platform,
"display_name": person.person_name or person.nickname or person.user_id,
"display_name": person.person_name or person.user_nickname or person.user_id,
}
)
@ -448,7 +463,9 @@ async def websocket_chat(
# 如果 URL 参数中提供了虚拟身份信息,自动配置
if platform and person_id:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
if person:
# 使用前端传递的 group_id如果没有则生成一个稳定的
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
@ -457,7 +474,7 @@ async def websocket_chat(
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
user_nickname=person.person_name or person.user_nickname or person.user_id,
group_id=virtual_group_id,
group_name=group_name or "WebUI虚拟群聊",
)
@ -471,7 +488,7 @@ async def websocket_chat(
try:
# 构建会话信息
session_info_data = {
session_info_data: dict[str, Any] = {
"type": "session_info",
"session_id": session_id,
"user_id": user_id,
@ -641,7 +658,13 @@ async def websocket_chat(
# 获取用户信息
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
with get_db_session() as session:
statement = (
select(PersonInfo)
.where(col(PersonInfo.person_id) == virtual_data.get("person_id"))
.limit(1)
)
person = session.exec(statement).first()
if not person:
await chat_manager.send_message(
session_id,
@ -665,7 +688,7 @@ async def websocket_chat(
platform=person.platform,
person_id=person.person_id,
user_id=person.user_id,
user_nickname=person.person_name or person.nickname or person.user_id,
user_nickname=person.person_name or person.user_nickname or person.user_id,
group_id=group_id,
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
)
@ -769,7 +792,7 @@ async def get_chat_info(_auth: bool = Depends(require_auth)):
}
def get_webui_chat_broadcaster() -> tuple:
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
"""获取 WebUI 聊天广播器,供外部模块使用
Returns:

View File

@ -3,11 +3,16 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from sqlalchemy import case
from datetime import datetime, timedelta
from sqlalchemy import case, func
from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database_model import Expression, ChatStreams
from src.common.database.database import get_db_session
from src.common.database.database_model import Expression
from src.chat.message_receive.chat_stream import get_chat_manager
from src.webui.core import verify_auth_token_from_cookie_or_header
import time
logger = get_logger("webui.expression")
@ -98,30 +103,32 @@ def verify_auth_token(
def expression_to_response(expression: Expression) -> ExpressionResponse:
"""将 Expression 模型转换为响应对象"""
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
create_date = expression.create_time.timestamp() if expression.create_time else None
return ExpressionResponse(
id=expression.id,
id=expression.id if expression.id is not None else 0,
situation=expression.situation,
style=expression.style,
last_active_time=expression.last_active_time,
chat_id=expression.chat_id,
create_date=expression.create_date,
checked=expression.checked,
rejected=expression.rejected,
modified_by=expression.modified_by,
last_active_time=last_active_time,
chat_id=expression.session_id or "",
create_date=create_date,
checked=False,
rejected=False,
modified_by=None,
)
def get_chat_name(chat_id: str) -> str:
"""根据 chat_id 获取聊天名称"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream:
# 优先使用群聊名称,否则使用用户昵称
if chat_stream.group_name:
return chat_stream.group_name
elif chat_stream.user_nickname:
return chat_stream.user_nickname
return chat_id # 找不到时返回原始ID
chat_stream = get_chat_manager().get_stream(chat_id)
if not chat_stream:
return chat_id
if chat_stream.group_info and chat_stream.group_info.group_name:
return chat_stream.group_info.group_name
if chat_stream.user_info and chat_stream.user_info.user_nickname:
return chat_stream.user_info.user_nickname
return chat_id
except Exception:
return chat_id
@ -130,12 +137,15 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
"""批量获取聊天名称"""
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
try:
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
for cs in chat_streams:
if cs.group_name:
result[cs.stream_id] = cs.group_name
elif cs.user_nickname:
result[cs.stream_id] = cs.user_nickname
chat_manager = get_chat_manager()
for chat_id in chat_ids:
chat_stream = chat_manager.get_stream(chat_id)
if not chat_stream:
continue
if chat_stream.group_info and chat_stream.group_info.group_name:
result[chat_id] = chat_stream.group_info.group_name
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
result[chat_id] = chat_stream.user_info.user_nickname
except Exception as e:
logger.warning(f"批量获取聊天名称失败: {e}")
return result
@ -172,14 +182,17 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat
verify_auth_token(maibot_session, authorization)
chat_list = []
for cs in ChatStreams.select():
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
for stream_id, stream in get_chat_manager().streams.items():
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None
if not chat_name and stream.user_info and stream.user_info.user_nickname:
chat_name = stream.user_info.user_nickname
chat_name = chat_name or stream_id
chat_list.append(
ChatInfo(
chat_id=cs.stream_id,
chat_id=stream_id,
chat_name=chat_name,
platform=cs.platform,
is_group=bool(cs.group_id),
platform=stream.platform,
is_group=bool(stream.group_info and stream.group_info.group_id),
)
)
@ -221,29 +234,39 @@ async def get_expression_list(
verify_auth_token(maibot_session, authorization)
# 构建查询
query = Expression.select()
statement = select(Expression)
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序最后活跃时间倒序NULL 值放在最后)
query = query.order_by(
case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
statement = statement.order_by(
case((col(Expression.last_active_time).is_(None), 1), else_=0),
col(Expression.last_active_time).desc(),
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
# 转换为响应对象
data = [expression_to_response(expr) for expr in expressions]
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@ -272,7 +295,9 @@ async def get_expression_detail(
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
@ -305,16 +330,22 @@ async def create_expression(
try:
verify_auth_token(maibot_session, authorization)
current_time = time.time()
current_time = datetime.now()
# 创建表达方式
expression = Expression.create(
situation=request.situation,
style=request.style,
chat_id=request.chat_id,
last_active_time=current_time,
create_date=current_time,
)
with get_db_session() as session:
expression = Expression(
situation=request.situation,
style=request.style,
context="",
up_content="",
content_list="[]",
count=0,
last_active_time=current_time,
create_time=current_time,
session_id=request.chat_id,
)
session.add(expression)
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
@ -350,16 +381,18 @@ async def update_expression(
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
# 冲突检测:如果要求未检查状态,但已经被检查了
if request.require_unchecked and expression.checked:
if request.require_unchecked and getattr(expression, "checked", False):
raise HTTPException(
status_code=409,
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
)
# 只更新提供的字段
@ -376,13 +409,18 @@ async def update_expression(
update_data["modified_by"] = "user"
# 更新最后活跃时间
update_data["last_active_time"] = time.time()
update_data["last_active_time"] = datetime.now()
# 执行更新
for field, value in update_data.items():
setattr(expression, field, value)
expression.save()
with get_db_session() as session:
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
if not db_expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
for field, value in update_data.items():
if hasattr(db_expression, field):
setattr(db_expression, field, value)
session.add(db_expression)
expression = db_expression
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
@ -414,7 +452,9 @@ async def delete_expression(
try:
verify_auth_token(maibot_session, authorization)
expression = Expression.get_or_none(Expression.id == expression_id)
with get_db_session() as session:
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
expression = session.exec(statement).first()
if not expression:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
@ -423,7 +463,8 @@ async def delete_expression(
situation = expression.situation
# 执行删除
expression.delete_instance()
with get_db_session() as session:
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
@ -465,8 +506,9 @@ async def batch_delete_expressions(
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
# 查找所有要删除的表达方式
expressions = Expression.select().where(Expression.id.in_(request.ids))
found_ids = [expr.id for expr in expressions]
with get_db_session() as session:
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
found_ids = [expr_id for expr_id in session.exec(statements).all()]
# 检查是否有未找到的ID
not_found_ids = set(request.ids) - set(found_ids)
@ -474,7 +516,9 @@ async def batch_delete_expressions(
logger.warning(f"部分表达方式未找到: {not_found_ids}")
# 执行批量删除
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
with get_db_session() as session:
result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids)))
deleted_count = result.rowcount or 0
logger.info(f"批量删除了 {deleted_count} 个表达方式")
@ -503,21 +547,21 @@ async def get_expression_stats(
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
# 按 chat_id 统计
chat_stats = {}
for expr in Expression.select(Expression.chat_id):
chat_id = expr.chat_id
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
chat_stats = {}
for chat_id in session.exec(select(Expression.session_id)).all():
if chat_id:
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
# 获取最近创建的记录数7天内
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
recent = (
Expression.select()
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
.count()
)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_statement = (
select(func.count())
.select_from(Expression)
.where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago)
)
recent = session.exec(recent_statement).one()
return {
"success": True,
@ -561,12 +605,13 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori
try:
verify_auth_token(maibot_session, authorization)
total = Expression.select().count()
unchecked = Expression.select().where(Expression.checked == False).count()
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
user_checked = Expression.select().where(Expression.modified_by == "user").count()
with get_db_session() as session:
total = len(session.exec(select(Expression.id)).all())
unchecked = 0
passed = 0
rejected = 0
ai_checked = 0
user_checked = 0
return ReviewStatsResponse(
total=total,
@ -620,31 +665,44 @@ async def get_review_list(
try:
verify_auth_token(maibot_session, authorization)
query = Expression.select()
statement = select(Expression)
# 根据筛选类型过滤
if filter_type == "unchecked":
query = query.where(Expression.checked == False)
elif filter_type == "passed":
query = query.where((Expression.checked == True) & (Expression.rejected == False))
elif filter_type == "rejected":
query = query.where((Expression.checked == True) & (Expression.rejected == True))
if filter_type in {"unchecked", "passed", "rejected"}:
statement = statement.where(col(Expression.id) == -1)
# all 不需要额外过滤
# 搜索过滤
if search:
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
statement = statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
# 聊天ID过滤
if chat_id:
query = query.where(Expression.chat_id == chat_id)
statement = statement.where(col(Expression.session_id) == chat_id)
# 排序:创建时间倒序
query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
statement = statement.order_by(
case((col(Expression.create_time).is_(None), 1), else_=0),
col(Expression.create_time).desc(),
)
total = query.count()
offset = (page - 1) * page_size
expressions = query.offset(offset).limit(page_size)
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
expressions = session.exec(statement).all()
count_statement = select(Expression.id)
if filter_type in {"unchecked", "passed", "rejected"}:
count_statement = count_statement.where(col(Expression.id) == -1)
if search:
count_statement = count_statement.where(
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
)
if chat_id:
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
total = len(session.exec(count_statement).all())
return ReviewListResponse(
success=True,
@ -720,7 +778,8 @@ async def batch_review_expressions(
for item in request.items:
try:
expression = Expression.get_or_none(Expression.id == item.id)
with get_db_session() as session:
expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first()
if not expression:
results.append(
@ -730,23 +789,28 @@ async def batch_review_expressions(
continue
# 冲突检测
if item.require_unchecked and expression.checked:
if item.require_unchecked:
results.append(
BatchReviewResultItem(
id=item.id,
success=False,
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
)
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
)
failed += 1
continue
# 更新状态
expression.checked = True
expression.rejected = item.rejected
expression.modified_by = "user"
expression.last_active_time = time.time()
expression.save()
with get_db_session() as session:
db_expression = session.exec(
select(Expression).where(col(Expression.id) == item.id).limit(1)
).first()
if not db_expression:
results.append(
BatchReviewResultItem(
id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式"
)
)
failed += 1
continue
db_expression.last_active_time = datetime.now()
session.add(db_expression)
results.append(
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")

View File

@ -3,12 +3,16 @@
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
from pydantic import BaseModel
from typing import Optional, List, Dict
from datetime import datetime
from sqlalchemy import case
from sqlmodel import col, select, delete
from src.common.logger import get_logger
from src.common.database.database import get_db_session
from src.common.database.database_model import PersonInfo
from src.webui.core import verify_auth_token_from_cookie_or_header
import json
import time
logger = get_logger("webui.person")
@ -29,7 +33,7 @@ class PersonInfoResponse(BaseModel):
nickname: Optional[str]
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
memory_points: Optional[str]
know_times: Optional[float]
know_times: Optional[int]
know_since: Optional[float]
last_know: Optional[float]
@ -112,20 +116,22 @@ def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[D
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
"""将 PersonInfo 模型转换为响应对象"""
know_since = person.first_known_time.timestamp() if person.first_known_time else None
last_know = person.last_known_time.timestamp() if person.last_known_time else None
return PersonInfoResponse(
id=person.id,
id=person.id or 0,
is_known=person.is_known,
person_id=person.person_id,
person_name=person.person_name,
name_reason=person.name_reason,
platform=person.platform,
user_id=person.user_id,
nickname=person.nickname,
group_nick_name=parse_group_nick_name(person.group_nick_name),
nickname=person.user_nickname,
group_nick_name=parse_group_nick_name(person.group_nickname),
memory_points=person.memory_points,
know_times=person.know_times,
know_since=person.know_since,
last_know=person.last_know,
know_times=person.know_counts,
know_since=know_since,
last_know=last_know,
)
@ -157,36 +163,50 @@ async def get_person_list(
verify_auth_token(maibot_session, authorization)
# 构建查询
query = PersonInfo.select()
statement = select(PersonInfo)
# 搜索过滤
if search:
query = query.where(
(PersonInfo.person_name.contains(search))
| (PersonInfo.nickname.contains(search))
| (PersonInfo.user_id.contains(search))
statement = statement.where(
(col(PersonInfo.person_name).contains(search))
| (col(PersonInfo.user_nickname).contains(search))
| (col(PersonInfo.user_id).contains(search))
)
# 已认识状态过滤
if is_known is not None:
query = query.where(PersonInfo.is_known == is_known)
statement = statement.where(col(PersonInfo.is_known) == is_known)
# 平台过滤
if platform:
query = query.where(PersonInfo.platform == platform)
statement = statement.where(col(PersonInfo.platform) == platform)
# 排序最后更新时间倒序NULL 值放在最后)
# Peewee 不支持 nulls_last使用 CASE WHEN 来实现
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
statement = statement.order_by(
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
col(PersonInfo.last_known_time).desc(),
)
# 获取总数
total = query.count()
# 分页
offset = (page - 1) * page_size
persons = query.offset(offset).limit(page_size)
statement = statement.offset(offset).limit(page_size)
with get_db_session() as session:
persons = session.exec(statement).all()
count_statement = select(PersonInfo.id)
if search:
count_statement = count_statement.where(
(col(PersonInfo.person_name).contains(search))
| (col(PersonInfo.user_nickname).contains(search))
| (col(PersonInfo.user_id).contains(search))
)
if is_known is not None:
count_statement = count_statement.where(col(PersonInfo.is_known) == is_known)
if platform:
count_statement = count_statement.where(col(PersonInfo.platform) == platform)
total = len(session.exec(count_statement).all())
# 转换为响应对象
data = [person_to_response(person) for person in persons]
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
@ -215,7 +235,9 @@ async def get_person_detail(
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
@ -250,7 +272,9 @@ async def update_person(
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
@ -262,13 +286,18 @@ async def update_person(
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
# 更新最后修改时间
update_data["last_know"] = time.time()
update_data["last_known_time"] = datetime.now()
# 执行更新
for field, value in update_data.items():
setattr(person, field, value)
person.save()
with get_db_session() as session:
db_person = session.exec(select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)).first()
if not db_person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
for field, value in update_data.items():
if hasattr(db_person, field):
setattr(db_person, field, value)
session.add(db_person)
person = db_person
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
@ -300,16 +329,19 @@ async def delete_person(
try:
verify_auth_token(maibot_session, authorization)
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
with get_db_session() as session:
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
person = session.exec(statement).first()
if not person:
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
# 记录删除信息
person_name = person.person_name or person.nickname or person.user_id
person_name = person.person_name or person.user_nickname or person.user_id
# 执行删除
person.delete_instance()
with get_db_session() as session:
session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id))
logger.info(f"人物信息已删除: {person_id} ({person_name})")
@ -336,15 +368,17 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
try:
verify_auth_token(maibot_session, authorization)
total = PersonInfo.select().count()
known = PersonInfo.select().where(PersonInfo.is_known).count()
with get_db_session() as session:
total = len(session.exec(select(PersonInfo.id)).all())
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all())
unknown = total - known
# 按平台统计
platforms = {}
for person in PersonInfo.select(PersonInfo.platform):
platform = person.platform
platforms[platform] = platforms.get(platform, 0) + 1
with get_db_session() as session:
for platform in session.exec(select(PersonInfo.platform)).all():
if platform:
platforms[platform] = platforms.get(platform, 0) + 1
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
@ -383,14 +417,17 @@ async def batch_delete_persons(
for person_id in request.person_ids:
try:
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
if person:
person.delete_instance()
deleted_count += 1
logger.info(f"批量删除: {person_id}")
else:
failed_count += 1
failed_ids.append(person_id)
with get_db_session() as session:
person = session.exec(
select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
).first()
if person:
session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id))
deleted_count += 1
logger.info(f"批量删除: {person_id}")
else:
failed_count += 1
failed_ids.append(person_id)
except Exception as e:
logger.error(f"删除 {person_id} 失败: {e}")
failed_count += 1

View File

@ -1,13 +1,16 @@
"""统计数据 API 路由"""
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
from pydantic import BaseModel, Field
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from sqlalchemy import func as fn
from typing import Any, Optional
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy import desc, func, or_
from sqlmodel import col, select
from src.common.database.database import get_db_session
from src.common.database.database_model import Messages, ModelUsage, OnlineTime
from src.common.logger import get_logger
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
from src.webui.core import verify_auth_token_from_cookie_or_header
logger = get_logger("webui.statistics")
@ -60,10 +63,10 @@ class DashboardData(BaseModel):
"""仪表盘数据"""
summary: StatisticsSummary
model_stats: List[ModelStatistics]
hourly_data: List[TimeSeriesData]
daily_data: List[TimeSeriesData]
recent_activity: List[Dict[str, Any]]
model_stats: list[ModelStatistics]
hourly_data: list[TimeSeriesData]
daily_data: list[TimeSeriesData]
recent_activity: list[dict[str, Any]]
@router.get("/dashboard", response_model=DashboardData)
@ -111,26 +114,44 @@ async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
"""获取摘要统计数据(优化:使用数据库聚合)"""
summary = StatisticsSummary()
summary = StatisticsSummary(
total_requests=0,
total_cost=0.0,
total_tokens=0,
online_time=0.0,
total_messages=0,
total_replies=0,
avg_response_time=0.0,
cost_per_hour=0.0,
tokens_per_hour=0.0,
)
# 使用聚合查询替代全量加载
query = LLMUsage.select(
fn.COUNT(LLMUsage.id).alias("total_requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
with get_db_session() as session:
statement = select(
func.count().label("total_requests"),
func.sum(col(ModelUsage.cost)).label("total_cost"),
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
func.avg(col(ModelUsage.time_cost)).label("avg_response_time"),
).where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
result = session.execute(statement).first()
result = query.dicts().get()
summary.total_requests = result["total_requests"]
summary.total_cost = result["total_cost"]
summary.total_tokens = result["total_tokens"]
summary.avg_response_time = result["avg_response_time"] or 0.0
if result:
total_requests, total_cost, total_tokens, avg_response_time = result
summary.total_requests = total_requests or 0
summary.total_cost = float(total_cost or 0.0)
summary.total_tokens = total_tokens or 0
summary.avg_response_time = float(avg_response_time or 0.0)
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
online_records = list(
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
)
with get_db_session() as session:
statement = select(OnlineTime).where(
or_(
col(OnlineTime.start_timestamp) >= start_time,
col(OnlineTime.end_timestamp) >= start_time,
)
)
online_records = session.execute(statement).scalars().all()
for record in online_records:
start = max(record.start_timestamp, start_time)
@ -139,18 +160,23 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
summary.online_time += (end - start).total_seconds()
# 查询消息数量 - 使用聚合优化
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
)
summary.total_messages = messages_query.scalar() or 0
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= start_time,
col(Messages.timestamp) <= end_time,
)
total_messages = session.execute(statement).scalar()
summary.total_messages = int(total_messages or 0)
# 统计回复数量
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
(Messages.time >= start_time.timestamp())
& (Messages.time <= end_time.timestamp())
& (Messages.reply_to.is_null(False))
)
summary.total_replies = replies_query.scalar() or 0
with get_db_session() as session:
statement = select(func.count()).where(
col(Messages.timestamp) >= start_time,
col(Messages.timestamp) <= end_time,
col(Messages.reply_to).is_not(None),
)
total_replies = session.execute(statement).scalar()
summary.total_replies = int(total_replies or 0)
# 计算派生指标
if summary.online_time > 0:
@ -161,55 +187,80 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
return summary
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
# 使用GROUP BY聚合避免全量加载
query = (
LLMUsage.select(
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
fn.COUNT(LLMUsage.id).alias("request_count"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
)
.where(LLMUsage.timestamp >= start_time)
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
.order_by(fn.COUNT(LLMUsage.id).desc())
.limit(10) # 只取前10个
statement = (
select(ModelUsage)
.where(col(ModelUsage.timestamp) >= start_time)
.order_by(desc(col(ModelUsage.timestamp)))
.limit(200)
)
result = []
for row in query.dicts():
with get_db_session() as session:
rows = session.execute(statement).all()
aggregates: dict[str, dict[str, float | int]] = {}
for record in rows:
model_name = record.model_assign_name or record.model_name or "unknown"
if model_name not in aggregates:
aggregates[model_name] = {
"request_count": 0,
"total_cost": 0.0,
"total_tokens": 0,
"total_time_cost": 0.0,
"time_cost_count": 0,
}
bucket = aggregates[model_name]
bucket["request_count"] = int(bucket["request_count"]) + 1
bucket["total_cost"] = float(bucket["total_cost"]) + float(record.cost or 0.0)
bucket["total_tokens"] = int(bucket["total_tokens"]) + int(record.total_tokens or 0)
if record.time_cost:
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
result: list[ModelStatistics] = []
for model_name, bucket in sorted(
aggregates.items(),
key=lambda item: float(item[1]["request_count"]),
reverse=True,
)[:10]:
time_cost_count = int(bucket["time_cost_count"])
avg_time_cost = float(bucket["total_time_cost"]) / time_cost_count if time_cost_count > 0 else 0.0
result.append(
ModelStatistics(
model_name=row["model_name"],
request_count=row["request_count"],
total_cost=row["total_cost"],
total_tokens=row["total_tokens"],
avg_response_time=row["avg_response_time"] or 0.0,
model_name=model_name,
request_count=int(bucket["request_count"]),
total_cost=float(bucket["total_cost"]),
total_tokens=int(bucket["total_tokens"]),
avg_response_time=avg_time_cost,
)
)
return result
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
"""获取小时级统计数据(优化:使用数据库聚合)"""
# SQLite的日期时间函数进行小时分组
# 使用strftime将timestamp格式化为小时级别
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
hour_expr = func.strftime("%Y-%m-%dT%H:00:00", col(ModelUsage.timestamp))
statement = (
select(
hour_expr.label("hour"),
func.count().label("requests"),
func.sum(col(ModelUsage.cost)).label("cost"),
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
.group_by(hour_expr)
)
with get_db_session() as session:
rows = session.execute(statement).all()
# 转换为字典以快速查找
data_dict = {row["hour"]: row for row in query.dicts()}
data_dict = {row[0]: row for row in rows}
# 填充所有小时(包括没有数据的)
result = []
@ -219,7 +270,12 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
if hour_str in data_dict:
row = data_dict[hour_str]
result.append(
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
TimeSeriesData(
timestamp=hour_str,
requests=row[1] or 0,
cost=float(row[2] or 0.0),
tokens=row[3] or 0,
)
)
else:
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
@ -228,22 +284,26 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
return result
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
"""获取日级统计数据(优化:使用数据库聚合)"""
# 使用strftime按日期分组
query = (
LLMUsage.select(
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
fn.COUNT(LLMUsage.id).alias("requests"),
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
statement = (
select(
day_expr.label("day"),
func.count().label("requests"),
func.sum(col(ModelUsage.cost)).label("cost"),
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
)
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
.group_by(day_expr)
)
with get_db_session() as session:
rows = session.execute(statement).all()
# 转换为字典
data_dict = {row["day"]: row for row in query.dicts()}
data_dict = {row[0]: row for row in rows}
# 填充所有天
result = []
@ -253,7 +313,12 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
if day_str in data_dict:
row = data_dict[day_str]
result.append(
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
TimeSeriesData(
timestamp=day_str,
requests=row[1] or 0,
cost=float(row[2] or 0.0),
tokens=row[3] or 0,
)
)
else:
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
@ -262,9 +327,11 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
return result
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
"""获取最近活动"""
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
with get_db_session() as session:
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
records = session.execute(statement).scalars().all()
activities = []
for record in records:
@ -273,10 +340,10 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
"timestamp": record.timestamp.isoformat(),
"model": record.model_assign_name or record.model_name,
"request_type": record.request_type,
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
"tokens": record.total_tokens or 0,
"cost": record.cost or 0.0,
"time_cost": record.time_cost or 0.0,
"status": record.status,
"status": None,
}
)