diff --git a/bot.py b/bot.py index 33fcbdd1..3f3a4e9c 100644 --- a/bot.py +++ b/bot.py @@ -1,4 +1,4 @@ -raise RuntimeError("System Not Ready") +# raise RuntimeError("System Not Ready") import asyncio import hashlib import os diff --git a/pyproject.toml b/pyproject.toml index 72fe4984..dcee3892 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "MaiBot" -version = "0.11.6" +version = "1.0.0" description = "MaiCore 是一个基于大语言模型的可交互智能体" requires-python = ">=3.10" dependencies = [ @@ -35,6 +35,7 @@ dependencies = [ "tomlkit>=0.13.3", "urllib3>=2.5.0", "uvicorn>=0.35.0", + "msgpack>=1.1.2", ] diff --git a/requirements.txt b/requirements.txt index 6bd487cb..65105819 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ toml>=0.10.2 tomlkit>=0.13.3 urllib3>=2.5.0 uvicorn>=0.35.0 +msgpack>=1.1.2 \ No newline at end of file diff --git a/src/bw_learner/expression_auto_check_task.py b/src/bw_learner/expression_auto_check_task.py index 5fa7cbdb..d90eb4da 100644 --- a/src/bw_learner/expression_auto_check_task.py +++ b/src/bw_learner/expression_auto_check_task.py @@ -8,11 +8,15 @@ 4. 未通过评估的:rejected=1, checked=1 """ +from typing import List import asyncio import json import random -from typing import List +from sqlmodel import select + +from src.bw_learner.expression_review_store import get_review_state, set_review_state +from src.common.database.database import get_db_session from src.common.database.database_model import Expression from src.common.logger import get_logger from src.config.config import global_config @@ -26,11 +30,11 @@ logger = get_logger("expression_auto_check_task") def create_evaluation_prompt(situation: str, style: str) -> str: """ 创建评估提示词 - + Args: situation: 情境 style: 风格 - + Returns: 评估提示词 """ @@ -39,20 +43,20 @@ def create_evaluation_prompt(situation: str, style: str) -> str: "表达方式或言语风格 是否与使用条件或使用情景 匹配", "允许部分语法错误或口头化或缺省出现", "表达方式不能太过特指,需要具有泛用性", - "一般不涉及具体的人名或名称" + "一般不涉及具体的人名或名称", ] - + # 从配置中获取额外的自定义标准 custom_criteria = global_config.expression.expression_auto_check_custom_criteria - + # 合并所有评估标准 all_criteria = base_criteria.copy() if custom_criteria: all_criteria.extend(custom_criteria) - + # 构建评估标准列表字符串 - criteria_list = "\n".join([f"{i+1}. {criterion}" for i, criterion in enumerate(all_criteria)]) - + criteria_list = "\n".join([f"{i + 1}. {criterion}" for i, criterion in enumerate(all_criteria)]) + prompt = f"""请评估以下表达方式或语言风格以及使用条件或使用情景是否合适: 使用条件或使用情景:{situation} 表达方式或言语风格:{style} @@ -68,54 +72,52 @@ def create_evaluation_prompt(situation: str, style: str) -> str: }} 如果合适,suitable设为true;如果不合适,suitable设为false,并在reason中说明原因。 请严格按照JSON格式输出,不要包含其他内容。""" - + return prompt -judge_llm = LLMRequest( - model_set=model_config.model_task_config.tool_use, - request_type="expression_check" -) -async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str]: +judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check") + + +async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]: """ 执行单次LLM评估 - + Args: situation: 情境 style: 风格 - + Returns: (suitable, reason, error) 元组,如果出错则 suitable 为 False,error 包含错误信息 """ try: prompt = create_evaluation_prompt(situation, style) logger.debug(f"正在评估表达方式: situation={situation}, style={style}") - + response, (reasoning, model_name, _) = await judge_llm.generate_response_async( - prompt=prompt, - temperature=0.6, - max_tokens=1024 + prompt=prompt, temperature=0.6, max_tokens=1024 ) - + logger.debug(f"LLM响应: {response}") - + # 解析JSON响应 try: evaluation = json.loads(response) except json.JSONDecodeError as e: import re + json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL) if json_match: evaluation = json.loads(json_match.group()) else: raise ValueError("无法从响应中提取JSON格式的评估结果") from e - + suitable = evaluation.get("suitable", False) reason = evaluation.get("reason", "未提供理由") - + logger.debug(f"评估结果: {'通过' if suitable else '不通过'}") return suitable, reason, None - + except Exception as e: logger.error(f"评估表达方式 (situation={situation}, style={style}) 时出错: {e}") return False, f"评估过程出错: {str(e)}", str(e) @@ -130,36 +132,37 @@ class ExpressionAutoCheckTask(AsyncTask): super().__init__( task_name="Expression Auto Check Task", wait_before_start=60, # 启动后等待60秒再开始第一次检查 - run_interval=check_interval + run_interval=check_interval, ) async def _select_expressions(self, count: int) -> List[Expression]: """ 随机选择指定数量的未检查表达方式 - + Args: count: 需要选择的数量 - + Returns: 选中的表达方式列表 """ try: - # 查询所有未检查的表达方式(checked=False) - unevaluated_expressions = list( - Expression.select().where(~Expression.checked) - ) - + with get_db_session() as session: + statement = select(Expression) + all_expressions = session.exec(statement).all() + + unevaluated_expressions = [expr for expr in all_expressions if not get_review_state(expr.id)["checked"]] + if not unevaluated_expressions: logger.info("没有未检查的表达方式") return [] - + # 随机选择指定数量 selected_count = min(count, len(unevaluated_expressions)) selected = random.sample(unevaluated_expressions, selected_count) - + logger.info(f"从 {len(unevaluated_expressions)} 条未检查表达方式中随机选择了 {selected_count} 条") return selected - + except Exception as e: logger.error(f"选择表达方式时出错: {e}") return [] @@ -167,26 +170,23 @@ class ExpressionAutoCheckTask(AsyncTask): async def _evaluate_expression(self, expression: Expression) -> bool: """ 评估单个表达方式 - + Args: expression: 要评估的表达方式 - + Returns: True表示通过,False表示不通过 """ - + suitable, reason, error = await single_expression_check( expression.situation, expression.style, ) - + # 更新数据库 try: - expression.checked = True - expression.rejected = not suitable # 通过则rejected=0,不通过则rejected=1 - expression.modified_by = 'ai' # 标记为AI检查 - expression.save() - + set_review_state(expression.id, True, not suitable, "ai") + status = "通过" if suitable else "不通过" logger.info( f"表达方式评估完成 [ID: {expression.id}] - {status} | " @@ -194,12 +194,12 @@ class ExpressionAutoCheckTask(AsyncTask): f"Style: {expression.style}... | " f"Reason: {reason[:50]}..." ) - + if error: logger.warning(f"表达方式评估时出现错误 [ID: {expression.id}]: {error}") - + return suitable - + except Exception as e: logger.error(f"更新表达方式状态失败 [ID: {expression.id}]: {e}") return False @@ -211,42 +211,39 @@ class ExpressionAutoCheckTask(AsyncTask): if not global_config.expression.expression_self_reflect: logger.debug("表达方式自动检查未启用,跳过本次执行") return - + check_count = global_config.expression.expression_auto_check_count if check_count <= 0: logger.warning(f"检查数量配置无效: {check_count},跳过本次执行") return - + logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条") - - + # 选择要检查的表达方式 expressions = await self._select_expressions(check_count) - + if not expressions: logger.info("没有需要检查的表达方式") return - + # 逐个评估 passed_count = 0 failed_count = 0 - + for i, expression in enumerate(expressions, 1): logger.info(f"正在评估 [{i}/{len(expressions)}]: ID={expression.id}") - + if await self._evaluate_expression(expression): passed_count += 1 else: failed_count += 1 - + # 避免请求过快 await asyncio.sleep(0.3) - + logger.info( - f"表达方式自动检查完成: 总计 {len(expressions)} 条," - f"通过 {passed_count} 条,不通过 {failed_count} 条" + f"表达方式自动检查完成: 总计 {len(expressions)} 条,通过 {passed_count} 条,不通过 {failed_count} 条" ) - + except Exception as e: logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True) - diff --git a/src/bw_learner/expression_review_store.py b/src/bw_learner/expression_review_store.py new file mode 100644 index 00000000..b470b9fe --- /dev/null +++ b/src/bw_learner/expression_review_store.py @@ -0,0 +1,35 @@ +from typing import Any, Dict, Optional + +from src.manager.local_store_manager import local_storage + + +def _review_key(expression_id: int) -> str: + return f"expression_review:{expression_id}" + + +def get_review_state(expression_id: Optional[int]) -> Dict[str, Any]: + if expression_id is None: + return {"checked": False, "rejected": False, "modified_by": None} + value = local_storage[_review_key(expression_id)] + if isinstance(value, dict): + return { + "checked": bool(value.get("checked", False)), + "rejected": bool(value.get("rejected", False)), + "modified_by": value.get("modified_by"), + } + return {"checked": False, "rejected": False, "modified_by": None} + + +def set_review_state( + expression_id: Optional[int], + checked: bool, + rejected: bool, + modified_by: Optional[str], +) -> None: + if expression_id is None: + return + local_storage[_review_key(expression_id)] = { + "checked": checked, + "rejected": rejected, + "modified_by": modified_by, + } diff --git a/src/chat/__init__.py b/src/chat/__init__.py index a569c022..35bd5e02 100644 --- a/src/chat/__init__.py +++ b/src/chat/__init__.py @@ -3,11 +3,11 @@ MaiBot模块系统 包含聊天、情绪、记忆、日程等功能模块 """ +from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager -from src.chat.emoji_system.emoji_manager import get_emoji_manager # 导出主要组件供外部使用 __all__ = [ "get_chat_manager", - "get_emoji_manager", + "emoji_manager", ] diff --git a/src/chat/brain_chat/PFC/chat_observer.py b/src/chat/brain_chat/PFC/chat_observer.py index 91df8333..60426d4c 100644 --- a/src/chat/brain_chat/PFC/chat_observer.py +++ b/src/chat/brain_chat/PFC/chat_observer.py @@ -1,8 +1,11 @@ import time import asyncio import traceback +from datetime import datetime from typing import Optional, Dict, Any, List from src.common.logger import get_logger +from sqlmodel import select, col +from src.common.database.database import get_db_session from src.common.database.database_model import Messages from maim_message import UserInfo from src.config.config import global_config @@ -16,17 +19,18 @@ logger = get_logger("chat_observer") def _message_to_dict(message: Messages) -> Dict[str, Any]: """Convert Peewee Message model to dict for PFC compatibility - + Args: message: Peewee Messages model instance - + Returns: Dict[str, Any]: Message dictionary """ + message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp return { "message_id": message.message_id, - "time": message.time, - "chat_id": message.chat_id, + "time": message_timestamp, + "chat_id": message.session_id, "user_id": message.user_id, "user_nickname": message.user_nickname, "processed_plain_text": message.processed_plain_text, @@ -37,7 +41,7 @@ def _message_to_dict(message: Messages) -> Dict[str, Any]: "user_info": { "user_id": message.user_id, "user_nickname": message.user_nickname, - } + }, } @@ -109,10 +113,13 @@ class ChatObserver: """ logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}") - new_message_exists = Messages.select().where( - (Messages.chat_id == self.stream_id) & - (Messages.time > self.last_check_time) - ).exists() + last_check_time = self.last_check_time or 0.0 + last_check_dt = datetime.fromtimestamp(last_check_time) + with get_db_session() as session: + statement = select(Messages).where( + (col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt) + ) + new_message_exists = session.exec(statement).first() is not None if new_message_exists: logger.debug(f"[私聊][{self.private_name}]发现新消息") @@ -183,20 +190,21 @@ class ChatObserver: ) return has_new - - async def _fetch_new_messages(self) -> List[Dict[str, Any]]: """获取新消息 Returns: List[Dict[str, Any]]: 新消息列表 """ - query = Messages.select().where( - (Messages.chat_id == self.stream_id) & - (Messages.time > self.last_message_time) - ).order_by(Messages.time.asc()) - - new_messages = [_message_to_dict(msg) for msg in query] + last_message_time = self.last_message_time or 0.0 + last_message_dt = datetime.fromtimestamp(last_message_time) + with get_db_session() as session: + statement = ( + select(Messages) + .where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt)) + .order_by(col(Messages.timestamp)) + ) + new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()] if new_messages: self.last_message_read = new_messages[-1] @@ -215,13 +223,16 @@ class ChatObserver: Returns: List[Dict[str, Any]]: 最多5条消息 """ - query = Messages.select().where( - (Messages.chat_id == self.stream_id) & - (Messages.time < time_point) - ).order_by(Messages.time.desc()).limit(5) - - messages = list(query) - messages.reverse() # 需要按时间正序排列 + time_point_dt = datetime.fromtimestamp(time_point) + with get_db_session() as session: + statement = ( + select(Messages) + .where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt)) + .order_by(col(Messages.timestamp)) + .limit(5) + ) + messages = list(session.exec(statement).all()) + messages.reverse() new_messages = [_message_to_dict(msg) for msg in messages] if new_messages: diff --git a/src/chat/heart_flow/heartflow_message_processor.py b/src/chat/heart_flow/heartflow_message_processor.py index 09a3a94d..9301980b 100644 --- a/src/chat/heart_flow/heartflow_message_processor.py +++ b/src/chat/heart_flow/heartflow_message_processor.py @@ -10,7 +10,9 @@ from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.chat_message_builder import replace_user_references from src.common.logger import get_logger from src.person_info.person_info import Person -from src.common.database.database_model import Images +from sqlmodel import select, col +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType if TYPE_CHECKING: pass @@ -47,6 +49,12 @@ class HeartFCMessageReceiver: # 1. 消息解析与初始化 userinfo = message.message_info.user_info chat = message.chat_stream + if userinfo is None or message.message_info.platform is None: + raise ValueError("message userinfo or platform is missing") + if userinfo.user_id is None or userinfo.user_nickname is None: + raise ValueError("message userinfo id or nickname is missing") + user_id = userinfo.user_id + nickname = userinfo.user_nickname # 2. 计算at信息 is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message) @@ -70,7 +78,15 @@ class HeartFCMessageReceiver: processed_text = message.processed_plain_text if picid_list: for picid in picid_list: - image = Images.get_or_none(Images.image_id == picid) + with get_db_session() as session: + statement = ( + select(Images).where( + (col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE) + ) + if picid.isdigit() + else None + ) + image = session.exec(statement).first() if statement is not None else None if image and image.description: # 将[picid:xxxx]替换成图片描述 processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]") @@ -80,26 +96,24 @@ class HeartFCMessageReceiver: # 应用用户引用格式替换,将回复和@格式转换为可读格式 processed_plain_text = replace_user_references( - processed_text, - message.message_info.platform, # type: ignore - replace_bot_name=True, + processed_text, message.message_info.platform, replace_bot_name=True ) # if not processed_plain_text: # print(message) - logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore + logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # 如果是群聊,获取群号和群昵称 group_id = None group_nick_name = None if chat.group_info: - group_id = chat.group_info.group_id # type: ignore - group_nick_name = userinfo.user_cardname # type: ignore + group_id = chat.group_info.group_id + group_nick_name = userinfo.user_cardname _ = Person.register_person( - platform=message.message_info.platform, # type: ignore - user_id=message.message_info.user_info.user_id, # type: ignore - nickname=userinfo.user_nickname, # type: ignore + platform=message.message_info.platform, + user_id=user_id, + nickname=nickname, group_id=group_id, group_nick_name=group_nick_name, ) diff --git a/src/chat/message_receive/__init__.py b/src/chat/message_receive/__init__.py index 44b9eee3..fad126f4 100644 --- a/src/chat/message_receive/__init__.py +++ b/src/chat/message_receive/__init__.py @@ -1,10 +1,10 @@ -from src.chat.emoji_system.emoji_manager import get_emoji_manager +from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.storage import MessageStorage __all__ = [ - "get_emoji_manager", "get_chat_manager", "MessageStorage", + "emoji_manager", ] diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 81f78901..3e59b802 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -2,13 +2,15 @@ import asyncio import hashlib import time import copy +from datetime import datetime from typing import Dict, Optional, TYPE_CHECKING from rich.traceback import install from maim_message import GroupInfo, UserInfo +from sqlmodel import select, col from src.common.logger import get_logger -from src.common.database.database import db -from src.common.database.database_model import ChatStreams # 新增导入 +from src.common.database.database import get_db_session +from src.common.database.database_model import ChatSession # 避免循环导入,使用TYPE_CHECKING进行类型提示 if TYPE_CHECKING: @@ -76,7 +78,7 @@ class ChatStream: self.create_time = data.get("create_time", time.time()) if data else time.time() self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time self.saved = False - self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息 + self.context: Optional[ChatMessageContext] = None def to_dict(self) -> dict: """转换为字典格式""" @@ -95,10 +97,13 @@ class ChatStream: user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None + if user_info is None: + raise ValueError("user_info is required to build ChatStream") + return cls( stream_id=data["stream_id"], platform=data["platform"], - user_info=user_info, # type: ignore + user_info=user_info, group_info=group_info, data=data, ) @@ -128,12 +133,7 @@ class ChatManager: if not self._initialized: self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message - try: - db.connect(reuse_if_open=True) - # 确保 ChatStreams 表存在 - db.create_tables([ChatStreams], safe=True) - except Exception as e: - logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}") + get_db_session() self._initialized = True # 在事件循环中启动初始化 @@ -161,8 +161,13 @@ class ChatManager: def register_message(self, message: "MessageRecv"): """注册消息到聊天流""" + platform = message.message_info.platform or "" + if not platform: + raise ValueError("platform is required for ChatStream") + if message.message_info.user_info is None and message.message_info.group_info is None: + raise ValueError("user_info or group_info is required for ChatStream") stream_id = self._generate_stream_id( - message.message_info.platform, # type: ignore + platform, message.message_info.user_info, message.message_info.group_info, ) @@ -176,12 +181,18 @@ class ChatManager: """生成聊天流唯一ID""" if not user_info and not group_info: raise ValueError("用户信息或群组信息必须提供") + if group_info is None and user_info is None: + raise ValueError("用户信息或群组信息必须提供") if group_info: # 组合关键信息 components = [platform, str(group_info.group_id)] else: - components = [platform, str(user_info.user_id), "private"] # type: ignore + if user_info is None: + raise ValueError("用户信息或群组信息必须提供") + if user_info.user_id is None: + raise ValueError("user_id is required for private stream") + components = [platform, str(user_info.user_id), "private"] # 使用MD5生成唯一ID key = "_".join(components) @@ -231,33 +242,35 @@ class ChatManager: # 检查数据库中是否存在 def _db_find_stream_sync(s_id: str): - return ChatStreams.get_or_none(ChatStreams.stream_id == s_id) + with get_db_session() as session: + statement = select(ChatSession).where(col(ChatSession.session_id) == s_id) + return session.exec(statement).first() model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id) if model_instance: # 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式 user_info_data = { - "platform": model_instance.user_platform, + "platform": model_instance.platform, "user_id": model_instance.user_id, - "user_nickname": model_instance.user_nickname, - "user_cardname": model_instance.user_cardname or "", + "user_nickname": "", + "user_cardname": "", } group_info_data = None - if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息 + if model_instance.group_id: group_info_data = { - "platform": model_instance.group_platform, + "platform": model_instance.platform, "group_id": model_instance.group_id, - "group_name": model_instance.group_name, + "group_name": "", } data_for_from_dict = { - "stream_id": model_instance.stream_id, + "stream_id": model_instance.session_id, "platform": model_instance.platform, "user_info": user_info_data, "group_info": group_info_data, - "create_time": model_instance.create_time, - "last_active_time": model_instance.last_active_time, + "create_time": model_instance.created_timestamp.timestamp(), + "last_active_time": model_instance.last_active_timestamp.timestamp(), } stream = ChatStream.from_dict(data_for_from_dict) # 更新用户信息和群组信息 @@ -329,20 +342,26 @@ class ChatManager: user_info_d = s_data_dict.get("user_info") group_info_d = s_data_dict.get("group_info") - fields_to_save = { - "platform": s_data_dict["platform"], - "create_time": s_data_dict["create_time"], - "last_active_time": s_data_dict["last_active_time"], - "user_platform": user_info_d["platform"] if user_info_d else "", - "user_id": user_info_d["user_id"] if user_info_d else "", - "user_nickname": user_info_d["user_nickname"] if user_info_d else "", - "user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None, - "group_platform": group_info_d["platform"] if group_info_d else "", - "group_id": group_info_d["group_id"] if group_info_d else "", - "group_name": group_info_d["group_name"] if group_info_d else "", - } + with get_db_session() as session: + statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"]) + record = session.exec(statement).first() + if record is None: + record = ChatSession( + session_id=s_data_dict["stream_id"], + platform=s_data_dict["platform"], + user_id=user_info_d["user_id"] if user_info_d else None, + group_id=group_info_d["group_id"] if group_info_d else None, + created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]), + last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]), + ) + session.add(record) + return - ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute() + record.platform = s_data_dict["platform"] + record.user_id = user_info_d["user_id"] if user_info_d else None + record.group_id = group_info_d["group_id"] if group_info_d else None + record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"]) + record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"]) try: await asyncio.to_thread(_db_save_stream_sync, stream_data_dict) @@ -361,30 +380,32 @@ class ChatManager: def _db_load_all_streams_sync(): loaded_streams_data = [] - for model_instance in ChatStreams.select(): - user_info_data = { - "platform": model_instance.user_platform, - "user_id": model_instance.user_id, - "user_nickname": model_instance.user_nickname, - "user_cardname": model_instance.user_cardname or "", - } - group_info_data = None - if model_instance.group_id: - group_info_data = { - "platform": model_instance.group_platform, - "group_id": model_instance.group_id, - "group_name": model_instance.group_name, + with get_db_session() as session: + statement = select(ChatSession) + for model_instance in session.exec(statement).all(): + user_info_data = { + "platform": model_instance.platform, + "user_id": model_instance.user_id or "", + "user_nickname": "", + "user_cardname": "", } + group_info_data = None + if model_instance.group_id: + group_info_data = { + "platform": model_instance.platform, + "group_id": model_instance.group_id, + "group_name": "", + } - data_for_from_dict = { - "stream_id": model_instance.stream_id, - "platform": model_instance.platform, - "user_info": user_info_data, - "group_info": group_info_data, - "create_time": model_instance.create_time, - "last_active_time": model_instance.last_active_time, - } - loaded_streams_data.append(data_for_from_dict) + data_for_from_dict = { + "stream_id": model_instance.session_id, + "platform": model_instance.platform, + "user_info": user_info_data, + "group_info": group_info_data, + "create_time": model_instance.created_timestamp.timestamp(), + "last_active_time": model_instance.last_active_timestamp.timestamp(), + } + loaded_streams_data.append(data_for_from_dict) return loaded_streams_data try: diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index b73b04a3..6defe19d 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -1,36 +1,74 @@ -import re -import json -import traceback -from typing import Union +from datetime import datetime +from collections.abc import Mapping +from typing import cast -from src.common.database.database_model import Messages, Images +import json +import re +import traceback + +from sqlmodel import col, select +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType, Messages from src.common.logger import get_logger +from src.common.data_models.message_component_model import MessageSequence, TextComponent +from src.common.utils.utils_message import MessageUtils from .chat_stream import ChatStream -from .message import MessageSending, MessageRecv +from .message import MessageRecv, MessageSending logger = get_logger("message_storage") class MessageStorage: @staticmethod - def _serialize_keywords(keywords) -> str: + def _coerce_str_list(value: object) -> list[str]: + if isinstance(value, list): + return [str(item) for item in value] + if isinstance(value, tuple): + return [str(item) for item in value] + if isinstance(value, set): + return [str(item) for item in value] + if isinstance(value, str): + return [value] + return [] + + @staticmethod + def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str: + value = mapping.get(key) + if value is None: + return default + return str(value) + + @staticmethod + def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None: + value = mapping.get(key) + if value is None: + return None + return str(value) + + @staticmethod + def _serialize_keywords(keywords: list[str] | None) -> str: """将关键词列表序列化为JSON字符串""" if isinstance(keywords, list): return json.dumps(keywords, ensure_ascii=False) return "[]" @staticmethod - def _deserialize_keywords(keywords_str: str) -> list: + def _deserialize_keywords(keywords_str: str) -> list[str]: """将JSON字符串反序列化为关键词列表""" if not keywords_str: return [] try: - return json.loads(keywords_str) + parsed = cast(object, json.loads(keywords_str)) except (json.JSONDecodeError, TypeError): return [] + if isinstance(parsed, list): + return [str(item) for item in parsed] + if isinstance(parsed, str): + return [parsed] + return [] @staticmethod - async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None: + async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None: """存储消息到数据库""" try: # 通知消息不存储 @@ -66,7 +104,7 @@ class MessageStorage: priority_mode = "" priority_info = {} is_emoji = False - is_picid = False + is_picture = False is_notify = False is_command = False key_words = "" @@ -83,66 +121,73 @@ class MessageStorage: priority_mode = message.priority_mode priority_info = message.priority_info is_emoji = message.is_emoji - is_picid = message.is_picid + is_picture = message.is_picid is_notify = message.is_notify is_command = message.is_command intercept_message_level = getattr(message, "intercept_message_level", 0) # 序列化关键词列表为JSON字符串 - key_words = MessageStorage._serialize_keywords(message.key_words) - key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite) + key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words)) + key_words_lite = MessageStorage._serialize_keywords( + MessageStorage._coerce_str_list(message.key_words_lite) + ) selected_expressions = "" - chat_info_dict = chat_stream.to_dict() - user_info_dict = message.message_info.user_info.to_dict() # type: ignore + chat_info_dict = cast(dict[str, object], chat_stream.to_dict()) + if message.message_info.user_info is None: + raise ValueError("message.user_info is required") + user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict()) # message_id 现在是 TextField,直接使用字符串值 - msg_id = message.message_info.message_id + msg_id = message.message_info.message_id or "" # 安全地获取 group_info, 如果为 None 则视为空字典 - group_info_from_chat = chat_info_dict.get("group_info") or {} - # 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一) - user_info_from_chat = chat_info_dict.get("user_info") or {} + group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {}) - Messages.create( - message_id=msg_id, - time=float(message.message_info.time), # type: ignore - chat_id=chat_stream.stream_id, - # Flattened chat_info + additional_config: dict[str, object] = dict(message.message_info.additional_config or {}) + additional_config.update( + { + "interest_value": interest_value, + "priority_mode": priority_mode, + "priority_info": priority_info, + "reply_probability_boost": reply_probability_boost, + "intercept_message_level": intercept_message_level, + "key_words": key_words, + "key_words_lite": key_words_lite, + "selected_expressions": selected_expressions, + "is_picid": is_picture, + } + ) + processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or "" + raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else []) + raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence) + + timestamp_value = message.message_info.time + if timestamp_value is None: + raise ValueError("message.message_info.time is required") + db_message = Messages( + message_id=str(msg_id), + timestamp=datetime.fromtimestamp(float(timestamp_value)), + platform=MessageStorage._get_str(chat_info_dict, "platform"), + user_id=MessageStorage._get_str(user_info_dict, "user_id"), + user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"), + user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"), + group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"), + group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"), + is_mentioned=bool(is_mentioned), + is_at=bool(is_at), + session_id=chat_stream.stream_id, reply_to=reply_to, - is_mentioned=is_mentioned, - is_at=is_at, - reply_probability_boost=reply_probability_boost, - chat_info_stream_id=chat_info_dict.get("stream_id"), - chat_info_platform=chat_info_dict.get("platform"), - chat_info_user_platform=user_info_from_chat.get("platform"), - chat_info_user_id=user_info_from_chat.get("user_id"), - chat_info_user_nickname=user_info_from_chat.get("user_nickname"), - chat_info_user_cardname=user_info_from_chat.get("user_cardname"), - chat_info_group_platform=group_info_from_chat.get("platform"), - chat_info_group_id=group_info_from_chat.get("group_id"), - chat_info_group_name=group_info_from_chat.get("group_name"), - chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)), - chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)), - # Flattened user_info (message sender) - user_platform=user_info_dict.get("platform"), - user_id=user_info_dict.get("user_id"), - user_nickname=user_info_dict.get("user_nickname"), - user_cardname=user_info_dict.get("user_cardname"), - # Text content + is_emoji=is_emoji, + is_picture=is_picture, + is_command=is_command, + is_notify=is_notify, + raw_content=raw_content, processed_plain_text=filtered_processed_plain_text, display_message=filtered_display_message, - interest_value=interest_value, - priority_mode=priority_mode, - priority_info=priority_info, - is_emoji=is_emoji, - is_picid=is_picid, - is_notify=is_notify, - is_command=is_command, - intercept_message_level=intercept_message_level, - key_words=key_words, - key_words_lite=key_words_lite, - selected_expressions=selected_expressions, + additional_config=json.dumps(additional_config, ensure_ascii=False), ) + with get_db_session() as session: + session.add(db_message) except Exception: logger.exception("存储消息失败") logger.error(f"消息:{message}") @@ -156,16 +201,21 @@ class MessageStorage: if not qq_message_id: logger.info("消息不存在message_id,无法更新") return False - if matched_message := ( - Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first() - ): - # 更新找到的消息记录 - Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore - logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") - return True - else: - logger.debug("未找到匹配的消息") - return False + with get_db_session() as session: + statement = ( + select(Messages) + .where(col(Messages.message_id) == mmc_message_id) + .order_by(col(Messages.timestamp).desc()) + .limit(1) + ) + matched_message = session.exec(statement).first() + if matched_message: + matched_message.message_id = qq_message_id + session.add(matched_message) + logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}") + return True + logger.debug("未找到匹配的消息") + return False except Exception as e: logger.error(f"更新消息ID失败: {e}") @@ -182,13 +232,18 @@ class MessageStorage: logger.debug("文本中没有图片标记,直接返回原文本") return text - def replace_match(match): + def replace_match(match: re.Match[str]) -> str: description = match.group(1).strip() try: - image_record = ( - Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first() - ) - return f"[picid:{image_record.image_id}]" if image_record else match.group(0) + with get_db_session() as session: + statement = ( + select(Images) + .where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE)) + .order_by(col(Images.record_time).desc()) + .limit(1) + ) + image_record = session.exec(statement).first() + return f"[picid:{image_record.id}]" if image_record else match.group(0) except Exception: return match.group(0) diff --git a/src/chat/utils/chat_message_builder.py b/src/chat/utils/chat_message_builder.py index b0ba919c..11113f97 100644 --- a/src/chat/utils/chat_message_builder.py +++ b/src/chat/utils/chat_message_builder.py @@ -1,17 +1,17 @@ import time import random import re - +from datetime import datetime from typing import List, Dict, Any, Tuple, Optional, Callable from rich.traceback import install +from sqlmodel import select, col from src.config.config import global_config from src.common.logger import get_logger from src.common.message_repository import find_messages, count_messages from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords -from src.common.data_models.message_data_model import MessageAndActionModel -from src.common.database.database_model import ActionRecords -from src.common.database.database_model import Images +from src.common.database.database import get_db_session +from src.common.database.database_model import ActionRecord, Images from src.person_info.person_info import Person, get_person_id from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self @@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat( limit_mode: str = "latest", ) -> List[DatabaseActionRecords]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表""" - query = ActionRecords.select().where( - (ActionRecords.chat_id == chat_id) - & (ActionRecords.time > timestamp_start) # type: ignore - & (ActionRecords.time < timestamp_end) # type: ignore - ) + with get_db_session() as session: + statement = ( + select(ActionRecord) + .where((col(ActionRecord.session_id) == chat_id)) + .where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start)) + .where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end)) + ) - if limit > 0: - if limit_mode == "latest": - query = query.order_by(ActionRecords.time.desc()).limit(limit) - # 获取后需要反转列表,以保持最终输出为时间升序 - actions = list(query) - actions.reverse() - else: # earliest - query = query.order_by(ActionRecords.time.asc()).limit(limit) - else: - query = query.order_by(ActionRecords.time.asc()) - - actions = list(query) + if limit > 0: + if limit_mode == "latest": + statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit) + actions = list(session.exec(statement).all()) + actions = list(reversed(actions)) + else: + statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit) + actions = list(session.exec(statement).all()) + else: + statement = statement.order_by(col(ActionRecord.timestamp)) + actions = session.exec(statement).all() return [ DatabaseActionRecords( action_id=action.action_id, - time=action.time, + time=action.timestamp.timestamp(), action_name=action.action_name, - action_data=action.action_data, - action_done=action.action_done, - action_build_into_prompt=action.action_build_into_prompt, - action_prompt_display=action.action_prompt_display, - chat_id=action.chat_id, - chat_info_stream_id=action.chat_info_stream_id, - chat_info_platform=action.chat_info_platform, - action_reasoning=action.action_reasoning, + action_data=action.action_data or "{}", + action_done=True, + action_build_into_prompt=bool(action.action_display_prompt), + action_prompt_display=action.action_display_prompt or "", + chat_id=action.session_id, + chat_info_stream_id=action.session_id, + chat_info_platform=global_config.bot.platform, + action_reasoning=action.action_reasoning or "", ) for action in actions ] @@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive( chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest" ) -> List[Dict[str, Any]]: """获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表""" - query = ActionRecords.select().where( - (ActionRecords.chat_id == chat_id) - & (ActionRecords.time >= timestamp_start) # type: ignore - & (ActionRecords.time <= timestamp_end) # type: ignore - ) + with get_db_session() as session: + statement = ( + select(ActionRecord) + .where((col(ActionRecord.session_id) == chat_id)) + .where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start)) + .where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end)) + ) - if limit > 0: - if limit_mode == "latest": - query = query.order_by(ActionRecords.time.desc()).limit(limit) - # 获取后需要反转列表,以保持最终输出为时间升序 - actions = list(query) - return [action.__data__ for action in reversed(actions)] - else: # earliest - query = query.order_by(ActionRecords.time.asc()).limit(limit) - else: - query = query.order_by(ActionRecords.time.asc()) + if limit > 0: + if limit_mode == "latest": + statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit) + actions = list(session.exec(statement).all()) + actions = list(reversed(actions)) + else: + statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit) + actions = list(session.exec(statement).all()) + else: + statement = statement.order_by(col(ActionRecord.timestamp)) + actions = session.exec(statement).all() - actions = list(query) - return [action.__data__ for action in actions] + return [action.model_dump() for action in actions] def get_raw_msg_by_timestamp_random( @@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random( def get_raw_msg_by_timestamp_with_users( - timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest" + timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest" ) -> List[DatabaseMessages]: """获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 @@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat( def get_raw_msg_before_timestamp_with_users( - timestamp: float, person_ids: list, limit: int = 0 + timestamp: float, person_ids: List[str], limit: int = 0 ) -> List[DatabaseMessages]: """获取指定时间戳之前的消息,按时间升序排序,返回消息列表 limit: 限制返回的消息数量,0为不限制 @@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp def num_new_messages_since_with_users( - chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list + chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str] ) -> int: """检查某些特定用户在特定聊天在指定时间戳之间有多少新消息""" if not person_ids: # 保持空列表检查 @@ -358,7 +361,7 @@ def num_new_messages_since_with_users( def _build_readable_messages_internal( - messages: List[MessageAndActionModel], + messages: List[DatabaseMessages], replace_bot_name: bool = True, timestamp_mode: str = "relative", truncate: bool = False, @@ -413,7 +416,7 @@ def _build_readable_messages_internal( # 匹配 [picid:xxxxx] 格式 pic_pattern = r"\[picid:([^\]]+)\]" - def replace_pic_id(match: re.Match) -> str: + def replace_pic_id(match: re.Match[str]) -> str: nonlocal current_pic_counter nonlocal pic_counter pic_id = match.group(1) @@ -421,7 +424,8 @@ def _build_readable_messages_internal( if pic_id not in pic_description_cache: description = "内容正在阅读,请稍等" try: - image = Images.get_or_none(Images.image_id == pic_id) + with get_db_session() as session: + image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None if image and image.description: description = image.description except Exception: @@ -438,16 +442,11 @@ def _build_readable_messages_internal( # 1: 获取发送者信息并提取消息组件 for message in messages: - if message.is_action_record: - # 对于动作记录,也处理图片ID - content = process_pic_ids(message.display_message) - detailed_messages_raw.append((message.time, message.user_nickname, content, True)) - continue - - platform = message.user_platform - user_id = message.user_id - user_nickname = message.user_nickname - user_cardname = message.user_cardname + user_info = message.user_info + platform = user_info.platform + user_id = user_info.user_id + user_nickname = user_info.user_nickname + user_cardname = user_info.user_cardname timestamp = message.time content = message.display_message or message.processed_plain_text or "" @@ -525,12 +524,12 @@ def _build_readable_messages_internal( if long_time_notice and prev_timestamp is not None: time_diff = timestamp - prev_timestamp time_diff_hours = time_diff / 3600 - + # 检查是否跨天 prev_date = time.strftime("%Y-%m-%d", time.localtime(prev_timestamp)) current_date = time.strftime("%Y-%m-%d", time.localtime(timestamp)) is_cross_day = prev_date != current_date - + # 如果间隔大于8小时或跨天,插入提示 if time_diff_hours > 8 or is_cross_day: # 格式化日期为中文格式:xxxx年xx月xx日(去掉前导零) @@ -542,20 +541,15 @@ def _build_readable_messages_internal( hours_str = f"{int(time_diff_hours)}h" notice = f"以下聊天开始时间:{date_str}。距离上一条消息过去了{hours_str}\n" output_lines.append(notice) - + readable_time = translate_timestamp_to_human_readable(timestamp, mode=timestamp_mode) # 查找消息id(如果有)并构建id_prefix message_id = timestamp_to_id_mapping.get(timestamp, "") id_prefix = f"[{message_id}]" if message_id else "" - if is_action: - # 对于动作记录,使用特殊格式 - output_lines.append(f"{id_prefix}{readable_time}, {content}") - else: - output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}") - output_lines.append("\n") # 在每个消息块后添加换行,保持可读性 - + output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}") + output_lines.append("\n") prev_timestamp = timestamp formatted_string = "".join(output_lines).strip() @@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str: # 从数据库中获取图片描述 description = "内容正在阅读,请稍等" try: - image = Images.get_or_none(Images.image_id == pic_id) + with get_db_session() as session: + image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None if image and image.description: description = image.description except Exception: @@ -663,7 +658,7 @@ async def build_readable_messages_with_list( 允许通过参数控制格式化行为。 """ formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal( - [MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages], + messages, replace_bot_name, timestamp_mode, truncate, @@ -754,7 +749,7 @@ def build_readable_messages( filtered_messages = [] for msg in messages: # 获取消息内容 - content = msg.processed_plain_text + content = msg.processed_plain_text or "" # 移除表情包 emoji_pattern = r"\[表情包:[^\]]+\]" content = re.sub(emoji_pattern, "", content) @@ -765,17 +760,14 @@ def build_readable_messages( messages = filtered_messages - copy_messages: List[MessageAndActionModel] = [] + copy_messages: List[DatabaseMessages] = [] for msg in messages: if remove_emoji_stickers: - # 创建 MessageAndActionModel 但移除表情包 - model = MessageAndActionModel.from_DatabaseMessages(msg) # 移除表情包 - if model.processed_plain_text: - model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text) - copy_messages.append(model) + msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "") + copy_messages.append(msg) else: - copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg)) + copy_messages.append(msg) if show_actions and copy_messages: # 获取所有消息的时间范围 @@ -786,40 +778,45 @@ def build_readable_messages( chat_id = messages[0].chat_id if messages else None # 获取这个时间范围内的动作记录,并匹配chat_id - actions_in_range = ( - ActionRecords.select() - .where( - (ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id) - ) - .order_by(ActionRecords.time) - ) + with get_db_session() as session: + actions_in_range = session.exec( + select(ActionRecord) + .where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time)) + .where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time)) + .where(col(ActionRecord.session_id) == chat_id) + .order_by(col(ActionRecord.timestamp)) + ).all() # 获取最新消息之后的第一个动作记录 - action_after_latest = ( - ActionRecords.select() - .where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id)) - .order_by(ActionRecords.time) - .limit(1) - ) + with get_db_session() as session: + action_after_latest = session.exec( + select(ActionRecord) + .where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time)) + .where(col(ActionRecord.session_id) == chat_id) + .order_by(col(ActionRecord.timestamp)) + .limit(1) + ).all() # 合并两部分动作记录 - actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest) + actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest) # 将动作记录转换为消息格式 for action in actions: # 只有当build_into_prompt为True时才添加动作记录 - if action.action_build_into_prompt: - action_msg = MessageAndActionModel( - time=float(action.time), # type: ignore - user_id=global_config.bot.qq_account, # 使用机器人的QQ账号 - user_platform=global_config.bot.platform, # 使用机器人的平台 - user_nickname=global_config.bot.nickname, # 使用机器人的用户名 - user_cardname="", # 机器人没有群名片 - processed_plain_text=f"{action.action_prompt_display}", - display_message=f"{action.action_prompt_display}", - chat_info_platform=str(action.chat_info_platform), - is_action_record=True, # 添加标识字段 - action_name=str(action.action_name), # 保存动作名称 + action_display_prompt = action.action_display_prompt or "" + if action_display_prompt: + action_msg = DatabaseMessages( + message_id=f"action_{action.action_id}", + time=float(action.timestamp.timestamp()), + chat_id=chat_id or "", + processed_plain_text=action_display_prompt, + display_message=action_display_prompt, + user_platform=global_config.bot.platform, + user_id=str(global_config.bot.qq_account), + user_nickname=global_config.bot.nickname, + user_cardname="", + chat_info_platform=str(global_config.bot.platform), + chat_info_stream_id=chat_id or "", ) copy_messages.append(action_msg) @@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]: person_ids_set = set() # 使用集合来自动去重 for msg in messages: - platform: str = msg.get("user_platform") # type: ignore - user_id: str = msg.get("user_id") # type: ignore + platform = msg.get("user_platform") or "" + user_id = msg.get("user_id") or "" # 检查必要信息是否存在 且 不是机器人自己 if not all([platform, user_id]) or user_id == global_config.bot.qq_account: continue - # 添加空值检查,防止 platform 为 None 时出错 - if platform is None: - platform = "unknown" - if person_id := get_person_id(platform, user_id): person_ids_set.add(person_id) diff --git a/src/chat/utils/statistic.py b/src/chat/utils/statistic.py index b23e4503..21eef538 100644 --- a/src/chat/utils/statistic.py +++ b/src/chat/utils/statistic.py @@ -4,17 +4,65 @@ import json from collections import defaultdict from datetime import datetime, timedelta -from typing import Any, Dict, Tuple, List +from typing import cast + +from typing_extensions import TypedDict + +from sqlmodel import col, select from src.common.logger import get_logger -from src.common.database.database import db -from src.common.database.database_model import OnlineTime, LLMUsage, Messages, ActionRecords +from src.common.database.database import get_db_session +from src.common.database.database_model import OnlineTime, ModelUsage, Messages, ActionRecord from src.manager.async_task_manager import AsyncTask from src.manager.local_store_manager import local_storage from src.config.config import global_config logger = get_logger("maibot_statistic") + +class StatPeriodData(TypedDict): + total_requests: int + total_cost: float + requests_by_type: defaultdict[str, int] + requests_by_user: defaultdict[str, int] + requests_by_model: defaultdict[str, int] + requests_by_module: defaultdict[str, int] + in_tokens_by_type: defaultdict[str, int] + in_tokens_by_user: defaultdict[str, int] + in_tokens_by_model: defaultdict[str, int] + in_tokens_by_module: defaultdict[str, int] + out_tokens_by_type: defaultdict[str, int] + out_tokens_by_user: defaultdict[str, int] + out_tokens_by_model: defaultdict[str, int] + out_tokens_by_module: defaultdict[str, int] + tokens_by_type: defaultdict[str, int] + tokens_by_user: defaultdict[str, int] + tokens_by_model: defaultdict[str, int] + tokens_by_module: defaultdict[str, int] + costs_by_type: defaultdict[str, float] + costs_by_user: defaultdict[str, float] + costs_by_model: defaultdict[str, float] + costs_by_module: defaultdict[str, float] + time_costs_by_type: defaultdict[str, list[float]] + time_costs_by_user: defaultdict[str, list[float]] + time_costs_by_model: defaultdict[str, list[float]] + time_costs_by_module: defaultdict[str, list[float]] + avg_time_costs_by_type: defaultdict[str, float] + avg_time_costs_by_user: defaultdict[str, float] + avg_time_costs_by_model: defaultdict[str, float] + avg_time_costs_by_module: defaultdict[str, float] + std_time_costs_by_type: defaultdict[str, float] + std_time_costs_by_user: defaultdict[str, float] + std_time_costs_by_model: defaultdict[str, float] + std_time_costs_by_module: defaultdict[str, float] + online_time: float + total_messages: int + messages_by_chat: defaultdict[str, int] + total_replies: int + + +StatPeriodMapping = dict[str, StatPeriodData] + # 统计数据的键 TOTAL_REQ_CNT = "total_requests" TOTAL_COST = "total_cost" @@ -70,8 +118,8 @@ class OnlineTimeRecordTask(AsyncTask): @staticmethod def _init_database(): """初始化数据库""" - with db.atomic(): # Use atomic operations for schema changes - OnlineTime.create_table(safe=True) # Creates table if it doesn't exist, Peewee handles indexes from model + with get_db_session() as _: + return async def run(self): # sourcery skip: use-named-expression try: @@ -80,36 +128,41 @@ class OnlineTimeRecordTask(AsyncTask): if self.record_id: # 如果有记录,则更新结束时间 - query = OnlineTime.update(end_timestamp=extended_end_time).where(OnlineTime.id == self.record_id) # type: ignore - updated_rows = query.execute() - if updated_rows == 0: - # Record might have been deleted or ID is stale, try to find/create - self.record_id = None # Reset record_id to trigger find/create logic below + with get_db_session() as session: + statement = select(OnlineTime).where(col(OnlineTime.id) == self.record_id).limit(1) + existing_record = session.exec(statement).first() + if existing_record: + existing_record.end_timestamp = extended_end_time + session.add(existing_record) + else: + self.record_id = None if not self.record_id: # Check again if record_id was reset or initially None # 如果没有记录,检查一分钟以内是否已有记录 # Look for a record whose end_timestamp is recent enough to be considered ongoing - recent_record = ( - OnlineTime.select() - .where(OnlineTime.end_timestamp >= (current_time - timedelta(minutes=1))) # type: ignore - .order_by(OnlineTime.end_timestamp.desc()) - .first() - ) - - if recent_record: - # 如果有记录,则更新结束时间 - self.record_id = recent_record.id - recent_record.end_timestamp = extended_end_time - recent_record.save() - else: - # 若没有记录,则插入新的在线时间记录 - new_record = OnlineTime.create( - timestamp=current_time.timestamp(), # 添加此行 - start_timestamp=current_time, - end_timestamp=extended_end_time, - duration=5, # 初始时长为5分钟 + with get_db_session() as session: + statement = ( + select(OnlineTime) + .where(col(OnlineTime.end_timestamp) >= (current_time - timedelta(minutes=1))) + .order_by(col(OnlineTime.end_timestamp).desc()) + .limit(1) ) - self.record_id = new_record.id + recent_record = session.exec(statement).first() + + if recent_record: + self.record_id = recent_record.id + recent_record.end_timestamp = extended_end_time + session.add(recent_record) + else: + new_record = OnlineTime( + timestamp=current_time, + start_timestamp=current_time, + end_timestamp=extended_end_time, + duration_minutes=5, + ) + session.add(new_record) + session.flush() + self.record_id = new_record.id except Exception as e: logger.error(f"在线时间记录失败,错误信息:{e}") @@ -177,7 +230,7 @@ class StatisticOutputTask(AsyncTask): # 延迟300秒启动,运行间隔300秒 super().__init__(task_name="Statistics Data Output Task", wait_before_start=0, run_interval=300) - self.name_mapping: Dict[str, Tuple[str, float]] = {} + self.name_mapping: dict[str, tuple[str, float]] = {} """ 联系人/群聊名称映射 {聊天ID: (联系人/群聊名称, 记录时间(timestamp))} 注:设计记录时间的目的是方便更新名称,使联系人/群聊名称保持最新 @@ -191,13 +244,13 @@ class StatisticOutputTask(AsyncTask): now = datetime.now() if "deploy_time" in local_storage: # 如果存在部署时间,则使用该时间作为全量统计的起始时间 - deploy_time = datetime.fromtimestamp(local_storage["deploy_time"]) # type: ignore + deploy_time = datetime.fromtimestamp(self._to_float_timestamp(local_storage["deploy_time"])) else: # 否则,使用最大时间范围,并记录部署时间为当前时间 deploy_time = datetime(2000, 1, 1) local_storage["deploy_time"] = now.timestamp() - self.stat_period: List[Tuple[str, timedelta, str]] = [ + self.stat_period: list[tuple[str, timedelta, str]] = [ ("all_time", now - deploy_time, "自部署以来"), # 必须保留"all_time" ("last_30_days", timedelta(days=30), "近30天"), ("last_7_days", timedelta(days=7), "近7天"), @@ -211,7 +264,7 @@ class StatisticOutputTask(AsyncTask): 统计时间段 [(统计名称, 统计时间段, 统计描述), ...] """ - def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): + def _statistic_console_output(self, stats: StatPeriodMapping, now: datetime) -> None: """ 输出统计数据到控制台 :param stats: 统计数据 @@ -281,18 +334,13 @@ class StatisticOutputTask(AsyncTask): with concurrent.futures.ThreadPoolExecutor() as executor: logger.info("正在后台收集统计数据...") - # 创建后台任务,不等待完成 - collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore - ) - - stats = await collect_task + stats = await loop.run_in_executor(executor, self._collect_all_statistics, now) logger.info("统计数据收集完成") # 创建并发的输出任务 output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore + loop.run_in_executor(executor, self._statistic_console_output, stats, now), + loop.run_in_executor(executor, self._generate_html_report, stats, now), ] # 等待所有输出任务完成 @@ -308,7 +356,113 @@ class StatisticOutputTask(AsyncTask): # -- 以下为统计数据收集方法 -- @staticmethod - def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + def _build_stat_period_data() -> StatPeriodData: + time_costs_by_type: defaultdict[str, list[float]] = defaultdict(list) + time_costs_by_user: defaultdict[str, list[float]] = defaultdict(list) + time_costs_by_model: defaultdict[str, list[float]] = defaultdict(list) + time_costs_by_module: defaultdict[str, list[float]] = defaultdict(list) + avg_time_costs_by_type: defaultdict[str, float] = defaultdict(float) + avg_time_costs_by_user: defaultdict[str, float] = defaultdict(float) + avg_time_costs_by_model: defaultdict[str, float] = defaultdict(float) + avg_time_costs_by_module: defaultdict[str, float] = defaultdict(float) + std_time_costs_by_type: defaultdict[str, float] = defaultdict(float) + std_time_costs_by_user: defaultdict[str, float] = defaultdict(float) + std_time_costs_by_model: defaultdict[str, float] = defaultdict(float) + std_time_costs_by_module: defaultdict[str, float] = defaultdict(float) + + return { + TOTAL_REQ_CNT: 0, + REQ_CNT_BY_TYPE: defaultdict(int), + REQ_CNT_BY_USER: defaultdict(int), + REQ_CNT_BY_MODEL: defaultdict(int), + REQ_CNT_BY_MODULE: defaultdict(int), + IN_TOK_BY_TYPE: defaultdict(int), + IN_TOK_BY_USER: defaultdict(int), + IN_TOK_BY_MODEL: defaultdict(int), + IN_TOK_BY_MODULE: defaultdict(int), + OUT_TOK_BY_TYPE: defaultdict(int), + OUT_TOK_BY_USER: defaultdict(int), + OUT_TOK_BY_MODEL: defaultdict(int), + OUT_TOK_BY_MODULE: defaultdict(int), + TOTAL_TOK_BY_TYPE: defaultdict(int), + TOTAL_TOK_BY_USER: defaultdict(int), + TOTAL_TOK_BY_MODEL: defaultdict(int), + TOTAL_TOK_BY_MODULE: defaultdict(int), + TOTAL_COST: 0.0, + COST_BY_TYPE: defaultdict(float), + COST_BY_USER: defaultdict(float), + COST_BY_MODEL: defaultdict(float), + COST_BY_MODULE: defaultdict(float), + TIME_COST_BY_TYPE: time_costs_by_type, + TIME_COST_BY_USER: time_costs_by_user, + TIME_COST_BY_MODEL: time_costs_by_model, + TIME_COST_BY_MODULE: time_costs_by_module, + AVG_TIME_COST_BY_TYPE: avg_time_costs_by_type, + AVG_TIME_COST_BY_USER: avg_time_costs_by_user, + AVG_TIME_COST_BY_MODEL: avg_time_costs_by_model, + AVG_TIME_COST_BY_MODULE: avg_time_costs_by_module, + STD_TIME_COST_BY_TYPE: std_time_costs_by_type, + STD_TIME_COST_BY_USER: std_time_costs_by_user, + STD_TIME_COST_BY_MODEL: std_time_costs_by_model, + STD_TIME_COST_BY_MODULE: std_time_costs_by_module, + ONLINE_TIME: 0.0, + TOTAL_MSG_CNT: 0, + MSG_CNT_BY_CHAT: defaultdict(int), + TOTAL_REPLY_CNT: 0, + } + + @staticmethod + def _add_int_stat(stats_period: StatPeriodData, key: str, amount: int) -> None: + stats_period[key] = cast(int, stats_period.get(key, 0)) + amount + + @staticmethod + def _add_float_stat(stats_period: StatPeriodData, key: str, amount: float) -> None: + stats_period[key] = cast(float, stats_period.get(key, 0.0)) + amount + + @staticmethod + def _add_defaultdict_int(stats_period: StatPeriodData, key: str, subkey: str, amount: int) -> None: + counter = cast(defaultdict[str, int], stats_period[key]) + counter[subkey] += amount + + @staticmethod + def _add_defaultdict_float(stats_period: StatPeriodData, key: str, subkey: str, amount: float) -> None: + counter = cast(defaultdict[str, float], stats_period[key]) + counter[subkey] += amount + + @staticmethod + def _append_defaultdict_list(stats_period: StatPeriodData, key: str, subkey: str, value: float) -> None: + counter = cast(defaultdict[str, list[float]], stats_period[key]) + counter[subkey].append(value) + + @staticmethod + def _fetch_online_time_since(query_start_time: datetime) -> list[tuple[datetime, datetime]]: + with get_db_session() as session: + statement = select(OnlineTime).where(col(OnlineTime.end_timestamp) >= query_start_time) + records = session.exec(statement).all() + return [(record.start_timestamp, record.end_timestamp) for record in records] + + @staticmethod + def _fetch_model_usage_since(query_start_time: datetime) -> list[dict[str, object]]: + with get_db_session() as session: + statement = select(ModelUsage).where(col(ModelUsage.timestamp) >= query_start_time) + records = session.exec(statement).all() + return [ + { + "timestamp": record.timestamp, + "request_type": record.request_type, + "model_api_provider_name": record.model_api_provider_name, + "model_assign_name": record.model_assign_name, + "model_name": record.model_name, + "prompt_tokens": record.prompt_tokens, + "completion_tokens": record.completion_tokens, + "cost": record.cost, + "time_cost": record.time_cost, + } + for record in records + ] + + @staticmethod + def _collect_model_request_for_period(collect_period: list[tuple[str, datetime]]) -> StatPeriodMapping: """ 收集指定时间段的LLM请求统计数据 @@ -320,101 +474,99 @@ class StatisticOutputTask(AsyncTask): # 排序-按照时间段开始时间降序排列(最晚的时间段在前) collect_period.sort(key=lambda x: x[1], reverse=True) - stats = { - period_key: { - TOTAL_REQ_CNT: 0, - REQ_CNT_BY_TYPE: defaultdict(int), - REQ_CNT_BY_USER: defaultdict(int), - REQ_CNT_BY_MODEL: defaultdict(int), - REQ_CNT_BY_MODULE: defaultdict(int), - IN_TOK_BY_TYPE: defaultdict(int), - IN_TOK_BY_USER: defaultdict(int), - IN_TOK_BY_MODEL: defaultdict(int), - IN_TOK_BY_MODULE: defaultdict(int), - OUT_TOK_BY_TYPE: defaultdict(int), - OUT_TOK_BY_USER: defaultdict(int), - OUT_TOK_BY_MODEL: defaultdict(int), - OUT_TOK_BY_MODULE: defaultdict(int), - TOTAL_TOK_BY_TYPE: defaultdict(int), - TOTAL_TOK_BY_USER: defaultdict(int), - TOTAL_TOK_BY_MODEL: defaultdict(int), - TOTAL_TOK_BY_MODULE: defaultdict(int), - TOTAL_COST: 0.0, - COST_BY_TYPE: defaultdict(float), - COST_BY_USER: defaultdict(float), - COST_BY_MODEL: defaultdict(float), - COST_BY_MODULE: defaultdict(float), - TIME_COST_BY_TYPE: defaultdict(list), - TIME_COST_BY_USER: defaultdict(list), - TIME_COST_BY_MODEL: defaultdict(list), - TIME_COST_BY_MODULE: defaultdict(list), - AVG_TIME_COST_BY_TYPE: defaultdict(float), - AVG_TIME_COST_BY_USER: defaultdict(float), - AVG_TIME_COST_BY_MODEL: defaultdict(float), - AVG_TIME_COST_BY_MODULE: defaultdict(float), - STD_TIME_COST_BY_TYPE: defaultdict(float), - STD_TIME_COST_BY_USER: defaultdict(float), - STD_TIME_COST_BY_MODEL: defaultdict(float), - STD_TIME_COST_BY_MODULE: defaultdict(float), - } - for period_key, _ in collect_period + stats: StatPeriodMapping = { + period_key: StatisticOutputTask._build_stat_period_data() for period_key, _ in collect_period } # 以最早的时间戳为起始时间获取记录 # Assuming LLMUsage.timestamp is a DateTimeField query_start_time = collect_period[-1][1] - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore - record_timestamp = record.timestamp # This is already a datetime object + records = StatisticOutputTask._fetch_model_usage_since(query_start_time) + for record in records: + record_timestamp = cast(datetime, record["timestamp"]) for idx, (_, period_start) in enumerate(collect_period): if record_timestamp >= period_start: for period_key, _ in collect_period[idx:]: - stats[period_key][TOTAL_REQ_CNT] += 1 + StatisticOutputTask._add_int_stat(stats[period_key], TOTAL_REQ_CNT, 1) - request_type = record.request_type or "unknown" - user_id = record.user_id or "unknown" # user_id is TextField, already string - model_name = record.model_assign_name or record.model_name or "unknown" + request_type = cast(str | None, record["request_type"]) or "unknown" + user_id = cast(str | None, record["model_api_provider_name"]) or "unknown" + model_assign_name = cast(str | None, record["model_assign_name"]) + model_name = model_assign_name or cast(str | None, record["model_name"]) or "unknown" # 提取模块名:如果请求类型包含".",取第一个"."之前的部分 module_name = request_type.split(".")[0] if "." in request_type else request_type - stats[period_key][REQ_CNT_BY_TYPE][request_type] += 1 - stats[period_key][REQ_CNT_BY_USER][user_id] += 1 - stats[period_key][REQ_CNT_BY_MODEL][model_name] += 1 - stats[period_key][REQ_CNT_BY_MODULE][module_name] += 1 + StatisticOutputTask._add_defaultdict_int(stats[period_key], REQ_CNT_BY_TYPE, request_type, 1) + StatisticOutputTask._add_defaultdict_int(stats[period_key], REQ_CNT_BY_USER, user_id, 1) + StatisticOutputTask._add_defaultdict_int(stats[period_key], REQ_CNT_BY_MODEL, model_name, 1) + StatisticOutputTask._add_defaultdict_int(stats[period_key], REQ_CNT_BY_MODULE, module_name, 1) - prompt_tokens = record.prompt_tokens or 0 - completion_tokens = record.completion_tokens or 0 + prompt_tokens = cast(int | None, record["prompt_tokens"]) or 0 + completion_tokens = cast(int | None, record["completion_tokens"]) or 0 total_tokens = prompt_tokens + completion_tokens - stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens - stats[period_key][IN_TOK_BY_USER][user_id] += prompt_tokens - stats[period_key][IN_TOK_BY_MODEL][model_name] += prompt_tokens - stats[period_key][IN_TOK_BY_MODULE][module_name] += prompt_tokens + StatisticOutputTask._add_defaultdict_int( + stats[period_key], IN_TOK_BY_TYPE, request_type, prompt_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], IN_TOK_BY_USER, user_id, prompt_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], IN_TOK_BY_MODEL, model_name, prompt_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], IN_TOK_BY_MODULE, module_name, prompt_tokens + ) - stats[period_key][OUT_TOK_BY_TYPE][request_type] += completion_tokens - stats[period_key][OUT_TOK_BY_USER][user_id] += completion_tokens - stats[period_key][OUT_TOK_BY_MODEL][model_name] += completion_tokens - stats[period_key][OUT_TOK_BY_MODULE][module_name] += completion_tokens + StatisticOutputTask._add_defaultdict_int( + stats[period_key], OUT_TOK_BY_TYPE, request_type, completion_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], OUT_TOK_BY_USER, user_id, completion_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], OUT_TOK_BY_MODEL, model_name, completion_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], OUT_TOK_BY_MODULE, module_name, completion_tokens + ) - stats[period_key][TOTAL_TOK_BY_TYPE][request_type] += total_tokens - stats[period_key][TOTAL_TOK_BY_USER][user_id] += total_tokens - stats[period_key][TOTAL_TOK_BY_MODEL][model_name] += total_tokens - stats[period_key][TOTAL_TOK_BY_MODULE][module_name] += total_tokens + StatisticOutputTask._add_defaultdict_int( + stats[period_key], TOTAL_TOK_BY_TYPE, request_type, total_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], TOTAL_TOK_BY_USER, user_id, total_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], TOTAL_TOK_BY_MODEL, model_name, total_tokens + ) + StatisticOutputTask._add_defaultdict_int( + stats[period_key], TOTAL_TOK_BY_MODULE, module_name, total_tokens + ) - cost = record.cost or 0.0 - stats[period_key][TOTAL_COST] += cost - stats[period_key][COST_BY_TYPE][request_type] += cost - stats[period_key][COST_BY_USER][user_id] += cost - stats[period_key][COST_BY_MODEL][model_name] += cost - stats[period_key][COST_BY_MODULE][module_name] += cost + cost = cast(float | None, record["cost"]) or 0.0 + StatisticOutputTask._add_float_stat(stats[period_key], TOTAL_COST, cost) + StatisticOutputTask._add_defaultdict_float(stats[period_key], COST_BY_TYPE, request_type, cost) + StatisticOutputTask._add_defaultdict_float(stats[period_key], COST_BY_USER, user_id, cost) + StatisticOutputTask._add_defaultdict_float(stats[period_key], COST_BY_MODEL, model_name, cost) + StatisticOutputTask._add_defaultdict_float(stats[period_key], COST_BY_MODULE, module_name, cost) # 收集time_cost数据 - time_cost = record.time_cost or 0.0 + time_cost = cast(float | None, record["time_cost"]) or 0.0 if time_cost > 0: # 只记录有效的time_cost - stats[period_key][TIME_COST_BY_TYPE][request_type].append(time_cost) - stats[period_key][TIME_COST_BY_USER][user_id].append(time_cost) - stats[period_key][TIME_COST_BY_MODEL][model_name].append(time_cost) - stats[period_key][TIME_COST_BY_MODULE][module_name].append(time_cost) + StatisticOutputTask._append_defaultdict_list( + stats[period_key], TIME_COST_BY_TYPE, request_type, time_cost + ) + StatisticOutputTask._append_defaultdict_list( + stats[period_key], TIME_COST_BY_USER, user_id, time_cost + ) + StatisticOutputTask._append_defaultdict_list( + stats[period_key], TIME_COST_BY_MODEL, model_name, time_cost + ) + StatisticOutputTask._append_defaultdict_list( + stats[period_key], TIME_COST_BY_MODULE, module_name, time_cost + ) break # 计算平均耗时和标准差 @@ -424,28 +576,39 @@ class StatisticOutputTask(AsyncTask): avg_key = f"avg_time_costs_by_{category.split('_')[-1]}" std_key = f"std_time_costs_by_{category.split('_')[-1]}" - for item_name in stats[period_key][category]: - time_costs = stats[period_key][time_cost_key].get(item_name, []) + category_data = cast(dict[str, int], stats[period_key].get(category, {})) + time_cost_data = cast(dict[str, list[float]], stats[period_key].get(time_cost_key, {})) + avg_cost_data = cast(dict[str, float], stats[period_key].get(avg_key, {})) + std_cost_data = cast(dict[str, float], stats[period_key].get(std_key, {})) + + for item_name in category_data: + time_costs = time_cost_data.get(item_name, []) if time_costs: # 计算平均耗时 avg_time_cost = sum(time_costs) / len(time_costs) - stats[period_key][avg_key][item_name] = round(avg_time_cost, 3) + avg_cost_data[item_name] = round(avg_time_cost, 3) # 计算标准差 if len(time_costs) > 1: variance = sum((x - avg_time_cost) ** 2 for x in time_costs) / len(time_costs) std_time_cost = variance**0.5 - stats[period_key][std_key][item_name] = round(std_time_cost, 3) + std_cost_data[item_name] = round(std_time_cost, 3) else: - stats[period_key][std_key][item_name] = 0.0 + std_cost_data[item_name] = 0.0 else: - stats[period_key][avg_key][item_name] = 0.0 - stats[period_key][std_key][item_name] = 0.0 + avg_cost_data[item_name] = 0.0 + std_cost_data[item_name] = 0.0 + + stats[period_key][avg_key] = avg_cost_data + stats[period_key][std_key] = std_cost_data return stats @staticmethod - def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: + def _collect_online_time_for_period( + collect_period: list[tuple[str, datetime]], + now: datetime, + ) -> dict[str, dict[str, float]]: """ 收集指定时间段的在线时间统计数据 @@ -465,11 +628,8 @@ class StatisticOutputTask(AsyncTask): query_start_time = collect_period[-1][1] # Assuming OnlineTime.end_timestamp is a DateTimeField - for record in OnlineTime.select().where(OnlineTime.end_timestamp >= query_start_time): # type: ignore - # record.end_timestamp and record.start_timestamp are datetime objects - record_end_timestamp = record.end_timestamp - record_start_timestamp = record.start_timestamp - + records = StatisticOutputTask._fetch_online_time_since(query_start_time) + for record_start_timestamp, record_end_timestamp in records: for idx, (_, period_boundary_start) in enumerate(collect_period): if record_end_timestamp >= period_boundary_start: # Calculate effective end time for this record in relation to 'now' @@ -485,7 +645,10 @@ class StatisticOutputTask(AsyncTask): break return stats - def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: + def _collect_message_count_for_period( + self, + collect_period: list[tuple[str, datetime]], + ) -> StatPeriodMapping: """ 收集指定时间段的消息统计数据 @@ -496,30 +659,28 @@ class StatisticOutputTask(AsyncTask): collect_period.sort(key=lambda x: x[1], reverse=True) - stats = { - period_key: { - TOTAL_MSG_CNT: 0, - MSG_CNT_BY_CHAT: defaultdict(int), - TOTAL_REPLY_CNT: 0, - } - for period_key, _ in collect_period + stats: StatPeriodMapping = { + period_key: StatisticOutputTask._build_stat_period_data() for period_key, _ in collect_period } - query_start_timestamp = collect_period[-1][1].timestamp() # Messages.time is a DoubleField (timestamp) - for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore - message_time_ts = message.time # This is a float timestamp + query_start_timestamp = collect_period[-1][1] + with get_db_session() as session: + statement = select(Messages).where(col(Messages.timestamp) >= query_start_timestamp) + messages = session.exec(statement).all() + for message in messages: + message_time_ts = message.timestamp.timestamp() chat_id = None chat_name = None # Logic based on Peewee model structure, aiming to replicate original intent - if message.chat_info_group_id: - chat_id = f"g{message.chat_info_group_id}" - chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" - elif message.user_id: # Fallback to sender's info for chat_id if not a group_info based chat + if message.group_id: + chat_id = f"g{message.group_id}" + chat_name = message.group_name or f"群{message.group_id}" + elif message.user_id: # This uses the message SENDER's ID as per original logic's fallback - chat_id = f"u{message.user_id}" # SENDER's user_id - chat_name = message.user_nickname # SENDER's nickname + chat_id = f"u{message.user_id}" + chat_name = message.user_nickname else: # If neither group_id nor sender_id is available for chat identification logger.warning( @@ -545,45 +706,47 @@ class StatisticOutputTask(AsyncTask): for idx, (_, period_start_dt) in enumerate(collect_period): if message_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: - stats[period_key][TOTAL_MSG_CNT] += 1 - stats[period_key][MSG_CNT_BY_CHAT][chat_id] += 1 + StatisticOutputTask._add_int_stat(stats[period_key], TOTAL_MSG_CNT, 1) + StatisticOutputTask._add_defaultdict_int(stats[period_key], MSG_CNT_BY_CHAT, chat_id, 1) break # 使用 ActionRecords 中的 reply 动作次数作为回复数基准 try: - action_query_start_timestamp = collect_period[-1][1].timestamp() - for action in ActionRecords.select().where(ActionRecords.time >= action_query_start_timestamp): # type: ignore - # 仅统计已完成的 reply 动作 - if action.action_name != "reply" or not action.action_done: + action_query_start_timestamp = collect_period[-1][1] + with get_db_session() as session: + statement = select(ActionRecord).where(col(ActionRecord.timestamp) >= action_query_start_timestamp) + actions = session.exec(statement).all() + for action in actions: + if action.action_name != "reply": continue - action_time_ts = action.time + action_time_ts = action.timestamp.timestamp() for idx, (_, period_start_dt) in enumerate(collect_period): if action_time_ts >= period_start_dt.timestamp(): for period_key, _ in collect_period[idx:]: - stats[period_key][TOTAL_REPLY_CNT] += 1 + StatisticOutputTask._add_int_stat(stats[period_key], TOTAL_REPLY_CNT, 1) break except Exception as e: logger.warning(f"统计 reply 动作次数失败,将回复数视为 0,错误信息:{e}") return stats - def _collect_all_statistics(self, now: datetime) -> Dict[str, Dict[str, Any]]: + def _collect_all_statistics(self, now: datetime) -> StatPeriodMapping: """ 收集各时间段的统计数据 :param now: 基准当前时间 """ - last_all_time_stat = None + last_all_time_stat: dict[str, object] | None = None try: if "last_full_statistics" in local_storage: # 如果存在上次完整统计数据,则使用该数据进行增量统计 - last_stat: Dict[str, Any] = local_storage["last_full_statistics"] # 上次完整统计数据 # type: ignore + last_stat = cast(dict[str, object], local_storage["last_full_statistics"]) # 修复 name_mapping 数据类型不匹配问题 # JSON 中存储为列表,但代码期望为元组 - raw_name_mapping = last_stat["name_mapping"] + raw_name_mapping = cast(dict[str, object], last_stat["name_mapping"]) self.name_mapping = {} for chat_id, value in raw_name_mapping.items(): if isinstance(value, list) and len(value) == 2: @@ -596,8 +759,8 @@ class StatisticOutputTask(AsyncTask): # 数据格式不正确,跳过或使用默认值 logger.warning(f"name_mapping 中 chat_id {chat_id} 的数据格式不正确: {value}") continue - last_all_time_stat = last_stat["stat_data"] # 上次完整统计的统计数据 - last_stat_timestamp = datetime.fromtimestamp(last_stat["timestamp"]) # 上次完整统计数据的时间戳 + last_all_time_stat = cast(dict[str, object], last_stat["stat_data"]) # 上次完整统计的统计数据 + last_stat_timestamp = datetime.fromtimestamp(self._to_float_timestamp(last_stat["timestamp"])) self.stat_period = [ item for item in self.stat_period if item[0] != "all_time" ] # 删除"所有时间"的统计时段 @@ -664,9 +827,9 @@ class StatisticOutputTask(AsyncTask): "timestamp": now.timestamp(), } - return stat + return cast(StatPeriodMapping, stat) - def _convert_defaultdict_to_dict(self, data): + def _convert_defaultdict_to_dict(self, data: object) -> object: # sourcery skip: dict-comprehension, extract-duplicate-method, inline-immediately-returned-variable, merge-duplicate-blocks """递归转换defaultdict为普通dict""" if isinstance(data, defaultdict): @@ -685,10 +848,21 @@ class StatisticOutputTask(AsyncTask): # 其他类型直接返回 return data + @staticmethod + def _to_float_timestamp(value: object) -> float: + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str): + try: + return float(value) + except ValueError: + return 0.0 + return 0.0 + # -- 以下为统计数据格式化方法 -- @staticmethod - def _format_total_stat(stats: Dict[str, Any]) -> str: + def _format_total_stat(stats: StatPeriodData) -> str: """ 格式化总统计数据 """ @@ -718,7 +892,7 @@ class StatisticOutputTask(AsyncTask): ) output = [ - f"总在线时间: {_format_online_time(stats[ONLINE_TIME])}", + f"总在线时间: {_format_online_time(int(stats[ONLINE_TIME]))}", f"总消息数: {_format_large_number(stats[TOTAL_MSG_CNT])}", f"总回复数: {_format_large_number(total_replies)}", f"总请求数: {_format_large_number(stats[TOTAL_REQ_CNT])}", @@ -737,7 +911,7 @@ class StatisticOutputTask(AsyncTask): return "\n".join(output) @staticmethod - def _format_model_classified_stat(stats: Dict[str, Any]) -> str: + def _format_model_classified_stat(stats: StatPeriodData) -> str: """ 格式化按模型分类的统计数据 """ @@ -796,7 +970,7 @@ class StatisticOutputTask(AsyncTask): return "\n".join(output) @staticmethod - def _format_module_classified_stat(stats: Dict[str, Any]) -> str: + def _format_module_classified_stat(stats: StatPeriodData) -> str: """ 格式化按模块分类的统计数据 """ @@ -854,7 +1028,7 @@ class StatisticOutputTask(AsyncTask): output.append("") return "\n".join(output) - def _format_chat_stat(self, stats: Dict[str, Any]) -> str: + def _format_chat_stat(self, stats: StatPeriodData) -> str: """ 格式化聊天统计数据 """ @@ -905,7 +1079,7 @@ class StatisticOutputTask(AsyncTask): # 移除_generate_versions_tab方法 - def _generate_html_report(self, stat: dict[str, Any], now: datetime): + def _generate_html_report(self, stat: StatPeriodMapping, now: datetime): """ 生成HTML格式的统计报告 :param stat: 统计数据 @@ -921,7 +1095,7 @@ class StatisticOutputTask(AsyncTask): tab_list.append('') tab_list.append('') - def _format_stat_data(stat_data: dict[str, Any], div_id: str, start_time: datetime) -> str: + def _format_stat_data(stat_data: StatPeriodData, div_id: str, start_time: datetime) -> str: """ 格式化一个时间段的统计数据到html div块 :param stat_data: 统计数据 @@ -1020,7 +1194,7 @@ class StatisticOutputTask(AsyncTask):
总在线时间
-
{_format_online_time(stat_data[ONLINE_TIME])}
+
{_format_online_time(int(stat_data[ONLINE_TIME]))}
总消息数
@@ -1307,7 +1481,11 @@ class StatisticOutputTask(AsyncTask): ] tab_content_list.append( - _format_stat_data(stat["all_time"], "all_time", datetime.fromtimestamp(local_storage["deploy_time"])) # type: ignore + _format_stat_data( + stat["all_time"], + "all_time", + datetime.fromtimestamp(self._to_float_timestamp(local_storage["deploy_time"])), + ) ) # 不再添加版本对比内容 @@ -1507,10 +1685,10 @@ class StatisticOutputTask(AsyncTask): with open(self.record_file_path, "w", encoding="utf-8") as f: f.write(html_template) - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: + def _generate_chart_data(self, stat: StatPeriodMapping) -> dict[str, dict[str, object]]: """生成图表数据""" now = datetime.now() - chart_data = {} + chart_data: dict[str, dict[str, object]] = {} # 支持多个时间范围 time_ranges = [ @@ -1526,7 +1704,7 @@ class StatisticOutputTask(AsyncTask): return chart_data - def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: + def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict[str, object]: """收集指定时间范围内每个间隔的数据""" # 生成时间点 start_time = now - timedelta(hours=hours) @@ -1538,18 +1716,19 @@ class StatisticOutputTask(AsyncTask): current_time += timedelta(minutes=interval_minutes) # 初始化数据结构 - total_cost_data = [0] * len(time_points) - cost_by_model = {} - cost_by_module = {} - message_by_chat = {} + total_cost_data: list[float] = [0.0] * len(time_points) + cost_by_model: dict[str, list[float]] = {} + cost_by_module: dict[str, list[float]] = {} + message_by_chat: dict[str, list[int]] = {} time_labels = [t.strftime("%H:%M") for t in time_points] interval_seconds = interval_minutes * 60 # 查询LLM使用记录 query_start_time = start_time - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore - record_time = record.timestamp + records = StatisticOutputTask._fetch_model_usage_since(query_start_time) + for record in records: + record_time = cast(datetime, record["timestamp"]) # 找到对应的时间间隔索引 time_diff = (record_time - start_time).total_seconds() @@ -1557,26 +1736,30 @@ class StatisticOutputTask(AsyncTask): if 0 <= interval_index < len(time_points): # 累加总花费数据 - cost = record.cost or 0.0 - total_cost_data[interval_index] += cost # type: ignore + cost = cast(float | None, record["cost"]) or 0.0 + total_cost_data[interval_index] += cost # 累加按模型分类的花费 - model_name = record.model_assign_name or record.model_name or "unknown" + model_assign_name = cast(str | None, record["model_assign_name"]) + model_name = model_assign_name or cast(str | None, record["model_name"]) or "unknown" if model_name not in cost_by_model: - cost_by_model[model_name] = [0] * len(time_points) + cost_by_model[model_name] = [0.0] * len(time_points) cost_by_model[model_name][interval_index] += cost # 累加按模块分类的花费 - request_type = record.request_type or "unknown" + request_type = cast(str | None, record["request_type"]) or "unknown" module_name = request_type.split(".")[0] if "." in request_type else request_type if module_name not in cost_by_module: - cost_by_module[module_name] = [0] * len(time_points) + cost_by_module[module_name] = [0.0] * len(time_points) cost_by_module[module_name][interval_index] += cost # 查询消息记录 query_start_timestamp = start_time.timestamp() - for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore - message_time_ts = message.time + with get_db_session() as session: + statement = select(Messages).where(col(Messages.timestamp) >= start_time) + messages = session.exec(statement).all() + for message in messages: + message_time_ts = message.timestamp.timestamp() # 找到对应的时间间隔索引 time_diff = message_time_ts - query_start_timestamp @@ -1585,8 +1768,8 @@ class StatisticOutputTask(AsyncTask): if 0 <= interval_index < len(time_points): # 确定聊天流名称 chat_name = None - if message.chat_info_group_id: - chat_name = message.chat_info_group_name or f"群{message.chat_info_group_id}" + if message.group_id: + chat_name = message.group_name or f"群{message.group_id}" elif message.user_id: chat_name = message.user_nickname or f"用户{message.user_id}" else: @@ -1608,7 +1791,7 @@ class StatisticOutputTask(AsyncTask): "message_by_chat": message_by_chat, } - def _generate_chart_tab(self, chart_data: dict) -> str: + def _generate_chart_tab(self, chart_data: dict[str, dict[str, object]]) -> str: # sourcery skip: extract-duplicate-method, move-assign-in-block """生成图表选项卡HTML内容""" @@ -1627,11 +1810,14 @@ class StatisticOutputTask(AsyncTask): ] # 默认使用24小时数据生成数据集 - default_data = chart_data["24h"] + default_data = cast(dict[str, object], chart_data["24h"]) + cost_by_model = cast(dict[str, list[float]], default_data.get("cost_by_model", {})) + cost_by_module = cast(dict[str, list[float]], default_data.get("cost_by_module", {})) + message_by_chat = cast(dict[str, list[int]], default_data.get("message_by_chat", {})) # 为每个模型生成数据集 model_datasets = [] - for i, (model_name, cost_data) in enumerate(default_data["cost_by_model"].items()): + for i, (model_name, cost_data) in enumerate(cost_by_model.items()): color = colors[i % len(colors)] model_datasets.append(f"""{{ label: '{model_name}', @@ -1646,7 +1832,7 @@ class StatisticOutputTask(AsyncTask): # 为每个模块生成数据集 module_datasets = [] - for i, (module_name, cost_data) in enumerate(default_data["cost_by_module"].items()): + for i, (module_name, cost_data) in enumerate(cost_by_module.items()): color = colors[i % len(colors)] module_datasets.append(f"""{{ label: '{module_name}', @@ -1661,7 +1847,7 @@ class StatisticOutputTask(AsyncTask): # 为每个聊天流生成消息数据集 message_datasets = [] - for i, (chat_name, message_data) in enumerate(default_data["message_by_chat"].items()): + for i, (chat_name, message_data) in enumerate(message_by_chat.items()): color = colors[i % len(colors)] message_datasets.append(f"""{{ label: '{chat_name}', @@ -1886,7 +2072,7 @@ class StatisticOutputTask(AsyncTask):
""" - def _generate_metrics_data(self, now: datetime) -> dict: + def _generate_metrics_data(self, now: datetime) -> dict[str, object]: """生成指标趋势数据""" metrics_data = {} @@ -1901,7 +2087,7 @@ class StatisticOutputTask(AsyncTask): return metrics_data - def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict: + def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict[str, object]: """收集指定时间范围内每个间隔的指标数据""" start_time = now - timedelta(hours=hours) time_points = [] @@ -1936,17 +2122,18 @@ class StatisticOutputTask(AsyncTask): # 查询LLM使用记录 query_start_time = start_time - for record in LLMUsage.select().where(LLMUsage.timestamp >= query_start_time): # type: ignore - record_time = record.timestamp + records = StatisticOutputTask._fetch_model_usage_since(query_start_time) + for record in records: + record_time = cast(datetime, record["timestamp"]) # 找到对应的时间间隔索引 time_diff = (record_time - start_time).total_seconds() interval_index = int(time_diff // interval_seconds) if 0 <= interval_index < len(time_points): - cost = record.cost or 0.0 - prompt_tokens = record.prompt_tokens or 0 - completion_tokens = record.completion_tokens or 0 + cost = cast(float | None, record["cost"]) or 0.0 + prompt_tokens = cast(int | None, record["prompt_tokens"]) or 0 + completion_tokens = cast(int | None, record["completion_tokens"]) or 0 total_token = prompt_tokens + completion_tokens total_costs[interval_index] += cost @@ -1954,8 +2141,11 @@ class StatisticOutputTask(AsyncTask): # 查询消息记录 query_start_timestamp = start_time.timestamp() - for message in Messages.select().where(Messages.time >= query_start_timestamp): # type: ignore - message_time_ts = message.time + with get_db_session() as session: + statement = select(Messages).where(col(Messages.timestamp) >= start_time) + messages = session.exec(statement).all() + for message in messages: + message_time_ts = message.timestamp.timestamp() time_diff = message_time_ts - query_start_timestamp interval_index = int(time_diff // interval_seconds) @@ -1967,10 +2157,8 @@ class StatisticOutputTask(AsyncTask): total_replies[interval_index] += 1 # 查询在线时间记录 - for record in OnlineTime.select().where(OnlineTime.end_timestamp >= start_time): # type: ignore - record_start = record.start_timestamp - record_end = record.end_timestamp - + records = StatisticOutputTask._fetch_online_time_since(start_time) + for record_start, record_end in records: # 找到记录覆盖的所有时间间隔 for idx, time_point in enumerate(time_points): interval_start = time_point @@ -2016,7 +2204,7 @@ class StatisticOutputTask(AsyncTask): "cost_per_100_replies": cost_per_100_replies, } - def _generate_metrics_tab(self, metrics_data: dict) -> str: + def _generate_metrics_tab(self, metrics_data: dict[str, object]) -> str: """生成指标趋势图表选项卡HTML内容""" colors = { "cost_per_100_messages": "#8b5cf6", @@ -2213,6 +2401,7 @@ class AsyncStatisticOutputTask(AsyncTask): self.name_mapping = temp_stat_task.name_mapping self.record_file_path = temp_stat_task.record_file_path self.stat_period = temp_stat_task.stat_period + self._statistic_task = temp_stat_task async def run(self): """完全异步执行统计任务""" @@ -2226,17 +2415,13 @@ class AsyncStatisticOutputTask(AsyncTask): logger.info("正在后台收集统计数据...") # 数据收集任务 - collect_task = asyncio.create_task( - loop.run_in_executor(executor, self._collect_all_statistics, now) # type: ignore - ) - - stats = await collect_task + stats = await loop.run_in_executor(executor, self._statistic_task._collect_all_statistics, now) logger.info("统计数据收集完成") # 创建并发的输出任务 output_tasks = [ - asyncio.create_task(loop.run_in_executor(executor, self._statistic_console_output, stats, now)), # type: ignore - asyncio.create_task(loop.run_in_executor(executor, self._generate_html_report, stats, now)), # type: ignore + loop.run_in_executor(executor, self._statistic_task._statistic_console_output, stats, now), + loop.run_in_executor(executor, self._statistic_task._generate_html_report, stats, now), ] # 等待所有输出任务完成 @@ -2248,60 +2433,3 @@ class AsyncStatisticOutputTask(AsyncTask): # 创建后台任务,立即返回 asyncio.create_task(_async_collect_and_output()) - - # 复用 StatisticOutputTask 的所有方法 - def _collect_all_statistics(self, now: datetime): - return StatisticOutputTask._collect_all_statistics(self, now) # type: ignore - - def _statistic_console_output(self, stats: Dict[str, Any], now: datetime): - return StatisticOutputTask._statistic_console_output(self, stats, now) # type: ignore - - def _generate_html_report(self, stats: dict[str, Any], now: datetime): - return StatisticOutputTask._generate_html_report(self, stats, now) # type: ignore - - # 其他需要的方法也可以类似复用... - @staticmethod - def _collect_model_request_for_period(collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_model_request_for_period(collect_period) - - @staticmethod - def _collect_online_time_for_period(collect_period: List[Tuple[str, datetime]], now: datetime) -> Dict[str, Any]: - return StatisticOutputTask._collect_online_time_for_period(collect_period, now) - - def _collect_message_count_for_period(self, collect_period: List[Tuple[str, datetime]]) -> Dict[str, Any]: - return StatisticOutputTask._collect_message_count_for_period(self, collect_period) # type: ignore - - @staticmethod - def _format_total_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_total_stat(stats) - - @staticmethod - def _format_model_classified_stat(stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_model_classified_stat(stats) - - def _format_chat_stat(self, stats: Dict[str, Any]) -> str: - return StatisticOutputTask._format_chat_stat(self, stats) # type: ignore - - def _generate_chart_data(self, stat: dict[str, Any]) -> dict: - return StatisticOutputTask._generate_chart_data(self, stat) # type: ignore - - def _collect_interval_data(self, now: datetime, hours: int, interval_minutes: int) -> dict: - return StatisticOutputTask._collect_interval_data(self, now, hours, interval_minutes) # type: ignore - - def _generate_chart_tab(self, chart_data: dict) -> str: - return StatisticOutputTask._generate_chart_tab(self, chart_data) # type: ignore - - def _generate_metrics_data(self, now: datetime) -> dict: - return StatisticOutputTask._generate_metrics_data(self, now) # type: ignore - - def _collect_metrics_interval_data(self, now: datetime, hours: int, interval_hours: int) -> dict: - return StatisticOutputTask._collect_metrics_interval_data(self, now, hours, interval_hours) # type: ignore - - def _generate_metrics_tab(self, metrics_data: dict) -> str: - return StatisticOutputTask._generate_metrics_tab(self, metrics_data) # type: ignore - - def _get_chat_display_name_from_id(self, chat_id: str) -> str: - return StatisticOutputTask._get_chat_display_name_from_id(self, chat_id) # type: ignore - - def _convert_defaultdict_to_dict(self, data): - return StatisticOutputTask._convert_defaultdict_to_dict(self, data) # type: ignore diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 1145cc83..65eec4f0 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -1,18 +1,21 @@ import base64 +from datetime import datetime +from typing import Optional, Tuple + +import hashlib +import io import os import time -import hashlib import uuid -import io -import numpy as np -from typing import Optional, Tuple +import numpy as np from PIL import Image from rich.traceback import install +from sqlmodel import select, col from src.common.logger import get_logger -from src.common.database.database import db -from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest @@ -38,11 +41,7 @@ class ImageManager: self._initialized = True self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image") - try: - db.connect(reuse_if_open=True) - db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True) - except Exception as e: - logger.error(f"数据库连接或表创建失败: {e}") + get_db_session() try: self._cleanup_invalid_descriptions() @@ -72,10 +71,12 @@ class ImageManager: Optional[str]: 描述文本,如果不存在则返回None """ try: - record = ImageDescriptions.get_or_none( - (ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type) - ) - return record.description if record else None + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type)) + ) + record = session.exec(statement).first() + return record.description if record else None except Exception as e: logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}") return None @@ -90,15 +91,27 @@ class ImageManager: description_type: 描述类型 ('emoji' 或 'image') """ try: - current_timestamp = time.time() - defaults = {"description": description, "timestamp": current_timestamp} - desc_obj, created = ImageDescriptions.get_or_create( - image_description_hash=image_hash, type=description_type, defaults=defaults - ) - if not created: # 如果记录已存在,则更新 - desc_obj.description = description - desc_obj.timestamp = current_timestamp - desc_obj.save() + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type)) + ) + record = session.exec(statement).first() + if record: + record.description = description + session.add(record) + return + + new_record = Images( + image_hash=image_hash, + description=description, + full_path="", + image_type=ImageType(description_type), + query_count=0, + is_registered=False, + is_banned=False, + vlm_processed=True, + ) + session.add(new_record) except Exception as e: logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}") @@ -107,20 +120,18 @@ class ImageManager: """清理数据库中 description 为空或为 'None' 的记录""" invalid_values = ["", "None"] - # 清理 Images 表 - deleted_images = ( - Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute() - ) + with get_db_session() as session: + statement = ( + select(Images) + .where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values)) + .limit(1000) + ) + records = session.exec(statement).all() + for record in records: + session.delete(record) - # 清理 ImageDescriptions 表 - deleted_descriptions = ( - ImageDescriptions.delete() - .where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values)) - .execute() - ) - - if deleted_images or deleted_descriptions: - logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条") + if records: + logger.info(f"[清理完成] 删除 Images: {len(records)} 条") else: logger.info("[清理完成] 未发现无效描述记录") @@ -128,19 +139,15 @@ class ImageManager: def _cleanup_emoji_from_image_descriptions(): """清理Images和ImageDescriptions表中type为emoji的记录(已迁移到EmojiDescriptionCache)""" try: - # 清理Images表中type为emoji的记录 - deleted_images = Images.delete().where(Images.type == "emoji").execute() + with get_db_session() as session: + statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI) + records = session.exec(statement).all() + for record in records: + session.delete(record) - # 清理ImageDescriptions表中type为emoji的记录 - deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute() - - total_deleted = deleted_images + deleted_descriptions + total_deleted = len(records) if total_deleted > 0: - logger.info( - f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, " - f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, " - f"共删除 {total_deleted} 条记录" - ) + logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录") else: logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录") except Exception as e: @@ -148,14 +155,14 @@ class ImageManager: raise async def get_emoji_tag(self, image_base64: str) -> str: - from src.chat.emoji_system.emoji_manager import get_emoji_manager + from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance - emoji_manager = get_emoji_manager() + emoji_manager = emoji_manager_instance if isinstance(image_base64, str): image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - emoji = await emoji_manager.get_emoji_from_manager(image_hash) + emoji = emoji_manager.get_emoji_by_hash(image_hash) if not emoji: return "[表情包:未知]" emotion_list = emoji.emotion @@ -175,14 +182,14 @@ class ImageManager: try: from src.chat.emoji_system.emoji_manager import EMOJI_DIR - from src.chat.emoji_system.emoji_manager import get_emoji_manager + from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance # 确保目录存在 os.makedirs(EMOJI_DIR, exist_ok=True) # 检查是否已存在该表情包(通过哈希值) - emoji_manager = get_emoji_manager() - existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash) + emoji_manager = emoji_manager_instance + existing_emoji = emoji_manager.get_emoji_by_hash(image_hash) if existing_emoji: logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...") return @@ -212,14 +219,15 @@ class ImageManager: image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii") image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore + image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore # 优先使用EmojiManager查询已注册表情包的描述 try: - from src.chat.emoji_system.emoji_manager import get_emoji_manager + from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance - emoji_manager = get_emoji_manager() - tags = await emoji_manager.get_emoji_tag_by_hash(image_hash) + emoji_manager = emoji_manager_instance + emoji = emoji_manager.get_emoji_by_hash(image_hash) + tags = emoji.emotion if emoji else None if tags: tag_str = ",".join(tags) logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...") @@ -227,29 +235,26 @@ class ImageManager: except Exception as e: logger.debug(f"查询EmojiManager时出错: {e}") - # 查询EmojiDescriptionCache表的缓存(包含描述和情感标签) try: - cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash) + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI) + ) + cache_record = session.exec(statement).first() if cache_record: - # 优先使用情感标签,如果没有则使用详细描述 result_text = "" - if cache_record.emotion_tags: - logger.info( - f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..." - ) - result_text = f"[表情包:{cache_record.emotion_tags}]" + if cache_record.emotion: + logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...") + result_text = f"[表情包:{cache_record.emotion}]" elif cache_record.description: - logger.info( - f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..." - ) + logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...") result_text = f"[表情包:{cache_record.description}]" - # 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件 if result_text: await self._save_emoji_file_if_needed(image_base64, image_hash, image_format) return result_text except Exception as e: - logger.debug(f"查询EmojiDescriptionCache时出错: {e}") + logger.debug(f"查询Images缓存时出错: {e}") # === 二步走识别流程 === @@ -309,33 +314,42 @@ class ImageManager: logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}") - # 再次检查缓存(防止并发情况下其他线程已经保存) try: - cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash) - if cache_record and cache_record.emotion_tags: - logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}") - return f"[表情包:{cache_record.emotion_tags}]" + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI) + ) + cache_record = session.exec(statement).first() + if cache_record and cache_record.emotion: + logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}") + return f"[表情包:{cache_record.emotion}]" except Exception as e: - logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}") + logger.debug(f"再次查询Images缓存时出错: {e}") - # 保存识别出的详细描述和情感标签到 emoji_description_cache try: - current_timestamp = time.time() - cache_record, created = EmojiDescriptionCache.get_or_create( - emoji_hash=image_hash, - defaults={ - "description": detailed_description, - "emotion_tags": final_emotion, - "timestamp": current_timestamp, - }, - ) - if not created: - # 更新已有记录 - cache_record.description = detailed_description - cache_record.emotion_tags = final_emotion - cache_record.timestamp = current_timestamp - cache_record.save() - logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...") + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI) + ) + cache_record = session.exec(statement).first() + if cache_record: + cache_record.description = detailed_description + cache_record.emotion = final_emotion + session.add(cache_record) + else: + cache_record = Images( + image_hash=image_hash, + description=detailed_description, + full_path="", + image_type=ImageType.EMOJI, + emotion=final_emotion, + query_count=0, + is_registered=False, + is_banned=False, + vlm_processed=True, + ) + session.add(cache_record) + logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...") except Exception as e: logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}") @@ -358,14 +372,13 @@ class ImageManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 优先检查Images表中是否已有完整的描述 - existing_image = Images.get_or_none(Images.emoji_hash == image_hash) + with get_db_session() as session: + statement = select(Images).where(col(Images.image_hash) == image_hash) + existing_image = session.exec(statement).first() if existing_image: - # 更新计数 - if hasattr(existing_image, "count") and existing_image.count is not None: - existing_image.count += 1 - else: - existing_image.count = 1 - existing_image.save() + existing_image.query_count += 1 + with get_db_session() as session: + session.add(existing_image) # 如果已有描述,直接返回 if existing_image.description: @@ -377,7 +390,7 @@ class ImageManager: return f"[图片:{cached_description}]" # 调用AI获取描述 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore + image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore prompt = global_config.personality.visual_style logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)") description, _ = await self.vlm.generate_response_for_image( @@ -402,26 +415,27 @@ class ImageManager: # 保存到数据库,补充缺失字段 if existing_image: - existing_image.path = file_path + existing_image.full_path = file_path existing_image.description = description - existing_image.timestamp = current_timestamp - if not hasattr(existing_image, "image_id") or not existing_image.image_id: - existing_image.image_id = str(uuid.uuid4()) - if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None: - existing_image.vlm_processed = True - existing_image.save() + existing_image.record_time = datetime.fromtimestamp(current_timestamp) + existing_image.vlm_processed = True + with get_db_session() as session: + session.add(existing_image) logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...") else: - Images.create( - image_id=str(uuid.uuid4()), - emoji_hash=image_hash, - path=file_path, - type="image", - description=description, - timestamp=current_timestamp, - vlm_processed=True, - count=1, - ) + with get_db_session() as session: + new_record = Images( + image_hash=image_hash, + description=description, + full_path=file_path, + image_type=ImageType.IMAGE, + query_count=1, + is_registered=False, + is_banned=False, + record_time=datetime.fromtimestamp(current_timestamp), + vlm_processed=True, + ) + session.add(new_record) logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...") except Exception as e: logger.error(f"保存图片文件或元数据失败: {str(e)}") @@ -575,30 +589,17 @@ class ImageManager: image_bytes = base64.b64decode(image_base64) image_hash = hashlib.md5(image_bytes).hexdigest() - if existing_image := Images.get_or_none(Images.emoji_hash == image_hash): - # 检查是否缺少必要字段,如果缺少则创建新记录 - if ( - not hasattr(existing_image, "image_id") - or not existing_image.image_id - or not hasattr(existing_image, "count") - or existing_image.count is None - or not hasattr(existing_image, "vlm_processed") - or existing_image.vlm_processed is None - ): - logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}") - if not existing_image.image_id: - existing_image.image_id = str(uuid.uuid4()) - if existing_image.count is None: - existing_image.count = 0 - if existing_image.vlm_processed is None: - existing_image.vlm_processed = False + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE) + ) + existing_image = session.exec(statement).first() + if existing_image: + existing_image.query_count += 1 + session.add(existing_image) + return str(existing_image.id), f"[picid:{existing_image.id}]" - existing_image.count += 1 - existing_image.save() - return existing_image.image_id, f"[picid:{existing_image.image_id}]" - else: - # print(f"图片不存在: {image_hash}") - image_id = str(uuid.uuid4()) + image_id = str(uuid.uuid4()) # 保存新图片 current_timestamp = time.time() @@ -612,15 +613,19 @@ class ImageManager: f.write(image_bytes) # 保存到数据库 - Images.create( - image_id=image_id, - emoji_hash=image_hash, - path=file_path, - type="image", - timestamp=current_timestamp, - vlm_processed=False, - count=1, - ) + with get_db_session() as session: + new_record = Images( + image_hash=image_hash, + description="", + full_path=file_path, + image_type=ImageType.IMAGE, + query_count=1, + is_registered=False, + is_banned=False, + record_time=datetime.fromtimestamp(current_timestamp), + vlm_processed=False, + ) + session.add(new_record) # 启动异步VLM处理 await self._process_image_with_vlm(image_id, image_base64) @@ -647,17 +652,26 @@ class ImageManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 获取当前图片记录 - image = Images.get(Images.image_id == image_id) + with get_db_session() as session: + image = session.get(Images, int(image_id)) if image_id.isdigit() else None + if image is None: + logger.warning(f"未找到图片记录: {image_id}") + return # 优先检查是否已有其他相同哈希的图片记录包含描述 - existing_with_description = Images.get_or_none( - (Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "") - ) + with get_db_session() as session: + statement = select(Images).where( + (col(Images.image_hash) == image_hash) + & (col(Images.description).is_not(None)) + & (col(Images.description) != "") + ) + existing_with_description = session.exec(statement).first() if existing_with_description and existing_with_description.id != image.id: logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...") image.description = existing_with_description.description image.vlm_processed = True - image.save() + with get_db_session() as session: + session.add(image) # 同时保存到ImageDescriptions表作为备用缓存 self._save_description_to_db(image_hash, existing_with_description.description, "image") return @@ -667,11 +681,12 @@ class ImageManager: logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...") image.description = cached_description image.vlm_processed = True - image.save() + with get_db_session() as session: + session.add(image) return # 获取图片格式 - image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore + image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore # 构建prompt prompt = global_config.personality.visual_style @@ -692,7 +707,8 @@ class ImageManager: # 更新数据库 image.description = description image.vlm_processed = True - image.save() + with get_db_session() as session: + session.add(image) # 保存描述到ImageDescriptions表作为备用缓存 self._save_description_to_db(image_hash, description, "image") diff --git a/src/common/data_models/__init__.py b/src/common/data_models/__init__.py index a1032806..c83a9528 100644 --- a/src/common/data_models/__init__.py +++ b/src/common/data_models/__init__.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod -from typing import Self, TypeVar, Generic, TYPE_CHECKING +from dataclasses import is_dataclass +from typing import Any, Dict, Self, TypeVar, Generic, TYPE_CHECKING import copy @@ -15,9 +16,23 @@ class BaseDataModel: return copy.deepcopy(self) +def transform_class_to_dict(obj: Any) -> Dict[str, Any]: + if obj is None: + return {} + if is_dataclass(obj): + return obj.__dict__ + if hasattr(obj, "dict"): + return obj.dict() + if hasattr(obj, "model_dump"): + return obj.model_dump() + if hasattr(obj, "__dict__"): + return obj.__dict__ + return {"value": obj} + + class BaseDatabaseDataModel(ABC, Generic[T]): - @abstractmethod @classmethod + @abstractmethod def from_db_instance(cls, db_record: T) -> Self: """从数据库实例创建数据模型对象""" raise NotImplementedError diff --git a/src/common/data_models/message_data_model.py b/src/common/data_models/message_data_model.py new file mode 100644 index 00000000..938e0013 --- /dev/null +++ b/src/common/data_models/message_data_model.py @@ -0,0 +1,79 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Iterable, List, Optional, Tuple, Union + +from . import BaseDataModel + + +class ReplyContentType(Enum): + TEXT = "text" + IMAGE = "image" + EMOJI = "emoji" + COMMAND = "command" + VOICE = "voice" + HYBRID = "hybrid" + FORWARD = "forward" + + def __str__(self) -> str: + return self.value + + +@dataclass +class ReplyContent: + content_type: ReplyContentType | str + content: Any + + +@dataclass +class ForwardNode: + user_id: Optional[str] = None + user_nickname: Optional[str] = None + content: Union[str, List[ReplyContent], None] = None + + @classmethod + def construct_as_id_reference(cls, message_id: str) -> "ForwardNode": + return cls(content=message_id) + + @classmethod + def construct_as_created_node( + cls, + user_id: str, + user_nickname: str, + content: List[ReplyContent], + ) -> "ForwardNode": + return cls(user_id=user_id, user_nickname=user_nickname, content=content) + + +class ReplySetModel(BaseDataModel): + def __init__(self) -> None: + self.reply_data: List[ReplyContent] = [] + + def __len__(self) -> int: + return len(self.reply_data) + + def add_text_content(self, text: str) -> None: + self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text)) + + def add_voice_content(self, voice_base64: str) -> None: + self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64)) + + def add_hybrid_content_by_raw(self, message_tuple_list: Iterable[Tuple[ReplyContentType | str, str]]) -> None: + hybrid_contents: List[ReplyContent] = [] + for content_type, content in message_tuple_list: + hybrid_contents.append( + ReplyContent(content_type=self._normalize_content_type(content_type), content=content) + ) + self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_contents)) + + def add_forward_content(self, forward_nodes: List[ForwardNode]) -> None: + self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_nodes)) + + @staticmethod + def _normalize_content_type(content_type: ReplyContentType | str) -> ReplyContentType | str: + if isinstance(content_type, ReplyContentType): + return content_type + if isinstance(content_type, str): + for item in ReplyContentType: + if item.value == content_type: + return item + return content_type diff --git a/src/common/database/database.py b/src/common/database/database.py index e985466e..e88be9ec 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,11 +1,12 @@ from rich.traceback import install -from pathlib import Path from contextlib import contextmanager -from sqlalchemy.orm import sessionmaker +from pathlib import Path +from typing import Generator, TYPE_CHECKING + from sqlalchemy import event from sqlalchemy.engine import Engine -from sqlmodel import create_engine, Session -from typing import TYPE_CHECKING, Generator +from sqlalchemy.orm import sessionmaker +from sqlmodel import SQLModel, Session, create_engine if TYPE_CHECKING: from sqlite3 import Connection as SQLite3Connection @@ -53,6 +54,19 @@ SessionLocal = sessionmaker( class_=Session, ) +_db_initialized = False + + +def initialize_database() -> None: + global _db_initialized + if _db_initialized: + return + _DB_DIR.mkdir(parents=True, exist_ok=True) + import src.common.database.database_model # noqa: F401 + + SQLModel.metadata.create_all(engine) + _db_initialized = True + @contextmanager def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: @@ -87,6 +101,7 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: - auto_commit=True 时,成功执行完会自动提交 - auto_commit=False 时,需要手动调用 session.commit() """ + initialize_database() session = SessionLocal() try: yield session @@ -120,6 +135,7 @@ def get_db() -> Generator[Session, None, None]: Yields: Session: SQLAlchemy 数据库会话 """ + initialize_database() session = SessionLocal() try: yield session diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 19cb0544..8ade577a 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -1,35 +1,153 @@ import traceback +from datetime import datetime +from typing import Any -from typing import List, Any, Optional +import json -from src.config.config import global_config -from src.common.data_models.database_data_model import DatabaseMessages +from sqlalchemy import func +from sqlmodel import col, select + +from src.common.database.database import get_db_session from src.common.database.database_model import Messages +from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger +from src.config.config import global_config logger = get_logger(__name__) -def _model_to_instance(model_instance: Any) -> DatabaseMessages: - """ - 将 Peewee 模型实例转换为字典。 - """ - if isinstance(model_instance, dict): - return DatabaseMessages(**model_instance) - if hasattr(model_instance, "model_dump"): - return DatabaseMessages(**model_instance.model_dump()) - return DatabaseMessages(**model_instance.__dict__) +FIELD_MAP: dict[str, Any] = { + "time": Messages.timestamp, + "timestamp": Messages.timestamp, + "chat_id": Messages.session_id, + "session_id": Messages.session_id, + "user_id": Messages.user_id, + "message_id": Messages.message_id, + "group_id": Messages.group_id, + "platform": Messages.platform, + "is_command": Messages.is_command, + "is_mentioned": Messages.is_mentioned, + "is_at": Messages.is_at, + "is_emoji": Messages.is_emoji, + "is_picid": Messages.is_picture, + "is_picture": Messages.is_picture, + "reply_to": Messages.reply_to, +} + + +def _parse_additional_config(message: Messages) -> dict[str, Any]: + if not message.additional_config: + return {} + try: + parsed = json.loads(message.additional_config) + except (json.JSONDecodeError, TypeError): + return {} + if isinstance(parsed, dict): + return parsed + return {} + + +def _normalize_optional_str(value: object) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, ensure_ascii=False) + except (TypeError, ValueError): + return str(value) + + +def _message_to_instance(message: Messages) -> DatabaseMessages: + config = _parse_additional_config(message) + timestamp_value = message.timestamp + if isinstance(timestamp_value, datetime): + time_value = timestamp_value.timestamp() + else: + time_value = float(timestamp_value) + selected_expressions = _normalize_optional_str(config.get("selected_expressions")) + priority_info = _normalize_optional_str(config.get("priority_info")) + return DatabaseMessages( + message_id=message.message_id, + time=time_value, + chat_id=message.session_id, + reply_to=message.reply_to, + interest_value=config.get("interest_value"), + key_words=_normalize_optional_str(config.get("key_words")), + key_words_lite=_normalize_optional_str(config.get("key_words_lite")), + is_mentioned=message.is_mentioned, + is_at=message.is_at, + reply_probability_boost=config.get("reply_probability_boost"), + processed_plain_text=message.processed_plain_text, + display_message=message.display_message, + priority_mode=_normalize_optional_str(config.get("priority_mode")), + priority_info=priority_info, + additional_config=message.additional_config, + is_emoji=message.is_emoji, + is_picid=message.is_picture, + is_command=message.is_command, + intercept_message_level=config.get("intercept_message_level", 0), + is_notify=message.is_notify, + selected_expressions=selected_expressions, + user_id=message.user_id, + user_nickname=message.user_nickname, + user_cardname=message.user_cardname, + user_platform=message.platform, + chat_info_group_id=message.group_id, + chat_info_group_name=message.group_name, + chat_info_group_platform=message.platform, + chat_info_user_id=message.user_id, + chat_info_user_nickname=message.user_nickname, + chat_info_user_cardname=message.user_cardname, + chat_info_user_platform=message.platform, + chat_info_stream_id=message.session_id, + chat_info_platform=message.platform, + chat_info_create_time=0.0, + chat_info_last_active_time=0.0, + ) + + +def _coerce_datetime(value: Any) -> Any: + if isinstance(value, (int, float)): + return datetime.fromtimestamp(value) + return value + + +def _cast_value_for_field(field: Any, value: Any) -> Any: + if field is Messages.timestamp: + return _coerce_datetime(value) + return value + + +def _ensure_list(value: Any) -> list[Any]: + if value is None: + return [] + if isinstance(value, list): + return value + if isinstance(value, tuple): + return list(value) + if isinstance(value, set): + return list(value) + return [value] + + +def _resolve_field(field_name: str) -> Any | None: + if field_name in FIELD_MAP: + return FIELD_MAP[field_name] + if hasattr(Messages, field_name): + return getattr(Messages, field_name) + return None def find_messages( message_filter: dict[str, Any], - sort: Optional[List[tuple[str, int]]] = None, + sort: list[tuple[str, int]] | None = None, limit: int = 0, limit_mode: str = "latest", - filter_bot=False, - filter_command=False, - filter_intercept_message_level: Optional[int] = None, -) -> List[DatabaseMessages]: + filter_bot: bool = False, + filter_command: bool = False, + filter_intercept_message_level: int | None = None, +) -> list[DatabaseMessages]: """ 根据提供的过滤器、排序和限制条件查找消息。 @@ -43,92 +161,79 @@ def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: - query = Messages.select() - - # 应用过滤器 + conditions: list[Any] = [] if message_filter: - conditions = [] for key, value in message_filter.items(): - if hasattr(Messages, key): - field = getattr(Messages, key) - if isinstance(value, dict): - # 处理 MongoDB 风格的操作符 - for op, op_value in value.items(): - if op == "$gt": - conditions.append(field > op_value) - elif op == "$lt": - conditions.append(field < op_value) - elif op == "$gte": - conditions.append(field >= op_value) - elif op == "$lte": - conditions.append(field <= op_value) - elif op == "$ne": - conditions.append(field != op_value) - elif op == "$in": - conditions.append(field.in_(op_value)) - elif op == "$nin": - conditions.append(field.not_in(op_value)) - else: - logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") - else: - # 直接相等比较 - conditions.append(field == value) - else: + field = _resolve_field(key) + if field is None: logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") - if conditions: - query = query.where(*conditions) - - # 排除 id 为 "notice" 的消息 - query = query.where(Messages.message_id != "notice") + continue + if isinstance(value, dict): + for op, op_value in value.items(): + coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value + if op == "$gt": + conditions.append(field > coerced_value) + elif op == "$lt": + conditions.append(field < coerced_value) + elif op == "$gte": + conditions.append(field >= coerced_value) + elif op == "$lte": + conditions.append(field <= coerced_value) + elif op == "$ne": + conditions.append(field != coerced_value) + elif op == "$in": + conditions.append(field.in_(_ensure_list(coerced_value))) + elif op == "$nin": + conditions.append(field.not_in(_ensure_list(coerced_value))) + else: + logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") + else: + coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value + conditions.append(field == coerced_value) + conditions.append(Messages.message_id != "notice") if filter_bot: - query = query.where(Messages.user_id != global_config.bot.qq_account) - + conditions.append(Messages.user_id != global_config.bot.qq_account) if filter_command: - # 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较 - query = query.where(~Messages.is_command) - - if filter_intercept_message_level is not None: - # 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息 - query = query.where(Messages.intercept_message_level <= filter_intercept_message_level) + conditions.append(Messages.is_command == False) # noqa: E712 + statement = select(Messages).where(*conditions) if limit > 0: if limit_mode == "earliest": - # 获取时间最早的 limit 条记录,已经是正序 - query = query.order_by("time").limit(limit) - peewee_results = list(query) - else: # 默认为 'latest' - # 获取时间最晚的 limit 条记录 - query = query.order_by("-time").limit(limit) - latest_results_peewee = list(query) - # 将结果按时间正序排列 - peewee_results = sorted( - latest_results_peewee, - key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0), - ) + statement = statement.order_by(col(Messages.timestamp)).limit(limit) + with get_db_session() as session: + results = list(session.exec(statement).all()) + else: + statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit) + with get_db_session() as session: + results = list(session.exec(statement).all()) + results = list(reversed(results)) else: - # limit 为 0 时,应用传入的 sort 参数 if sort: - peewee_sort_terms = [] + order_terms: list[Any] = [] for field_name, direction in sort: - if hasattr(Messages, field_name): - field = getattr(Messages, field_name) - if direction == 1: # ASC - peewee_sort_terms.append(field_name) - elif direction == -1: # DESC - peewee_sort_terms.append(f"-{field_name}") - else: - logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。") - else: + sort_field = _resolve_field(field_name) + if sort_field is None: logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。") - if peewee_sort_terms: - query = query.order_by(*peewee_sort_terms) - peewee_results = list(query) + continue + order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc()) + if order_terms: + statement = statement.order_by(*order_terms) + with get_db_session() as session: + results = list(session.exec(statement).all()) - return [_model_to_instance(msg) for msg in peewee_results] + if filter_intercept_message_level is not None: + filtered_results = [] + for msg in results: + config = _parse_additional_config(msg) + if config.get("intercept_message_level", 0) <= filter_intercept_message_level: + filtered_results.append(msg) + results = filtered_results + + return [_message_to_instance(msg) for msg in results] except Exception as e: log_message = ( - f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n" + traceback.format_exc() ) logger.error(log_message) @@ -146,54 +251,42 @@ def count_messages(message_filter: dict[str, Any]) -> int: 符合条件的消息数量,如果出错则返回 0。 """ try: - query = Messages.select() - - # 应用过滤器 + conditions: list[Any] = [] if message_filter: - conditions = [] for key, value in message_filter.items(): - if hasattr(Messages, key): - field = getattr(Messages, key) - if isinstance(value, dict): - # 处理 MongoDB 风格的操作符 - for op, op_value in value.items(): - if op == "$gt": - conditions.append(field > op_value) - elif op == "$lt": - conditions.append(field < op_value) - elif op == "$gte": - conditions.append(field >= op_value) - elif op == "$lte": - conditions.append(field <= op_value) - elif op == "$ne": - conditions.append(field != op_value) - elif op == "$in": - conditions.append(field.in_(op_value)) - elif op == "$nin": - conditions.append(field.not_in(op_value)) - else: - logger.warning( - f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。" - ) - else: - # 直接相等比较 - conditions.append(field == value) - else: + field = _resolve_field(key) + if field is None: logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。") - if conditions: - query = query.where(*conditions) + continue + if isinstance(value, dict): + for op, op_value in value.items(): + coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value + if op == "$gt": + conditions.append(field > coerced_value) + elif op == "$lt": + conditions.append(field < coerced_value) + elif op == "$gte": + conditions.append(field >= coerced_value) + elif op == "$lte": + conditions.append(field <= coerced_value) + elif op == "$ne": + conditions.append(field != coerced_value) + elif op == "$in": + conditions.append(field.in_(_ensure_list(coerced_value))) + elif op == "$nin": + conditions.append(field.not_in(_ensure_list(coerced_value))) + else: + logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。") + else: + coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value + conditions.append(field == coerced_value) - # 排除 id 为 "notice" 的消息 - query = query.where(Messages.message_id != "notice") - - count = query.count() - return count + conditions.append(Messages.message_id != "notice") + statement = select(func.count()).select_from(Messages).where(*conditions) + with get_db_session() as session: + result = session.exec(statement).one() + return int(result or 0) except Exception as e: - log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" + log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}" logger.error(log_message) return 0 - - -# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。 -# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。 -# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。 diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index 4bef8bf7..99a754e6 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -5,8 +5,8 @@ from PIL import Image from datetime import datetime from src.common.logger import get_logger -from src.common.database.database import db # 确保 db 被导入用于 create_tables -from src.common.database.database_model import LLMUsage +from src.common.database.database import get_db_session +from src.common.database.database_model import ModelUsage, ModelUser from src.config.model_configs import ModelInfo from .payload_content.message import Message, MessageBuilder from .model_client.base_client import UsageRecord @@ -158,12 +158,7 @@ class LLMUsageRecorder: """ def __init__(self): - try: - # 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误 - db.create_tables([LLMUsage], safe=True) - # logger.debug("LLMUsage 表已初始化/确保存在。") - except Exception as e: - logger.error(f"创建 LLMUsage 表失败: {str(e)}") + pass def record_usage_to_database( self, @@ -178,22 +173,22 @@ class LLMUsageRecorder: output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out total_cost = round(input_cost + output_cost, 6) try: - # 使用 Peewee 模型创建记录 - LLMUsage.create( - model_name=model_info.model_identifier, - model_assign_name=model_info.name, - model_api_provider=model_info.api_provider, - user_id=user_id, - request_type=request_type, - endpoint=endpoint, - prompt_tokens=model_usage.prompt_tokens or 0, - completion_tokens=model_usage.completion_tokens or 0, - total_tokens=model_usage.total_tokens or 0, - cost=total_cost or 0.0, - time_cost=round(time_cost or 0.0, 3), - status="success", - timestamp=datetime.now(), # Peewee 会处理 DateTimeField - ) + with get_db_session() as session: + record = ModelUsage( + model_name=model_info.model_identifier, + model_assign_name=model_info.name, + model_api_provider_name=model_info.api_provider, + endpoint=endpoint, + user_type=ModelUser.SYSTEM, + request_type=request_type, + time_cost=round(time_cost or 0.0, 3), + timestamp=datetime.now(), + prompt_tokens=model_usage.prompt_tokens or 0, + completion_tokens=model_usage.completion_tokens or 0, + total_tokens=model_usage.total_tokens or 0, + cost=total_cost or 0.0, + ) + session.add(record) logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " f"用户: {user_id}, 类型: {request_type}, " diff --git a/src/main.py b/src/main.py index 16494686..97e58578 100644 --- a/src/main.py +++ b/src/main.py @@ -7,7 +7,7 @@ from src.manager.async_task_manager import async_task_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask # from src.chat.utils.token_statistics import TokenStatisticsTask -from src.chat.emoji_system.emoji_manager import get_emoji_manager +from src.chat.emoji_system.emoji_manager import emoji_manager from src.chat.message_receive.chat_stream import get_chat_manager from src.config.config import global_config from src.chat.message_receive.bot import chat_bot @@ -107,7 +107,7 @@ class MainSystem: plugin_manager.load_all_plugins() # 初始化表情管理器 - get_emoji_manager().initialize() + emoji_manager.load_emojis_from_db() logger.info("表情包管理器初始化成功") # 初始化聊天管理器 @@ -121,7 +121,7 @@ class MainSystem: # 将bot.py中的chat_bot.message_process消息处理函数注册到api.py的消息处理基类中 self.app.register_message_handler(chat_bot.message_process) self.app.register_custom_message_handler("message_id_echo", chat_bot.echo_message_process) - + prompt_manager.load_prompts() # 触发 ON_START 事件 @@ -141,7 +141,7 @@ class MainSystem: """调度定时任务""" try: tasks = [ - get_emoji_manager().start_periodic_check_register(), + emoji_manager.periodic_emoji_maintenance(), start_dream_scheduler(), self.app.run(), self.server.run(), diff --git a/src/memory_system/memory_retrieval.py b/src/memory_system/memory_retrieval.py index e77eb78c..b356663a 100644 --- a/src/memory_system/memory_retrieval.py +++ b/src/memory_system/memory_retrieval.py @@ -1,12 +1,15 @@ import time import json import asyncio -from typing import List, Dict, Any, Optional, Tuple +from datetime import datetime +from typing import List, Dict, Any, Optional, Tuple, Callable, cast from src.common.logger import get_logger from src.config.config import global_config, model_config from src.prompt.prompt_manager import prompt_manager from src.plugin_system.apis import llm_api -from src.common.database.database_model import ThinkingBack +from sqlmodel import select, col +from src.common.database.database import get_db_session +from src.common.database.database_model import ThinkingQuestion from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message from src.chat.message_receive.chat_stream import get_chat_manager @@ -29,13 +32,16 @@ def _cleanup_stale_not_found_thinking_back() -> None: threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS try: - deleted_rows = ( - ThinkingBack.delete() - .where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time)) - .execute() - ) - if deleted_rows: - logger.info(f"清理过期的未找到答案thinking_back记录 {deleted_rows} 条") + with get_db_session() as session: + statement = select(ThinkingQuestion).where( + (ThinkingQuestion.found_answer == False) + & (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time)) + ) + records = session.exec(statement).all() + for record in records: + session.delete(record) + if records: + logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)} 条") _last_not_found_cleanup_ts = now except Exception as e: logger.error(f"清理未找到答案的thinking_back记录失败: {e}") @@ -249,12 +255,12 @@ async def _react_agent_solve_question( # 后续迭代都复用第一次构建的head_prompt head_prompt = first_head_prompt - def message_factory( + def _build_messages( _client, *, _head_prompt: str = head_prompt, _conversation_messages: List[Message] = conversation_messages, - ) -> List[Message]: + ): messages: List[Message] = [] system_builder = MessageBuilder() @@ -266,6 +272,7 @@ async def _react_agent_solve_question( return messages + message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues] ( success, response, @@ -273,7 +280,7 @@ async def _react_agent_solve_question( model_name, tool_calls, ) = await llm_api.generate_with_model_with_tools_by_message_factory( - message_factory, + message_factory_fn, # type: ignore[arg-type] model_config=model_config.model_task_config.tool_use, tool_options=tool_definitions, request_type="memory.react", @@ -304,7 +311,12 @@ async def _react_agent_solve_question( assistant_message = assistant_builder.build() # 记录思考步骤 - step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []} + step: Dict[str, Any] = { + "iteration": iteration + 1, + "thought": response, + "actions": [], + "observations": [], + } if assistant_message: conversation_messages.append(assistant_message) @@ -417,20 +429,21 @@ async def _react_agent_solve_question( "action_params": {"information": parsed_information or ""}, } ) - if parsed_information and parsed_information.strip(): + parsed_info_text = parsed_information if isinstance(parsed_information, str) else "" + if parsed_info_text.strip(): step["observations"] = [f"检测到return_information{format_type}调用,返回信息"] thinking_steps.append(step) logger.info( - f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_information[:100]}..." + f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_info_text[:100]}..." ) _log_conversation_messages( conversation_messages, head_prompt=first_head_prompt, - final_status=f"返回信息:{parsed_information}", + final_status=f"返回信息:{parsed_info_text}", ) - return True, parsed_information, thinking_steps, False + return True, parsed_info_text, thinking_steps, False else: # 信息为空,直接退出查询 step["observations"] = [f"检测到return_information{format_type}调用,信息为空"] @@ -776,15 +789,16 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0) current_time = time.time() start_time = current_time - time_window_seconds - # 查询最近时间窗口内的记录,按更新时间倒序 - records = ( - ThinkingBack.select() - .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time)) - .order_by(ThinkingBack.update_time.desc()) - .limit(5) # 最多返回5条最近的记录 - ) + with get_db_session() as session: + statement = ( + select(ThinkingQuestion) + .where(col(ThinkingQuestion.context) == chat_id) + .order_by(col(ThinkingQuestion.updated_timestamp).desc()) + .limit(5) + ) + records = session.exec(statement).all() - if not records.exists(): + if not records: return "" history_lines = [] @@ -828,20 +842,19 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0) start_time = current_time - time_window_seconds # 查询最近时间窗口内已找到答案的记录,按更新时间倒序 - records = ( - ThinkingBack.select() - .where( - (ThinkingBack.chat_id == chat_id) - & (ThinkingBack.update_time >= start_time) - & (ThinkingBack.found_answer == 1) - & (ThinkingBack.answer.is_null(False)) - & (ThinkingBack.answer != "") + with get_db_session() as session: + statement = ( + select(ThinkingQuestion) + .where(col(ThinkingQuestion.context) == chat_id) + .where(col(ThinkingQuestion.found_answer) == True) + .where(col(ThinkingQuestion.answer).is_not(None)) + .where(col(ThinkingQuestion.answer) != "") + .order_by(col(ThinkingQuestion.updated_timestamp).desc()) + .limit(3) ) - .order_by(ThinkingBack.update_time.desc()) - .limit(3) # 最多返回5条最近的记录 - ) + records = session.exec(statement).all() - if not records.exists(): + if not records: return [] found_answers = [] @@ -873,36 +886,35 @@ def _store_thinking_back( now = time.time() # 先查询是否已存在相同chat_id和问题的记录 - existing = ( - ThinkingBack.select() - .where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question)) - .order_by(ThinkingBack.update_time.desc()) - .limit(1) - ) + with get_db_session() as session: + statement = ( + select(ThinkingQuestion) + .where(col(ThinkingQuestion.context) == chat_id) + .where(col(ThinkingQuestion.question) == question) + .order_by(col(ThinkingQuestion.updated_timestamp).desc()) + .limit(1) + ) + record = session.exec(statement).first() + if record: + record.context = context + record.found_answer = found_answer + record.answer = answer + record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False) + record.updated_timestamp = datetime.fromtimestamp(now) + session.add(record) + logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...") + return - if existing.exists(): - # 更新现有记录 - record = existing.get() - record.context = context - record.found_answer = found_answer - record.answer = answer - record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False) - record.update_time = now - record.save() - logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...") - else: - # 创建新记录 - ThinkingBack.create( - chat_id=chat_id, + new_record = ThinkingQuestion( question=question, - context=context, + context=chat_id, found_answer=found_answer, answer=answer, thinking_steps=json.dumps(thinking_steps, ensure_ascii=False), - create_time=now, - update_time=now, + created_timestamp=datetime.fromtimestamp(now), + updated_timestamp=datetime.fromtimestamp(now), ) - # logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...") + session.add(new_record) except Exception as e: logger.error(f"存储思考过程失败: {e}") diff --git a/src/person_info/person_info.py b/src/person_info/person_info.py index fe4e2116..4e245b67 100644 --- a/src/person_info/person_info.py +++ b/src/person_info/person_info.py @@ -6,10 +6,13 @@ import random import math from json_repair import repair_json -from typing import Union, Optional +from typing import Union, Optional, Dict +from datetime import datetime + +from sqlmodel import col, select from src.common.logger import get_logger -from src.common.database.database import db +from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.llm_models.utils_model import LLMRequest from src.config.config import global_config, model_config @@ -35,24 +38,37 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str: def get_person_id_by_person_name(person_name: str) -> str: """根据用户名获取用户ID""" try: - record = PersonInfo.get_or_none(PersonInfo.person_name == person_name) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1) + record = session.exec(statement).first() return record.person_id if record else "" except Exception as e: - logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}") + logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}") return "" -def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore +def is_person_known( + person_id: Optional[str] = None, + user_id: Optional[str] = None, + platform: Optional[str] = None, + person_name: Optional[str] = None, +) -> bool: if person_id: - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() return person.is_known if person else False elif user_id and platform: person_id = get_person_id(platform, user_id) - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() return person.is_known if person else False elif person_name: person_id = get_person_id_by_person_name(person_name) - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() return person.is_known if person else False else: return False @@ -442,17 +458,18 @@ class Person: def load_from_database(self): """从数据库加载个人信息数据""" try: - # 查询数据库中的记录 - record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1) + record = session.exec(statement).first() if record: self.user_id = record.user_id or "" self.platform = record.platform or "" self.is_known = record.is_known or False - self.nickname = record.nickname or "" + self.nickname = record.user_nickname or "" self.person_name = record.person_name or self.nickname self.name_reason = record.name_reason or None - self.know_times = record.know_times or 0 + self.know_times = record.know_counts or 0 # 处理points字段(JSON格式的列表) if record.memory_points: @@ -470,16 +487,16 @@ class Person: self.memory_points = [] # 处理group_nick_name字段(JSON格式的列表) - if record.group_nick_name: + if record.group_nickname: try: - loaded_group_nick_names = json.loads(record.group_nick_name) + loaded_group_nick_names = json.loads(record.group_nickname) # 确保是列表格式 if isinstance(loaded_group_nick_names, list): self.group_nick_name = loaded_group_nick_names else: self.group_nick_name = [] except (json.JSONDecodeError, TypeError): - logger.warning(f"解析用户 {self.person_id} 的group_nick_name字段失败,使用默认值") + logger.warning(f"解析用户 {self.person_id} 的group_nickname字段失败,使用默认值") self.group_nick_name = [] else: self.group_nick_name = [] @@ -498,42 +515,55 @@ class Person: if not self.is_known: return try: - # 准备数据 - data = { - "person_id": self.person_id, - "is_known": self.is_known, - "platform": self.platform, - "user_id": self.user_id, - "nickname": self.nickname, - "person_name": self.person_name, - "name_reason": self.name_reason, - "know_times": self.know_times, - "know_since": self.know_since, - "last_know": self.last_know, - "memory_points": json.dumps( - [point for point in self.memory_points if point is not None], ensure_ascii=False - ) + memory_points_value = ( + json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False) if self.memory_points - else json.dumps([], ensure_ascii=False), - "group_nick_name": json.dumps(self.group_nick_name, ensure_ascii=False) + else json.dumps([], ensure_ascii=False) + ) + group_nickname_value = ( + json.dumps(self.group_nick_name, ensure_ascii=False) if self.group_nick_name - else json.dumps([], ensure_ascii=False), - } + else json.dumps([], ensure_ascii=False) + ) + first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None + last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None - # 检查记录是否存在 - record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1) + record = session.exec(statement).first() - if record: - # 更新现有记录 - for field, value in data.items(): - if hasattr(record, field): - setattr(record, field, value) - record.save() - logger.debug(f"已同步用户 {self.person_id} 的信息到数据库") - else: - # 创建新记录 - PersonInfo.create(**data) - logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") + if record: + record.person_id = self.person_id + record.is_known = self.is_known + record.platform = self.platform + record.user_id = self.user_id + record.user_nickname = self.nickname + record.person_name = self.person_name + record.name_reason = self.name_reason + record.know_counts = self.know_times + record.first_known_time = first_known_time + record.last_known_time = last_known_time + record.memory_points = memory_points_value + record.group_nickname = group_nickname_value + session.add(record) + logger.debug(f"已同步用户 {self.person_id} 的信息到数据库") + else: + record = PersonInfo( + person_id=self.person_id, + is_known=self.is_known, + platform=self.platform, + user_id=self.user_id, + user_nickname=self.nickname, + person_name=self.person_name, + name_reason=self.name_reason, + know_counts=self.know_times, + first_known_time=first_known_time, + last_known_time=last_known_time, + memory_points=memory_points_value, + group_nickname=group_nickname_value, + ) + session.add(record) + logger.debug(f"已创建用户 {self.person_id} 的信息到数据库") except Exception as e: logger.error(f"同步用户 {self.person_id} 信息到数据库时出错: {e}") @@ -621,30 +651,26 @@ class PersonInfoManager: self.person_name_list = {} self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name") try: - db.connect(reuse_if_open=True) - # 设置连接池参数 - if hasattr(db, "execute_sql"): - # 设置SQLite优化参数 - db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存 - db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中 - db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射 - db.create_tables([PersonInfo], safe=True) + with get_db_session() as _: + pass except Exception as e: logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}") # 初始化时读取所有person_name try: - for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where( - PersonInfo.person_name.is_null(False) - ): - if record.person_name: - self.person_name_list[record.person_id] = record.person_name - logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)") + with get_db_session() as session: + statement = select(PersonInfo.person_id, PersonInfo.person_name).where( + col(PersonInfo.person_name).is_not(None) + ) + for person_id, person_name in session.exec(statement).all(): + if person_name: + self.person_name_list[person_id] = person_name + logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称") except Exception as e: - logger.error(f"从 Peewee 加载 person_name_list 失败: {e}") + logger.error(f"加载 person_name_list 失败: {e}") @staticmethod - def _extract_json_from_text(text: str) -> dict: + def _extract_json_from_text(text: str) -> Dict[str, str]: """从文本中提取JSON数据的高容错方法""" try: fixed_json = repair_json(text) @@ -744,7 +770,9 @@ class PersonInfoManager: else: def _db_check_name_exists_sync(name_to_check): - return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists() + with get_db_session() as session: + statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == name_to_check) + return session.exec(statement).first() is not None if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname): is_duplicate = True @@ -804,7 +832,7 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str, if not person_id: # 如果通过person_name找不到,尝试从chat_stream获取user_info - if chat_stream.user_info: + if platform and chat_stream.user_info and chat_stream.user_info.user_id: user_id = chat_stream.user_info.user_id person_id = get_person_id(platform, user_id) else: diff --git a/src/plugin_system/apis/emoji_api.py b/src/plugin_system/apis/emoji_api.py index e4d6fcd7..2a99d25c 100644 --- a/src/plugin_system/apis/emoji_api.py +++ b/src/plugin_system/apis/emoji_api.py @@ -15,8 +15,9 @@ import uuid from typing import Optional, Tuple, List, Dict, Any from src.common.logger import get_logger -from src.chat.emoji_system.emoji_manager import get_emoji_manager, EMOJI_DIR +from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR from src.chat.utils.utils_image import image_path_to_base64, base64_to_image +from src.config.config import global_config logger = get_logger("emoji_api") @@ -46,14 +47,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]] try: logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}") - emoji_manager = get_emoji_manager() - emoji_result = await emoji_manager.get_emoji_for_text(description) + emoji_obj = await emoji_manager.get_emoji_for_emotion(description) - if not emoji_result: + if not emoji_obj: logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包") return None - emoji_path, emoji_description, matched_emotion = emoji_result + emoji_path = str(emoji_obj.full_path) + emoji_description = emoji_obj.description + matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else "" emoji_base64 = image_path_to_base64(emoji_path) if not emoji_base64: @@ -90,8 +92,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: return [] try: - emoji_manager = get_emoji_manager() - all_emojis = emoji_manager.emoji_objects + all_emojis = emoji_manager.emojis if not all_emojis: logger.warning("[EmojiAPI] 没有可用的表情包") @@ -114,7 +115,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: results = [] for selected_emoji in selected_emojis: - emoji_base64 = image_path_to_base64(selected_emoji.full_path) + emoji_base64 = image_path_to_base64(str(selected_emoji.full_path)) if not emoji_base64: logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}") @@ -123,7 +124,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]: matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情" # 记录使用次数 - emoji_manager.record_usage(selected_emoji.hash) + emoji_manager.update_emoji_usage(selected_emoji) results.append((emoji_base64, selected_emoji.description, matched_emotion)) if not results and count > 0: @@ -158,8 +159,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: try: logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}") - emoji_manager = get_emoji_manager() - all_emojis = emoji_manager.emoji_objects + all_emojis = emoji_manager.emojis # 筛选匹配情感的表情包 matching_emojis = [] @@ -181,7 +181,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]: return None # 记录使用次数 - emoji_manager.record_usage(selected_emoji.hash) + emoji_manager.update_emoji_usage(selected_emoji) logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}") return emoji_base64, selected_emoji.description, emotion @@ -203,8 +203,7 @@ def get_count() -> int: int: 当前可用的表情包数量 """ try: - emoji_manager = get_emoji_manager() - return emoji_manager.emoji_num + return len(emoji_manager.emojis) except Exception as e: logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}") return 0 @@ -217,11 +216,10 @@ def get_info(): dict: 包含表情包数量、最大数量、可用数量信息 """ try: - emoji_manager = get_emoji_manager() return { - "current_count": emoji_manager.emoji_num, - "max_count": emoji_manager.emoji_num_max, - "available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]), + "current_count": len(emoji_manager.emojis), + "max_count": global_config.emoji.max_reg_num, + "available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]), } except Exception as e: logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}") @@ -235,10 +233,9 @@ def get_emotions() -> List[str]: list: 所有表情包的情感标签列表(去重) """ try: - emoji_manager = get_emoji_manager() emotions = set() - for emoji_obj in emoji_manager.emoji_objects: + for emoji_obj in emoji_manager.emojis: if not emoji_obj.is_deleted and emoji_obj.emotion: emotions.update(emoji_obj.emotion) @@ -255,8 +252,7 @@ async def get_all() -> List[Tuple[str, str, str]]: List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表 """ try: - emoji_manager = get_emoji_manager() - all_emojis = emoji_manager.emoji_objects + all_emojis = emoji_manager.emojis if not all_emojis: logger.warning("[EmojiAPI] 没有可用的表情包") @@ -267,7 +263,7 @@ async def get_all() -> List[Tuple[str, str, str]]: if emoji_obj.is_deleted: continue - emoji_base64 = image_path_to_base64(emoji_obj.full_path) + emoji_base64 = image_path_to_base64(str(emoji_obj.full_path)) if not emoji_base64: logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}") @@ -291,12 +287,11 @@ def get_descriptions() -> List[str]: list: 所有可用表情包的描述列表 """ try: - emoji_manager = get_emoji_manager() descriptions = [] descriptions.extend( emoji_obj.description - for emoji_obj in emoji_manager.emoji_objects + for emoji_obj in emoji_manager.emojis if not emoji_obj.is_deleted and emoji_obj.description ) return descriptions @@ -341,14 +336,11 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}") # 1. 获取emoji管理器并检查容量 - emoji_manager = get_emoji_manager() - count_before = emoji_manager.emoji_num - max_count = emoji_manager.emoji_num_max + count_before = len(emoji_manager.emojis) + max_count = global_config.emoji.max_reg_num # 2. 检查是否可以注册(未达到上限或启用替换) - can_register = count_before < max_count or ( - count_before >= max_count and emoji_manager.emoji_num_max_reach_deletion - ) + can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace) if not can_register: return { @@ -474,7 +466,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D # 8. 构建返回结果 if register_success: - count_after = emoji_manager.emoji_num + count_after = len(emoji_manager.emojis) replaced = count_after <= count_before # 如果数量没增加,说明是替换 # 尝试获取新注册的表情包信息 @@ -483,10 +475,10 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D # 获取最新的表情包信息 try: # 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变) - for emoji_obj in reversed(emoji_manager.emoji_objects): + for emoji_obj in reversed(emoji_manager.emojis): if not emoji_obj.is_deleted and ( - emoji_obj.filename == filename # 直接匹配 - or (hasattr(emoji_obj, "full_path") and filename in emoji_obj.full_path) # 路径包含匹配 + emoji_obj.file_name == filename + or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path)) ): new_emoji_info = emoji_obj break @@ -495,7 +487,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D description = new_emoji_info.description if new_emoji_info else None emotions = new_emoji_info.emotion if new_emoji_info else None - emoji_hash = new_emoji_info.hash if new_emoji_info else None + emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None return { "success": True, @@ -560,12 +552,14 @@ async def delete_emoji(emoji_hash: str) -> Dict[str, Any]: logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}") # 1. 获取emoji管理器和删除前的数量 - emoji_manager = get_emoji_manager() - count_before = emoji_manager.emoji_num + count_before = len(emoji_manager.emojis) # 2. 获取被删除表情包的信息(用于返回结果) + deleted_emoji = None try: - deleted_emoji = await emoji_manager.get_emoji_from_manager(emoji_hash) + deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db( + emoji_hash + ) description = deleted_emoji.description if deleted_emoji else None emotions = deleted_emoji.emotion if deleted_emoji else None except Exception as info_error: @@ -574,10 +568,12 @@ async def delete_emoji(emoji_hash: str) -> Dict[str, Any]: emotions = None # 3. 执行删除操作 - delete_success = await emoji_manager.delete_emoji(emoji_hash) + delete_success = False + if deleted_emoji: + delete_success = emoji_manager.delete_emoji(deleted_emoji) # 4. 获取删除后的数量 - count_after = emoji_manager.emoji_num + count_after = len(emoji_manager.emojis) # 5. 构建返回结果 if delete_success: @@ -638,8 +634,7 @@ async def delete_emoji_by_description(description: str, exact_match: bool = Fals try: logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})") - emoji_manager = get_emoji_manager() - all_emojis = emoji_manager.emoji_objects + all_emojis = emoji_manager.emojis # 筛选匹配的表情包 matching_emojis = [] @@ -669,12 +664,12 @@ async def delete_emoji_by_description(description: str, exact_match: bool = Fals deleted_hashes = [] for emoji_obj in matching_emojis: try: - delete_success = await emoji_manager.delete_emoji(emoji_obj.hash) + delete_success = emoji_manager.delete_emoji(emoji_obj) if delete_success: deleted_count += 1 - deleted_hashes.append(emoji_obj.hash) + deleted_hashes.append(emoji_obj.emoji_hash) except Exception as delete_error: - logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.hash}): {delete_error}") + logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}") # 构建返回结果 if deleted_count > 0: diff --git a/src/plugin_system/apis/message_api.py b/src/plugin_system/apis/message_api.py index 12578000..b7d1d2cf 100644 --- a/src/plugin_system/apis/message_api.py +++ b/src/plugin_system/apis/message_api.py @@ -10,8 +10,10 @@ import time from typing import List, Dict, Any, Tuple, Optional +from sqlmodel import col, select from src.common.data_models.database_data_model import DatabaseMessages -from src.common.database.database_model import Images +from src.common.database.database import get_db_session +from src.common.database.database_model import Images, ImageType from src.chat.utils.utils import is_bot_self from src.chat.utils.chat_message_builder import ( get_raw_msg_by_timestamp, @@ -516,7 +518,13 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag def translate_pid_to_description(pid: str) -> str: - image = Images.get_or_none(Images.image_id == pid) + with get_db_session() as session: + statement = ( + select(Images).where((col(Images.id) == int(pid)) & (col(Images.image_type) == ImageType.IMAGE)) + if pid.isdigit() + else None + ) + image = session.exec(statement).first() if statement is not None else None description = "" if image and image.description and image.description.strip(): description = image.description.strip() diff --git a/src/webui/routers/annual_report.py b/src/webui/routers/annual_report.py index a0f676f2..a277c070 100644 --- a/src/webui/routers/annual_report.py +++ b/src/webui/routers/annual_report.py @@ -1,23 +1,25 @@ """麦麦 2025 年度总结 API 路由""" -from fastapi import APIRouter, HTTPException, Depends, Cookie, Header -from pydantic import BaseModel, Field -from typing import Dict, Any, List, Optional from datetime import datetime -from sqlalchemy import func as fn +from typing import Any, Optional -from src.common.logger import get_logger +from fastapi import APIRouter, Cookie, Depends, Header, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy import desc, func +from sqlmodel import col, select + +from src.common.database.database import get_db_session from src.common.database.database_model import ( - LLMUsage, - OnlineTime, - Messages, - ChatStreams, - PersonInfo, - Emoji, + ActionRecord, Expression, - ActionRecords, + Images, Jargon, + Messages, + ModelUsage, + OnlineTime, + PersonInfo, ) +from src.common.logger import get_logger from src.webui.core import verify_auth_token_from_cookie_or_header logger = get_logger("webui.annual_report") @@ -45,7 +47,7 @@ class TimeFootprintData(BaseModel): first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)") busiest_day: Optional[str] = Field(None, description="最忙碌的一天") busiest_day_count: int = Field(0, description="最忙碌那天的消息数") - hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布") + hourly_distribution: list[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布") midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数") is_night_owl: bool = Field(False, description="是否是夜猫子") @@ -54,8 +56,8 @@ class SocialNetworkData(BaseModel): """社交网络数据""" total_groups: int = Field(0, description="加入的群组总数") - top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5") - top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5") + top_groups: list[dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5") + top_users: list[dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5") at_count: int = Field(0, description="被@次数") mentioned_count: int = Field(0, description="被提及次数") longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户") @@ -69,11 +71,11 @@ class BrainPowerData(BaseModel): total_cost: float = Field(0.0, description="年度总花费") favorite_model: Optional[str] = Field(None, description="最爱用的模型") favorite_model_count: int = Field(0, description="最爱模型的调用次数") - model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布") - top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5") + model_distribution: list[dict[str, Any]] = Field(default_factory=list, description="模型使用分布") + top_reply_models: list[dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5") most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费") most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间") - top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3") + top_token_consumers: list[dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3") silence_rate: float = Field(0.0, description="高冷指数(沉默率)") total_actions: int = Field(0, description="总动作数") no_reply_count: int = Field(0, description="选择沉默的次数") @@ -88,23 +90,23 @@ class BrainPowerData(BaseModel): class ExpressionVibeData(BaseModel): """个性与表达数据""" - top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王") - top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包") - top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格") + top_emoji: Optional[dict[str, Any]] = Field(None, description="表情包之王") + top_emojis: list[dict[str, Any]] = Field(default_factory=list, description="TOP3表情包") + top_expressions: list[dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格") rejected_expression_count: int = Field(0, description="被拒绝的表达次数") checked_expression_count: int = Field(0, description="已检查的表达次数") total_expressions: int = Field(0, description="表达总数") - action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布") + action_types: list[dict[str, Any]] = Field(default_factory=list, description="动作类型分布") image_processed_count: int = Field(0, description="处理的图片数量") - late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复") - favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复") + late_night_reply: Optional[dict[str, Any]] = Field(None, description="深夜还在回复") + favorite_reply: Optional[dict[str, Any]] = Field(None, description="最喜欢的回复") class AchievementData(BaseModel): """趣味成就数据""" new_jargon_count: int = Field(0, description="新学到的黑话数量") - sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例") + sample_jargons: list[dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例") total_messages: int = Field(0, description="总消息数") total_replies: int = Field(0, description="总回复数") @@ -115,11 +117,11 @@ class AnnualReportData(BaseModel): year: int = Field(2025, description="报告年份") bot_name: str = Field("麦麦", description="Bot名称") generated_at: str = Field(..., description="报告生成时间") - time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData) - social_network: SocialNetworkData = Field(default_factory=SocialNetworkData) - brain_power: BrainPowerData = Field(default_factory=BrainPowerData) - expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData) - achievements: AchievementData = Field(default_factory=AchievementData) + time_footprint: TimeFootprintData = Field(default_factory=lambda: TimeFootprintData.model_construct()) + social_network: SocialNetworkData = Field(default_factory=lambda: SocialNetworkData.model_construct()) + brain_power: BrainPowerData = Field(default_factory=lambda: BrainPowerData.model_construct()) + expression_vibe: ExpressionVibeData = Field(default_factory=lambda: ExpressionVibeData.model_construct()) + achievements: AchievementData = Field(default_factory=lambda: AchievementData.model_construct()) # ==================== 辅助函数 ==================== @@ -144,15 +146,18 @@ def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]: async def get_time_footprint(year: int = 2025) -> TimeFootprintData: """获取时光足迹数据""" - data = TimeFootprintData() + data = TimeFootprintData.model_construct() start_ts, end_ts = get_year_time_range(year) start_dt, end_dt = get_year_datetime_range(year) try: # 1. 年度在线时长 - online_records = list( - OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt)) - ) + with get_db_session() as session: + statement = select(OnlineTime).where( + col(OnlineTime.start_timestamp) >= start_dt, + col(OnlineTime.end_timestamp) <= end_dt, + ) + online_records = session.exec(statement).all() total_seconds = 0 for record in online_records: try: @@ -165,50 +170,66 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData: data.total_online_hours = round(total_seconds / 3600, 2) # 2. 初次相遇 - 年度第一条消息 - first_msg = ( - Messages.select() - .where((Messages.time >= start_ts) & (Messages.time <= end_ts)) - .order_by(Messages.time.asc()) - .first() - ) + with get_db_session() as session: + statement = ( + select(Messages) + .where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + .order_by(col(Messages.timestamp).asc()) + .limit(1) + ) + first_msg = session.exec(statement).first() if first_msg: - data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S") + data.first_message_time = first_msg.timestamp.strftime("%Y-%m-%d %H:%M:%S") data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户" content = first_msg.processed_plain_text or first_msg.display_message or "" data.first_message_content = content[:50] + "..." if len(content) > 50 else content # 3. 最忙碌的一天 # 使用 SQLite 的 date 函数按日期分组 - busiest_query = ( - Messages.select( - fn.date(Messages.time, "unixepoch").alias("day"), - fn.COUNT(Messages.id).alias("count"), + day_expr = func.date(col(Messages.timestamp)) + with get_db_session() as session: + statement = ( + select( + day_expr.label("day"), + func.count().label("count"), + ) + .where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + .group_by(day_expr) + .order_by(func.count().desc()) + .limit(1) ) - .where((Messages.time >= start_ts) & (Messages.time <= end_ts)) - .group_by(fn.date(Messages.time, "unixepoch")) - .order_by(fn.COUNT(Messages.id).desc()) - .limit(1) - ) - busiest_result = list(busiest_query.dicts()) + busiest_result = session.exec(statement).all() if busiest_result: - data.busiest_day = busiest_result[0].get("day") - data.busiest_day_count = busiest_result[0].get("count", 0) + data.busiest_day = busiest_result[0][0] + data.busiest_day_count = busiest_result[0][1] or 0 # 4. 昼夜节律 - 24小时活跃分布 - hourly_query = ( - Messages.select( - fn.strftime("%H", Messages.time, "unixepoch").alias("hour"), - fn.COUNT(Messages.id).alias("count"), + hour_expr = func.strftime("%H", col(Messages.timestamp)) + with get_db_session() as session: + statement = ( + select( + hour_expr.label("hour"), + func.count().label("count"), + ) + .where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + .group_by(hour_expr) ) - .where((Messages.time >= start_ts) & (Messages.time <= end_ts)) - .group_by(fn.strftime("%H", Messages.time, "unixepoch")) - ) + hourly_rows = session.exec(statement).all() hourly_distribution = [0] * 24 - for row in hourly_query.dicts(): + for row in hourly_rows: try: - hour = int(row.get("hour", 0)) + hour = int(row[0] or 0) if 0 <= hour < 24: - hourly_distribution[hour] = row.get("count", 0) + hourly_distribution[hour] = row[1] or 0 except (ValueError, TypeError): continue data.hourly_distribution = hourly_distribution @@ -234,7 +255,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: """获取社交网络数据""" from src.config.config import global_config - data = SocialNetworkData() + data = SocialNetworkData.model_construct() start_ts, end_ts = get_year_time_range(year) # 获取 bot 自身的 QQ 账号,用于过滤 @@ -242,91 +263,110 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: try: # 1. 加入的群组总数 - data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count() + with get_db_session() as session: + statement = select(func.count(func.distinct(col(Messages.group_id)))).where( + col(Messages.group_id).is_not(None), + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + data.total_groups = int(session.exec(statement).first() or 0) # 2. 话痨群组 TOP3 - top_groups_query = ( - Messages.select( - Messages.chat_info_group_id, - Messages.chat_info_group_name, - fn.COUNT(Messages.id).alias("count"), + with get_db_session() as session: + statement = ( + select( + col(Messages.group_id), + func.max(col(Messages.group_name)).label("group_name"), + func.count().label("count"), + ) + .where( + col(Messages.group_id).is_not(None), + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + .group_by(col(Messages.group_id)) + .order_by(func.count().desc()) + .limit(5) ) - .where( - (Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False)) - ) - .group_by(Messages.chat_info_group_id) - .order_by(fn.COUNT(Messages.id).desc()) - .limit(5) - ) + top_groups_rows = session.exec(statement).all() data.top_groups = [ { - "group_id": row["chat_info_group_id"], - "group_name": row["chat_info_group_name"] or "未知群组", - "message_count": row["count"], - "is_webui": str(row["chat_info_group_id"]).startswith("webui_"), + "group_id": row[0], + "group_name": row[1] or "未知群组", + "message_count": row[2] or 0, + "is_webui": str(row[0]).startswith("webui_"), } - for row in top_groups_query.dicts() + for row in top_groups_rows ] # 3. 互动最多的用户 TOP5(过滤 bot 自身) - top_users_query = ( - Messages.select( - Messages.user_id, - Messages.user_nickname, - fn.COUNT(Messages.id).alias("count"), + with get_db_session() as session: + statement = ( + select( + col(Messages.user_id), + func.max(col(Messages.user_nickname)).label("user_nickname"), + func.count().label("count"), + ) + .where( + col(Messages.user_id).is_not(None), + col(Messages.user_id) != bot_qq, + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + .group_by(col(Messages.user_id)) + .order_by(func.count().desc()) + .limit(5) ) - .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.user_id.is_null(False)) - & (Messages.user_id != bot_qq) # 过滤 bot 自身 - ) - .group_by(Messages.user_id) - .order_by(fn.COUNT(Messages.id).desc()) - .limit(5) - ) + top_users_rows = session.exec(statement).all() data.top_users = [ { - "user_id": row["user_id"], - "user_nickname": row["user_nickname"] or "未知用户", - "message_count": row["count"], - "is_webui": str(row["user_id"]).startswith("webui_"), + "user_id": row[0], + "user_nickname": row[1] or "未知用户", + "message_count": row[2] or 0, + "is_webui": str(row[0]).startswith("webui_"), } - for row in top_users_query.dicts() + for row in top_users_rows ] # 4. 被@次数 - data.at_count = ( - Messages.select() - .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True)) - .count() - ) + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + col(Messages.is_at) == True, + ) + data.at_count = int(session.exec(statement).first() or 0) # 5. 被提及次数 - data.mentioned_count = ( - Messages.select() - .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True)) - .count() - ) + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + col(Messages.is_mentioned) == True, + ) + data.mentioned_count = int(session.exec(statement).first() or 0) # 6. 最长情陪伴的用户(过滤 bot 自身) - companion_query = ( - ChatStreams.select( - ChatStreams.user_id, - ChatStreams.user_nickname, - (ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"), + with get_db_session() as session: + statement = select(PersonInfo).where( + col(PersonInfo.user_id) != bot_qq, + col(PersonInfo.first_known_time).is_not(None), + col(PersonInfo.last_known_time).is_not(None), ) - .where( - (ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身 - ) - .order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc()) - .limit(1) - ) - companion_result = list(companion_query.dicts()) - if companion_result: - data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户" - duration = companion_result[0].get("duration", 0) or 0 - data.longest_companion_days = int(duration / 86400) # 转换为天 + persons = session.exec(statement).all() + if persons: + + def _companion_days(person: PersonInfo) -> float: + if not person.first_known_time or not person.last_known_time: + return 0.0 + return (person.last_known_time - person.first_known_time).total_seconds() + + longest = max(persons, key=_companion_days) + data.longest_companion_user = longest.person_name or longest.user_nickname or longest.user_id + data.longest_companion_days = int(_companion_days(longest) / 86400) + else: + data.longest_companion_user = None + data.longest_companion_days = 0 except Exception as e: logger.error(f"获取社交网络数据失败: {e}") @@ -339,154 +379,139 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData: async def get_brain_power(year: int = 2025) -> BrainPowerData: """获取最强大脑数据""" - data = BrainPowerData() + data = BrainPowerData.model_construct() start_dt, end_dt = get_year_datetime_range(year) start_ts, end_ts = get_year_time_range(year) try: # 1. 年度消耗 Token 总量和总花费 - token_query = LLMUsage.select( - fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"), - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"), - ).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt)) - result = token_query.dicts().get() - data.total_tokens = int(result.get("total_tokens", 0) or 0) - data.total_cost = round(float(result.get("total_cost", 0) or 0), 4) + with get_db_session() as session: + statement = select( + func.sum(col(ModelUsage.total_tokens)).label("total_tokens"), + func.sum(col(ModelUsage.cost)).label("total_cost"), + ).where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt) + result = session.exec(statement).first() + if result: + data.total_tokens = int(result[0] or 0) + data.total_cost = round(float(result[1] or 0), 4) # 2. 最爱用的模型 - model_query = ( - LLMUsage.select( - fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"), - fn.COUNT(LLMUsage.id).alias("count"), - fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"), - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"), + with get_db_session() as session: + statement = ( + select(ModelUsage) + .where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt) + .order_by(desc(col(ModelUsage.timestamp))) ) - .where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt)) - .group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name)) - .order_by(fn.COUNT(LLMUsage.id).desc()) - .limit(10) - ) - model_results = list(model_query.dicts()) + records = session.exec(statement).all() + + model_agg: dict[str, dict[str, float | int]] = {} + for record in records: + model_name = record.model_assign_name or record.model_name or "unknown" + if model_name not in model_agg: + model_agg[model_name] = {"count": 0, "tokens": 0, "cost": 0.0} + bucket = model_agg[model_name] + bucket["count"] = int(bucket["count"]) + 1 + bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0) + bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0) + + model_results = sorted( + model_agg.items(), + key=lambda item: float(item[1]["count"]), + reverse=True, + )[:10] if model_results: - data.favorite_model = model_results[0].get("model") - data.favorite_model_count = model_results[0].get("count", 0) + data.favorite_model = model_results[0][0] + data.favorite_model_count = int(model_results[0][1]["count"]) data.model_distribution = [ { - "model": row["model"], - "count": row["count"], - "tokens": row["tokens"], - "cost": round(row["cost"], 4), + "model": model_name, + "count": int(bucket["count"]), + "tokens": int(bucket["tokens"]), + "cost": round(float(bucket["cost"]), 4), } - for row in model_results + for model_name, bucket in model_results ] # 3. 最昂贵的一次思考 - expensive_query = ( - LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp) - .where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt)) - .order_by(LLMUsage.cost.desc()) - .limit(1) - ) - expensive_result = expensive_query.first() - if expensive_result: - data.most_expensive_cost = round(expensive_result.cost or 0, 4) - data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S") + if records: + expensive_record = max(records, key=lambda record: record.cost or 0.0) + data.most_expensive_cost = round(expensive_record.cost or 0.0, 4) + data.most_expensive_time = expensive_record.timestamp.strftime("%Y-%m-%d %H:%M:%S") # 4. 烧钱大户 TOP3 (按用户,过滤 system) - consumer_query = ( - LLMUsage.select( - LLMUsage.user_id, - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"), - fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"), - ) - .where( - (LLMUsage.timestamp >= start_dt) - & (LLMUsage.timestamp <= end_dt) - & (LLMUsage.user_id != "system") # 过滤 system 用户 - & (LLMUsage.user_id.is_null(False)) - ) - .group_by(LLMUsage.user_id) - .order_by(fn.SUM(LLMUsage.cost).desc()) - .limit(3) - ) + consumer_agg: dict[str, dict[str, float | int]] = {} + for record in records: + user_id = record.model_api_provider_name + if not user_id or user_id == "system": + continue + if user_id not in consumer_agg: + consumer_agg[user_id] = {"cost": 0.0, "tokens": 0} + bucket = consumer_agg[user_id] + bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0) + bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0) + data.top_token_consumers = [ { - "user_id": row["user_id"], - "cost": round(row["cost"], 4), - "tokens": row["tokens"], + "user_id": user_id, + "cost": round(float(bucket["cost"]), 4), + "tokens": int(bucket["tokens"]), } - for row in consumer_query.dicts() + for user_id, bucket in sorted( + consumer_agg.items(), + key=lambda item: float(item[1]["cost"]), + reverse=True, + )[:3] ] # 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用) # 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别 - reply_model_query = ( - LLMUsage.select( - fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"), - fn.COUNT(LLMUsage.id).alias("count"), - ) - .where( - (LLMUsage.timestamp >= start_dt) - & (LLMUsage.timestamp <= end_dt) - & ( - LLMUsage.model_assign_name.contains("replyer") - | LLMUsage.model_assign_name.contains("回复") - | LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况 - ) - ) - .group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name)) - .order_by(fn.COUNT(LLMUsage.id).desc()) - .limit(5) - ) - data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()] + reply_model_agg: dict[str, int] = {} + for record in records: + model_assign_name = record.model_assign_name or "" + if "replyer" not in model_assign_name and "回复" not in model_assign_name: + continue + model_name = model_assign_name or record.model_name or "unknown" + reply_model_agg[model_name] = reply_model_agg.get(model_name, 0) + 1 + data.top_reply_models = [ + {"model": model_name, "count": count} + for model_name, count in sorted(reply_model_agg.items(), key=lambda item: item[1], reverse=True)[:5] + ] # 6. 高冷指数 (沉默率) - 基于 ActionRecords - total_actions = ( - ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count() - ) - no_reply_count = ( - ActionRecords.select() - .where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_name == "no_reply") + with get_db_session() as session: + statement = select(func.count()).where( + col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), + col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), ) - .count() - ) + total_actions = int(session.exec(statement).first() or 0) + with get_db_session() as session: + statement = select(func.count()).where( + col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), + col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), + col(ActionRecord.action_name) == "no_reply", + ) + no_reply_count = int(session.exec(statement).first() or 0) data.total_actions = total_actions data.no_reply_count = no_reply_count data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0 # 6. 情绪波动 (兴趣值) - interest_query = Messages.select( - fn.AVG(Messages.interest_value).alias("avg_interest"), - fn.MAX(Messages.interest_value).alias("max_interest"), - ).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False))) - interest_result = interest_query.dicts().get() - data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2) - data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2) + data.avg_interest_value = 0.0 + data.max_interest_value = 0.0 # 找到最高兴趣值的时间 if data.max_interest_value > 0: - max_interest_msg = ( - Messages.select(Messages.time) - .where( - (Messages.time >= start_ts) - & (Messages.time <= end_ts) - & (Messages.interest_value == data.max_interest_value) - ) - .first() - ) - if max_interest_msg: - data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S") + data.max_interest_time = None # 7. 思考深度 (基于 action_reasoning 长度) - reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_reasoning.is_null(False)) - & (ActionRecords.action_reasoning != "") - ) + with get_db_session() as session: + statement = select(ActionRecord).where( + col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), + col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), + col(ActionRecord.action_reasoning).is_not(None), + col(ActionRecord.action_reasoning) != "", + ) + reasoning_records = session.exec(statement).all() reasoning_lengths = [] max_len = 0 max_len_time = None @@ -496,13 +521,13 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData: reasoning_lengths.append(length) if length > max_len: max_len = length - max_len_time = record.time + max_len_time = record.timestamp if reasoning_lengths: data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1) data.max_reasoning_length = max_len if max_len_time: - data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S") + data.max_reasoning_time = max_len_time.strftime("%Y-%m-%d %H:%M:%S") except Exception as e: logger.error(f"获取最强大脑数据失败: {e}") @@ -517,7 +542,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: """获取个性与表达数据""" from src.config.config import global_config - data = ExpressionVibeData() + data = ExpressionVibeData.model_construct() start_ts, end_ts = get_year_time_range(year) # 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息 @@ -525,75 +550,58 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: try: # 1. 表情包之王 - 使用次数最多的表情包 - top_emoji_query = ( - Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash) - .where(Emoji.is_registered == True) - .order_by(Emoji.usage_count.desc()) - .limit(5) - ) - top_emojis = list(top_emoji_query.dicts()) + with get_db_session() as session: + statement = ( + select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5) + ) + top_emojis = session.exec(statement).all() if top_emojis: data.top_emoji = { - "id": top_emojis[0].get("id"), - "path": top_emojis[0].get("full_path"), - "description": top_emojis[0].get("description"), - "usage_count": top_emojis[0].get("usage_count", 0), - "hash": top_emojis[0].get("emoji_hash"), + "id": top_emojis[0].id, + "path": top_emojis[0].full_path, + "description": top_emojis[0].description, + "usage_count": top_emojis[0].query_count, + "hash": top_emojis[0].image_hash, } data.top_emojis = [ { - "id": e.get("id"), - "path": e.get("full_path"), - "description": e.get("description"), - "usage_count": e.get("usage_count", 0), - "hash": e.get("emoji_hash"), + "id": e.id, + "path": e.full_path, + "description": e.description, + "usage_count": e.query_count, + "hash": e.image_hash, } for e in top_emojis ] # 2. 百变麦麦 - 最常用的表达风格 - expression_query = ( - Expression.select( - Expression.style, - fn.SUM(Expression.count).alias("total_count"), + with get_db_session() as session: + statement = ( + select(Expression.style, func.sum(col(Expression.count)).label("total_count")) + .where( + col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts), + col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts), + ) + .group_by(Expression.style) + .order_by(func.sum(col(Expression.count)).desc()) + .limit(5) ) - .where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts)) - .group_by(Expression.style) - .order_by(fn.SUM(Expression.count).desc()) - .limit(5) - ) - data.top_expressions = [ - {"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts() - ] + expression_rows = session.exec(statement).all() + data.top_expressions = [{"style": row[0], "count": row[1] or 0} for row in expression_rows] # 3. 被拒绝的表达 - data.rejected_expression_count = ( - Expression.select() - .where( - (Expression.last_active_time >= start_ts) - & (Expression.last_active_time <= end_ts) - & (Expression.rejected == True) - ) - .count() - ) + data.rejected_expression_count = 0 # 4. 已检查的表达 - data.checked_expression_count = ( - Expression.select() - .where( - (Expression.last_active_time >= start_ts) - & (Expression.last_active_time <= end_ts) - & (Expression.checked == True) - ) - .count() - ) + data.checked_expression_count = 0 # 5. 表达总数 - data.total_expressions = ( - Expression.select() - .where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts)) - .count() - ) + with get_db_session() as session: + statement = select(func.count()).where( + col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts), + col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts), + ) + data.total_expressions = int(session.exec(statement).first() or 0) # 6. 动作类型分布 (过滤无意义的动作) # 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore @@ -608,28 +616,29 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: "listening", "block_and_ignore", ] - action_query = ( - ActionRecords.select( - ActionRecords.action_name, - fn.COUNT(ActionRecords.id).alias("count"), + with get_db_session() as session: + statement = ( + select(ActionRecord.action_name, func.count().label("count")) + .where( + col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), + col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), + col(ActionRecord.action_name).not_in(excluded_actions), + ) + .group_by(ActionRecord.action_name) + .order_by(func.count().desc()) + .limit(10) ) - .where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_name.not_in(excluded_actions)) - ) - .group_by(ActionRecords.action_name) - .order_by(fn.COUNT(ActionRecords.id).desc()) - .limit(10) - ) - data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()] + action_rows = session.exec(statement).all() + data.action_types = [{"action": row[0], "count": row[1]} for row in action_rows] # 7. 处理的图片数量 - data.image_processed_count = ( - Messages.select() - .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True)) - .count() - ) + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + col(Messages.is_picture) == True, + ) + data.image_processed_count = int(session.exec(statement).first() or 0) # 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条) import random @@ -648,21 +657,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: return content # 使用 user_id 判断是否是 bot 发送的消息 - late_night_messages = list( - Messages.select( - Messages.time, - Messages.processed_plain_text, - Messages.display_message, + with get_db_session() as session: + statement = ( + select(Messages) + .where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + col(Messages.user_id) == bot_qq, + ) + .order_by(desc(col(Messages.timestamp))) + .limit(200) ) - .where( - (Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息 - ) - .order_by(Messages.time.desc()) - ) + late_night_messages = session.exec(statement).all() # 筛选出0-6点的消息 late_night_filtered = [] for msg in late_night_messages: - msg_dt = datetime.fromtimestamp(msg.time) + msg_dt = msg.timestamp hour = msg_dt.hour if 0 <= hour < 6: # 0点到6点 raw_content = msg.processed_plain_text or msg.display_message or "" @@ -671,7 +681,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: if cleaned_content and len(cleaned_content) > 2: late_night_filtered.append( { - "time": msg.time, + "time": msg_dt.timestamp(), "hour": hour, "minute": msg_dt.minute, "content": cleaned_content, @@ -693,13 +703,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: from collections import Counter import json as json_lib - reply_records = ActionRecords.select(ActionRecords.action_data).where( - (ActionRecords.time >= start_ts) - & (ActionRecords.time <= end_ts) - & (ActionRecords.action_name == "reply") - & (ActionRecords.action_data.is_null(False)) - & (ActionRecords.action_data != "") - ) + with get_db_session() as session: + statement = select(ActionRecord).where( + col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts), + col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts), + col(ActionRecord.action_name) == "reply", + col(ActionRecord.action_data).is_not(None), + col(ActionRecord.action_data) != "", + ) + reply_records = session.exec(statement).all() reply_contents = [] for record in reply_records: @@ -762,21 +774,20 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData: async def get_achievements(year: int = 2025) -> AchievementData: """获取趣味成就数据""" - data = AchievementData() + data = AchievementData.model_construct() start_ts, end_ts = get_year_time_range(year) try: # 1. 新学到的黑话数量 # Jargon 表没有时间字段,统计全部已确认的黑话 - data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count() + with get_db_session() as session: + statement = select(func.count()).where(col(Jargon.is_jargon) == True) + data.new_jargon_count = int(session.exec(statement).first() or 0) # 2. 代表性黑话示例 - jargon_samples = ( - Jargon.select(Jargon.content, Jargon.meaning, Jargon.count) - .where(Jargon.is_jargon == True) - .order_by(Jargon.count.desc()) - .limit(5) - ) + with get_db_session() as session: + statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5) + jargon_samples = session.exec(statement).all() data.sample_jargons = [ { "content": j.content, @@ -787,14 +798,21 @@ async def get_achievements(year: int = 2025) -> AchievementData: ] # 3. 总消息数 - data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count() + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + ) + data.total_messages = int(session.exec(statement).first() or 0) # 4. 总回复数 (有 reply_to 的消息) - data.total_replies = ( - Messages.select() - .where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False))) - .count() - ) + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= datetime.fromtimestamp(start_ts), + col(Messages.timestamp) <= datetime.fromtimestamp(end_ts), + col(Messages.reply_to).is_not(None), + ) + data.total_replies = int(session.exec(statement).first() or 0) except Exception as e: logger.error(f"获取趣味成就数据失败: {e}") diff --git a/src/webui/routers/chat.py b/src/webui/routers/chat.py index f666e85c..e1e71780 100644 --- a/src/webui/routers/chat.py +++ b/src/webui/routers/chat.py @@ -7,16 +7,19 @@ import time import uuid -from typing import Dict, Any, Optional, List -from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header -from pydantic import BaseModel -from sqlalchemy import case, func as fn +from typing import Any, Dict, List, Optional + +from fastapi import APIRouter, Cookie, Depends, Header, Query, WebSocket, WebSocketDisconnect +from pydantic import BaseModel +from sqlalchemy import case, desc, func +from sqlmodel import col, select, delete -from src.common.logger import get_logger -from src.common.database.database_model import Messages, PersonInfo -from src.config.config import global_config from src.chat.message_receive.bot import chat_bot -from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager +from src.common.database.database import get_db_session +from src.common.database.database_model import Messages, PersonInfo +from src.common.logger import get_logger +from src.config.config import global_config +from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header from src.webui.routers.websocket.auth import verify_ws_token logger = get_logger("webui.chat") @@ -97,7 +100,7 @@ class ChatHistoryManager: "id": msg.message_id, "type": "bot" if is_bot else "user", "content": msg.processed_plain_text or msg.display_message or "", - "timestamp": msg.time, + "timestamp": msg.timestamp.timestamp(), "sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"), "sender_id": "bot" if is_bot else user_id, "is_bot": is_bot, @@ -113,12 +116,14 @@ class ChatHistoryManager: target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID try: # 查询指定群的消息,按时间排序 - messages = ( - Messages.select() - .where(Messages.chat_info_group_id == target_group_id) - .order_by(Messages.time.desc()) - .limit(limit) - ) + with get_db_session() as session: + statement = ( + select(Messages) + .where(col(Messages.group_id) == target_group_id) + .order_by(desc(col(Messages.timestamp))) + .limit(limit) + ) + messages = session.exec(statement).all() # 转换为列表并反转(使最旧的消息在前) # 传递 group_id 以便正确判断虚拟群中的机器人消息 @@ -139,7 +144,10 @@ class ChatHistoryManager: """ target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID try: - deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute() + with get_db_session() as session: + statement = delete(Messages).where(col(Messages.group_id) == target_group_id) + result = session.exec(statement) + deleted = result.rowcount or 0 logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})") return deleted except Exception as e: @@ -172,14 +180,14 @@ class ChatConnectionManager: del self.user_sessions[user_id] logger.info(f"WebUI 聊天会话已断开: session={session_id}") - async def send_message(self, session_id: str, message: dict): + async def send_message(self, session_id: str, message: dict[str, Any]): if session_id in self.active_connections: try: await self.active_connections[session_id].send_json(message) except Exception as e: logger.error(f"发送消息失败: {e}") - async def broadcast(self, message: dict): + async def broadcast(self, message: dict[str, Any]): """广播消息给所有连接""" for session_id in list(self.active_connections.keys()): await self.send_message(session_id, message) @@ -292,16 +300,18 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)): """ try: # 查询所有不同的平台 - platforms = ( - PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count")) - .group_by(PersonInfo.platform) - .order_by(fn.COUNT(PersonInfo.id).desc()) - ) + with get_db_session() as session: + statement = ( + select(PersonInfo.platform, func.count().label("count")) + .group_by(PersonInfo.platform) + .order_by(func.count().desc()) + ) + platforms = session.exec(statement).all() result = [] - for p in platforms: - if p.platform: # 排除空平台 - result.append({"platform": p.platform, "count": p.count}) + for platform, count in platforms: + if platform: + result.append({"platform": platform, "count": count}) return {"success": True, "platforms": result} except Exception as e: @@ -325,31 +335,36 @@ async def get_persons_by_platform( """ try: # 构建查询 - query = PersonInfo.select().where(PersonInfo.platform == platform) + statement = select(PersonInfo).where(col(PersonInfo.platform) == platform) # 搜索过滤 if search: - query = query.where( - (PersonInfo.person_name.contains(search)) - | (PersonInfo.nickname.contains(search)) - | (PersonInfo.user_id.contains(search)) + statement = statement.where( + (col(PersonInfo.person_name).contains(search)) + | (col(PersonInfo.user_nickname).contains(search)) + | (col(PersonInfo.user_id).contains(search)) ) # 按最后交互时间排序,优先显示活跃用户 - query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc()) - query = query.limit(limit) + statement = statement.order_by( + case((col(PersonInfo.last_known_time).is_(None), 1), else_=0), + col(PersonInfo.last_known_time).desc(), + ).limit(limit) + + with get_db_session() as session: + persons = session.exec(statement).all() result = [] - for person in query: + for person in persons: result.append( { "person_id": person.person_id, "user_id": person.user_id, "person_name": person.person_name, - "nickname": person.nickname, + "nickname": person.user_nickname, "is_known": person.is_known, "platform": person.platform, - "display_name": person.person_name or person.nickname or person.user_id, + "display_name": person.person_name or person.user_nickname or person.user_id, } ) @@ -448,7 +463,9 @@ async def websocket_chat( # 如果 URL 参数中提供了虚拟身份信息,自动配置 if platform and person_id: try: - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() if person: # 使用前端传递的 group_id,如果没有则生成一个稳定的 virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}" @@ -457,7 +474,7 @@ async def websocket_chat( platform=person.platform, person_id=person.person_id, user_id=person.user_id, - user_nickname=person.person_name or person.nickname or person.user_id, + user_nickname=person.person_name or person.user_nickname or person.user_id, group_id=virtual_group_id, group_name=group_name or "WebUI虚拟群聊", ) @@ -471,7 +488,7 @@ async def websocket_chat( try: # 构建会话信息 - session_info_data = { + session_info_data: dict[str, Any] = { "type": "session_info", "session_id": session_id, "user_id": user_id, @@ -641,7 +658,13 @@ async def websocket_chat( # 获取用户信息 try: - person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id")) + with get_db_session() as session: + statement = ( + select(PersonInfo) + .where(col(PersonInfo.person_id) == virtual_data.get("person_id")) + .limit(1) + ) + person = session.exec(statement).first() if not person: await chat_manager.send_message( session_id, @@ -665,7 +688,7 @@ async def websocket_chat( platform=person.platform, person_id=person.person_id, user_id=person.user_id, - user_nickname=person.person_name or person.nickname or person.user_id, + user_nickname=person.person_name or person.user_nickname or person.user_id, group_id=group_id, group_name=virtual_data.get("group_name", "WebUI虚拟群聊"), ) @@ -769,7 +792,7 @@ async def get_chat_info(_auth: bool = Depends(require_auth)): } -def get_webui_chat_broadcaster() -> tuple: +def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]: """获取 WebUI 聊天广播器,供外部模块使用 Returns: diff --git a/src/webui/routers/expression.py b/src/webui/routers/expression.py index 67495263..0b051591 100644 --- a/src/webui/routers/expression.py +++ b/src/webui/routers/expression.py @@ -3,11 +3,16 @@ from fastapi import APIRouter, HTTPException, Header, Query, Cookie from pydantic import BaseModel from typing import Optional, List, Dict -from sqlalchemy import case +from datetime import datetime, timedelta + +from sqlalchemy import case, func +from sqlmodel import col, select, delete + from src.common.logger import get_logger -from src.common.database.database_model import Expression, ChatStreams +from src.common.database.database import get_db_session +from src.common.database.database_model import Expression +from src.chat.message_receive.chat_stream import get_chat_manager from src.webui.core import verify_auth_token_from_cookie_or_header -import time logger = get_logger("webui.expression") @@ -98,30 +103,32 @@ def verify_auth_token( def expression_to_response(expression: Expression) -> ExpressionResponse: """将 Expression 模型转换为响应对象""" + last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0 + create_date = expression.create_time.timestamp() if expression.create_time else None return ExpressionResponse( - id=expression.id, + id=expression.id if expression.id is not None else 0, situation=expression.situation, style=expression.style, - last_active_time=expression.last_active_time, - chat_id=expression.chat_id, - create_date=expression.create_date, - checked=expression.checked, - rejected=expression.rejected, - modified_by=expression.modified_by, + last_active_time=last_active_time, + chat_id=expression.session_id or "", + create_date=create_date, + checked=False, + rejected=False, + modified_by=None, ) def get_chat_name(chat_id: str) -> str: """根据 chat_id 获取聊天名称""" try: - chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) - if chat_stream: - # 优先使用群聊名称,否则使用用户昵称 - if chat_stream.group_name: - return chat_stream.group_name - elif chat_stream.user_nickname: - return chat_stream.user_nickname - return chat_id # 找不到时返回原始ID + chat_stream = get_chat_manager().get_stream(chat_id) + if not chat_stream: + return chat_id + if chat_stream.group_info and chat_stream.group_info.group_name: + return chat_stream.group_info.group_name + if chat_stream.user_info and chat_stream.user_info.user_nickname: + return chat_stream.user_info.user_nickname + return chat_id except Exception: return chat_id @@ -130,12 +137,15 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]: """批量获取聊天名称""" result = {cid: cid for cid in chat_ids} # 默认值为原始ID try: - chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids)) - for cs in chat_streams: - if cs.group_name: - result[cs.stream_id] = cs.group_name - elif cs.user_nickname: - result[cs.stream_id] = cs.user_nickname + chat_manager = get_chat_manager() + for chat_id in chat_ids: + chat_stream = chat_manager.get_stream(chat_id) + if not chat_stream: + continue + if chat_stream.group_info and chat_stream.group_info.group_name: + result[chat_id] = chat_stream.group_info.group_name + elif chat_stream.user_info and chat_stream.user_info.user_nickname: + result[chat_id] = chat_stream.user_info.user_nickname except Exception as e: logger.warning(f"批量获取聊天名称失败: {e}") return result @@ -172,14 +182,17 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat verify_auth_token(maibot_session, authorization) chat_list = [] - for cs in ChatStreams.select(): - chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id) + for stream_id, stream in get_chat_manager().streams.items(): + chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None + if not chat_name and stream.user_info and stream.user_info.user_nickname: + chat_name = stream.user_info.user_nickname + chat_name = chat_name or stream_id chat_list.append( ChatInfo( - chat_id=cs.stream_id, + chat_id=stream_id, chat_name=chat_name, - platform=cs.platform, - is_group=bool(cs.group_id), + platform=stream.platform, + is_group=bool(stream.group_info and stream.group_info.group_id), ) ) @@ -221,29 +234,39 @@ async def get_expression_list( verify_auth_token(maibot_session, authorization) # 构建查询 - query = Expression.select() + statement = select(Expression) # 搜索过滤 if search: - query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search))) + statement = statement.where( + (col(Expression.situation).contains(search)) | (col(Expression.style).contains(search)) + ) # 聊天ID过滤 if chat_id: - query = query.where(Expression.chat_id == chat_id) + statement = statement.where(col(Expression.session_id) == chat_id) # 排序:最后活跃时间倒序(NULL 值放在最后) - query = query.order_by( - case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc() + statement = statement.order_by( + case((col(Expression.last_active_time).is_(None), 1), else_=0), + col(Expression.last_active_time).desc(), ) - # 获取总数 - total = query.count() - - # 分页 offset = (page - 1) * page_size - expressions = query.offset(offset).limit(page_size) + statement = statement.offset(offset).limit(page_size) + + with get_db_session() as session: + expressions = session.exec(statement).all() + + count_statement = select(Expression.id) + if search: + count_statement = count_statement.where( + (col(Expression.situation).contains(search)) | (col(Expression.style).contains(search)) + ) + if chat_id: + count_statement = count_statement.where(col(Expression.session_id) == chat_id) + total = len(session.exec(count_statement).all()) - # 转换为响应对象 data = [expression_to_response(expr) for expr in expressions] return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data) @@ -272,7 +295,9 @@ async def get_expression_detail( try: verify_auth_token(maibot_session, authorization) - expression = Expression.get_or_none(Expression.id == expression_id) + with get_db_session() as session: + statement = select(Expression).where(col(Expression.id) == expression_id).limit(1) + expression = session.exec(statement).first() if not expression: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") @@ -305,16 +330,22 @@ async def create_expression( try: verify_auth_token(maibot_session, authorization) - current_time = time.time() + current_time = datetime.now() # 创建表达方式 - expression = Expression.create( - situation=request.situation, - style=request.style, - chat_id=request.chat_id, - last_active_time=current_time, - create_date=current_time, - ) + with get_db_session() as session: + expression = Expression( + situation=request.situation, + style=request.style, + context="", + up_content="", + content_list="[]", + count=0, + last_active_time=current_time, + create_time=current_time, + session_id=request.chat_id, + ) + session.add(expression) logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}") @@ -350,16 +381,18 @@ async def update_expression( try: verify_auth_token(maibot_session, authorization) - expression = Expression.get_or_none(Expression.id == expression_id) + with get_db_session() as session: + statement = select(Expression).where(col(Expression.id) == expression_id).limit(1) + expression = session.exec(statement).first() if not expression: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") # 冲突检测:如果要求未检查状态,但已经被检查了 - if request.require_unchecked and expression.checked: + if request.require_unchecked and getattr(expression, "checked", False): raise HTTPException( status_code=409, - detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表", + detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表", ) # 只更新提供的字段 @@ -376,13 +409,18 @@ async def update_expression( update_data["modified_by"] = "user" # 更新最后活跃时间 - update_data["last_active_time"] = time.time() + update_data["last_active_time"] = datetime.now() # 执行更新 - for field, value in update_data.items(): - setattr(expression, field, value) - - expression.save() + with get_db_session() as session: + db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first() + if not db_expression: + raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") + for field, value in update_data.items(): + if hasattr(db_expression, field): + setattr(db_expression, field, value) + session.add(db_expression) + expression = db_expression logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}") @@ -414,7 +452,9 @@ async def delete_expression( try: verify_auth_token(maibot_session, authorization) - expression = Expression.get_or_none(Expression.id == expression_id) + with get_db_session() as session: + statement = select(Expression).where(col(Expression.id) == expression_id).limit(1) + expression = session.exec(statement).first() if not expression: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式") @@ -423,7 +463,8 @@ async def delete_expression( situation = expression.situation # 执行删除 - expression.delete_instance() + with get_db_session() as session: + session.exec(delete(Expression).where(col(Expression.id) == expression_id)) logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}") @@ -465,8 +506,9 @@ async def batch_delete_expressions( raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID") # 查找所有要删除的表达方式 - expressions = Expression.select().where(Expression.id.in_(request.ids)) - found_ids = [expr.id for expr in expressions] + with get_db_session() as session: + statements = select(Expression.id).where(col(Expression.id).in_(request.ids)) + found_ids = [expr_id for expr_id in session.exec(statements).all()] # 检查是否有未找到的ID not_found_ids = set(request.ids) - set(found_ids) @@ -474,7 +516,9 @@ async def batch_delete_expressions( logger.warning(f"部分表达方式未找到: {not_found_ids}") # 执行批量删除 - deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute() + with get_db_session() as session: + result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids))) + deleted_count = result.rowcount or 0 logger.info(f"批量删除了 {deleted_count} 个表达方式") @@ -503,21 +547,21 @@ async def get_expression_stats( try: verify_auth_token(maibot_session, authorization) - total = Expression.select().count() + with get_db_session() as session: + total = len(session.exec(select(Expression.id)).all()) - # 按 chat_id 统计 - chat_stats = {} - for expr in Expression.select(Expression.chat_id): - chat_id = expr.chat_id - chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1 + chat_stats = {} + for chat_id in session.exec(select(Expression.session_id)).all(): + if chat_id: + chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1 - # 获取最近创建的记录数(7天内) - seven_days_ago = time.time() - (7 * 24 * 60 * 60) - recent = ( - Expression.select() - .where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago)) - .count() - ) + seven_days_ago = datetime.now() - timedelta(days=7) + recent_statement = ( + select(func.count()) + .select_from(Expression) + .where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago) + ) + recent = session.exec(recent_statement).one() return { "success": True, @@ -561,12 +605,13 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori try: verify_auth_token(maibot_session, authorization) - total = Expression.select().count() - unchecked = Expression.select().where(Expression.checked == False).count() - passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count() - rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count() - ai_checked = Expression.select().where(Expression.modified_by == "ai").count() - user_checked = Expression.select().where(Expression.modified_by == "user").count() + with get_db_session() as session: + total = len(session.exec(select(Expression.id)).all()) + unchecked = 0 + passed = 0 + rejected = 0 + ai_checked = 0 + user_checked = 0 return ReviewStatsResponse( total=total, @@ -620,31 +665,44 @@ async def get_review_list( try: verify_auth_token(maibot_session, authorization) - query = Expression.select() + statement = select(Expression) - # 根据筛选类型过滤 - if filter_type == "unchecked": - query = query.where(Expression.checked == False) - elif filter_type == "passed": - query = query.where((Expression.checked == True) & (Expression.rejected == False)) - elif filter_type == "rejected": - query = query.where((Expression.checked == True) & (Expression.rejected == True)) + if filter_type in {"unchecked", "passed", "rejected"}: + statement = statement.where(col(Expression.id) == -1) # all 不需要额外过滤 # 搜索过滤 if search: - query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search))) + statement = statement.where( + (col(Expression.situation).contains(search)) | (col(Expression.style).contains(search)) + ) # 聊天ID过滤 if chat_id: - query = query.where(Expression.chat_id == chat_id) + statement = statement.where(col(Expression.session_id) == chat_id) # 排序:创建时间倒序 - query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc()) + statement = statement.order_by( + case((col(Expression.create_time).is_(None), 1), else_=0), + col(Expression.create_time).desc(), + ) - total = query.count() offset = (page - 1) * page_size - expressions = query.offset(offset).limit(page_size) + statement = statement.offset(offset).limit(page_size) + + with get_db_session() as session: + expressions = session.exec(statement).all() + + count_statement = select(Expression.id) + if filter_type in {"unchecked", "passed", "rejected"}: + count_statement = count_statement.where(col(Expression.id) == -1) + if search: + count_statement = count_statement.where( + (col(Expression.situation).contains(search)) | (col(Expression.style).contains(search)) + ) + if chat_id: + count_statement = count_statement.where(col(Expression.session_id) == chat_id) + total = len(session.exec(count_statement).all()) return ReviewListResponse( success=True, @@ -720,7 +778,8 @@ async def batch_review_expressions( for item in request.items: try: - expression = Expression.get_or_none(Expression.id == item.id) + with get_db_session() as session: + expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first() if not expression: results.append( @@ -730,23 +789,28 @@ async def batch_review_expressions( continue # 冲突检测 - if item.require_unchecked and expression.checked: + if item.require_unchecked: results.append( - BatchReviewResultItem( - id=item.id, - success=False, - message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查", - ) + BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤") ) failed += 1 continue # 更新状态 - expression.checked = True - expression.rejected = item.rejected - expression.modified_by = "user" - expression.last_active_time = time.time() - expression.save() + with get_db_session() as session: + db_expression = session.exec( + select(Expression).where(col(Expression.id) == item.id).limit(1) + ).first() + if not db_expression: + results.append( + BatchReviewResultItem( + id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式" + ) + ) + failed += 1 + continue + db_expression.last_active_time = datetime.now() + session.add(db_expression) results.append( BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝") diff --git a/src/webui/routers/person.py b/src/webui/routers/person.py index de3b8587..d1b86a02 100644 --- a/src/webui/routers/person.py +++ b/src/webui/routers/person.py @@ -3,12 +3,16 @@ from fastapi import APIRouter, HTTPException, Header, Query, Cookie from pydantic import BaseModel from typing import Optional, List, Dict +from datetime import datetime + from sqlalchemy import case +from sqlmodel import col, select, delete + from src.common.logger import get_logger +from src.common.database.database import get_db_session from src.common.database.database_model import PersonInfo from src.webui.core import verify_auth_token_from_cookie_or_header import json -import time logger = get_logger("webui.person") @@ -29,7 +33,7 @@ class PersonInfoResponse(BaseModel): nickname: Optional[str] group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON memory_points: Optional[str] - know_times: Optional[float] + know_times: Optional[int] know_since: Optional[float] last_know: Optional[float] @@ -112,20 +116,22 @@ def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[D def person_to_response(person: PersonInfo) -> PersonInfoResponse: """将 PersonInfo 模型转换为响应对象""" + know_since = person.first_known_time.timestamp() if person.first_known_time else None + last_know = person.last_known_time.timestamp() if person.last_known_time else None return PersonInfoResponse( - id=person.id, + id=person.id or 0, is_known=person.is_known, person_id=person.person_id, person_name=person.person_name, name_reason=person.name_reason, platform=person.platform, user_id=person.user_id, - nickname=person.nickname, - group_nick_name=parse_group_nick_name(person.group_nick_name), + nickname=person.user_nickname, + group_nick_name=parse_group_nick_name(person.group_nickname), memory_points=person.memory_points, - know_times=person.know_times, - know_since=person.know_since, - last_know=person.last_know, + know_times=person.know_counts, + know_since=know_since, + last_know=last_know, ) @@ -157,36 +163,50 @@ async def get_person_list( verify_auth_token(maibot_session, authorization) # 构建查询 - query = PersonInfo.select() + statement = select(PersonInfo) # 搜索过滤 if search: - query = query.where( - (PersonInfo.person_name.contains(search)) - | (PersonInfo.nickname.contains(search)) - | (PersonInfo.user_id.contains(search)) + statement = statement.where( + (col(PersonInfo.person_name).contains(search)) + | (col(PersonInfo.user_nickname).contains(search)) + | (col(PersonInfo.user_id).contains(search)) ) # 已认识状态过滤 if is_known is not None: - query = query.where(PersonInfo.is_known == is_known) + statement = statement.where(col(PersonInfo.is_known) == is_known) # 平台过滤 if platform: - query = query.where(PersonInfo.platform == platform) + statement = statement.where(col(PersonInfo.platform) == platform) # 排序:最后更新时间倒序(NULL 值放在最后) # Peewee 不支持 nulls_last,使用 CASE WHEN 来实现 - query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc()) + statement = statement.order_by( + case((col(PersonInfo.last_known_time).is_(None), 1), else_=0), + col(PersonInfo.last_known_time).desc(), + ) - # 获取总数 - total = query.count() - - # 分页 offset = (page - 1) * page_size - persons = query.offset(offset).limit(page_size) + statement = statement.offset(offset).limit(page_size) + + with get_db_session() as session: + persons = session.exec(statement).all() + + count_statement = select(PersonInfo.id) + if search: + count_statement = count_statement.where( + (col(PersonInfo.person_name).contains(search)) + | (col(PersonInfo.user_nickname).contains(search)) + | (col(PersonInfo.user_id).contains(search)) + ) + if is_known is not None: + count_statement = count_statement.where(col(PersonInfo.is_known) == is_known) + if platform: + count_statement = count_statement.where(col(PersonInfo.platform) == platform) + total = len(session.exec(count_statement).all()) - # 转换为响应对象 data = [person_to_response(person) for person in persons] return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data) @@ -215,7 +235,9 @@ async def get_person_detail( try: verify_auth_token(maibot_session, authorization) - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() if not person: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") @@ -250,7 +272,9 @@ async def update_person( try: verify_auth_token(maibot_session, authorization) - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() if not person: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") @@ -262,13 +286,18 @@ async def update_person( raise HTTPException(status_code=400, detail="未提供任何需要更新的字段") # 更新最后修改时间 - update_data["last_know"] = time.time() + update_data["last_known_time"] = datetime.now() # 执行更新 - for field, value in update_data.items(): - setattr(person, field, value) - - person.save() + with get_db_session() as session: + db_person = session.exec(select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)).first() + if not db_person: + raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") + for field, value in update_data.items(): + if hasattr(db_person, field): + setattr(db_person, field, value) + session.add(db_person) + person = db_person logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}") @@ -300,16 +329,19 @@ async def delete_person( try: verify_auth_token(maibot_session, authorization) - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) + with get_db_session() as session: + statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + person = session.exec(statement).first() if not person: raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息") # 记录删除信息 - person_name = person.person_name or person.nickname or person.user_id + person_name = person.person_name or person.user_nickname or person.user_id # 执行删除 - person.delete_instance() + with get_db_session() as session: + session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id)) logger.info(f"人物信息已删除: {person_id} ({person_name})") @@ -336,15 +368,17 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori try: verify_auth_token(maibot_session, authorization) - total = PersonInfo.select().count() - known = PersonInfo.select().where(PersonInfo.is_known).count() + with get_db_session() as session: + total = len(session.exec(select(PersonInfo.id)).all()) + known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all()) unknown = total - known # 按平台统计 platforms = {} - for person in PersonInfo.select(PersonInfo.platform): - platform = person.platform - platforms[platform] = platforms.get(platform, 0) + 1 + with get_db_session() as session: + for platform in session.exec(select(PersonInfo.platform)).all(): + if platform: + platforms[platform] = platforms.get(platform, 0) + 1 return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}} @@ -383,14 +417,17 @@ async def batch_delete_persons( for person_id in request.person_ids: try: - person = PersonInfo.get_or_none(PersonInfo.person_id == person_id) - if person: - person.delete_instance() - deleted_count += 1 - logger.info(f"批量删除: {person_id}") - else: - failed_count += 1 - failed_ids.append(person_id) + with get_db_session() as session: + person = session.exec( + select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1) + ).first() + if person: + session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id)) + deleted_count += 1 + logger.info(f"批量删除: {person_id}") + else: + failed_count += 1 + failed_ids.append(person_id) except Exception as e: logger.error(f"删除 {person_id} 失败: {e}") failed_count += 1 diff --git a/src/webui/routers/statistics.py b/src/webui/routers/statistics.py index da49883d..05a1dde1 100644 --- a/src/webui/routers/statistics.py +++ b/src/webui/routers/statistics.py @@ -1,13 +1,16 @@ """统计数据 API 路由""" -from fastapi import APIRouter, HTTPException, Depends, Cookie, Header -from pydantic import BaseModel, Field -from typing import Dict, Any, List, Optional from datetime import datetime, timedelta -from sqlalchemy import func as fn +from typing import Any, Optional +from fastapi import APIRouter, Cookie, Depends, Header, HTTPException +from pydantic import BaseModel, Field +from sqlalchemy import desc, func, or_ +from sqlmodel import col, select + +from src.common.database.database import get_db_session +from src.common.database.database_model import Messages, ModelUsage, OnlineTime from src.common.logger import get_logger -from src.common.database.database_model import LLMUsage, OnlineTime, Messages from src.webui.core import verify_auth_token_from_cookie_or_header logger = get_logger("webui.statistics") @@ -60,10 +63,10 @@ class DashboardData(BaseModel): """仪表盘数据""" summary: StatisticsSummary - model_stats: List[ModelStatistics] - hourly_data: List[TimeSeriesData] - daily_data: List[TimeSeriesData] - recent_activity: List[Dict[str, Any]] + model_stats: list[ModelStatistics] + hourly_data: list[TimeSeriesData] + daily_data: list[TimeSeriesData] + recent_activity: list[dict[str, Any]] @router.get("/dashboard", response_model=DashboardData) @@ -111,26 +114,44 @@ async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary: """获取摘要统计数据(优化:使用数据库聚合)""" - summary = StatisticsSummary() + summary = StatisticsSummary( + total_requests=0, + total_cost=0.0, + total_tokens=0, + online_time=0.0, + total_messages=0, + total_replies=0, + avg_response_time=0.0, + cost_per_hour=0.0, + tokens_per_hour=0.0, + ) # 使用聚合查询替代全量加载 - query = LLMUsage.select( - fn.COUNT(LLMUsage.id).alias("total_requests"), - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"), - fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"), - fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"), - ).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time)) + with get_db_session() as session: + statement = select( + func.count().label("total_requests"), + func.sum(col(ModelUsage.cost)).label("total_cost"), + func.sum(col(ModelUsage.total_tokens)).label("total_tokens"), + func.avg(col(ModelUsage.time_cost)).label("avg_response_time"), + ).where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time) + result = session.execute(statement).first() - result = query.dicts().get() - summary.total_requests = result["total_requests"] - summary.total_cost = result["total_cost"] - summary.total_tokens = result["total_tokens"] - summary.avg_response_time = result["avg_response_time"] or 0.0 + if result: + total_requests, total_cost, total_tokens, avg_response_time = result + summary.total_requests = total_requests or 0 + summary.total_cost = float(total_cost or 0.0) + summary.total_tokens = total_tokens or 0 + summary.avg_response_time = float(avg_response_time or 0.0) # 查询在线时间 - 这个数据量通常不大,保留原逻辑 - online_records = list( - OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time)) - ) + with get_db_session() as session: + statement = select(OnlineTime).where( + or_( + col(OnlineTime.start_timestamp) >= start_time, + col(OnlineTime.end_timestamp) >= start_time, + ) + ) + online_records = session.execute(statement).scalars().all() for record in online_records: start = max(record.start_timestamp, start_time) @@ -139,18 +160,23 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S summary.online_time += (end - start).total_seconds() # 查询消息数量 - 使用聚合优化 - messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where( - (Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp()) - ) - summary.total_messages = messages_query.scalar() or 0 + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= start_time, + col(Messages.timestamp) <= end_time, + ) + total_messages = session.execute(statement).scalar() + summary.total_messages = int(total_messages or 0) # 统计回复数量 - replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where( - (Messages.time >= start_time.timestamp()) - & (Messages.time <= end_time.timestamp()) - & (Messages.reply_to.is_null(False)) - ) - summary.total_replies = replies_query.scalar() or 0 + with get_db_session() as session: + statement = select(func.count()).where( + col(Messages.timestamp) >= start_time, + col(Messages.timestamp) <= end_time, + col(Messages.reply_to).is_not(None), + ) + total_replies = session.execute(statement).scalar() + summary.total_replies = int(total_replies or 0) # 计算派生指标 if summary.online_time > 0: @@ -161,55 +187,80 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S return summary -async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]: +async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]: """获取模型统计数据(优化:使用数据库聚合和分组)""" # 使用GROUP BY聚合,避免全量加载 - query = ( - LLMUsage.select( - fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"), - fn.COUNT(LLMUsage.id).alias("request_count"), - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"), - fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"), - fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"), - ) - .where(LLMUsage.timestamp >= start_time) - .group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown")) - .order_by(fn.COUNT(LLMUsage.id).desc()) - .limit(10) # 只取前10个 + statement = ( + select(ModelUsage) + .where(col(ModelUsage.timestamp) >= start_time) + .order_by(desc(col(ModelUsage.timestamp))) + .limit(200) ) - result = [] - for row in query.dicts(): + with get_db_session() as session: + rows = session.execute(statement).all() + + aggregates: dict[str, dict[str, float | int]] = {} + for record in rows: + model_name = record.model_assign_name or record.model_name or "unknown" + if model_name not in aggregates: + aggregates[model_name] = { + "request_count": 0, + "total_cost": 0.0, + "total_tokens": 0, + "total_time_cost": 0.0, + "time_cost_count": 0, + } + bucket = aggregates[model_name] + bucket["request_count"] = int(bucket["request_count"]) + 1 + bucket["total_cost"] = float(bucket["total_cost"]) + float(record.cost or 0.0) + bucket["total_tokens"] = int(bucket["total_tokens"]) + int(record.total_tokens or 0) + if record.time_cost: + bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost) + bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1 + + result: list[ModelStatistics] = [] + for model_name, bucket in sorted( + aggregates.items(), + key=lambda item: float(item[1]["request_count"]), + reverse=True, + )[:10]: + time_cost_count = int(bucket["time_cost_count"]) + avg_time_cost = float(bucket["total_time_cost"]) / time_cost_count if time_cost_count > 0 else 0.0 result.append( ModelStatistics( - model_name=row["model_name"], - request_count=row["request_count"], - total_cost=row["total_cost"], - total_tokens=row["total_tokens"], - avg_response_time=row["avg_response_time"] or 0.0, + model_name=model_name, + request_count=int(bucket["request_count"]), + total_cost=float(bucket["total_cost"]), + total_tokens=int(bucket["total_tokens"]), + avg_response_time=avg_time_cost, ) ) return result -async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: +async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]: """获取小时级统计数据(优化:使用数据库聚合)""" # SQLite的日期时间函数进行小时分组 # 使用strftime将timestamp格式化为小时级别 - query = ( - LLMUsage.select( - fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"), - fn.COUNT(LLMUsage.id).alias("requests"), - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"), - fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"), + hour_expr = func.strftime("%Y-%m-%dT%H:00:00", col(ModelUsage.timestamp)) + statement = ( + select( + hour_expr.label("hour"), + func.count().label("requests"), + func.sum(col(ModelUsage.cost)).label("cost"), + func.sum(col(ModelUsage.total_tokens)).label("tokens"), ) - .where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time)) - .group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp)) + .where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time) + .group_by(hour_expr) ) + with get_db_session() as session: + rows = session.execute(statement).all() + # 转换为字典以快速查找 - data_dict = {row["hour"]: row for row in query.dicts()} + data_dict = {row[0]: row for row in rows} # 填充所有小时(包括没有数据的) result = [] @@ -219,7 +270,12 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li if hour_str in data_dict: row = data_dict[hour_str] result.append( - TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"]) + TimeSeriesData( + timestamp=hour_str, + requests=row[1] or 0, + cost=float(row[2] or 0.0), + tokens=row[3] or 0, + ) ) else: result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0)) @@ -228,22 +284,26 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li return result -async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]: +async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]: """获取日级统计数据(优化:使用数据库聚合)""" # 使用strftime按日期分组 - query = ( - LLMUsage.select( - fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"), - fn.COUNT(LLMUsage.id).alias("requests"), - fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"), - fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"), + day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp)) + statement = ( + select( + day_expr.label("day"), + func.count().label("requests"), + func.sum(col(ModelUsage.cost)).label("cost"), + func.sum(col(ModelUsage.total_tokens)).label("tokens"), ) - .where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time)) - .group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp)) + .where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time) + .group_by(day_expr) ) + with get_db_session() as session: + rows = session.execute(statement).all() + # 转换为字典 - data_dict = {row["day"]: row for row in query.dicts()} + data_dict = {row[0]: row for row in rows} # 填充所有天 result = [] @@ -253,7 +313,12 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis if day_str in data_dict: row = data_dict[day_str] result.append( - TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"]) + TimeSeriesData( + timestamp=day_str, + requests=row[1] or 0, + cost=float(row[2] or 0.0), + tokens=row[3] or 0, + ) ) else: result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0)) @@ -262,9 +327,11 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis return result -async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]: +async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]: """获取最近活动""" - records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit)) + with get_db_session() as session: + statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit) + records = session.execute(statement).scalars().all() activities = [] for record in records: @@ -273,10 +340,10 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]: "timestamp": record.timestamp.isoformat(), "model": record.model_assign_name or record.model_name, "request_type": record.request_type, - "tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0), + "tokens": record.total_tokens or 0, "cost": record.cost or 0.0, "time_cost": record.time_cost or 0.0, - "status": record.status, + "status": None, } )