mirror of https://github.com/Mai-with-u/MaiBot.git
重构绝大部分模块以适配新版本的数据库和数据模型,修复缺少依赖问题,更新 pyproject
parent
c14736ffca
commit
16b16d2ca6
2
bot.py
2
bot.py
|
|
@ -1,4 +1,4 @@
|
||||||
raise RuntimeError("System Not Ready")
|
# raise RuntimeError("System Not Ready")
|
||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[project]
|
[project]
|
||||||
name = "MaiBot"
|
name = "MaiBot"
|
||||||
version = "0.11.6"
|
version = "1.0.0"
|
||||||
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
description = "MaiCore 是一个基于大语言模型的可交互智能体"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
@ -35,6 +35,7 @@ dependencies = [
|
||||||
"tomlkit>=0.13.3",
|
"tomlkit>=0.13.3",
|
||||||
"urllib3>=2.5.0",
|
"urllib3>=2.5.0",
|
||||||
"uvicorn>=0.35.0",
|
"uvicorn>=0.35.0",
|
||||||
|
"msgpack>=1.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,3 +29,4 @@ toml>=0.10.2
|
||||||
tomlkit>=0.13.3
|
tomlkit>=0.13.3
|
||||||
urllib3>=2.5.0
|
urllib3>=2.5.0
|
||||||
uvicorn>=0.35.0
|
uvicorn>=0.35.0
|
||||||
|
msgpack>=1.1.2
|
||||||
|
|
@ -8,11 +8,15 @@
|
||||||
4. 未通过评估的:rejected=1, checked=1
|
4. 未通过评估的:rejected=1, checked=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
from typing import List
|
|
||||||
|
|
||||||
|
from sqlmodel import select
|
||||||
|
|
||||||
|
from src.bw_learner.expression_review_store import get_review_state, set_review_state
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Expression
|
from src.common.database.database_model import Expression
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
@ -39,7 +43,7 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||||
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
|
"表达方式或言语风格 是否与使用条件或使用情景 匹配",
|
||||||
"允许部分语法错误或口头化或缺省出现",
|
"允许部分语法错误或口头化或缺省出现",
|
||||||
"表达方式不能太过特指,需要具有泛用性",
|
"表达方式不能太过特指,需要具有泛用性",
|
||||||
"一般不涉及具体的人名或名称"
|
"一般不涉及具体的人名或名称",
|
||||||
]
|
]
|
||||||
|
|
||||||
# 从配置中获取额外的自定义标准
|
# 从配置中获取额外的自定义标准
|
||||||
|
|
@ -71,12 +75,11 @@ def create_evaluation_prompt(situation: str, style: str) -> str:
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
judge_llm = LLMRequest(
|
|
||||||
model_set=model_config.model_task_config.tool_use,
|
|
||||||
request_type="expression_check"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str]:
|
judge_llm = LLMRequest(model_set=model_config.model_task_config.tool_use, request_type="expression_check")
|
||||||
|
|
||||||
|
|
||||||
|
async def single_expression_check(situation: str, style: str) -> tuple[bool, str, str | None]:
|
||||||
"""
|
"""
|
||||||
执行单次LLM评估
|
执行单次LLM评估
|
||||||
|
|
||||||
|
|
@ -92,9 +95,7 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||||
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
logger.debug(f"正在评估表达方式: situation={situation}, style={style}")
|
||||||
|
|
||||||
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
|
response, (reasoning, model_name, _) = await judge_llm.generate_response_async(
|
||||||
prompt=prompt,
|
prompt=prompt, temperature=0.6, max_tokens=1024
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=1024
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"LLM响应: {response}")
|
logger.debug(f"LLM响应: {response}")
|
||||||
|
|
@ -104,6 +105,7 @@ async def single_expression_check(situation: str, style: str) -> tuple[bool, str
|
||||||
evaluation = json.loads(response)
|
evaluation = json.loads(response)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
import re
|
import re
|
||||||
|
|
||||||
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
json_match = re.search(r'\{[^{}]*"suitable"[^{}]*\}', response, re.DOTALL)
|
||||||
if json_match:
|
if json_match:
|
||||||
evaluation = json.loads(json_match.group())
|
evaluation = json.loads(json_match.group())
|
||||||
|
|
@ -130,7 +132,7 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
task_name="Expression Auto Check Task",
|
task_name="Expression Auto Check Task",
|
||||||
wait_before_start=60, # 启动后等待60秒再开始第一次检查
|
wait_before_start=60, # 启动后等待60秒再开始第一次检查
|
||||||
run_interval=check_interval
|
run_interval=check_interval,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _select_expressions(self, count: int) -> List[Expression]:
|
async def _select_expressions(self, count: int) -> List[Expression]:
|
||||||
|
|
@ -144,10 +146,11 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||||
选中的表达方式列表
|
选中的表达方式列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 查询所有未检查的表达方式(checked=False)
|
with get_db_session() as session:
|
||||||
unevaluated_expressions = list(
|
statement = select(Expression)
|
||||||
Expression.select().where(~Expression.checked)
|
all_expressions = session.exec(statement).all()
|
||||||
)
|
|
||||||
|
unevaluated_expressions = [expr for expr in all_expressions if not get_review_state(expr.id)["checked"]]
|
||||||
|
|
||||||
if not unevaluated_expressions:
|
if not unevaluated_expressions:
|
||||||
logger.info("没有未检查的表达方式")
|
logger.info("没有未检查的表达方式")
|
||||||
|
|
@ -182,10 +185,7 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||||
|
|
||||||
# 更新数据库
|
# 更新数据库
|
||||||
try:
|
try:
|
||||||
expression.checked = True
|
set_review_state(expression.id, True, not suitable, "ai")
|
||||||
expression.rejected = not suitable # 通过则rejected=0,不通过则rejected=1
|
|
||||||
expression.modified_by = 'ai' # 标记为AI检查
|
|
||||||
expression.save()
|
|
||||||
|
|
||||||
status = "通过" if suitable else "不通过"
|
status = "通过" if suitable else "不通过"
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
@ -219,7 +219,6 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||||
|
|
||||||
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条")
|
logger.info(f"开始执行表达方式自动检查,本次将检查 {check_count} 条")
|
||||||
|
|
||||||
|
|
||||||
# 选择要检查的表达方式
|
# 选择要检查的表达方式
|
||||||
expressions = await self._select_expressions(check_count)
|
expressions = await self._select_expressions(check_count)
|
||||||
|
|
||||||
|
|
@ -243,10 +242,8 @@ class ExpressionAutoCheckTask(AsyncTask):
|
||||||
await asyncio.sleep(0.3)
|
await asyncio.sleep(0.3)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"表达方式自动检查完成: 总计 {len(expressions)} 条,"
|
f"表达方式自动检查完成: 总计 {len(expressions)} 条,通过 {passed_count} 条,不通过 {failed_count} 条"
|
||||||
f"通过 {passed_count} 条,不通过 {failed_count} 条"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)
|
logger.error(f"执行表达方式自动检查任务时出错: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from src.manager.local_store_manager import local_storage
|
||||||
|
|
||||||
|
|
||||||
|
def _review_key(expression_id: int) -> str:
|
||||||
|
return f"expression_review:{expression_id}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_review_state(expression_id: Optional[int]) -> Dict[str, Any]:
|
||||||
|
if expression_id is None:
|
||||||
|
return {"checked": False, "rejected": False, "modified_by": None}
|
||||||
|
value = local_storage[_review_key(expression_id)]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {
|
||||||
|
"checked": bool(value.get("checked", False)),
|
||||||
|
"rejected": bool(value.get("rejected", False)),
|
||||||
|
"modified_by": value.get("modified_by"),
|
||||||
|
}
|
||||||
|
return {"checked": False, "rejected": False, "modified_by": None}
|
||||||
|
|
||||||
|
|
||||||
|
def set_review_state(
|
||||||
|
expression_id: Optional[int],
|
||||||
|
checked: bool,
|
||||||
|
rejected: bool,
|
||||||
|
modified_by: Optional[str],
|
||||||
|
) -> None:
|
||||||
|
if expression_id is None:
|
||||||
|
return
|
||||||
|
local_storage[_review_key(expression_id)] = {
|
||||||
|
"checked": checked,
|
||||||
|
"rejected": rejected,
|
||||||
|
"modified_by": modified_by,
|
||||||
|
}
|
||||||
|
|
@ -3,11 +3,11 @@ MaiBot模块系统
|
||||||
包含聊天、情绪、记忆、日程等功能模块
|
包含聊天、情绪、记忆、日程等功能模块
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
|
||||||
|
|
||||||
# 导出主要组件供外部使用
|
# 导出主要组件供外部使用
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_chat_manager",
|
"get_chat_manager",
|
||||||
"get_emoji_manager",
|
"emoji_manager",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,11 @@
|
||||||
import time
|
import time
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from sqlmodel import select, col
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Messages
|
from src.common.database.database_model import Messages
|
||||||
from maim_message import UserInfo
|
from maim_message import UserInfo
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
@ -23,10 +26,11 @@ def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||||
Returns:
|
Returns:
|
||||||
Dict[str, Any]: Message dictionary
|
Dict[str, Any]: Message dictionary
|
||||||
"""
|
"""
|
||||||
|
message_timestamp = message.timestamp.timestamp() if isinstance(message.timestamp, datetime) else message.timestamp
|
||||||
return {
|
return {
|
||||||
"message_id": message.message_id,
|
"message_id": message.message_id,
|
||||||
"time": message.time,
|
"time": message_timestamp,
|
||||||
"chat_id": message.chat_id,
|
"chat_id": message.session_id,
|
||||||
"user_id": message.user_id,
|
"user_id": message.user_id,
|
||||||
"user_nickname": message.user_nickname,
|
"user_nickname": message.user_nickname,
|
||||||
"processed_plain_text": message.processed_plain_text,
|
"processed_plain_text": message.processed_plain_text,
|
||||||
|
|
@ -37,7 +41,7 @@ def _message_to_dict(message: Messages) -> Dict[str, Any]:
|
||||||
"user_info": {
|
"user_info": {
|
||||||
"user_id": message.user_id,
|
"user_id": message.user_id,
|
||||||
"user_nickname": message.user_nickname,
|
"user_nickname": message.user_nickname,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -109,10 +113,13 @@ class ChatObserver:
|
||||||
"""
|
"""
|
||||||
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
logger.debug(f"[私聊][{self.private_name}]检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||||
|
|
||||||
new_message_exists = Messages.select().where(
|
last_check_time = self.last_check_time or 0.0
|
||||||
(Messages.chat_id == self.stream_id) &
|
last_check_dt = datetime.fromtimestamp(last_check_time)
|
||||||
(Messages.time > self.last_check_time)
|
with get_db_session() as session:
|
||||||
).exists()
|
statement = select(Messages).where(
|
||||||
|
(col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_check_dt)
|
||||||
|
)
|
||||||
|
new_message_exists = session.exec(statement).first() is not None
|
||||||
|
|
||||||
if new_message_exists:
|
if new_message_exists:
|
||||||
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
logger.debug(f"[私聊][{self.private_name}]发现新消息")
|
||||||
|
|
@ -183,20 +190,21 @@ class ChatObserver:
|
||||||
)
|
)
|
||||||
return has_new
|
return has_new
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
async def _fetch_new_messages(self) -> List[Dict[str, Any]]:
|
||||||
"""获取新消息
|
"""获取新消息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 新消息列表
|
List[Dict[str, Any]]: 新消息列表
|
||||||
"""
|
"""
|
||||||
query = Messages.select().where(
|
last_message_time = self.last_message_time or 0.0
|
||||||
(Messages.chat_id == self.stream_id) &
|
last_message_dt = datetime.fromtimestamp(last_message_time)
|
||||||
(Messages.time > self.last_message_time)
|
with get_db_session() as session:
|
||||||
).order_by(Messages.time.asc())
|
statement = (
|
||||||
|
select(Messages)
|
||||||
new_messages = [_message_to_dict(msg) for msg in query]
|
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) > last_message_dt))
|
||||||
|
.order_by(col(Messages.timestamp))
|
||||||
|
)
|
||||||
|
new_messages = [_message_to_dict(msg) for msg in session.exec(statement).all()]
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]
|
self.last_message_read = new_messages[-1]
|
||||||
|
|
@ -215,13 +223,16 @@ class ChatObserver:
|
||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 最多5条消息
|
List[Dict[str, Any]]: 最多5条消息
|
||||||
"""
|
"""
|
||||||
query = Messages.select().where(
|
time_point_dt = datetime.fromtimestamp(time_point)
|
||||||
(Messages.chat_id == self.stream_id) &
|
with get_db_session() as session:
|
||||||
(Messages.time < time_point)
|
statement = (
|
||||||
).order_by(Messages.time.desc()).limit(5)
|
select(Messages)
|
||||||
|
.where((col(Messages.session_id) == self.stream_id) & (col(Messages.timestamp) < time_point_dt))
|
||||||
messages = list(query)
|
.order_by(col(Messages.timestamp))
|
||||||
messages.reverse() # 需要按时间正序排列
|
.limit(5)
|
||||||
|
)
|
||||||
|
messages = list(session.exec(statement).all())
|
||||||
|
messages.reverse()
|
||||||
new_messages = [_message_to_dict(msg) for msg in messages]
|
new_messages = [_message_to_dict(msg) for msg in messages]
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
from src.chat.utils.chat_message_builder import replace_user_references
|
from src.chat.utils.chat_message_builder import replace_user_references
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.person_info.person_info import Person
|
from src.person_info.person_info import Person
|
||||||
from src.common.database.database_model import Images
|
from sqlmodel import select, col
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Images, ImageType
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
pass
|
||||||
|
|
@ -47,6 +49,12 @@ class HeartFCMessageReceiver:
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
userinfo = message.message_info.user_info
|
userinfo = message.message_info.user_info
|
||||||
chat = message.chat_stream
|
chat = message.chat_stream
|
||||||
|
if userinfo is None or message.message_info.platform is None:
|
||||||
|
raise ValueError("message userinfo or platform is missing")
|
||||||
|
if userinfo.user_id is None or userinfo.user_nickname is None:
|
||||||
|
raise ValueError("message userinfo id or nickname is missing")
|
||||||
|
user_id = userinfo.user_id
|
||||||
|
nickname = userinfo.user_nickname
|
||||||
|
|
||||||
# 2. 计算at信息
|
# 2. 计算at信息
|
||||||
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
is_mentioned, is_at, reply_probability_boost = is_mentioned_bot_in_message(message)
|
||||||
|
|
@ -70,7 +78,15 @@ class HeartFCMessageReceiver:
|
||||||
processed_text = message.processed_plain_text
|
processed_text = message.processed_plain_text
|
||||||
if picid_list:
|
if picid_list:
|
||||||
for picid in picid_list:
|
for picid in picid_list:
|
||||||
image = Images.get_or_none(Images.image_id == picid)
|
with get_db_session() as session:
|
||||||
|
statement = (
|
||||||
|
select(Images).where(
|
||||||
|
(col(Images.id) == int(picid)) & (col(Images.image_type) == ImageType.IMAGE)
|
||||||
|
)
|
||||||
|
if picid.isdigit()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
image = session.exec(statement).first() if statement is not None else None
|
||||||
if image and image.description:
|
if image and image.description:
|
||||||
# 将[picid:xxxx]替换成图片描述
|
# 将[picid:xxxx]替换成图片描述
|
||||||
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
|
processed_text = processed_text.replace(f"[picid:{picid}]", f"[图片:{image.description}]")
|
||||||
|
|
@ -80,26 +96,24 @@ class HeartFCMessageReceiver:
|
||||||
|
|
||||||
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
# 应用用户引用格式替换,将回复<aaa:bbb>和@<aaa:bbb>格式转换为可读格式
|
||||||
processed_plain_text = replace_user_references(
|
processed_plain_text = replace_user_references(
|
||||||
processed_text,
|
processed_text, message.message_info.platform, replace_bot_name=True
|
||||||
message.message_info.platform, # type: ignore
|
|
||||||
replace_bot_name=True,
|
|
||||||
)
|
)
|
||||||
# if not processed_plain_text:
|
# if not processed_plain_text:
|
||||||
# print(message)
|
# print(message)
|
||||||
|
|
||||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}") # type: ignore
|
logger.info(f"[{mes_name}]{userinfo.user_nickname}:{processed_plain_text}")
|
||||||
|
|
||||||
# 如果是群聊,获取群号和群昵称
|
# 如果是群聊,获取群号和群昵称
|
||||||
group_id = None
|
group_id = None
|
||||||
group_nick_name = None
|
group_nick_name = None
|
||||||
if chat.group_info:
|
if chat.group_info:
|
||||||
group_id = chat.group_info.group_id # type: ignore
|
group_id = chat.group_info.group_id
|
||||||
group_nick_name = userinfo.user_cardname # type: ignore
|
group_nick_name = userinfo.user_cardname
|
||||||
|
|
||||||
_ = Person.register_person(
|
_ = Person.register_person(
|
||||||
platform=message.message_info.platform, # type: ignore
|
platform=message.message_info.platform,
|
||||||
user_id=message.message_info.user_info.user_id, # type: ignore
|
user_id=user_id,
|
||||||
nickname=userinfo.user_nickname, # type: ignore
|
nickname=nickname,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
group_nick_name=group_nick_name,
|
group_nick_name=group_nick_name,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_emoji_manager",
|
|
||||||
"get_chat_manager",
|
"get_chat_manager",
|
||||||
"MessageStorage",
|
"MessageStorage",
|
||||||
|
"emoji_manager",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -2,13 +2,15 @@ import asyncio
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
from datetime import datetime
|
||||||
from typing import Dict, Optional, TYPE_CHECKING
|
from typing import Dict, Optional, TYPE_CHECKING
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
|
from sqlmodel import select, col
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import ChatStreams # 新增导入
|
from src.common.database.database_model import ChatSession
|
||||||
|
|
||||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
@ -76,7 +78,7 @@ class ChatStream:
|
||||||
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
self.create_time = data.get("create_time", time.time()) if data else time.time()
|
||||||
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
|
||||||
self.saved = False
|
self.saved = False
|
||||||
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
|
self.context: Optional[ChatMessageContext] = None
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
|
|
@ -95,10 +97,13 @@ class ChatStream:
|
||||||
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
|
||||||
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
|
||||||
|
|
||||||
|
if user_info is None:
|
||||||
|
raise ValueError("user_info is required to build ChatStream")
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
stream_id=data["stream_id"],
|
stream_id=data["stream_id"],
|
||||||
platform=data["platform"],
|
platform=data["platform"],
|
||||||
user_info=user_info, # type: ignore
|
user_info=user_info,
|
||||||
group_info=group_info,
|
group_info=group_info,
|
||||||
data=data,
|
data=data,
|
||||||
)
|
)
|
||||||
|
|
@ -128,12 +133,7 @@ class ChatManager:
|
||||||
if not self._initialized:
|
if not self._initialized:
|
||||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||||
try:
|
get_db_session()
|
||||||
db.connect(reuse_if_open=True)
|
|
||||||
# 确保 ChatStreams 表存在
|
|
||||||
db.create_tables([ChatStreams], safe=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库连接或 ChatStreams 表创建失败: {e}")
|
|
||||||
|
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
# 在事件循环中启动初始化
|
# 在事件循环中启动初始化
|
||||||
|
|
@ -161,8 +161,13 @@ class ChatManager:
|
||||||
|
|
||||||
def register_message(self, message: "MessageRecv"):
|
def register_message(self, message: "MessageRecv"):
|
||||||
"""注册消息到聊天流"""
|
"""注册消息到聊天流"""
|
||||||
|
platform = message.message_info.platform or ""
|
||||||
|
if not platform:
|
||||||
|
raise ValueError("platform is required for ChatStream")
|
||||||
|
if message.message_info.user_info is None and message.message_info.group_info is None:
|
||||||
|
raise ValueError("user_info or group_info is required for ChatStream")
|
||||||
stream_id = self._generate_stream_id(
|
stream_id = self._generate_stream_id(
|
||||||
message.message_info.platform, # type: ignore
|
platform,
|
||||||
message.message_info.user_info,
|
message.message_info.user_info,
|
||||||
message.message_info.group_info,
|
message.message_info.group_info,
|
||||||
)
|
)
|
||||||
|
|
@ -176,12 +181,18 @@ class ChatManager:
|
||||||
"""生成聊天流唯一ID"""
|
"""生成聊天流唯一ID"""
|
||||||
if not user_info and not group_info:
|
if not user_info and not group_info:
|
||||||
raise ValueError("用户信息或群组信息必须提供")
|
raise ValueError("用户信息或群组信息必须提供")
|
||||||
|
if group_info is None and user_info is None:
|
||||||
|
raise ValueError("用户信息或群组信息必须提供")
|
||||||
|
|
||||||
if group_info:
|
if group_info:
|
||||||
# 组合关键信息
|
# 组合关键信息
|
||||||
components = [platform, str(group_info.group_id)]
|
components = [platform, str(group_info.group_id)]
|
||||||
else:
|
else:
|
||||||
components = [platform, str(user_info.user_id), "private"] # type: ignore
|
if user_info is None:
|
||||||
|
raise ValueError("用户信息或群组信息必须提供")
|
||||||
|
if user_info.user_id is None:
|
||||||
|
raise ValueError("user_id is required for private stream")
|
||||||
|
components = [platform, str(user_info.user_id), "private"]
|
||||||
|
|
||||||
# 使用MD5生成唯一ID
|
# 使用MD5生成唯一ID
|
||||||
key = "_".join(components)
|
key = "_".join(components)
|
||||||
|
|
@ -231,33 +242,35 @@ class ChatManager:
|
||||||
|
|
||||||
# 检查数据库中是否存在
|
# 检查数据库中是否存在
|
||||||
def _db_find_stream_sync(s_id: str):
|
def _db_find_stream_sync(s_id: str):
|
||||||
return ChatStreams.get_or_none(ChatStreams.stream_id == s_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(ChatSession).where(col(ChatSession.session_id) == s_id)
|
||||||
|
return session.exec(statement).first()
|
||||||
|
|
||||||
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
model_instance = await asyncio.to_thread(_db_find_stream_sync, stream_id)
|
||||||
|
|
||||||
if model_instance:
|
if model_instance:
|
||||||
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
# 从 Peewee 模型转换回 ChatStream.from_dict 期望的格式
|
||||||
user_info_data = {
|
user_info_data = {
|
||||||
"platform": model_instance.user_platform,
|
"platform": model_instance.platform,
|
||||||
"user_id": model_instance.user_id,
|
"user_id": model_instance.user_id,
|
||||||
"user_nickname": model_instance.user_nickname,
|
"user_nickname": "",
|
||||||
"user_cardname": model_instance.user_cardname or "",
|
"user_cardname": "",
|
||||||
}
|
}
|
||||||
group_info_data = None
|
group_info_data = None
|
||||||
if model_instance.group_id: # 假设 group_id 为空字符串表示没有群组信息
|
if model_instance.group_id:
|
||||||
group_info_data = {
|
group_info_data = {
|
||||||
"platform": model_instance.group_platform,
|
"platform": model_instance.platform,
|
||||||
"group_id": model_instance.group_id,
|
"group_id": model_instance.group_id,
|
||||||
"group_name": model_instance.group_name,
|
"group_name": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
data_for_from_dict = {
|
data_for_from_dict = {
|
||||||
"stream_id": model_instance.stream_id,
|
"stream_id": model_instance.session_id,
|
||||||
"platform": model_instance.platform,
|
"platform": model_instance.platform,
|
||||||
"user_info": user_info_data,
|
"user_info": user_info_data,
|
||||||
"group_info": group_info_data,
|
"group_info": group_info_data,
|
||||||
"create_time": model_instance.create_time,
|
"create_time": model_instance.created_timestamp.timestamp(),
|
||||||
"last_active_time": model_instance.last_active_time,
|
"last_active_time": model_instance.last_active_timestamp.timestamp(),
|
||||||
}
|
}
|
||||||
stream = ChatStream.from_dict(data_for_from_dict)
|
stream = ChatStream.from_dict(data_for_from_dict)
|
||||||
# 更新用户信息和群组信息
|
# 更新用户信息和群组信息
|
||||||
|
|
@ -329,20 +342,26 @@ class ChatManager:
|
||||||
user_info_d = s_data_dict.get("user_info")
|
user_info_d = s_data_dict.get("user_info")
|
||||||
group_info_d = s_data_dict.get("group_info")
|
group_info_d = s_data_dict.get("group_info")
|
||||||
|
|
||||||
fields_to_save = {
|
with get_db_session() as session:
|
||||||
"platform": s_data_dict["platform"],
|
statement = select(ChatSession).where(col(ChatSession.session_id) == s_data_dict["stream_id"])
|
||||||
"create_time": s_data_dict["create_time"],
|
record = session.exec(statement).first()
|
||||||
"last_active_time": s_data_dict["last_active_time"],
|
if record is None:
|
||||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
record = ChatSession(
|
||||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
session_id=s_data_dict["stream_id"],
|
||||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
platform=s_data_dict["platform"],
|
||||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
user_id=user_info_d["user_id"] if user_info_d else None,
|
||||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
group_id=group_info_d["group_id"] if group_info_d else None,
|
||||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
created_timestamp=datetime.fromtimestamp(s_data_dict["create_time"]),
|
||||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
last_active_timestamp=datetime.fromtimestamp(s_data_dict["last_active_time"]),
|
||||||
}
|
)
|
||||||
|
session.add(record)
|
||||||
|
return
|
||||||
|
|
||||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
record.platform = s_data_dict["platform"]
|
||||||
|
record.user_id = user_info_d["user_id"] if user_info_d else None
|
||||||
|
record.group_id = group_info_d["group_id"] if group_info_d else None
|
||||||
|
record.created_timestamp = datetime.fromtimestamp(s_data_dict["create_time"])
|
||||||
|
record.last_active_timestamp = datetime.fromtimestamp(s_data_dict["last_active_time"])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||||
|
|
@ -361,28 +380,30 @@ class ChatManager:
|
||||||
|
|
||||||
def _db_load_all_streams_sync():
|
def _db_load_all_streams_sync():
|
||||||
loaded_streams_data = []
|
loaded_streams_data = []
|
||||||
for model_instance in ChatStreams.select():
|
with get_db_session() as session:
|
||||||
|
statement = select(ChatSession)
|
||||||
|
for model_instance in session.exec(statement).all():
|
||||||
user_info_data = {
|
user_info_data = {
|
||||||
"platform": model_instance.user_platform,
|
"platform": model_instance.platform,
|
||||||
"user_id": model_instance.user_id,
|
"user_id": model_instance.user_id or "",
|
||||||
"user_nickname": model_instance.user_nickname,
|
"user_nickname": "",
|
||||||
"user_cardname": model_instance.user_cardname or "",
|
"user_cardname": "",
|
||||||
}
|
}
|
||||||
group_info_data = None
|
group_info_data = None
|
||||||
if model_instance.group_id:
|
if model_instance.group_id:
|
||||||
group_info_data = {
|
group_info_data = {
|
||||||
"platform": model_instance.group_platform,
|
"platform": model_instance.platform,
|
||||||
"group_id": model_instance.group_id,
|
"group_id": model_instance.group_id,
|
||||||
"group_name": model_instance.group_name,
|
"group_name": "",
|
||||||
}
|
}
|
||||||
|
|
||||||
data_for_from_dict = {
|
data_for_from_dict = {
|
||||||
"stream_id": model_instance.stream_id,
|
"stream_id": model_instance.session_id,
|
||||||
"platform": model_instance.platform,
|
"platform": model_instance.platform,
|
||||||
"user_info": user_info_data,
|
"user_info": user_info_data,
|
||||||
"group_info": group_info_data,
|
"group_info": group_info_data,
|
||||||
"create_time": model_instance.create_time,
|
"create_time": model_instance.created_timestamp.timestamp(),
|
||||||
"last_active_time": model_instance.last_active_time,
|
"last_active_time": model_instance.last_active_timestamp.timestamp(),
|
||||||
}
|
}
|
||||||
loaded_streams_data.append(data_for_from_dict)
|
loaded_streams_data.append(data_for_from_dict)
|
||||||
return loaded_streams_data
|
return loaded_streams_data
|
||||||
|
|
|
||||||
|
|
@ -1,36 +1,74 @@
|
||||||
import re
|
from datetime import datetime
|
||||||
import json
|
from collections.abc import Mapping
|
||||||
import traceback
|
from typing import cast
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from src.common.database.database_model import Messages, Images
|
import json
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
from sqlmodel import col, select
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Images, ImageType, Messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.data_models.message_component_model import MessageSequence, TextComponent
|
||||||
|
from src.common.utils.utils_message import MessageUtils
|
||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .message import MessageSending, MessageRecv
|
from .message import MessageRecv, MessageSending
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
logger = get_logger("message_storage")
|
||||||
|
|
||||||
|
|
||||||
class MessageStorage:
|
class MessageStorage:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _serialize_keywords(keywords) -> str:
|
def _coerce_str_list(value: object) -> list[str]:
|
||||||
|
if isinstance(value, list):
|
||||||
|
return [str(item) for item in value]
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return [str(item) for item in value]
|
||||||
|
if isinstance(value, set):
|
||||||
|
return [str(item) for item in value]
|
||||||
|
if isinstance(value, str):
|
||||||
|
return [value]
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_str(mapping: Mapping[str, object], key: str, default: str = "") -> str:
|
||||||
|
value = mapping.get(key)
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_optional_str(mapping: Mapping[str, object], key: str) -> str | None:
|
||||||
|
value = mapping.get(key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _serialize_keywords(keywords: list[str] | None) -> str:
|
||||||
"""将关键词列表序列化为JSON字符串"""
|
"""将关键词列表序列化为JSON字符串"""
|
||||||
if isinstance(keywords, list):
|
if isinstance(keywords, list):
|
||||||
return json.dumps(keywords, ensure_ascii=False)
|
return json.dumps(keywords, ensure_ascii=False)
|
||||||
return "[]"
|
return "[]"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _deserialize_keywords(keywords_str: str) -> list:
|
def _deserialize_keywords(keywords_str: str) -> list[str]:
|
||||||
"""将JSON字符串反序列化为关键词列表"""
|
"""将JSON字符串反序列化为关键词列表"""
|
||||||
if not keywords_str:
|
if not keywords_str:
|
||||||
return []
|
return []
|
||||||
try:
|
try:
|
||||||
return json.loads(keywords_str)
|
parsed = cast(object, json.loads(keywords_str))
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
return []
|
return []
|
||||||
|
if isinstance(parsed, list):
|
||||||
|
return [str(item) for item in parsed]
|
||||||
|
if isinstance(parsed, str):
|
||||||
|
return [parsed]
|
||||||
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||||
"""存储消息到数据库"""
|
"""存储消息到数据库"""
|
||||||
try:
|
try:
|
||||||
# 通知消息不存储
|
# 通知消息不存储
|
||||||
|
|
@ -66,7 +104,7 @@ class MessageStorage:
|
||||||
priority_mode = ""
|
priority_mode = ""
|
||||||
priority_info = {}
|
priority_info = {}
|
||||||
is_emoji = False
|
is_emoji = False
|
||||||
is_picid = False
|
is_picture = False
|
||||||
is_notify = False
|
is_notify = False
|
||||||
is_command = False
|
is_command = False
|
||||||
key_words = ""
|
key_words = ""
|
||||||
|
|
@ -83,66 +121,73 @@ class MessageStorage:
|
||||||
priority_mode = message.priority_mode
|
priority_mode = message.priority_mode
|
||||||
priority_info = message.priority_info
|
priority_info = message.priority_info
|
||||||
is_emoji = message.is_emoji
|
is_emoji = message.is_emoji
|
||||||
is_picid = message.is_picid
|
is_picture = message.is_picid
|
||||||
is_notify = message.is_notify
|
is_notify = message.is_notify
|
||||||
is_command = message.is_command
|
is_command = message.is_command
|
||||||
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
intercept_message_level = getattr(message, "intercept_message_level", 0)
|
||||||
# 序列化关键词列表为JSON字符串
|
# 序列化关键词列表为JSON字符串
|
||||||
key_words = MessageStorage._serialize_keywords(message.key_words)
|
key_words = MessageStorage._serialize_keywords(MessageStorage._coerce_str_list(message.key_words))
|
||||||
key_words_lite = MessageStorage._serialize_keywords(message.key_words_lite)
|
key_words_lite = MessageStorage._serialize_keywords(
|
||||||
|
MessageStorage._coerce_str_list(message.key_words_lite)
|
||||||
|
)
|
||||||
selected_expressions = ""
|
selected_expressions = ""
|
||||||
|
|
||||||
chat_info_dict = chat_stream.to_dict()
|
chat_info_dict = cast(dict[str, object], chat_stream.to_dict())
|
||||||
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
|
if message.message_info.user_info is None:
|
||||||
|
raise ValueError("message.user_info is required")
|
||||||
|
user_info_dict = cast(dict[str, object], message.message_info.user_info.to_dict())
|
||||||
|
|
||||||
# message_id 现在是 TextField,直接使用字符串值
|
# message_id 现在是 TextField,直接使用字符串值
|
||||||
msg_id = message.message_info.message_id
|
msg_id = message.message_info.message_id or ""
|
||||||
|
|
||||||
# 安全地获取 group_info, 如果为 None 则视为空字典
|
# 安全地获取 group_info, 如果为 None 则视为空字典
|
||||||
group_info_from_chat = chat_info_dict.get("group_info") or {}
|
group_info_from_chat = cast(dict[str, object], chat_info_dict.get("group_info") or {})
|
||||||
# 安全地获取 user_info, 如果为 None 则视为空字典 (以防万一)
|
|
||||||
user_info_from_chat = chat_info_dict.get("user_info") or {}
|
|
||||||
|
|
||||||
Messages.create(
|
additional_config: dict[str, object] = dict(message.message_info.additional_config or {})
|
||||||
message_id=msg_id,
|
additional_config.update(
|
||||||
time=float(message.message_info.time), # type: ignore
|
{
|
||||||
chat_id=chat_stream.stream_id,
|
"interest_value": interest_value,
|
||||||
# Flattened chat_info
|
"priority_mode": priority_mode,
|
||||||
|
"priority_info": priority_info,
|
||||||
|
"reply_probability_boost": reply_probability_boost,
|
||||||
|
"intercept_message_level": intercept_message_level,
|
||||||
|
"key_words": key_words,
|
||||||
|
"key_words_lite": key_words_lite,
|
||||||
|
"selected_expressions": selected_expressions,
|
||||||
|
"is_picid": is_picture,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
processed_text_for_raw = filtered_processed_plain_text or filtered_display_message or ""
|
||||||
|
raw_sequence = MessageSequence([TextComponent(processed_text_for_raw)] if processed_text_for_raw else [])
|
||||||
|
raw_content = MessageUtils.from_MaiSeq_to_db_record_msg(raw_sequence)
|
||||||
|
|
||||||
|
timestamp_value = message.message_info.time
|
||||||
|
if timestamp_value is None:
|
||||||
|
raise ValueError("message.message_info.time is required")
|
||||||
|
db_message = Messages(
|
||||||
|
message_id=str(msg_id),
|
||||||
|
timestamp=datetime.fromtimestamp(float(timestamp_value)),
|
||||||
|
platform=MessageStorage._get_str(chat_info_dict, "platform"),
|
||||||
|
user_id=MessageStorage._get_str(user_info_dict, "user_id"),
|
||||||
|
user_nickname=MessageStorage._get_str(user_info_dict, "user_nickname"),
|
||||||
|
user_cardname=MessageStorage._get_optional_str(user_info_dict, "user_cardname"),
|
||||||
|
group_id=MessageStorage._get_optional_str(group_info_from_chat, "group_id"),
|
||||||
|
group_name=MessageStorage._get_optional_str(group_info_from_chat, "group_name"),
|
||||||
|
is_mentioned=bool(is_mentioned),
|
||||||
|
is_at=bool(is_at),
|
||||||
|
session_id=chat_stream.stream_id,
|
||||||
reply_to=reply_to,
|
reply_to=reply_to,
|
||||||
is_mentioned=is_mentioned,
|
is_emoji=is_emoji,
|
||||||
is_at=is_at,
|
is_picture=is_picture,
|
||||||
reply_probability_boost=reply_probability_boost,
|
is_command=is_command,
|
||||||
chat_info_stream_id=chat_info_dict.get("stream_id"),
|
is_notify=is_notify,
|
||||||
chat_info_platform=chat_info_dict.get("platform"),
|
raw_content=raw_content,
|
||||||
chat_info_user_platform=user_info_from_chat.get("platform"),
|
|
||||||
chat_info_user_id=user_info_from_chat.get("user_id"),
|
|
||||||
chat_info_user_nickname=user_info_from_chat.get("user_nickname"),
|
|
||||||
chat_info_user_cardname=user_info_from_chat.get("user_cardname"),
|
|
||||||
chat_info_group_platform=group_info_from_chat.get("platform"),
|
|
||||||
chat_info_group_id=group_info_from_chat.get("group_id"),
|
|
||||||
chat_info_group_name=group_info_from_chat.get("group_name"),
|
|
||||||
chat_info_create_time=float(chat_info_dict.get("create_time", 0.0)),
|
|
||||||
chat_info_last_active_time=float(chat_info_dict.get("last_active_time", 0.0)),
|
|
||||||
# Flattened user_info (message sender)
|
|
||||||
user_platform=user_info_dict.get("platform"),
|
|
||||||
user_id=user_info_dict.get("user_id"),
|
|
||||||
user_nickname=user_info_dict.get("user_nickname"),
|
|
||||||
user_cardname=user_info_dict.get("user_cardname"),
|
|
||||||
# Text content
|
|
||||||
processed_plain_text=filtered_processed_plain_text,
|
processed_plain_text=filtered_processed_plain_text,
|
||||||
display_message=filtered_display_message,
|
display_message=filtered_display_message,
|
||||||
interest_value=interest_value,
|
additional_config=json.dumps(additional_config, ensure_ascii=False),
|
||||||
priority_mode=priority_mode,
|
|
||||||
priority_info=priority_info,
|
|
||||||
is_emoji=is_emoji,
|
|
||||||
is_picid=is_picid,
|
|
||||||
is_notify=is_notify,
|
|
||||||
is_command=is_command,
|
|
||||||
intercept_message_level=intercept_message_level,
|
|
||||||
key_words=key_words,
|
|
||||||
key_words_lite=key_words_lite,
|
|
||||||
selected_expressions=selected_expressions,
|
|
||||||
)
|
)
|
||||||
|
with get_db_session() as session:
|
||||||
|
session.add(db_message)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("存储消息失败")
|
logger.exception("存储消息失败")
|
||||||
logger.error(f"消息:{message}")
|
logger.error(f"消息:{message}")
|
||||||
|
|
@ -156,14 +201,19 @@ class MessageStorage:
|
||||||
if not qq_message_id:
|
if not qq_message_id:
|
||||||
logger.info("消息不存在message_id,无法更新")
|
logger.info("消息不存在message_id,无法更新")
|
||||||
return False
|
return False
|
||||||
if matched_message := (
|
with get_db_session() as session:
|
||||||
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
|
statement = (
|
||||||
):
|
select(Messages)
|
||||||
# 更新找到的消息记录
|
.where(col(Messages.message_id) == mmc_message_id)
|
||||||
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
|
.order_by(col(Messages.timestamp).desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
matched_message = session.exec(statement).first()
|
||||||
|
if matched_message:
|
||||||
|
matched_message.message_id = qq_message_id
|
||||||
|
session.add(matched_message)
|
||||||
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
|
||||||
return True
|
return True
|
||||||
else:
|
|
||||||
logger.debug("未找到匹配的消息")
|
logger.debug("未找到匹配的消息")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
@ -182,13 +232,18 @@ class MessageStorage:
|
||||||
logger.debug("文本中没有图片标记,直接返回原文本")
|
logger.debug("文本中没有图片标记,直接返回原文本")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def replace_match(match):
|
def replace_match(match: re.Match[str]) -> str:
|
||||||
description = match.group(1).strip()
|
description = match.group(1).strip()
|
||||||
try:
|
try:
|
||||||
image_record = (
|
with get_db_session() as session:
|
||||||
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
|
statement = (
|
||||||
|
select(Images)
|
||||||
|
.where((col(Images.description) == description) & (col(Images.image_type) == ImageType.IMAGE))
|
||||||
|
.order_by(col(Images.record_time).desc())
|
||||||
|
.limit(1)
|
||||||
)
|
)
|
||||||
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
|
image_record = session.exec(statement).first()
|
||||||
|
return f"[picid:{image_record.id}]" if image_record else match.group(0)
|
||||||
except Exception:
|
except Exception:
|
||||||
return match.group(0)
|
return match.group(0)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,17 @@
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
from typing import List, Dict, Any, Tuple, Optional, Callable
|
from typing import List, Dict, Any, Tuple, Optional, Callable
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from sqlmodel import select, col
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import find_messages, count_messages
|
from src.common.message_repository import find_messages, count_messages
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
|
from src.common.data_models.database_data_model import DatabaseMessages, DatabaseActionRecords
|
||||||
from src.common.data_models.message_data_model import MessageAndActionModel
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import ActionRecords
|
from src.common.database.database_model import ActionRecord, Images
|
||||||
from src.common.database.database_model import Images
|
|
||||||
from src.person_info.person_info import Person, get_person_id
|
from src.person_info.person_info import Person, get_person_id
|
||||||
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
|
from src.chat.utils.utils import translate_timestamp_to_human_readable, assign_message_ids, is_bot_self
|
||||||
|
|
||||||
|
|
@ -198,37 +198,38 @@ def get_actions_by_timestamp_with_chat(
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
) -> List[DatabaseActionRecords]:
|
) -> List[DatabaseActionRecords]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录,按时间升序排序,返回动作记录列表"""
|
||||||
query = ActionRecords.select().where(
|
with get_db_session() as session:
|
||||||
(ActionRecords.chat_id == chat_id)
|
statement = (
|
||||||
& (ActionRecords.time > timestamp_start) # type: ignore
|
select(ActionRecord)
|
||||||
& (ActionRecords.time < timestamp_end) # type: ignore
|
.where((col(ActionRecord.session_id) == chat_id))
|
||||||
|
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(timestamp_start))
|
||||||
|
.where(col(ActionRecord.timestamp) < datetime.fromtimestamp(timestamp_end))
|
||||||
)
|
)
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
if limit_mode == "latest":
|
if limit_mode == "latest":
|
||||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
|
||||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
actions = list(session.exec(statement).all())
|
||||||
actions = list(query)
|
actions = list(reversed(actions))
|
||||||
actions.reverse()
|
|
||||||
else: # earliest
|
|
||||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
|
||||||
else:
|
else:
|
||||||
query = query.order_by(ActionRecords.time.asc())
|
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
|
||||||
|
actions = list(session.exec(statement).all())
|
||||||
actions = list(query)
|
else:
|
||||||
|
statement = statement.order_by(col(ActionRecord.timestamp))
|
||||||
|
actions = session.exec(statement).all()
|
||||||
return [
|
return [
|
||||||
DatabaseActionRecords(
|
DatabaseActionRecords(
|
||||||
action_id=action.action_id,
|
action_id=action.action_id,
|
||||||
time=action.time,
|
time=action.timestamp.timestamp(),
|
||||||
action_name=action.action_name,
|
action_name=action.action_name,
|
||||||
action_data=action.action_data,
|
action_data=action.action_data or "{}",
|
||||||
action_done=action.action_done,
|
action_done=True,
|
||||||
action_build_into_prompt=action.action_build_into_prompt,
|
action_build_into_prompt=bool(action.action_display_prompt),
|
||||||
action_prompt_display=action.action_prompt_display,
|
action_prompt_display=action.action_display_prompt or "",
|
||||||
chat_id=action.chat_id,
|
chat_id=action.session_id,
|
||||||
chat_info_stream_id=action.chat_info_stream_id,
|
chat_info_stream_id=action.session_id,
|
||||||
chat_info_platform=action.chat_info_platform,
|
chat_info_platform=global_config.bot.platform,
|
||||||
action_reasoning=action.action_reasoning,
|
action_reasoning=action.action_reasoning or "",
|
||||||
)
|
)
|
||||||
for action in actions
|
for action in actions
|
||||||
]
|
]
|
||||||
|
|
@ -238,25 +239,27 @@ def get_actions_by_timestamp_with_chat_inclusive(
|
||||||
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
chat_id: str, timestamp_start: float, timestamp_end: float, limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
"""获取在特定聊天从指定时间戳到指定时间戳的动作记录(包含边界),按时间升序排序,返回动作记录列表"""
|
||||||
query = ActionRecords.select().where(
|
with get_db_session() as session:
|
||||||
(ActionRecords.chat_id == chat_id)
|
statement = (
|
||||||
& (ActionRecords.time >= timestamp_start) # type: ignore
|
select(ActionRecord)
|
||||||
& (ActionRecords.time <= timestamp_end) # type: ignore
|
.where((col(ActionRecord.session_id) == chat_id))
|
||||||
|
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(timestamp_start))
|
||||||
|
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(timestamp_end))
|
||||||
)
|
)
|
||||||
|
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
if limit_mode == "latest":
|
if limit_mode == "latest":
|
||||||
query = query.order_by(ActionRecords.time.desc()).limit(limit)
|
statement = statement.order_by(col(ActionRecord.timestamp).desc()).limit(limit)
|
||||||
# 获取后需要反转列表,以保持最终输出为时间升序
|
actions = list(session.exec(statement).all())
|
||||||
actions = list(query)
|
actions = list(reversed(actions))
|
||||||
return [action.__data__ for action in reversed(actions)]
|
|
||||||
else: # earliest
|
|
||||||
query = query.order_by(ActionRecords.time.asc()).limit(limit)
|
|
||||||
else:
|
else:
|
||||||
query = query.order_by(ActionRecords.time.asc())
|
statement = statement.order_by(col(ActionRecord.timestamp)).limit(limit)
|
||||||
|
actions = list(session.exec(statement).all())
|
||||||
|
else:
|
||||||
|
statement = statement.order_by(col(ActionRecord.timestamp))
|
||||||
|
actions = session.exec(statement).all()
|
||||||
|
|
||||||
actions = list(query)
|
return [action.model_dump() for action in actions]
|
||||||
return [action.__data__ for action in actions]
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_random(
|
def get_raw_msg_by_timestamp_random(
|
||||||
|
|
@ -278,7 +281,7 @@ def get_raw_msg_by_timestamp_random(
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_by_timestamp_with_users(
|
def get_raw_msg_by_timestamp_with_users(
|
||||||
timestamp_start: float, timestamp_end: float, person_ids: list, limit: int = 0, limit_mode: str = "latest"
|
timestamp_start: float, timestamp_end: float, person_ids: List[str], limit: int = 0, limit_mode: str = "latest"
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
"""获取某些特定用户在 *所有聊天* 中从指定时间戳到指定时间戳的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
|
|
@ -316,7 +319,7 @@ def get_raw_msg_before_timestamp_with_chat(
|
||||||
|
|
||||||
|
|
||||||
def get_raw_msg_before_timestamp_with_users(
|
def get_raw_msg_before_timestamp_with_users(
|
||||||
timestamp: float, person_ids: list, limit: int = 0
|
timestamp: float, person_ids: List[str], limit: int = 0
|
||||||
) -> List[DatabaseMessages]:
|
) -> List[DatabaseMessages]:
|
||||||
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
"""获取指定时间戳之前的消息,按时间升序排序,返回消息列表
|
||||||
limit: 限制返回的消息数量,0为不限制
|
limit: 限制返回的消息数量,0为不限制
|
||||||
|
|
@ -344,7 +347,7 @@ def num_new_messages_since(chat_id: str, timestamp_start: float = 0.0, timestamp
|
||||||
|
|
||||||
|
|
||||||
def num_new_messages_since_with_users(
|
def num_new_messages_since_with_users(
|
||||||
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: list
|
chat_id: str, timestamp_start: float, timestamp_end: float, person_ids: List[str]
|
||||||
) -> int:
|
) -> int:
|
||||||
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
"""检查某些特定用户在特定聊天在指定时间戳之间有多少新消息"""
|
||||||
if not person_ids: # 保持空列表检查
|
if not person_ids: # 保持空列表检查
|
||||||
|
|
@ -358,7 +361,7 @@ def num_new_messages_since_with_users(
|
||||||
|
|
||||||
|
|
||||||
def _build_readable_messages_internal(
|
def _build_readable_messages_internal(
|
||||||
messages: List[MessageAndActionModel],
|
messages: List[DatabaseMessages],
|
||||||
replace_bot_name: bool = True,
|
replace_bot_name: bool = True,
|
||||||
timestamp_mode: str = "relative",
|
timestamp_mode: str = "relative",
|
||||||
truncate: bool = False,
|
truncate: bool = False,
|
||||||
|
|
@ -413,7 +416,7 @@ def _build_readable_messages_internal(
|
||||||
# 匹配 [picid:xxxxx] 格式
|
# 匹配 [picid:xxxxx] 格式
|
||||||
pic_pattern = r"\[picid:([^\]]+)\]"
|
pic_pattern = r"\[picid:([^\]]+)\]"
|
||||||
|
|
||||||
def replace_pic_id(match: re.Match) -> str:
|
def replace_pic_id(match: re.Match[str]) -> str:
|
||||||
nonlocal current_pic_counter
|
nonlocal current_pic_counter
|
||||||
nonlocal pic_counter
|
nonlocal pic_counter
|
||||||
pic_id = match.group(1)
|
pic_id = match.group(1)
|
||||||
|
|
@ -421,7 +424,8 @@ def _build_readable_messages_internal(
|
||||||
if pic_id not in pic_description_cache:
|
if pic_id not in pic_description_cache:
|
||||||
description = "内容正在阅读,请稍等"
|
description = "内容正在阅读,请稍等"
|
||||||
try:
|
try:
|
||||||
image = Images.get_or_none(Images.image_id == pic_id)
|
with get_db_session() as session:
|
||||||
|
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
|
||||||
if image and image.description:
|
if image and image.description:
|
||||||
description = image.description
|
description = image.description
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -438,16 +442,11 @@ def _build_readable_messages_internal(
|
||||||
|
|
||||||
# 1: 获取发送者信息并提取消息组件
|
# 1: 获取发送者信息并提取消息组件
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message.is_action_record:
|
user_info = message.user_info
|
||||||
# 对于动作记录,也处理图片ID
|
platform = user_info.platform
|
||||||
content = process_pic_ids(message.display_message)
|
user_id = user_info.user_id
|
||||||
detailed_messages_raw.append((message.time, message.user_nickname, content, True))
|
user_nickname = user_info.user_nickname
|
||||||
continue
|
user_cardname = user_info.user_cardname
|
||||||
|
|
||||||
platform = message.user_platform
|
|
||||||
user_id = message.user_id
|
|
||||||
user_nickname = message.user_nickname
|
|
||||||
user_cardname = message.user_cardname
|
|
||||||
|
|
||||||
timestamp = message.time
|
timestamp = message.time
|
||||||
content = message.display_message or message.processed_plain_text or ""
|
content = message.display_message or message.processed_plain_text or ""
|
||||||
|
|
@ -549,13 +548,8 @@ def _build_readable_messages_internal(
|
||||||
message_id = timestamp_to_id_mapping.get(timestamp, "")
|
message_id = timestamp_to_id_mapping.get(timestamp, "")
|
||||||
id_prefix = f"[{message_id}]" if message_id else ""
|
id_prefix = f"[{message_id}]" if message_id else ""
|
||||||
|
|
||||||
if is_action:
|
|
||||||
# 对于动作记录,使用特殊格式
|
|
||||||
output_lines.append(f"{id_prefix}{readable_time}, {content}")
|
|
||||||
else:
|
|
||||||
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
output_lines.append(f"{id_prefix}{readable_time}, {name}: {content}")
|
||||||
output_lines.append("\n") # 在每个消息块后添加换行,保持可读性
|
output_lines.append("\n")
|
||||||
|
|
||||||
prev_timestamp = timestamp
|
prev_timestamp = timestamp
|
||||||
|
|
||||||
formatted_string = "".join(output_lines).strip()
|
formatted_string = "".join(output_lines).strip()
|
||||||
|
|
@ -592,7 +586,8 @@ def build_pic_mapping_info(pic_id_mapping: Dict[str, str]) -> str:
|
||||||
# 从数据库中获取图片描述
|
# 从数据库中获取图片描述
|
||||||
description = "内容正在阅读,请稍等"
|
description = "内容正在阅读,请稍等"
|
||||||
try:
|
try:
|
||||||
image = Images.get_or_none(Images.image_id == pic_id)
|
with get_db_session() as session:
|
||||||
|
image = session.get(Images, int(pic_id)) if pic_id.isdigit() else None
|
||||||
if image and image.description:
|
if image and image.description:
|
||||||
description = image.description
|
description = image.description
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -663,7 +658,7 @@ async def build_readable_messages_with_list(
|
||||||
允许通过参数控制格式化行为。
|
允许通过参数控制格式化行为。
|
||||||
"""
|
"""
|
||||||
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
formatted_string, details_list, pic_id_mapping, _ = _build_readable_messages_internal(
|
||||||
[MessageAndActionModel.from_DatabaseMessages(msg) for msg in messages],
|
messages,
|
||||||
replace_bot_name,
|
replace_bot_name,
|
||||||
timestamp_mode,
|
timestamp_mode,
|
||||||
truncate,
|
truncate,
|
||||||
|
|
@ -754,7 +749,7 @@ def build_readable_messages(
|
||||||
filtered_messages = []
|
filtered_messages = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# 获取消息内容
|
# 获取消息内容
|
||||||
content = msg.processed_plain_text
|
content = msg.processed_plain_text or ""
|
||||||
# 移除表情包
|
# 移除表情包
|
||||||
emoji_pattern = r"\[表情包:[^\]]+\]"
|
emoji_pattern = r"\[表情包:[^\]]+\]"
|
||||||
content = re.sub(emoji_pattern, "", content)
|
content = re.sub(emoji_pattern, "", content)
|
||||||
|
|
@ -765,17 +760,14 @@ def build_readable_messages(
|
||||||
|
|
||||||
messages = filtered_messages
|
messages = filtered_messages
|
||||||
|
|
||||||
copy_messages: List[MessageAndActionModel] = []
|
copy_messages: List[DatabaseMessages] = []
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if remove_emoji_stickers:
|
if remove_emoji_stickers:
|
||||||
# 创建 MessageAndActionModel 但移除表情包
|
|
||||||
model = MessageAndActionModel.from_DatabaseMessages(msg)
|
|
||||||
# 移除表情包
|
# 移除表情包
|
||||||
if model.processed_plain_text:
|
msg.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", msg.processed_plain_text or "")
|
||||||
model.processed_plain_text = re.sub(r"\[表情包:[^\]]+\]", "", model.processed_plain_text)
|
copy_messages.append(msg)
|
||||||
copy_messages.append(model)
|
|
||||||
else:
|
else:
|
||||||
copy_messages.append(MessageAndActionModel.from_DatabaseMessages(msg))
|
copy_messages.append(msg)
|
||||||
|
|
||||||
if show_actions and copy_messages:
|
if show_actions and copy_messages:
|
||||||
# 获取所有消息的时间范围
|
# 获取所有消息的时间范围
|
||||||
|
|
@ -786,40 +778,45 @@ def build_readable_messages(
|
||||||
chat_id = messages[0].chat_id if messages else None
|
chat_id = messages[0].chat_id if messages else None
|
||||||
|
|
||||||
# 获取这个时间范围内的动作记录,并匹配chat_id
|
# 获取这个时间范围内的动作记录,并匹配chat_id
|
||||||
actions_in_range = (
|
with get_db_session() as session:
|
||||||
ActionRecords.select()
|
actions_in_range = session.exec(
|
||||||
.where(
|
select(ActionRecord)
|
||||||
(ActionRecords.time >= min_time) & (ActionRecords.time <= max_time) & (ActionRecords.chat_id == chat_id)
|
.where(col(ActionRecord.timestamp) >= datetime.fromtimestamp(min_time))
|
||||||
)
|
.where(col(ActionRecord.timestamp) <= datetime.fromtimestamp(max_time))
|
||||||
.order_by(ActionRecords.time)
|
.where(col(ActionRecord.session_id) == chat_id)
|
||||||
)
|
.order_by(col(ActionRecord.timestamp))
|
||||||
|
).all()
|
||||||
|
|
||||||
# 获取最新消息之后的第一个动作记录
|
# 获取最新消息之后的第一个动作记录
|
||||||
action_after_latest = (
|
with get_db_session() as session:
|
||||||
ActionRecords.select()
|
action_after_latest = session.exec(
|
||||||
.where((ActionRecords.time > max_time) & (ActionRecords.chat_id == chat_id))
|
select(ActionRecord)
|
||||||
.order_by(ActionRecords.time)
|
.where(col(ActionRecord.timestamp) > datetime.fromtimestamp(max_time))
|
||||||
|
.where(col(ActionRecord.session_id) == chat_id)
|
||||||
|
.order_by(col(ActionRecord.timestamp))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
).all()
|
||||||
|
|
||||||
# 合并两部分动作记录
|
# 合并两部分动作记录
|
||||||
actions: List[ActionRecords] = list(actions_in_range) + list(action_after_latest)
|
actions: List[ActionRecord] = list(actions_in_range) + list(action_after_latest)
|
||||||
|
|
||||||
# 将动作记录转换为消息格式
|
# 将动作记录转换为消息格式
|
||||||
for action in actions:
|
for action in actions:
|
||||||
# 只有当build_into_prompt为True时才添加动作记录
|
# 只有当build_into_prompt为True时才添加动作记录
|
||||||
if action.action_build_into_prompt:
|
action_display_prompt = action.action_display_prompt or ""
|
||||||
action_msg = MessageAndActionModel(
|
if action_display_prompt:
|
||||||
time=float(action.time), # type: ignore
|
action_msg = DatabaseMessages(
|
||||||
user_id=global_config.bot.qq_account, # 使用机器人的QQ账号
|
message_id=f"action_{action.action_id}",
|
||||||
user_platform=global_config.bot.platform, # 使用机器人的平台
|
time=float(action.timestamp.timestamp()),
|
||||||
user_nickname=global_config.bot.nickname, # 使用机器人的用户名
|
chat_id=chat_id or "",
|
||||||
user_cardname="", # 机器人没有群名片
|
processed_plain_text=action_display_prompt,
|
||||||
processed_plain_text=f"{action.action_prompt_display}",
|
display_message=action_display_prompt,
|
||||||
display_message=f"{action.action_prompt_display}",
|
user_platform=global_config.bot.platform,
|
||||||
chat_info_platform=str(action.chat_info_platform),
|
user_id=str(global_config.bot.qq_account),
|
||||||
is_action_record=True, # 添加标识字段
|
user_nickname=global_config.bot.nickname,
|
||||||
action_name=str(action.action_name), # 保存动作名称
|
user_cardname="",
|
||||||
|
chat_info_platform=str(global_config.bot.platform),
|
||||||
|
chat_info_stream_id=chat_id or "",
|
||||||
)
|
)
|
||||||
copy_messages.append(action_msg)
|
copy_messages.append(action_msg)
|
||||||
|
|
||||||
|
|
@ -1026,17 +1023,13 @@ async def get_person_id_list(messages: List[Dict[str, Any]]) -> List[str]:
|
||||||
person_ids_set = set() # 使用集合来自动去重
|
person_ids_set = set() # 使用集合来自动去重
|
||||||
|
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
platform: str = msg.get("user_platform") # type: ignore
|
platform = msg.get("user_platform") or ""
|
||||||
user_id: str = msg.get("user_id") # type: ignore
|
user_id = msg.get("user_id") or ""
|
||||||
|
|
||||||
# 检查必要信息是否存在 且 不是机器人自己
|
# 检查必要信息是否存在 且 不是机器人自己
|
||||||
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
if not all([platform, user_id]) or user_id == global_config.bot.qq_account:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 添加空值检查,防止 platform 为 None 时出错
|
|
||||||
if platform is None:
|
|
||||||
platform = "unknown"
|
|
||||||
|
|
||||||
if person_id := get_person_id(platform, user_id):
|
if person_id := get_person_id(platform, user_id):
|
||||||
person_ids_set.add(person_id)
|
person_ids_set.add(person_id)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,18 +1,21 @@
|
||||||
import base64
|
import base64
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import hashlib
|
|
||||||
import uuid
|
import uuid
|
||||||
import io
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from typing import Optional, Tuple
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
|
from sqlmodel import select, col
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Images, ImageDescriptions, EmojiDescriptionCache
|
from src.common.database.database_model import Images, ImageType
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
|
@ -38,11 +41,7 @@ class ImageManager:
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
self.vlm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="image")
|
||||||
|
|
||||||
try:
|
get_db_session()
|
||||||
db.connect(reuse_if_open=True)
|
|
||||||
db.create_tables([Images, ImageDescriptions, EmojiDescriptionCache], safe=True)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库连接或表创建失败: {e}")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self._cleanup_invalid_descriptions()
|
self._cleanup_invalid_descriptions()
|
||||||
|
|
@ -72,9 +71,11 @@ class ImageManager:
|
||||||
Optional[str]: 描述文本,如果不存在则返回None
|
Optional[str]: 描述文本,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
record = ImageDescriptions.get_or_none(
|
with get_db_session() as session:
|
||||||
(ImageDescriptions.image_description_hash == image_hash) & (ImageDescriptions.type == description_type)
|
statement = select(Images).where(
|
||||||
|
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
|
||||||
)
|
)
|
||||||
|
record = session.exec(statement).first()
|
||||||
return record.description if record else None
|
return record.description if record else None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
logger.error(f"从数据库获取描述失败 (Peewee): {str(e)}")
|
||||||
|
|
@ -90,15 +91,27 @@ class ImageManager:
|
||||||
description_type: 描述类型 ('emoji' 或 'image')
|
description_type: 描述类型 ('emoji' 或 'image')
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
current_timestamp = time.time()
|
with get_db_session() as session:
|
||||||
defaults = {"description": description, "timestamp": current_timestamp}
|
statement = select(Images).where(
|
||||||
desc_obj, created = ImageDescriptions.get_or_create(
|
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType(description_type))
|
||||||
image_description_hash=image_hash, type=description_type, defaults=defaults
|
|
||||||
)
|
)
|
||||||
if not created: # 如果记录已存在,则更新
|
record = session.exec(statement).first()
|
||||||
desc_obj.description = description
|
if record:
|
||||||
desc_obj.timestamp = current_timestamp
|
record.description = description
|
||||||
desc_obj.save()
|
session.add(record)
|
||||||
|
return
|
||||||
|
|
||||||
|
new_record = Images(
|
||||||
|
image_hash=image_hash,
|
||||||
|
description=description,
|
||||||
|
full_path="",
|
||||||
|
image_type=ImageType(description_type),
|
||||||
|
query_count=0,
|
||||||
|
is_registered=False,
|
||||||
|
is_banned=False,
|
||||||
|
vlm_processed=True,
|
||||||
|
)
|
||||||
|
session.add(new_record)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
logger.error(f"保存描述到数据库失败 (Peewee): {str(e)}")
|
||||||
|
|
||||||
|
|
@ -107,20 +120,18 @@ class ImageManager:
|
||||||
"""清理数据库中 description 为空或为 'None' 的记录"""
|
"""清理数据库中 description 为空或为 'None' 的记录"""
|
||||||
invalid_values = ["", "None"]
|
invalid_values = ["", "None"]
|
||||||
|
|
||||||
# 清理 Images 表
|
with get_db_session() as session:
|
||||||
deleted_images = (
|
statement = (
|
||||||
Images.delete().where((Images.description >> None) | (Images.description << invalid_values)).execute()
|
select(Images)
|
||||||
|
.where(col(Images.description).is_(None) | col(Images.description).in_(invalid_values))
|
||||||
|
.limit(1000)
|
||||||
)
|
)
|
||||||
|
records = session.exec(statement).all()
|
||||||
|
for record in records:
|
||||||
|
session.delete(record)
|
||||||
|
|
||||||
# 清理 ImageDescriptions 表
|
if records:
|
||||||
deleted_descriptions = (
|
logger.info(f"[清理完成] 删除 Images: {len(records)} 条")
|
||||||
ImageDescriptions.delete()
|
|
||||||
.where((ImageDescriptions.description >> None) | (ImageDescriptions.description << invalid_values))
|
|
||||||
.execute()
|
|
||||||
)
|
|
||||||
|
|
||||||
if deleted_images or deleted_descriptions:
|
|
||||||
logger.info(f"[清理完成] 删除 Images: {deleted_images} 条, ImageDescriptions: {deleted_descriptions} 条")
|
|
||||||
else:
|
else:
|
||||||
logger.info("[清理完成] 未发现无效描述记录")
|
logger.info("[清理完成] 未发现无效描述记录")
|
||||||
|
|
||||||
|
|
@ -128,19 +139,15 @@ class ImageManager:
|
||||||
def _cleanup_emoji_from_image_descriptions():
|
def _cleanup_emoji_from_image_descriptions():
|
||||||
"""清理Images和ImageDescriptions表中type为emoji的记录(已迁移到EmojiDescriptionCache)"""
|
"""清理Images和ImageDescriptions表中type为emoji的记录(已迁移到EmojiDescriptionCache)"""
|
||||||
try:
|
try:
|
||||||
# 清理Images表中type为emoji的记录
|
with get_db_session() as session:
|
||||||
deleted_images = Images.delete().where(Images.type == "emoji").execute()
|
statement = select(Images).where(col(Images.image_type) == ImageType.EMOJI)
|
||||||
|
records = session.exec(statement).all()
|
||||||
|
for record in records:
|
||||||
|
session.delete(record)
|
||||||
|
|
||||||
# 清理ImageDescriptions表中type为emoji的记录
|
total_deleted = len(records)
|
||||||
deleted_descriptions = ImageDescriptions.delete().where(ImageDescriptions.type == "emoji").execute()
|
|
||||||
|
|
||||||
total_deleted = deleted_images + deleted_descriptions
|
|
||||||
if total_deleted > 0:
|
if total_deleted > 0:
|
||||||
logger.info(
|
logger.info(f"[清理完成] 从Images表中删除 {total_deleted} 条emoji类型记录")
|
||||||
f"[清理完成] 从Images表中删除 {deleted_images} 条emoji类型记录, "
|
|
||||||
f"从ImageDescriptions表中删除 {deleted_descriptions} 条emoji类型记录, "
|
|
||||||
f"共删除 {total_deleted} 条记录"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
|
logger.info("[清理完成] Images和ImageDescriptions表中未发现emoji类型记录")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -148,14 +155,14 @@ class ImageManager:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_emoji_tag(self, image_base64: str) -> str:
|
async def get_emoji_tag(self, image_base64: str) -> str:
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||||
|
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = emoji_manager_instance
|
||||||
if isinstance(image_base64, str):
|
if isinstance(image_base64, str):
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||||
if not emoji:
|
if not emoji:
|
||||||
return "[表情包:未知]"
|
return "[表情包:未知]"
|
||||||
emotion_list = emoji.emotion
|
emotion_list = emoji.emotion
|
||||||
|
|
@ -175,14 +182,14 @@ class ImageManager:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
from src.chat.emoji_system.emoji_manager import EMOJI_DIR
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||||
|
|
||||||
# 确保目录存在
|
# 确保目录存在
|
||||||
os.makedirs(EMOJI_DIR, exist_ok=True)
|
os.makedirs(EMOJI_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 检查是否已存在该表情包(通过哈希值)
|
# 检查是否已存在该表情包(通过哈希值)
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = emoji_manager_instance
|
||||||
existing_emoji = await emoji_manager.get_emoji_from_manager(image_hash)
|
existing_emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||||
if existing_emoji:
|
if existing_emoji:
|
||||||
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
|
logger.debug(f"[自动保存] 表情包已注册,跳过保存: {image_hash[:8]}...")
|
||||||
return
|
return
|
||||||
|
|
@ -212,14 +219,15 @@ class ImageManager:
|
||||||
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
image_base64 = image_base64.encode("ascii", errors="ignore").decode("ascii")
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||||
|
|
||||||
# 优先使用EmojiManager查询已注册表情包的描述
|
# 优先使用EmojiManager查询已注册表情包的描述
|
||||||
try:
|
try:
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager as emoji_manager_instance
|
||||||
|
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_manager = emoji_manager_instance
|
||||||
tags = await emoji_manager.get_emoji_tag_by_hash(image_hash)
|
emoji = emoji_manager.get_emoji_by_hash(image_hash)
|
||||||
|
tags = emoji.emotion if emoji else None
|
||||||
if tags:
|
if tags:
|
||||||
tag_str = ",".join(tags)
|
tag_str = ",".join(tags)
|
||||||
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
|
logger.info(f"[缓存命中] 使用已注册表情包描述: {tag_str}...")
|
||||||
|
|
@ -227,29 +235,26 @@ class ImageManager:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"查询EmojiManager时出错: {e}")
|
logger.debug(f"查询EmojiManager时出错: {e}")
|
||||||
|
|
||||||
# 查询EmojiDescriptionCache表的缓存(包含描述和情感标签)
|
|
||||||
try:
|
try:
|
||||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
with get_db_session() as session:
|
||||||
|
statement = select(Images).where(
|
||||||
|
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||||
|
)
|
||||||
|
cache_record = session.exec(statement).first()
|
||||||
if cache_record:
|
if cache_record:
|
||||||
# 优先使用情感标签,如果没有则使用详细描述
|
|
||||||
result_text = ""
|
result_text = ""
|
||||||
if cache_record.emotion_tags:
|
if cache_record.emotion:
|
||||||
logger.info(
|
logger.info(f"[缓存命中] 使用Images表中的情感标签: {cache_record.emotion[:50]}...")
|
||||||
f"[缓存命中] 使用EmojiDescriptionCache表中的情感标签: {cache_record.emotion_tags[:50]}..."
|
result_text = f"[表情包:{cache_record.emotion}]"
|
||||||
)
|
|
||||||
result_text = f"[表情包:{cache_record.emotion_tags}]"
|
|
||||||
elif cache_record.description:
|
elif cache_record.description:
|
||||||
logger.info(
|
logger.info(f"[缓存命中] 使用Images表中的描述: {cache_record.description[:50]}...")
|
||||||
f"[缓存命中] 使用EmojiDescriptionCache表中的描述: {cache_record.description[:50]}..."
|
|
||||||
)
|
|
||||||
result_text = f"[表情包:{cache_record.description}]"
|
result_text = f"[表情包:{cache_record.description}]"
|
||||||
|
|
||||||
# 即使缓存命中,如果启用了steal_emoji,也检查是否需要保存文件
|
|
||||||
if result_text:
|
if result_text:
|
||||||
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
await self._save_emoji_file_if_needed(image_base64, image_hash, image_format)
|
||||||
return result_text
|
return result_text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"查询EmojiDescriptionCache时出错: {e}")
|
logger.debug(f"查询Images缓存时出错: {e}")
|
||||||
|
|
||||||
# === 二步走识别流程 ===
|
# === 二步走识别流程 ===
|
||||||
|
|
||||||
|
|
@ -309,33 +314,42 @@ class ImageManager:
|
||||||
|
|
||||||
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
logger.debug(f"[emoji识别] 详细描述: {detailed_description[:50]}... -> 情感标签: {final_emotion}")
|
||||||
|
|
||||||
# 再次检查缓存(防止并发情况下其他线程已经保存)
|
|
||||||
try:
|
try:
|
||||||
cache_record = EmojiDescriptionCache.get_or_none(EmojiDescriptionCache.emoji_hash == image_hash)
|
with get_db_session() as session:
|
||||||
if cache_record and cache_record.emotion_tags:
|
statement = select(Images).where(
|
||||||
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion_tags}")
|
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||||
return f"[表情包:{cache_record.emotion_tags}]"
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"再次查询EmojiDescriptionCache时出错: {e}")
|
|
||||||
|
|
||||||
# 保存识别出的详细描述和情感标签到 emoji_description_cache
|
|
||||||
try:
|
|
||||||
current_timestamp = time.time()
|
|
||||||
cache_record, created = EmojiDescriptionCache.get_or_create(
|
|
||||||
emoji_hash=image_hash,
|
|
||||||
defaults={
|
|
||||||
"description": detailed_description,
|
|
||||||
"emotion_tags": final_emotion,
|
|
||||||
"timestamp": current_timestamp,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
if not created:
|
cache_record = session.exec(statement).first()
|
||||||
# 更新已有记录
|
if cache_record and cache_record.emotion:
|
||||||
|
logger.warning(f"虽然生成了描述,但是找到缓存表情包情感标签: {cache_record.emotion}")
|
||||||
|
return f"[表情包:{cache_record.emotion}]"
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"再次查询Images缓存时出错: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with get_db_session() as session:
|
||||||
|
statement = select(Images).where(
|
||||||
|
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.EMOJI)
|
||||||
|
)
|
||||||
|
cache_record = session.exec(statement).first()
|
||||||
|
if cache_record:
|
||||||
cache_record.description = detailed_description
|
cache_record.description = detailed_description
|
||||||
cache_record.emotion_tags = final_emotion
|
cache_record.emotion = final_emotion
|
||||||
cache_record.timestamp = current_timestamp
|
session.add(cache_record)
|
||||||
cache_record.save()
|
else:
|
||||||
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到EmojiDescriptionCache: {image_hash[:8]}...")
|
cache_record = Images(
|
||||||
|
image_hash=image_hash,
|
||||||
|
description=detailed_description,
|
||||||
|
full_path="",
|
||||||
|
image_type=ImageType.EMOJI,
|
||||||
|
emotion=final_emotion,
|
||||||
|
query_count=0,
|
||||||
|
is_registered=False,
|
||||||
|
is_banned=False,
|
||||||
|
vlm_processed=True,
|
||||||
|
)
|
||||||
|
session.add(cache_record)
|
||||||
|
logger.info(f"[缓存保存] 表情包描述和情感标签已保存到Images: {image_hash[:8]}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
|
logger.error(f"保存表情包描述和情感标签缓存失败: {str(e)}")
|
||||||
|
|
||||||
|
|
@ -358,14 +372,13 @@ class ImageManager:
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
# 优先检查Images表中是否已有完整的描述
|
# 优先检查Images表中是否已有完整的描述
|
||||||
existing_image = Images.get_or_none(Images.emoji_hash == image_hash)
|
with get_db_session() as session:
|
||||||
|
statement = select(Images).where(col(Images.image_hash) == image_hash)
|
||||||
|
existing_image = session.exec(statement).first()
|
||||||
if existing_image:
|
if existing_image:
|
||||||
# 更新计数
|
existing_image.query_count += 1
|
||||||
if hasattr(existing_image, "count") and existing_image.count is not None:
|
with get_db_session() as session:
|
||||||
existing_image.count += 1
|
session.add(existing_image)
|
||||||
else:
|
|
||||||
existing_image.count = 1
|
|
||||||
existing_image.save()
|
|
||||||
|
|
||||||
# 如果已有描述,直接返回
|
# 如果已有描述,直接返回
|
||||||
if existing_image.description:
|
if existing_image.description:
|
||||||
|
|
@ -377,7 +390,7 @@ class ImageManager:
|
||||||
return f"[图片:{cached_description}]"
|
return f"[图片:{cached_description}]"
|
||||||
|
|
||||||
# 调用AI获取描述
|
# 调用AI获取描述
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||||
prompt = global_config.personality.visual_style
|
prompt = global_config.personality.visual_style
|
||||||
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
logger.info(f"[VLM调用] 为图片生成新描述 (Hash: {image_hash[:8]}...)")
|
||||||
description, _ = await self.vlm.generate_response_for_image(
|
description, _ = await self.vlm.generate_response_for_image(
|
||||||
|
|
@ -402,26 +415,27 @@ class ImageManager:
|
||||||
|
|
||||||
# 保存到数据库,补充缺失字段
|
# 保存到数据库,补充缺失字段
|
||||||
if existing_image:
|
if existing_image:
|
||||||
existing_image.path = file_path
|
existing_image.full_path = file_path
|
||||||
existing_image.description = description
|
existing_image.description = description
|
||||||
existing_image.timestamp = current_timestamp
|
existing_image.record_time = datetime.fromtimestamp(current_timestamp)
|
||||||
if not hasattr(existing_image, "image_id") or not existing_image.image_id:
|
|
||||||
existing_image.image_id = str(uuid.uuid4())
|
|
||||||
if not hasattr(existing_image, "vlm_processed") or existing_image.vlm_processed is None:
|
|
||||||
existing_image.vlm_processed = True
|
existing_image.vlm_processed = True
|
||||||
existing_image.save()
|
with get_db_session() as session:
|
||||||
|
session.add(existing_image)
|
||||||
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
logger.debug(f"[数据库] 更新已有图片记录: {image_hash[:8]}...")
|
||||||
else:
|
else:
|
||||||
Images.create(
|
with get_db_session() as session:
|
||||||
image_id=str(uuid.uuid4()),
|
new_record = Images(
|
||||||
emoji_hash=image_hash,
|
image_hash=image_hash,
|
||||||
path=file_path,
|
|
||||||
type="image",
|
|
||||||
description=description,
|
description=description,
|
||||||
timestamp=current_timestamp,
|
full_path=file_path,
|
||||||
|
image_type=ImageType.IMAGE,
|
||||||
|
query_count=1,
|
||||||
|
is_registered=False,
|
||||||
|
is_banned=False,
|
||||||
|
record_time=datetime.fromtimestamp(current_timestamp),
|
||||||
vlm_processed=True,
|
vlm_processed=True,
|
||||||
count=1,
|
|
||||||
)
|
)
|
||||||
|
session.add(new_record)
|
||||||
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
logger.debug(f"[数据库] 创建新图片记录: {image_hash[:8]}...")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
logger.error(f"保存图片文件或元数据失败: {str(e)}")
|
||||||
|
|
@ -575,29 +589,16 @@ class ImageManager:
|
||||||
image_bytes = base64.b64decode(image_base64)
|
image_bytes = base64.b64decode(image_base64)
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
if existing_image := Images.get_or_none(Images.emoji_hash == image_hash):
|
with get_db_session() as session:
|
||||||
# 检查是否缺少必要字段,如果缺少则创建新记录
|
statement = select(Images).where(
|
||||||
if (
|
(col(Images.image_hash) == image_hash) & (col(Images.image_type) == ImageType.IMAGE)
|
||||||
not hasattr(existing_image, "image_id")
|
)
|
||||||
or not existing_image.image_id
|
existing_image = session.exec(statement).first()
|
||||||
or not hasattr(existing_image, "count")
|
if existing_image:
|
||||||
or existing_image.count is None
|
existing_image.query_count += 1
|
||||||
or not hasattr(existing_image, "vlm_processed")
|
session.add(existing_image)
|
||||||
or existing_image.vlm_processed is None
|
return str(existing_image.id), f"[picid:{existing_image.id}]"
|
||||||
):
|
|
||||||
logger.debug(f"图片记录缺少必要字段,补全旧记录: {image_hash}")
|
|
||||||
if not existing_image.image_id:
|
|
||||||
existing_image.image_id = str(uuid.uuid4())
|
|
||||||
if existing_image.count is None:
|
|
||||||
existing_image.count = 0
|
|
||||||
if existing_image.vlm_processed is None:
|
|
||||||
existing_image.vlm_processed = False
|
|
||||||
|
|
||||||
existing_image.count += 1
|
|
||||||
existing_image.save()
|
|
||||||
return existing_image.image_id, f"[picid:{existing_image.image_id}]"
|
|
||||||
else:
|
|
||||||
# print(f"图片不存在: {image_hash}")
|
|
||||||
image_id = str(uuid.uuid4())
|
image_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 保存新图片
|
# 保存新图片
|
||||||
|
|
@ -612,15 +613,19 @@ class ImageManager:
|
||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
|
|
||||||
# 保存到数据库
|
# 保存到数据库
|
||||||
Images.create(
|
with get_db_session() as session:
|
||||||
image_id=image_id,
|
new_record = Images(
|
||||||
emoji_hash=image_hash,
|
image_hash=image_hash,
|
||||||
path=file_path,
|
description="",
|
||||||
type="image",
|
full_path=file_path,
|
||||||
timestamp=current_timestamp,
|
image_type=ImageType.IMAGE,
|
||||||
|
query_count=1,
|
||||||
|
is_registered=False,
|
||||||
|
is_banned=False,
|
||||||
|
record_time=datetime.fromtimestamp(current_timestamp),
|
||||||
vlm_processed=False,
|
vlm_processed=False,
|
||||||
count=1,
|
|
||||||
)
|
)
|
||||||
|
session.add(new_record)
|
||||||
|
|
||||||
# 启动异步VLM处理
|
# 启动异步VLM处理
|
||||||
await self._process_image_with_vlm(image_id, image_base64)
|
await self._process_image_with_vlm(image_id, image_base64)
|
||||||
|
|
@ -647,17 +652,26 @@ class ImageManager:
|
||||||
image_hash = hashlib.md5(image_bytes).hexdigest()
|
image_hash = hashlib.md5(image_bytes).hexdigest()
|
||||||
|
|
||||||
# 获取当前图片记录
|
# 获取当前图片记录
|
||||||
image = Images.get(Images.image_id == image_id)
|
with get_db_session() as session:
|
||||||
|
image = session.get(Images, int(image_id)) if image_id.isdigit() else None
|
||||||
|
if image is None:
|
||||||
|
logger.warning(f"未找到图片记录: {image_id}")
|
||||||
|
return
|
||||||
|
|
||||||
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
# 优先检查是否已有其他相同哈希的图片记录包含描述
|
||||||
existing_with_description = Images.get_or_none(
|
with get_db_session() as session:
|
||||||
(Images.emoji_hash == image_hash) & (Images.description.is_null(False)) & (Images.description != "")
|
statement = select(Images).where(
|
||||||
|
(col(Images.image_hash) == image_hash)
|
||||||
|
& (col(Images.description).is_not(None))
|
||||||
|
& (col(Images.description) != "")
|
||||||
)
|
)
|
||||||
|
existing_with_description = session.exec(statement).first()
|
||||||
if existing_with_description and existing_with_description.id != image.id:
|
if existing_with_description and existing_with_description.id != image.id:
|
||||||
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
logger.debug(f"[缓存复用] 从其他相同图片记录复用描述: {existing_with_description.description[:50]}...")
|
||||||
image.description = existing_with_description.description
|
image.description = existing_with_description.description
|
||||||
image.vlm_processed = True
|
image.vlm_processed = True
|
||||||
image.save()
|
with get_db_session() as session:
|
||||||
|
session.add(image)
|
||||||
# 同时保存到ImageDescriptions表作为备用缓存
|
# 同时保存到ImageDescriptions表作为备用缓存
|
||||||
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
self._save_description_to_db(image_hash, existing_with_description.description, "image")
|
||||||
return
|
return
|
||||||
|
|
@ -667,11 +681,12 @@ class ImageManager:
|
||||||
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
logger.debug(f"[缓存复用] 从ImageDescriptions表复用描述: {cached_description[:50]}...")
|
||||||
image.description = cached_description
|
image.description = cached_description
|
||||||
image.vlm_processed = True
|
image.vlm_processed = True
|
||||||
image.save()
|
with get_db_session() as session:
|
||||||
|
session.add(image)
|
||||||
return
|
return
|
||||||
|
|
||||||
# 获取图片格式
|
# 获取图片格式
|
||||||
image_format = Image.open(io.BytesIO(image_bytes)).format.lower() # type: ignore
|
image_format = (Image.open(io.BytesIO(image_bytes)).format or "png").lower() # type: ignore
|
||||||
|
|
||||||
# 构建prompt
|
# 构建prompt
|
||||||
prompt = global_config.personality.visual_style
|
prompt = global_config.personality.visual_style
|
||||||
|
|
@ -692,7 +707,8 @@ class ImageManager:
|
||||||
# 更新数据库
|
# 更新数据库
|
||||||
image.description = description
|
image.description = description
|
||||||
image.vlm_processed = True
|
image.vlm_processed = True
|
||||||
image.save()
|
with get_db_session() as session:
|
||||||
|
session.add(image)
|
||||||
|
|
||||||
# 保存描述到ImageDescriptions表作为备用缓存
|
# 保存描述到ImageDescriptions表作为备用缓存
|
||||||
self._save_description_to_db(image_hash, description, "image")
|
self._save_description_to_db(image_hash, description, "image")
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
from typing import Self, TypeVar, Generic, TYPE_CHECKING
|
from dataclasses import is_dataclass
|
||||||
|
from typing import Any, Dict, Self, TypeVar, Generic, TYPE_CHECKING
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
|
|
@ -15,9 +16,23 @@ class BaseDataModel:
|
||||||
return copy.deepcopy(self)
|
return copy.deepcopy(self)
|
||||||
|
|
||||||
|
|
||||||
|
def transform_class_to_dict(obj: Any) -> Dict[str, Any]:
|
||||||
|
if obj is None:
|
||||||
|
return {}
|
||||||
|
if is_dataclass(obj):
|
||||||
|
return obj.__dict__
|
||||||
|
if hasattr(obj, "dict"):
|
||||||
|
return obj.dict()
|
||||||
|
if hasattr(obj, "model_dump"):
|
||||||
|
return obj.model_dump()
|
||||||
|
if hasattr(obj, "__dict__"):
|
||||||
|
return obj.__dict__
|
||||||
|
return {"value": obj}
|
||||||
|
|
||||||
|
|
||||||
class BaseDatabaseDataModel(ABC, Generic[T]):
|
class BaseDatabaseDataModel(ABC, Generic[T]):
|
||||||
@abstractmethod
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
def from_db_instance(cls, db_record: T) -> Self:
|
def from_db_instance(cls, db_record: T) -> Self:
|
||||||
"""从数据库实例创建数据模型对象"""
|
"""从数据库实例创建数据模型对象"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from . import BaseDataModel
|
||||||
|
|
||||||
|
|
||||||
|
class ReplyContentType(Enum):
|
||||||
|
TEXT = "text"
|
||||||
|
IMAGE = "image"
|
||||||
|
EMOJI = "emoji"
|
||||||
|
COMMAND = "command"
|
||||||
|
VOICE = "voice"
|
||||||
|
HYBRID = "hybrid"
|
||||||
|
FORWARD = "forward"
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReplyContent:
|
||||||
|
content_type: ReplyContentType | str
|
||||||
|
content: Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ForwardNode:
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
user_nickname: Optional[str] = None
|
||||||
|
content: Union[str, List[ReplyContent], None] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def construct_as_id_reference(cls, message_id: str) -> "ForwardNode":
|
||||||
|
return cls(content=message_id)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def construct_as_created_node(
|
||||||
|
cls,
|
||||||
|
user_id: str,
|
||||||
|
user_nickname: str,
|
||||||
|
content: List[ReplyContent],
|
||||||
|
) -> "ForwardNode":
|
||||||
|
return cls(user_id=user_id, user_nickname=user_nickname, content=content)
|
||||||
|
|
||||||
|
|
||||||
|
class ReplySetModel(BaseDataModel):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.reply_data: List[ReplyContent] = []
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.reply_data)
|
||||||
|
|
||||||
|
def add_text_content(self, text: str) -> None:
|
||||||
|
self.reply_data.append(ReplyContent(content_type=ReplyContentType.TEXT, content=text))
|
||||||
|
|
||||||
|
def add_voice_content(self, voice_base64: str) -> None:
|
||||||
|
self.reply_data.append(ReplyContent(content_type=ReplyContentType.VOICE, content=voice_base64))
|
||||||
|
|
||||||
|
def add_hybrid_content_by_raw(self, message_tuple_list: Iterable[Tuple[ReplyContentType | str, str]]) -> None:
|
||||||
|
hybrid_contents: List[ReplyContent] = []
|
||||||
|
for content_type, content in message_tuple_list:
|
||||||
|
hybrid_contents.append(
|
||||||
|
ReplyContent(content_type=self._normalize_content_type(content_type), content=content)
|
||||||
|
)
|
||||||
|
self.reply_data.append(ReplyContent(content_type=ReplyContentType.HYBRID, content=hybrid_contents))
|
||||||
|
|
||||||
|
def add_forward_content(self, forward_nodes: List[ForwardNode]) -> None:
|
||||||
|
self.reply_data.append(ReplyContent(content_type=ReplyContentType.FORWARD, content=forward_nodes))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_content_type(content_type: ReplyContentType | str) -> ReplyContentType | str:
|
||||||
|
if isinstance(content_type, ReplyContentType):
|
||||||
|
return content_type
|
||||||
|
if isinstance(content_type, str):
|
||||||
|
for item in ReplyContentType:
|
||||||
|
if item.value == content_type:
|
||||||
|
return item
|
||||||
|
return content_type
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from sqlalchemy.orm import sessionmaker
|
from pathlib import Path
|
||||||
|
from typing import Generator, TYPE_CHECKING
|
||||||
|
|
||||||
from sqlalchemy import event
|
from sqlalchemy import event
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlmodel import create_engine, Session
|
from sqlalchemy.orm import sessionmaker
|
||||||
from typing import TYPE_CHECKING, Generator
|
from sqlmodel import SQLModel, Session, create_engine
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sqlite3 import Connection as SQLite3Connection
|
from sqlite3 import Connection as SQLite3Connection
|
||||||
|
|
@ -53,6 +54,19 @@ SessionLocal = sessionmaker(
|
||||||
class_=Session,
|
class_=Session,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_db_initialized = False
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_database() -> None:
|
||||||
|
global _db_initialized
|
||||||
|
if _db_initialized:
|
||||||
|
return
|
||||||
|
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
import src.common.database.database_model # noqa: F401
|
||||||
|
|
||||||
|
SQLModel.metadata.create_all(engine)
|
||||||
|
_db_initialized = True
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||||
|
|
@ -87,6 +101,7 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||||
- auto_commit=True 时,成功执行完会自动提交
|
- auto_commit=True 时,成功执行完会自动提交
|
||||||
- auto_commit=False 时,需要手动调用 session.commit()
|
- auto_commit=False 时,需要手动调用 session.commit()
|
||||||
"""
|
"""
|
||||||
|
initialize_database()
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
|
|
@ -120,6 +135,7 @@ def get_db() -> Generator[Session, None, None]:
|
||||||
Yields:
|
Yields:
|
||||||
Session: SQLAlchemy 数据库会话
|
Session: SQLAlchemy 数据库会话
|
||||||
"""
|
"""
|
||||||
|
initialize_database()
|
||||||
session = SessionLocal()
|
session = SessionLocal()
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
|
|
|
||||||
|
|
@ -1,35 +1,153 @@
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from typing import List, Any, Optional
|
import json
|
||||||
|
|
||||||
from src.config.config import global_config
|
from sqlalchemy import func
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from sqlmodel import col, select
|
||||||
|
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import Messages
|
from src.common.database.database_model import Messages
|
||||||
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _model_to_instance(model_instance: Any) -> DatabaseMessages:
|
FIELD_MAP: dict[str, Any] = {
|
||||||
"""
|
"time": Messages.timestamp,
|
||||||
将 Peewee 模型实例转换为字典。
|
"timestamp": Messages.timestamp,
|
||||||
"""
|
"chat_id": Messages.session_id,
|
||||||
if isinstance(model_instance, dict):
|
"session_id": Messages.session_id,
|
||||||
return DatabaseMessages(**model_instance)
|
"user_id": Messages.user_id,
|
||||||
if hasattr(model_instance, "model_dump"):
|
"message_id": Messages.message_id,
|
||||||
return DatabaseMessages(**model_instance.model_dump())
|
"group_id": Messages.group_id,
|
||||||
return DatabaseMessages(**model_instance.__dict__)
|
"platform": Messages.platform,
|
||||||
|
"is_command": Messages.is_command,
|
||||||
|
"is_mentioned": Messages.is_mentioned,
|
||||||
|
"is_at": Messages.is_at,
|
||||||
|
"is_emoji": Messages.is_emoji,
|
||||||
|
"is_picid": Messages.is_picture,
|
||||||
|
"is_picture": Messages.is_picture,
|
||||||
|
"reply_to": Messages.reply_to,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_additional_config(message: Messages) -> dict[str, Any]:
|
||||||
|
if not message.additional_config:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(message.additional_config)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return {}
|
||||||
|
if isinstance(parsed, dict):
|
||||||
|
return parsed
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_optional_str(value: object) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
try:
|
||||||
|
return json.dumps(value, ensure_ascii=False)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _message_to_instance(message: Messages) -> DatabaseMessages:
|
||||||
|
config = _parse_additional_config(message)
|
||||||
|
timestamp_value = message.timestamp
|
||||||
|
if isinstance(timestamp_value, datetime):
|
||||||
|
time_value = timestamp_value.timestamp()
|
||||||
|
else:
|
||||||
|
time_value = float(timestamp_value)
|
||||||
|
selected_expressions = _normalize_optional_str(config.get("selected_expressions"))
|
||||||
|
priority_info = _normalize_optional_str(config.get("priority_info"))
|
||||||
|
return DatabaseMessages(
|
||||||
|
message_id=message.message_id,
|
||||||
|
time=time_value,
|
||||||
|
chat_id=message.session_id,
|
||||||
|
reply_to=message.reply_to,
|
||||||
|
interest_value=config.get("interest_value"),
|
||||||
|
key_words=_normalize_optional_str(config.get("key_words")),
|
||||||
|
key_words_lite=_normalize_optional_str(config.get("key_words_lite")),
|
||||||
|
is_mentioned=message.is_mentioned,
|
||||||
|
is_at=message.is_at,
|
||||||
|
reply_probability_boost=config.get("reply_probability_boost"),
|
||||||
|
processed_plain_text=message.processed_plain_text,
|
||||||
|
display_message=message.display_message,
|
||||||
|
priority_mode=_normalize_optional_str(config.get("priority_mode")),
|
||||||
|
priority_info=priority_info,
|
||||||
|
additional_config=message.additional_config,
|
||||||
|
is_emoji=message.is_emoji,
|
||||||
|
is_picid=message.is_picture,
|
||||||
|
is_command=message.is_command,
|
||||||
|
intercept_message_level=config.get("intercept_message_level", 0),
|
||||||
|
is_notify=message.is_notify,
|
||||||
|
selected_expressions=selected_expressions,
|
||||||
|
user_id=message.user_id,
|
||||||
|
user_nickname=message.user_nickname,
|
||||||
|
user_cardname=message.user_cardname,
|
||||||
|
user_platform=message.platform,
|
||||||
|
chat_info_group_id=message.group_id,
|
||||||
|
chat_info_group_name=message.group_name,
|
||||||
|
chat_info_group_platform=message.platform,
|
||||||
|
chat_info_user_id=message.user_id,
|
||||||
|
chat_info_user_nickname=message.user_nickname,
|
||||||
|
chat_info_user_cardname=message.user_cardname,
|
||||||
|
chat_info_user_platform=message.platform,
|
||||||
|
chat_info_stream_id=message.session_id,
|
||||||
|
chat_info_platform=message.platform,
|
||||||
|
chat_info_create_time=0.0,
|
||||||
|
chat_info_last_active_time=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_datetime(value: Any) -> Any:
|
||||||
|
if isinstance(value, (int, float)):
|
||||||
|
return datetime.fromtimestamp(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_value_for_field(field: Any, value: Any) -> Any:
|
||||||
|
if field is Messages.timestamp:
|
||||||
|
return _coerce_datetime(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_list(value: Any) -> list[Any]:
|
||||||
|
if value is None:
|
||||||
|
return []
|
||||||
|
if isinstance(value, list):
|
||||||
|
return value
|
||||||
|
if isinstance(value, tuple):
|
||||||
|
return list(value)
|
||||||
|
if isinstance(value, set):
|
||||||
|
return list(value)
|
||||||
|
return [value]
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_field(field_name: str) -> Any | None:
|
||||||
|
if field_name in FIELD_MAP:
|
||||||
|
return FIELD_MAP[field_name]
|
||||||
|
if hasattr(Messages, field_name):
|
||||||
|
return getattr(Messages, field_name)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def find_messages(
|
def find_messages(
|
||||||
message_filter: dict[str, Any],
|
message_filter: dict[str, Any],
|
||||||
sort: Optional[List[tuple[str, int]]] = None,
|
sort: list[tuple[str, int]] | None = None,
|
||||||
limit: int = 0,
|
limit: int = 0,
|
||||||
limit_mode: str = "latest",
|
limit_mode: str = "latest",
|
||||||
filter_bot=False,
|
filter_bot: bool = False,
|
||||||
filter_command=False,
|
filter_command: bool = False,
|
||||||
filter_intercept_message_level: Optional[int] = None,
|
filter_intercept_message_level: int | None = None,
|
||||||
) -> List[DatabaseMessages]:
|
) -> list[DatabaseMessages]:
|
||||||
"""
|
"""
|
||||||
根据提供的过滤器、排序和限制条件查找消息。
|
根据提供的过滤器、排序和限制条件查找消息。
|
||||||
|
|
||||||
|
|
@ -43,92 +161,79 @@ def find_messages(
|
||||||
消息字典列表,如果出错则返回空列表。
|
消息字典列表,如果出错则返回空列表。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
query = Messages.select()
|
conditions: list[Any] = []
|
||||||
|
|
||||||
# 应用过滤器
|
|
||||||
if message_filter:
|
if message_filter:
|
||||||
conditions = []
|
|
||||||
for key, value in message_filter.items():
|
for key, value in message_filter.items():
|
||||||
if hasattr(Messages, key):
|
field = _resolve_field(key)
|
||||||
field = getattr(Messages, key)
|
if field is None:
|
||||||
|
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||||
|
continue
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
# 处理 MongoDB 风格的操作符
|
|
||||||
for op, op_value in value.items():
|
for op, op_value in value.items():
|
||||||
|
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
||||||
if op == "$gt":
|
if op == "$gt":
|
||||||
conditions.append(field > op_value)
|
conditions.append(field > coerced_value)
|
||||||
elif op == "$lt":
|
elif op == "$lt":
|
||||||
conditions.append(field < op_value)
|
conditions.append(field < coerced_value)
|
||||||
elif op == "$gte":
|
elif op == "$gte":
|
||||||
conditions.append(field >= op_value)
|
conditions.append(field >= coerced_value)
|
||||||
elif op == "$lte":
|
elif op == "$lte":
|
||||||
conditions.append(field <= op_value)
|
conditions.append(field <= coerced_value)
|
||||||
elif op == "$ne":
|
elif op == "$ne":
|
||||||
conditions.append(field != op_value)
|
conditions.append(field != coerced_value)
|
||||||
elif op == "$in":
|
elif op == "$in":
|
||||||
conditions.append(field.in_(op_value))
|
conditions.append(field.in_(_ensure_list(coerced_value)))
|
||||||
elif op == "$nin":
|
elif op == "$nin":
|
||||||
conditions.append(field.not_in(op_value))
|
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
||||||
else:
|
else:
|
||||||
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||||||
else:
|
else:
|
||||||
# 直接相等比较
|
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
||||||
conditions.append(field == value)
|
conditions.append(field == coerced_value)
|
||||||
else:
|
|
||||||
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
|
||||||
if conditions:
|
|
||||||
query = query.where(*conditions)
|
|
||||||
|
|
||||||
# 排除 id 为 "notice" 的消息
|
|
||||||
query = query.where(Messages.message_id != "notice")
|
|
||||||
|
|
||||||
|
conditions.append(Messages.message_id != "notice")
|
||||||
if filter_bot:
|
if filter_bot:
|
||||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
conditions.append(Messages.user_id != global_config.bot.qq_account)
|
||||||
|
|
||||||
if filter_command:
|
if filter_command:
|
||||||
# 使用按位取反构造 Peewee 的 NOT 条件,避免直接与 False 比较
|
conditions.append(Messages.is_command == False) # noqa: E712
|
||||||
query = query.where(~Messages.is_command)
|
|
||||||
|
|
||||||
if filter_intercept_message_level is not None:
|
|
||||||
# 过滤掉所有 intercept_message_level > filter_intercept_message_level 的消息
|
|
||||||
query = query.where(Messages.intercept_message_level <= filter_intercept_message_level)
|
|
||||||
|
|
||||||
|
statement = select(Messages).where(*conditions)
|
||||||
if limit > 0:
|
if limit > 0:
|
||||||
if limit_mode == "earliest":
|
if limit_mode == "earliest":
|
||||||
# 获取时间最早的 limit 条记录,已经是正序
|
statement = statement.order_by(col(Messages.timestamp)).limit(limit)
|
||||||
query = query.order_by("time").limit(limit)
|
with get_db_session() as session:
|
||||||
peewee_results = list(query)
|
results = list(session.exec(statement).all())
|
||||||
else: # 默认为 'latest'
|
else:
|
||||||
# 获取时间最晚的 limit 条记录
|
statement = statement.order_by(col(Messages.timestamp).desc()).limit(limit)
|
||||||
query = query.order_by("-time").limit(limit)
|
with get_db_session() as session:
|
||||||
latest_results_peewee = list(query)
|
results = list(session.exec(statement).all())
|
||||||
# 将结果按时间正序排列
|
results = list(reversed(results))
|
||||||
peewee_results = sorted(
|
|
||||||
latest_results_peewee,
|
|
||||||
key=lambda msg: msg.get("time", 0) if isinstance(msg, dict) else getattr(msg, "time", 0),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# limit 为 0 时,应用传入的 sort 参数
|
|
||||||
if sort:
|
if sort:
|
||||||
peewee_sort_terms = []
|
order_terms: list[Any] = []
|
||||||
for field_name, direction in sort:
|
for field_name, direction in sort:
|
||||||
if hasattr(Messages, field_name):
|
sort_field = _resolve_field(field_name)
|
||||||
field = getattr(Messages, field_name)
|
if sort_field is None:
|
||||||
if direction == 1: # ASC
|
|
||||||
peewee_sort_terms.append(field_name)
|
|
||||||
elif direction == -1: # DESC
|
|
||||||
peewee_sort_terms.append(f"-{field_name}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
|
||||||
else:
|
|
||||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||||
if peewee_sort_terms:
|
continue
|
||||||
query = query.order_by(*peewee_sort_terms)
|
order_terms.append(sort_field.asc() if direction == 1 else sort_field.desc())
|
||||||
peewee_results = list(query)
|
if order_terms:
|
||||||
|
statement = statement.order_by(*order_terms)
|
||||||
|
with get_db_session() as session:
|
||||||
|
results = list(session.exec(statement).all())
|
||||||
|
|
||||||
return [_model_to_instance(msg) for msg in peewee_results]
|
if filter_intercept_message_level is not None:
|
||||||
|
filtered_results = []
|
||||||
|
for msg in results:
|
||||||
|
config = _parse_additional_config(msg)
|
||||||
|
if config.get("intercept_message_level", 0) <= filter_intercept_message_level:
|
||||||
|
filtered_results.append(msg)
|
||||||
|
results = filtered_results
|
||||||
|
|
||||||
|
return [_message_to_instance(msg) for msg in results]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = (
|
log_message = (
|
||||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
f"使用 SQLModel 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||||
+ traceback.format_exc()
|
+ traceback.format_exc()
|
||||||
)
|
)
|
||||||
logger.error(log_message)
|
logger.error(log_message)
|
||||||
|
|
@ -146,54 +251,42 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||||
符合条件的消息数量,如果出错则返回 0。
|
符合条件的消息数量,如果出错则返回 0。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
query = Messages.select()
|
conditions: list[Any] = []
|
||||||
|
|
||||||
# 应用过滤器
|
|
||||||
if message_filter:
|
if message_filter:
|
||||||
conditions = []
|
|
||||||
for key, value in message_filter.items():
|
for key, value in message_filter.items():
|
||||||
if hasattr(Messages, key):
|
field = _resolve_field(key)
|
||||||
field = getattr(Messages, key)
|
if field is None:
|
||||||
if isinstance(value, dict):
|
|
||||||
# 处理 MongoDB 风格的操作符
|
|
||||||
for op, op_value in value.items():
|
|
||||||
if op == "$gt":
|
|
||||||
conditions.append(field > op_value)
|
|
||||||
elif op == "$lt":
|
|
||||||
conditions.append(field < op_value)
|
|
||||||
elif op == "$gte":
|
|
||||||
conditions.append(field >= op_value)
|
|
||||||
elif op == "$lte":
|
|
||||||
conditions.append(field <= op_value)
|
|
||||||
elif op == "$ne":
|
|
||||||
conditions.append(field != op_value)
|
|
||||||
elif op == "$in":
|
|
||||||
conditions.append(field.in_(op_value))
|
|
||||||
elif op == "$nin":
|
|
||||||
conditions.append(field.not_in(op_value))
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 直接相等比较
|
|
||||||
conditions.append(field == value)
|
|
||||||
else:
|
|
||||||
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
|
||||||
if conditions:
|
continue
|
||||||
query = query.where(*conditions)
|
if isinstance(value, dict):
|
||||||
|
for op, op_value in value.items():
|
||||||
|
coerced_value = _coerce_datetime(op_value) if field is Messages.timestamp else op_value
|
||||||
|
if op == "$gt":
|
||||||
|
conditions.append(field > coerced_value)
|
||||||
|
elif op == "$lt":
|
||||||
|
conditions.append(field < coerced_value)
|
||||||
|
elif op == "$gte":
|
||||||
|
conditions.append(field >= coerced_value)
|
||||||
|
elif op == "$lte":
|
||||||
|
conditions.append(field <= coerced_value)
|
||||||
|
elif op == "$ne":
|
||||||
|
conditions.append(field != coerced_value)
|
||||||
|
elif op == "$in":
|
||||||
|
conditions.append(field.in_(_ensure_list(coerced_value)))
|
||||||
|
elif op == "$nin":
|
||||||
|
conditions.append(field.not_in(_ensure_list(coerced_value)))
|
||||||
|
else:
|
||||||
|
logger.warning(f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
|
||||||
|
else:
|
||||||
|
coerced_value = _coerce_datetime(value) if field is Messages.timestamp else value
|
||||||
|
conditions.append(field == coerced_value)
|
||||||
|
|
||||||
# 排除 id 为 "notice" 的消息
|
conditions.append(Messages.message_id != "notice")
|
||||||
query = query.where(Messages.message_id != "notice")
|
statement = select(func.count()).select_from(Messages).where(*conditions)
|
||||||
|
with get_db_session() as session:
|
||||||
count = query.count()
|
result = session.exec(statement).one()
|
||||||
return count
|
return int(result or 0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
log_message = f"使用 SQLModel 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||||
logger.error(log_message)
|
logger.error(log_message)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
|
||||||
# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
|
|
||||||
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
|
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ from PIL import Image
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db # 确保 db 被导入用于 create_tables
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import LLMUsage
|
from src.common.database.database_model import ModelUsage, ModelUser
|
||||||
from src.config.model_configs import ModelInfo
|
from src.config.model_configs import ModelInfo
|
||||||
from .payload_content.message import Message, MessageBuilder
|
from .payload_content.message import Message, MessageBuilder
|
||||||
from .model_client.base_client import UsageRecord
|
from .model_client.base_client import UsageRecord
|
||||||
|
|
@ -158,12 +158,7 @@ class LLMUsageRecorder:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
try:
|
pass
|
||||||
# 使用 Peewee 创建表,safe=True 表示如果表已存在则不会抛出错误
|
|
||||||
db.create_tables([LLMUsage], safe=True)
|
|
||||||
# logger.debug("LLMUsage 表已初始化/确保存在。")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建 LLMUsage 表失败: {str(e)}")
|
|
||||||
|
|
||||||
def record_usage_to_database(
|
def record_usage_to_database(
|
||||||
self,
|
self,
|
||||||
|
|
@ -178,22 +173,22 @@ class LLMUsageRecorder:
|
||||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||||
total_cost = round(input_cost + output_cost, 6)
|
total_cost = round(input_cost + output_cost, 6)
|
||||||
try:
|
try:
|
||||||
# 使用 Peewee 模型创建记录
|
with get_db_session() as session:
|
||||||
LLMUsage.create(
|
record = ModelUsage(
|
||||||
model_name=model_info.model_identifier,
|
model_name=model_info.model_identifier,
|
||||||
model_assign_name=model_info.name,
|
model_assign_name=model_info.name,
|
||||||
model_api_provider=model_info.api_provider,
|
model_api_provider_name=model_info.api_provider,
|
||||||
user_id=user_id,
|
|
||||||
request_type=request_type,
|
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
|
user_type=ModelUser.SYSTEM,
|
||||||
|
request_type=request_type,
|
||||||
|
time_cost=round(time_cost or 0.0, 3),
|
||||||
|
timestamp=datetime.now(),
|
||||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||||
completion_tokens=model_usage.completion_tokens or 0,
|
completion_tokens=model_usage.completion_tokens or 0,
|
||||||
total_tokens=model_usage.total_tokens or 0,
|
total_tokens=model_usage.total_tokens or 0,
|
||||||
cost=total_cost or 0.0,
|
cost=total_cost or 0.0,
|
||||||
time_cost=round(time_cost or 0.0, 3),
|
|
||||||
status="success",
|
|
||||||
timestamp=datetime.now(), # Peewee 会处理 DateTimeField
|
|
||||||
)
|
)
|
||||||
|
session.add(record)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||||
f"用户: {user_id}, 类型: {request_type}, "
|
f"用户: {user_id}, 类型: {request_type}, "
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ from src.manager.async_task_manager import async_task_manager
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
|
|
||||||
# from src.chat.utils.token_statistics import TokenStatisticsTask
|
# from src.chat.utils.token_statistics import TokenStatisticsTask
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
|
|
@ -107,7 +107,7 @@ class MainSystem:
|
||||||
plugin_manager.load_all_plugins()
|
plugin_manager.load_all_plugins()
|
||||||
|
|
||||||
# 初始化表情管理器
|
# 初始化表情管理器
|
||||||
get_emoji_manager().initialize()
|
emoji_manager.load_emojis_from_db()
|
||||||
logger.info("表情包管理器初始化成功")
|
logger.info("表情包管理器初始化成功")
|
||||||
|
|
||||||
# 初始化聊天管理器
|
# 初始化聊天管理器
|
||||||
|
|
@ -141,7 +141,7 @@ class MainSystem:
|
||||||
"""调度定时任务"""
|
"""调度定时任务"""
|
||||||
try:
|
try:
|
||||||
tasks = [
|
tasks = [
|
||||||
get_emoji_manager().start_periodic_check_register(),
|
emoji_manager.periodic_emoji_maintenance(),
|
||||||
start_dream_scheduler(),
|
start_dream_scheduler(),
|
||||||
self.app.run(),
|
self.app.run(),
|
||||||
self.server.run(),
|
self.server.run(),
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
from datetime import datetime
|
||||||
|
from typing import List, Dict, Any, Optional, Tuple, Callable, cast
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.prompt.prompt_manager import prompt_manager
|
from src.prompt.prompt_manager import prompt_manager
|
||||||
from src.plugin_system.apis import llm_api
|
from src.plugin_system.apis import llm_api
|
||||||
from src.common.database.database_model import ThinkingBack
|
from sqlmodel import select, col
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import ThinkingQuestion
|
||||||
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
from src.memory_system.retrieval_tools import get_tool_registry, init_all_tools
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType, Message
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
|
|
@ -29,13 +32,16 @@ def _cleanup_stale_not_found_thinking_back() -> None:
|
||||||
|
|
||||||
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
|
threshold_time = now - THINKING_BACK_NOT_FOUND_RETENTION_SECONDS
|
||||||
try:
|
try:
|
||||||
deleted_rows = (
|
with get_db_session() as session:
|
||||||
ThinkingBack.delete()
|
statement = select(ThinkingQuestion).where(
|
||||||
.where((ThinkingBack.found_answer == 0) & (ThinkingBack.update_time < threshold_time))
|
(ThinkingQuestion.found_answer == False)
|
||||||
.execute()
|
& (ThinkingQuestion.updated_timestamp < datetime.fromtimestamp(threshold_time))
|
||||||
)
|
)
|
||||||
if deleted_rows:
|
records = session.exec(statement).all()
|
||||||
logger.info(f"清理过期的未找到答案thinking_back记录 {deleted_rows} 条")
|
for record in records:
|
||||||
|
session.delete(record)
|
||||||
|
if records:
|
||||||
|
logger.info(f"清理过期的未找到答案thinking_question记录 {len(records)} 条")
|
||||||
_last_not_found_cleanup_ts = now
|
_last_not_found_cleanup_ts = now
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
logger.error(f"清理未找到答案的thinking_back记录失败: {e}")
|
||||||
|
|
@ -249,12 +255,12 @@ async def _react_agent_solve_question(
|
||||||
# 后续迭代都复用第一次构建的head_prompt
|
# 后续迭代都复用第一次构建的head_prompt
|
||||||
head_prompt = first_head_prompt
|
head_prompt = first_head_prompt
|
||||||
|
|
||||||
def message_factory(
|
def _build_messages(
|
||||||
_client,
|
_client,
|
||||||
*,
|
*,
|
||||||
_head_prompt: str = head_prompt,
|
_head_prompt: str = head_prompt,
|
||||||
_conversation_messages: List[Message] = conversation_messages,
|
_conversation_messages: List[Message] = conversation_messages,
|
||||||
) -> List[Message]:
|
):
|
||||||
messages: List[Message] = []
|
messages: List[Message] = []
|
||||||
|
|
||||||
system_builder = MessageBuilder()
|
system_builder = MessageBuilder()
|
||||||
|
|
@ -266,6 +272,7 @@ async def _react_agent_solve_question(
|
||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
message_factory_fn: Callable[..., List[Message]] = _build_messages # pyright: ignore[reportGeneralTypeIssues]
|
||||||
(
|
(
|
||||||
success,
|
success,
|
||||||
response,
|
response,
|
||||||
|
|
@ -273,7 +280,7 @@ async def _react_agent_solve_question(
|
||||||
model_name,
|
model_name,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
) = await llm_api.generate_with_model_with_tools_by_message_factory(
|
||||||
message_factory,
|
message_factory_fn, # type: ignore[arg-type]
|
||||||
model_config=model_config.model_task_config.tool_use,
|
model_config=model_config.model_task_config.tool_use,
|
||||||
tool_options=tool_definitions,
|
tool_options=tool_definitions,
|
||||||
request_type="memory.react",
|
request_type="memory.react",
|
||||||
|
|
@ -304,7 +311,12 @@ async def _react_agent_solve_question(
|
||||||
assistant_message = assistant_builder.build()
|
assistant_message = assistant_builder.build()
|
||||||
|
|
||||||
# 记录思考步骤
|
# 记录思考步骤
|
||||||
step = {"iteration": iteration + 1, "thought": response, "actions": [], "observations": []}
|
step: Dict[str, Any] = {
|
||||||
|
"iteration": iteration + 1,
|
||||||
|
"thought": response,
|
||||||
|
"actions": [],
|
||||||
|
"observations": [],
|
||||||
|
}
|
||||||
|
|
||||||
if assistant_message:
|
if assistant_message:
|
||||||
conversation_messages.append(assistant_message)
|
conversation_messages.append(assistant_message)
|
||||||
|
|
@ -417,20 +429,21 @@ async def _react_agent_solve_question(
|
||||||
"action_params": {"information": parsed_information or ""},
|
"action_params": {"information": parsed_information or ""},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if parsed_information and parsed_information.strip():
|
parsed_info_text = parsed_information if isinstance(parsed_information, str) else ""
|
||||||
|
if parsed_info_text.strip():
|
||||||
step["observations"] = [f"检测到return_information{format_type}调用,返回信息"]
|
step["observations"] = [f"检测到return_information{format_type}调用,返回信息"]
|
||||||
thinking_steps.append(step)
|
thinking_steps.append(step)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_information[:100]}..."
|
f"{react_log_prefix}第 {iteration + 1} 次迭代 通过return_information{format_type}返回信息: {parsed_info_text[:100]}..."
|
||||||
)
|
)
|
||||||
|
|
||||||
_log_conversation_messages(
|
_log_conversation_messages(
|
||||||
conversation_messages,
|
conversation_messages,
|
||||||
head_prompt=first_head_prompt,
|
head_prompt=first_head_prompt,
|
||||||
final_status=f"返回信息:{parsed_information}",
|
final_status=f"返回信息:{parsed_info_text}",
|
||||||
)
|
)
|
||||||
|
|
||||||
return True, parsed_information, thinking_steps, False
|
return True, parsed_info_text, thinking_steps, False
|
||||||
else:
|
else:
|
||||||
# 信息为空,直接退出查询
|
# 信息为空,直接退出查询
|
||||||
step["observations"] = [f"检测到return_information{format_type}调用,信息为空"]
|
step["observations"] = [f"检测到return_information{format_type}调用,信息为空"]
|
||||||
|
|
@ -776,15 +789,16 @@ def _get_recent_query_history(chat_id: str, time_window_seconds: float = 600.0)
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
start_time = current_time - time_window_seconds
|
start_time = current_time - time_window_seconds
|
||||||
|
|
||||||
# 查询最近时间窗口内的记录,按更新时间倒序
|
with get_db_session() as session:
|
||||||
records = (
|
statement = (
|
||||||
ThinkingBack.select()
|
select(ThinkingQuestion)
|
||||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.update_time >= start_time))
|
.where(col(ThinkingQuestion.context) == chat_id)
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||||
.limit(5) # 最多返回5条最近的记录
|
.limit(5)
|
||||||
)
|
)
|
||||||
|
records = session.exec(statement).all()
|
||||||
|
|
||||||
if not records.exists():
|
if not records:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
history_lines = []
|
history_lines = []
|
||||||
|
|
@ -828,20 +842,19 @@ def _get_recent_found_answers(chat_id: str, time_window_seconds: float = 600.0)
|
||||||
start_time = current_time - time_window_seconds
|
start_time = current_time - time_window_seconds
|
||||||
|
|
||||||
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
|
# 查询最近时间窗口内已找到答案的记录,按更新时间倒序
|
||||||
records = (
|
with get_db_session() as session:
|
||||||
ThinkingBack.select()
|
statement = (
|
||||||
.where(
|
select(ThinkingQuestion)
|
||||||
(ThinkingBack.chat_id == chat_id)
|
.where(col(ThinkingQuestion.context) == chat_id)
|
||||||
& (ThinkingBack.update_time >= start_time)
|
.where(col(ThinkingQuestion.found_answer) == True)
|
||||||
& (ThinkingBack.found_answer == 1)
|
.where(col(ThinkingQuestion.answer).is_not(None))
|
||||||
& (ThinkingBack.answer.is_null(False))
|
.where(col(ThinkingQuestion.answer) != "")
|
||||||
& (ThinkingBack.answer != "")
|
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||||
)
|
.limit(3)
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
|
||||||
.limit(3) # 最多返回5条最近的记录
|
|
||||||
)
|
)
|
||||||
|
records = session.exec(statement).all()
|
||||||
|
|
||||||
if not records.exists():
|
if not records:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
found_answers = []
|
found_answers = []
|
||||||
|
|
@ -873,36 +886,35 @@ def _store_thinking_back(
|
||||||
now = time.time()
|
now = time.time()
|
||||||
|
|
||||||
# 先查询是否已存在相同chat_id和问题的记录
|
# 先查询是否已存在相同chat_id和问题的记录
|
||||||
existing = (
|
with get_db_session() as session:
|
||||||
ThinkingBack.select()
|
statement = (
|
||||||
.where((ThinkingBack.chat_id == chat_id) & (ThinkingBack.question == question))
|
select(ThinkingQuestion)
|
||||||
.order_by(ThinkingBack.update_time.desc())
|
.where(col(ThinkingQuestion.context) == chat_id)
|
||||||
|
.where(col(ThinkingQuestion.question) == question)
|
||||||
|
.order_by(col(ThinkingQuestion.updated_timestamp).desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
|
record = session.exec(statement).first()
|
||||||
if existing.exists():
|
if record:
|
||||||
# 更新现有记录
|
|
||||||
record = existing.get()
|
|
||||||
record.context = context
|
record.context = context
|
||||||
record.found_answer = found_answer
|
record.found_answer = found_answer
|
||||||
record.answer = answer
|
record.answer = answer
|
||||||
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
|
record.thinking_steps = json.dumps(thinking_steps, ensure_ascii=False)
|
||||||
record.update_time = now
|
record.updated_timestamp = datetime.fromtimestamp(now)
|
||||||
record.save()
|
session.add(record)
|
||||||
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
|
logger.info(f"已更新思考过程到数据库,问题: {question[:50]}...")
|
||||||
else:
|
return
|
||||||
# 创建新记录
|
|
||||||
ThinkingBack.create(
|
new_record = ThinkingQuestion(
|
||||||
chat_id=chat_id,
|
|
||||||
question=question,
|
question=question,
|
||||||
context=context,
|
context=chat_id,
|
||||||
found_answer=found_answer,
|
found_answer=found_answer,
|
||||||
answer=answer,
|
answer=answer,
|
||||||
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
thinking_steps=json.dumps(thinking_steps, ensure_ascii=False),
|
||||||
create_time=now,
|
created_timestamp=datetime.fromtimestamp(now),
|
||||||
update_time=now,
|
updated_timestamp=datetime.fromtimestamp(now),
|
||||||
)
|
)
|
||||||
# logger.info(f"已创建思考过程到数据库,问题: {question[:50]}...")
|
session.add(new_record)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"存储思考过程失败: {e}")
|
logger.error(f"存储思考过程失败: {e}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,13 @@ import random
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from typing import Union, Optional
|
from typing import Union, Optional, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlmodel import col, select
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database import db
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
|
|
@ -35,24 +38,37 @@ def get_person_id(platform: str, user_id: Union[int, str]) -> str:
|
||||||
def get_person_id_by_person_name(person_name: str) -> str:
|
def get_person_id_by_person_name(person_name: str) -> str:
|
||||||
"""根据用户名获取用户ID"""
|
"""根据用户名获取用户ID"""
|
||||||
try:
|
try:
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_name == person_name)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_name) == person_name).limit(1)
|
||||||
|
record = session.exec(statement).first()
|
||||||
return record.person_id if record else ""
|
return record.person_id if record else ""
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"根据用户名 {person_name} 获取用户ID时出错 (Peewee): {e}")
|
logger.error(f"根据用户名 {person_name} 获取用户ID时出错: {e}")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def is_person_known(person_id: str = None, user_id: str = None, platform: str = None, person_name: str = None) -> bool: # type: ignore
|
def is_person_known(
|
||||||
|
person_id: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
platform: Optional[str] = None,
|
||||||
|
person_name: Optional[str] = None,
|
||||||
|
) -> bool:
|
||||||
if person_id:
|
if person_id:
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
elif user_id and platform:
|
elif user_id and platform:
|
||||||
person_id = get_person_id(platform, user_id)
|
person_id = get_person_id(platform, user_id)
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
elif person_name:
|
elif person_name:
|
||||||
person_id = get_person_id_by_person_name(person_name)
|
person_id = get_person_id_by_person_name(person_name)
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
return person.is_known if person else False
|
return person.is_known if person else False
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
@ -442,17 +458,18 @@ class Person:
|
||||||
def load_from_database(self):
|
def load_from_database(self):
|
||||||
"""从数据库加载个人信息数据"""
|
"""从数据库加载个人信息数据"""
|
||||||
try:
|
try:
|
||||||
# 查询数据库中的记录
|
with get_db_session() as session:
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
|
||||||
|
record = session.exec(statement).first()
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
self.user_id = record.user_id or ""
|
self.user_id = record.user_id or ""
|
||||||
self.platform = record.platform or ""
|
self.platform = record.platform or ""
|
||||||
self.is_known = record.is_known or False
|
self.is_known = record.is_known or False
|
||||||
self.nickname = record.nickname or ""
|
self.nickname = record.user_nickname or ""
|
||||||
self.person_name = record.person_name or self.nickname
|
self.person_name = record.person_name or self.nickname
|
||||||
self.name_reason = record.name_reason or None
|
self.name_reason = record.name_reason or None
|
||||||
self.know_times = record.know_times or 0
|
self.know_times = record.know_counts or 0
|
||||||
|
|
||||||
# 处理points字段(JSON格式的列表)
|
# 处理points字段(JSON格式的列表)
|
||||||
if record.memory_points:
|
if record.memory_points:
|
||||||
|
|
@ -470,16 +487,16 @@ class Person:
|
||||||
self.memory_points = []
|
self.memory_points = []
|
||||||
|
|
||||||
# 处理group_nick_name字段(JSON格式的列表)
|
# 处理group_nick_name字段(JSON格式的列表)
|
||||||
if record.group_nick_name:
|
if record.group_nickname:
|
||||||
try:
|
try:
|
||||||
loaded_group_nick_names = json.loads(record.group_nick_name)
|
loaded_group_nick_names = json.loads(record.group_nickname)
|
||||||
# 确保是列表格式
|
# 确保是列表格式
|
||||||
if isinstance(loaded_group_nick_names, list):
|
if isinstance(loaded_group_nick_names, list):
|
||||||
self.group_nick_name = loaded_group_nick_names
|
self.group_nick_name = loaded_group_nick_names
|
||||||
else:
|
else:
|
||||||
self.group_nick_name = []
|
self.group_nick_name = []
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
logger.warning(f"解析用户 {self.person_id} 的group_nick_name字段失败,使用默认值")
|
logger.warning(f"解析用户 {self.person_id} 的group_nickname字段失败,使用默认值")
|
||||||
self.group_nick_name = []
|
self.group_nick_name = []
|
||||||
else:
|
else:
|
||||||
self.group_nick_name = []
|
self.group_nick_name = []
|
||||||
|
|
@ -498,41 +515,54 @@ class Person:
|
||||||
if not self.is_known:
|
if not self.is_known:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
# 准备数据
|
memory_points_value = (
|
||||||
data = {
|
json.dumps([point for point in self.memory_points if point is not None], ensure_ascii=False)
|
||||||
"person_id": self.person_id,
|
|
||||||
"is_known": self.is_known,
|
|
||||||
"platform": self.platform,
|
|
||||||
"user_id": self.user_id,
|
|
||||||
"nickname": self.nickname,
|
|
||||||
"person_name": self.person_name,
|
|
||||||
"name_reason": self.name_reason,
|
|
||||||
"know_times": self.know_times,
|
|
||||||
"know_since": self.know_since,
|
|
||||||
"last_know": self.last_know,
|
|
||||||
"memory_points": json.dumps(
|
|
||||||
[point for point in self.memory_points if point is not None], ensure_ascii=False
|
|
||||||
)
|
|
||||||
if self.memory_points
|
if self.memory_points
|
||||||
else json.dumps([], ensure_ascii=False),
|
else json.dumps([], ensure_ascii=False)
|
||||||
"group_nick_name": json.dumps(self.group_nick_name, ensure_ascii=False)
|
)
|
||||||
|
group_nickname_value = (
|
||||||
|
json.dumps(self.group_nick_name, ensure_ascii=False)
|
||||||
if self.group_nick_name
|
if self.group_nick_name
|
||||||
else json.dumps([], ensure_ascii=False),
|
else json.dumps([], ensure_ascii=False)
|
||||||
}
|
)
|
||||||
|
first_known_time = datetime.fromtimestamp(self.know_since) if self.know_since else None
|
||||||
|
last_known_time = datetime.fromtimestamp(self.last_know) if self.last_know else None
|
||||||
|
|
||||||
# 检查记录是否存在
|
with get_db_session() as session:
|
||||||
record = PersonInfo.get_or_none(PersonInfo.person_id == self.person_id)
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == self.person_id).limit(1)
|
||||||
|
record = session.exec(statement).first()
|
||||||
|
|
||||||
if record:
|
if record:
|
||||||
# 更新现有记录
|
record.person_id = self.person_id
|
||||||
for field, value in data.items():
|
record.is_known = self.is_known
|
||||||
if hasattr(record, field):
|
record.platform = self.platform
|
||||||
setattr(record, field, value)
|
record.user_id = self.user_id
|
||||||
record.save()
|
record.user_nickname = self.nickname
|
||||||
|
record.person_name = self.person_name
|
||||||
|
record.name_reason = self.name_reason
|
||||||
|
record.know_counts = self.know_times
|
||||||
|
record.first_known_time = first_known_time
|
||||||
|
record.last_known_time = last_known_time
|
||||||
|
record.memory_points = memory_points_value
|
||||||
|
record.group_nickname = group_nickname_value
|
||||||
|
session.add(record)
|
||||||
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
logger.debug(f"已同步用户 {self.person_id} 的信息到数据库")
|
||||||
else:
|
else:
|
||||||
# 创建新记录
|
record = PersonInfo(
|
||||||
PersonInfo.create(**data)
|
person_id=self.person_id,
|
||||||
|
is_known=self.is_known,
|
||||||
|
platform=self.platform,
|
||||||
|
user_id=self.user_id,
|
||||||
|
user_nickname=self.nickname,
|
||||||
|
person_name=self.person_name,
|
||||||
|
name_reason=self.name_reason,
|
||||||
|
know_counts=self.know_times,
|
||||||
|
first_known_time=first_known_time,
|
||||||
|
last_known_time=last_known_time,
|
||||||
|
memory_points=memory_points_value,
|
||||||
|
group_nickname=group_nickname_value,
|
||||||
|
)
|
||||||
|
session.add(record)
|
||||||
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
logger.debug(f"已创建用户 {self.person_id} 的信息到数据库")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -621,30 +651,26 @@ class PersonInfoManager:
|
||||||
self.person_name_list = {}
|
self.person_name_list = {}
|
||||||
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
self.qv_name_llm = LLMRequest(model_set=model_config.model_task_config.utils, request_type="relation.qv_name")
|
||||||
try:
|
try:
|
||||||
db.connect(reuse_if_open=True)
|
with get_db_session() as _:
|
||||||
# 设置连接池参数
|
pass
|
||||||
if hasattr(db, "execute_sql"):
|
|
||||||
# 设置SQLite优化参数
|
|
||||||
db.execute_sql("PRAGMA cache_size = -64000") # 64MB缓存
|
|
||||||
db.execute_sql("PRAGMA temp_store = memory") # 临时存储在内存中
|
|
||||||
db.execute_sql("PRAGMA mmap_size = 268435456") # 256MB内存映射
|
|
||||||
db.create_tables([PersonInfo], safe=True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
logger.error(f"数据库连接或 PersonInfo 表创建失败: {e}")
|
||||||
|
|
||||||
# 初始化时读取所有person_name
|
# 初始化时读取所有person_name
|
||||||
try:
|
try:
|
||||||
for record in PersonInfo.select(PersonInfo.person_id, PersonInfo.person_name).where(
|
with get_db_session() as session:
|
||||||
PersonInfo.person_name.is_null(False)
|
statement = select(PersonInfo.person_id, PersonInfo.person_name).where(
|
||||||
):
|
col(PersonInfo.person_name).is_not(None)
|
||||||
if record.person_name:
|
)
|
||||||
self.person_name_list[record.person_id] = record.person_name
|
for person_id, person_name in session.exec(statement).all():
|
||||||
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称 (Peewee)")
|
if person_name:
|
||||||
|
self.person_name_list[person_id] = person_name
|
||||||
|
logger.debug(f"已加载 {len(self.person_name_list)} 个用户名称")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从 Peewee 加载 person_name_list 失败: {e}")
|
logger.error(f"加载 person_name_list 失败: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_json_from_text(text: str) -> dict:
|
def _extract_json_from_text(text: str) -> Dict[str, str]:
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
try:
|
try:
|
||||||
fixed_json = repair_json(text)
|
fixed_json = repair_json(text)
|
||||||
|
|
@ -744,7 +770,9 @@ class PersonInfoManager:
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def _db_check_name_exists_sync(name_to_check):
|
def _db_check_name_exists_sync(name_to_check):
|
||||||
return PersonInfo.select().where(PersonInfo.person_name == name_to_check).exists()
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo.person_id).where(col(PersonInfo.person_name) == name_to_check)
|
||||||
|
return session.exec(statement).first() is not None
|
||||||
|
|
||||||
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
if await asyncio.to_thread(_db_check_name_exists_sync, generated_nickname):
|
||||||
is_duplicate = True
|
is_duplicate = True
|
||||||
|
|
@ -804,7 +832,7 @@ async def store_person_memory_from_answer(person_name: str, memory_content: str,
|
||||||
|
|
||||||
if not person_id:
|
if not person_id:
|
||||||
# 如果通过person_name找不到,尝试从chat_stream获取user_info
|
# 如果通过person_name找不到,尝试从chat_stream获取user_info
|
||||||
if chat_stream.user_info:
|
if platform and chat_stream.user_info and chat_stream.user_info.user_id:
|
||||||
user_id = chat_stream.user_info.user_id
|
user_id = chat_stream.user_info.user_id
|
||||||
person_id = get_person_id(platform, user_id)
|
person_id = get_person_id(platform, user_id)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,9 @@ import uuid
|
||||||
|
|
||||||
from typing import Optional, Tuple, List, Dict, Any
|
from typing import Optional, Tuple, List, Dict, Any
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager, EMOJI_DIR
|
from src.chat.emoji_system.emoji_manager import emoji_manager, EMOJI_DIR
|
||||||
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image
|
from src.chat.utils.utils_image import image_path_to_base64, base64_to_image
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger = get_logger("emoji_api")
|
logger = get_logger("emoji_api")
|
||||||
|
|
||||||
|
|
@ -46,14 +47,15 @@ async def get_by_description(description: str) -> Optional[Tuple[str, str, str]]
|
||||||
try:
|
try:
|
||||||
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
logger.debug(f"[EmojiAPI] 根据描述获取表情包: {description}")
|
||||||
|
|
||||||
emoji_manager = get_emoji_manager()
|
emoji_obj = await emoji_manager.get_emoji_for_emotion(description)
|
||||||
emoji_result = await emoji_manager.get_emoji_for_text(description)
|
|
||||||
|
|
||||||
if not emoji_result:
|
if not emoji_obj:
|
||||||
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
|
logger.warning(f"[EmojiAPI] 未找到匹配描述 '{description}' 的表情包")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
emoji_path, emoji_description, matched_emotion = emoji_result
|
emoji_path = str(emoji_obj.full_path)
|
||||||
|
emoji_description = emoji_obj.description
|
||||||
|
matched_emotion = random.choice(emoji_obj.emotion) if emoji_obj.emotion else ""
|
||||||
emoji_base64 = image_path_to_base64(emoji_path)
|
emoji_base64 = image_path_to_base64(emoji_path)
|
||||||
|
|
||||||
if not emoji_base64:
|
if not emoji_base64:
|
||||||
|
|
@ -90,8 +92,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
all_emojis = emoji_manager.emojis
|
||||||
all_emojis = emoji_manager.emoji_objects
|
|
||||||
|
|
||||||
if not all_emojis:
|
if not all_emojis:
|
||||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||||
|
|
@ -114,7 +115,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for selected_emoji in selected_emojis:
|
for selected_emoji in selected_emojis:
|
||||||
emoji_base64 = image_path_to_base64(selected_emoji.full_path)
|
emoji_base64 = image_path_to_base64(str(selected_emoji.full_path))
|
||||||
|
|
||||||
if not emoji_base64:
|
if not emoji_base64:
|
||||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {selected_emoji.full_path}")
|
||||||
|
|
@ -123,7 +124,7 @@ async def get_random(count: Optional[int] = 1) -> List[Tuple[str, str, str]]:
|
||||||
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
matched_emotion = random.choice(selected_emoji.emotion) if selected_emoji.emotion else "随机表情"
|
||||||
|
|
||||||
# 记录使用次数
|
# 记录使用次数
|
||||||
emoji_manager.record_usage(selected_emoji.hash)
|
emoji_manager.update_emoji_usage(selected_emoji)
|
||||||
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
results.append((emoji_base64, selected_emoji.description, matched_emotion))
|
||||||
|
|
||||||
if not results and count > 0:
|
if not results and count > 0:
|
||||||
|
|
@ -158,8 +159,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||||
try:
|
try:
|
||||||
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
logger.info(f"[EmojiAPI] 根据情感获取表情包: {emotion}")
|
||||||
|
|
||||||
emoji_manager = get_emoji_manager()
|
all_emojis = emoji_manager.emojis
|
||||||
all_emojis = emoji_manager.emoji_objects
|
|
||||||
|
|
||||||
# 筛选匹配情感的表情包
|
# 筛选匹配情感的表情包
|
||||||
matching_emojis = []
|
matching_emojis = []
|
||||||
|
|
@ -181,7 +181,7 @@ async def get_by_emotion(emotion: str) -> Optional[Tuple[str, str, str]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 记录使用次数
|
# 记录使用次数
|
||||||
emoji_manager.record_usage(selected_emoji.hash)
|
emoji_manager.update_emoji_usage(selected_emoji)
|
||||||
|
|
||||||
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
logger.info(f"[EmojiAPI] 成功获取情感表情包: {selected_emoji.description}")
|
||||||
return emoji_base64, selected_emoji.description, emotion
|
return emoji_base64, selected_emoji.description, emotion
|
||||||
|
|
@ -203,8 +203,7 @@ def get_count() -> int:
|
||||||
int: 当前可用的表情包数量
|
int: 当前可用的表情包数量
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
return len(emoji_manager.emojis)
|
||||||
return emoji_manager.emoji_num
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
|
logger.error(f"[EmojiAPI] 获取表情包数量失败: {e}")
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -217,11 +216,10 @@ def get_info():
|
||||||
dict: 包含表情包数量、最大数量、可用数量信息
|
dict: 包含表情包数量、最大数量、可用数量信息
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
|
||||||
return {
|
return {
|
||||||
"current_count": emoji_manager.emoji_num,
|
"current_count": len(emoji_manager.emojis),
|
||||||
"max_count": emoji_manager.emoji_num_max,
|
"max_count": global_config.emoji.max_reg_num,
|
||||||
"available_emojis": len([e for e in emoji_manager.emoji_objects if not e.is_deleted]),
|
"available_emojis": len([e for e in emoji_manager.emojis if not e.is_deleted]),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
|
logger.error(f"[EmojiAPI] 获取表情包信息失败: {e}")
|
||||||
|
|
@ -235,10 +233,9 @@ def get_emotions() -> List[str]:
|
||||||
list: 所有表情包的情感标签列表(去重)
|
list: 所有表情包的情感标签列表(去重)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
|
||||||
emotions = set()
|
emotions = set()
|
||||||
|
|
||||||
for emoji_obj in emoji_manager.emoji_objects:
|
for emoji_obj in emoji_manager.emojis:
|
||||||
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
if not emoji_obj.is_deleted and emoji_obj.emotion:
|
||||||
emotions.update(emoji_obj.emotion)
|
emotions.update(emoji_obj.emotion)
|
||||||
|
|
||||||
|
|
@ -255,8 +252,7 @@ async def get_all() -> List[Tuple[str, str, str]]:
|
||||||
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
|
List[Tuple[str, str, str]]: 包含(base64编码, 表情包描述, 随机情感标签)的元组列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
all_emojis = emoji_manager.emojis
|
||||||
all_emojis = emoji_manager.emoji_objects
|
|
||||||
|
|
||||||
if not all_emojis:
|
if not all_emojis:
|
||||||
logger.warning("[EmojiAPI] 没有可用的表情包")
|
logger.warning("[EmojiAPI] 没有可用的表情包")
|
||||||
|
|
@ -267,7 +263,7 @@ async def get_all() -> List[Tuple[str, str, str]]:
|
||||||
if emoji_obj.is_deleted:
|
if emoji_obj.is_deleted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
emoji_base64 = image_path_to_base64(emoji_obj.full_path)
|
emoji_base64 = image_path_to_base64(str(emoji_obj.full_path))
|
||||||
|
|
||||||
if not emoji_base64:
|
if not emoji_base64:
|
||||||
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
|
logger.error(f"[EmojiAPI] 无法转换表情包为base64: {emoji_obj.full_path}")
|
||||||
|
|
@ -291,12 +287,11 @@ def get_descriptions() -> List[str]:
|
||||||
list: 所有可用表情包的描述列表
|
list: 所有可用表情包的描述列表
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
emoji_manager = get_emoji_manager()
|
|
||||||
descriptions = []
|
descriptions = []
|
||||||
|
|
||||||
descriptions.extend(
|
descriptions.extend(
|
||||||
emoji_obj.description
|
emoji_obj.description
|
||||||
for emoji_obj in emoji_manager.emoji_objects
|
for emoji_obj in emoji_manager.emojis
|
||||||
if not emoji_obj.is_deleted and emoji_obj.description
|
if not emoji_obj.is_deleted and emoji_obj.description
|
||||||
)
|
)
|
||||||
return descriptions
|
return descriptions
|
||||||
|
|
@ -341,14 +336,11 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
|
||||||
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
|
logger.info(f"[EmojiAPI] 开始注册表情包,文件名: {filename or '自动生成'}")
|
||||||
|
|
||||||
# 1. 获取emoji管理器并检查容量
|
# 1. 获取emoji管理器并检查容量
|
||||||
emoji_manager = get_emoji_manager()
|
count_before = len(emoji_manager.emojis)
|
||||||
count_before = emoji_manager.emoji_num
|
max_count = global_config.emoji.max_reg_num
|
||||||
max_count = emoji_manager.emoji_num_max
|
|
||||||
|
|
||||||
# 2. 检查是否可以注册(未达到上限或启用替换)
|
# 2. 检查是否可以注册(未达到上限或启用替换)
|
||||||
can_register = count_before < max_count or (
|
can_register = count_before < max_count or (count_before >= max_count and global_config.emoji.do_replace)
|
||||||
count_before >= max_count and emoji_manager.emoji_num_max_reach_deletion
|
|
||||||
)
|
|
||||||
|
|
||||||
if not can_register:
|
if not can_register:
|
||||||
return {
|
return {
|
||||||
|
|
@ -474,7 +466,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
|
||||||
|
|
||||||
# 8. 构建返回结果
|
# 8. 构建返回结果
|
||||||
if register_success:
|
if register_success:
|
||||||
count_after = emoji_manager.emoji_num
|
count_after = len(emoji_manager.emojis)
|
||||||
replaced = count_after <= count_before # 如果数量没增加,说明是替换
|
replaced = count_after <= count_before # 如果数量没增加,说明是替换
|
||||||
|
|
||||||
# 尝试获取新注册的表情包信息
|
# 尝试获取新注册的表情包信息
|
||||||
|
|
@ -483,10 +475,10 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
|
||||||
# 获取最新的表情包信息
|
# 获取最新的表情包信息
|
||||||
try:
|
try:
|
||||||
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
|
# 通过文件名查找新注册的表情包(注意:文件名在注册后可能已经改变)
|
||||||
for emoji_obj in reversed(emoji_manager.emoji_objects):
|
for emoji_obj in reversed(emoji_manager.emojis):
|
||||||
if not emoji_obj.is_deleted and (
|
if not emoji_obj.is_deleted and (
|
||||||
emoji_obj.filename == filename # 直接匹配
|
emoji_obj.file_name == filename
|
||||||
or (hasattr(emoji_obj, "full_path") and filename in emoji_obj.full_path) # 路径包含匹配
|
or (hasattr(emoji_obj, "full_path") and filename in str(emoji_obj.full_path))
|
||||||
):
|
):
|
||||||
new_emoji_info = emoji_obj
|
new_emoji_info = emoji_obj
|
||||||
break
|
break
|
||||||
|
|
@ -495,7 +487,7 @@ async def register_emoji(image_base64: str, filename: Optional[str] = None) -> D
|
||||||
|
|
||||||
description = new_emoji_info.description if new_emoji_info else None
|
description = new_emoji_info.description if new_emoji_info else None
|
||||||
emotions = new_emoji_info.emotion if new_emoji_info else None
|
emotions = new_emoji_info.emotion if new_emoji_info else None
|
||||||
emoji_hash = new_emoji_info.hash if new_emoji_info else None
|
emoji_hash = new_emoji_info.emoji_hash if new_emoji_info else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|
@ -560,12 +552,14 @@ async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
|
||||||
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
|
logger.info(f"[EmojiAPI] 开始删除表情包,哈希值: {emoji_hash}")
|
||||||
|
|
||||||
# 1. 获取emoji管理器和删除前的数量
|
# 1. 获取emoji管理器和删除前的数量
|
||||||
emoji_manager = get_emoji_manager()
|
count_before = len(emoji_manager.emojis)
|
||||||
count_before = emoji_manager.emoji_num
|
|
||||||
|
|
||||||
# 2. 获取被删除表情包的信息(用于返回结果)
|
# 2. 获取被删除表情包的信息(用于返回结果)
|
||||||
|
deleted_emoji = None
|
||||||
try:
|
try:
|
||||||
deleted_emoji = await emoji_manager.get_emoji_from_manager(emoji_hash)
|
deleted_emoji = emoji_manager.get_emoji_by_hash(emoji_hash) or emoji_manager.get_emoji_by_hash_from_db(
|
||||||
|
emoji_hash
|
||||||
|
)
|
||||||
description = deleted_emoji.description if deleted_emoji else None
|
description = deleted_emoji.description if deleted_emoji else None
|
||||||
emotions = deleted_emoji.emotion if deleted_emoji else None
|
emotions = deleted_emoji.emotion if deleted_emoji else None
|
||||||
except Exception as info_error:
|
except Exception as info_error:
|
||||||
|
|
@ -574,10 +568,12 @@ async def delete_emoji(emoji_hash: str) -> Dict[str, Any]:
|
||||||
emotions = None
|
emotions = None
|
||||||
|
|
||||||
# 3. 执行删除操作
|
# 3. 执行删除操作
|
||||||
delete_success = await emoji_manager.delete_emoji(emoji_hash)
|
delete_success = False
|
||||||
|
if deleted_emoji:
|
||||||
|
delete_success = emoji_manager.delete_emoji(deleted_emoji)
|
||||||
|
|
||||||
# 4. 获取删除后的数量
|
# 4. 获取删除后的数量
|
||||||
count_after = emoji_manager.emoji_num
|
count_after = len(emoji_manager.emojis)
|
||||||
|
|
||||||
# 5. 构建返回结果
|
# 5. 构建返回结果
|
||||||
if delete_success:
|
if delete_success:
|
||||||
|
|
@ -638,8 +634,7 @@ async def delete_emoji_by_description(description: str, exact_match: bool = Fals
|
||||||
try:
|
try:
|
||||||
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
|
logger.info(f"[EmojiAPI] 根据描述删除表情包: {description} (精确匹配: {exact_match})")
|
||||||
|
|
||||||
emoji_manager = get_emoji_manager()
|
all_emojis = emoji_manager.emojis
|
||||||
all_emojis = emoji_manager.emoji_objects
|
|
||||||
|
|
||||||
# 筛选匹配的表情包
|
# 筛选匹配的表情包
|
||||||
matching_emojis = []
|
matching_emojis = []
|
||||||
|
|
@ -669,12 +664,12 @@ async def delete_emoji_by_description(description: str, exact_match: bool = Fals
|
||||||
deleted_hashes = []
|
deleted_hashes = []
|
||||||
for emoji_obj in matching_emojis:
|
for emoji_obj in matching_emojis:
|
||||||
try:
|
try:
|
||||||
delete_success = await emoji_manager.delete_emoji(emoji_obj.hash)
|
delete_success = emoji_manager.delete_emoji(emoji_obj)
|
||||||
if delete_success:
|
if delete_success:
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
deleted_hashes.append(emoji_obj.hash)
|
deleted_hashes.append(emoji_obj.emoji_hash)
|
||||||
except Exception as delete_error:
|
except Exception as delete_error:
|
||||||
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.hash}): {delete_error}")
|
logger.error(f"[EmojiAPI] 删除表情包失败 (哈希: {emoji_obj.emoji_hash}): {delete_error}")
|
||||||
|
|
||||||
# 构建返回结果
|
# 构建返回结果
|
||||||
if deleted_count > 0:
|
if deleted_count > 0:
|
||||||
|
|
|
||||||
|
|
@ -10,8 +10,10 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import List, Dict, Any, Tuple, Optional
|
from typing import List, Dict, Any, Tuple, Optional
|
||||||
|
from sqlmodel import col, select
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.database.database_model import Images
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Images, ImageType
|
||||||
from src.chat.utils.utils import is_bot_self
|
from src.chat.utils.utils import is_bot_self
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_by_timestamp,
|
get_raw_msg_by_timestamp,
|
||||||
|
|
@ -516,7 +518,13 @@ def filter_mai_messages(messages: List[DatabaseMessages]) -> List[DatabaseMessag
|
||||||
|
|
||||||
|
|
||||||
def translate_pid_to_description(pid: str) -> str:
|
def translate_pid_to_description(pid: str) -> str:
|
||||||
image = Images.get_or_none(Images.image_id == pid)
|
with get_db_session() as session:
|
||||||
|
statement = (
|
||||||
|
select(Images).where((col(Images.id) == int(pid)) & (col(Images.image_type) == ImageType.IMAGE))
|
||||||
|
if pid.isdigit()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
image = session.exec(statement).first() if statement is not None else None
|
||||||
description = ""
|
description = ""
|
||||||
if image and image.description and image.description.strip():
|
if image and image.description and image.description.strip():
|
||||||
description = image.description.strip()
|
description = image.description.strip()
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,25 @@
|
||||||
"""麦麦 2025 年度总结 API 路由"""
|
"""麦麦 2025 年度总结 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from sqlalchemy import func as fn
|
from typing import Any, Optional
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import desc, func
|
||||||
|
from sqlmodel import col, select
|
||||||
|
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import (
|
from src.common.database.database_model import (
|
||||||
LLMUsage,
|
ActionRecord,
|
||||||
OnlineTime,
|
|
||||||
Messages,
|
|
||||||
ChatStreams,
|
|
||||||
PersonInfo,
|
|
||||||
Emoji,
|
|
||||||
Expression,
|
Expression,
|
||||||
ActionRecords,
|
Images,
|
||||||
Jargon,
|
Jargon,
|
||||||
|
Messages,
|
||||||
|
ModelUsage,
|
||||||
|
OnlineTime,
|
||||||
|
PersonInfo,
|
||||||
)
|
)
|
||||||
|
from src.common.logger import get_logger
|
||||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = get_logger("webui.annual_report")
|
logger = get_logger("webui.annual_report")
|
||||||
|
|
@ -45,7 +47,7 @@ class TimeFootprintData(BaseModel):
|
||||||
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
|
first_message_content: Optional[str] = Field(None, description="初次消息内容(截断)")
|
||||||
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
|
busiest_day: Optional[str] = Field(None, description="最忙碌的一天")
|
||||||
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
|
busiest_day_count: int = Field(0, description="最忙碌那天的消息数")
|
||||||
hourly_distribution: List[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
|
hourly_distribution: list[int] = Field(default_factory=lambda: [0] * 24, description="24小时活跃分布")
|
||||||
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
|
midnight_chat_count: int = Field(0, description="深夜(0-4点)互动次数")
|
||||||
is_night_owl: bool = Field(False, description="是否是夜猫子")
|
is_night_owl: bool = Field(False, description="是否是夜猫子")
|
||||||
|
|
||||||
|
|
@ -54,8 +56,8 @@ class SocialNetworkData(BaseModel):
|
||||||
"""社交网络数据"""
|
"""社交网络数据"""
|
||||||
|
|
||||||
total_groups: int = Field(0, description="加入的群组总数")
|
total_groups: int = Field(0, description="加入的群组总数")
|
||||||
top_groups: List[Dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
|
top_groups: list[dict[str, Any]] = Field(default_factory=list, description="话痨群组TOP5")
|
||||||
top_users: List[Dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
|
top_users: list[dict[str, Any]] = Field(default_factory=list, description="互动最多的用户TOP5")
|
||||||
at_count: int = Field(0, description="被@次数")
|
at_count: int = Field(0, description="被@次数")
|
||||||
mentioned_count: int = Field(0, description="被提及次数")
|
mentioned_count: int = Field(0, description="被提及次数")
|
||||||
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
|
longest_companion_user: Optional[str] = Field(None, description="最长情陪伴的用户")
|
||||||
|
|
@ -69,11 +71,11 @@ class BrainPowerData(BaseModel):
|
||||||
total_cost: float = Field(0.0, description="年度总花费")
|
total_cost: float = Field(0.0, description="年度总花费")
|
||||||
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
|
favorite_model: Optional[str] = Field(None, description="最爱用的模型")
|
||||||
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
|
favorite_model_count: int = Field(0, description="最爱模型的调用次数")
|
||||||
model_distribution: List[Dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
|
model_distribution: list[dict[str, Any]] = Field(default_factory=list, description="模型使用分布")
|
||||||
top_reply_models: List[Dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
|
top_reply_models: list[dict[str, Any]] = Field(default_factory=list, description="最喜欢的回复模型TOP5")
|
||||||
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
|
most_expensive_cost: float = Field(0.0, description="最昂贵的一次思考花费")
|
||||||
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
|
most_expensive_time: Optional[str] = Field(None, description="最昂贵思考的时间")
|
||||||
top_token_consumers: List[Dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
|
top_token_consumers: list[dict[str, Any]] = Field(default_factory=list, description="烧钱大户TOP3")
|
||||||
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
|
silence_rate: float = Field(0.0, description="高冷指数(沉默率)")
|
||||||
total_actions: int = Field(0, description="总动作数")
|
total_actions: int = Field(0, description="总动作数")
|
||||||
no_reply_count: int = Field(0, description="选择沉默的次数")
|
no_reply_count: int = Field(0, description="选择沉默的次数")
|
||||||
|
|
@ -88,23 +90,23 @@ class BrainPowerData(BaseModel):
|
||||||
class ExpressionVibeData(BaseModel):
|
class ExpressionVibeData(BaseModel):
|
||||||
"""个性与表达数据"""
|
"""个性与表达数据"""
|
||||||
|
|
||||||
top_emoji: Optional[Dict[str, Any]] = Field(None, description="表情包之王")
|
top_emoji: Optional[dict[str, Any]] = Field(None, description="表情包之王")
|
||||||
top_emojis: List[Dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
|
top_emojis: list[dict[str, Any]] = Field(default_factory=list, description="TOP3表情包")
|
||||||
top_expressions: List[Dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
|
top_expressions: list[dict[str, Any]] = Field(default_factory=list, description="印象最深刻的表达风格")
|
||||||
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
|
rejected_expression_count: int = Field(0, description="被拒绝的表达次数")
|
||||||
checked_expression_count: int = Field(0, description="已检查的表达次数")
|
checked_expression_count: int = Field(0, description="已检查的表达次数")
|
||||||
total_expressions: int = Field(0, description="表达总数")
|
total_expressions: int = Field(0, description="表达总数")
|
||||||
action_types: List[Dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
|
action_types: list[dict[str, Any]] = Field(default_factory=list, description="动作类型分布")
|
||||||
image_processed_count: int = Field(0, description="处理的图片数量")
|
image_processed_count: int = Field(0, description="处理的图片数量")
|
||||||
late_night_reply: Optional[Dict[str, Any]] = Field(None, description="深夜还在回复")
|
late_night_reply: Optional[dict[str, Any]] = Field(None, description="深夜还在回复")
|
||||||
favorite_reply: Optional[Dict[str, Any]] = Field(None, description="最喜欢的回复")
|
favorite_reply: Optional[dict[str, Any]] = Field(None, description="最喜欢的回复")
|
||||||
|
|
||||||
|
|
||||||
class AchievementData(BaseModel):
|
class AchievementData(BaseModel):
|
||||||
"""趣味成就数据"""
|
"""趣味成就数据"""
|
||||||
|
|
||||||
new_jargon_count: int = Field(0, description="新学到的黑话数量")
|
new_jargon_count: int = Field(0, description="新学到的黑话数量")
|
||||||
sample_jargons: List[Dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
|
sample_jargons: list[dict[str, Any]] = Field(default_factory=list, description="代表性黑话示例")
|
||||||
total_messages: int = Field(0, description="总消息数")
|
total_messages: int = Field(0, description="总消息数")
|
||||||
total_replies: int = Field(0, description="总回复数")
|
total_replies: int = Field(0, description="总回复数")
|
||||||
|
|
||||||
|
|
@ -115,11 +117,11 @@ class AnnualReportData(BaseModel):
|
||||||
year: int = Field(2025, description="报告年份")
|
year: int = Field(2025, description="报告年份")
|
||||||
bot_name: str = Field("麦麦", description="Bot名称")
|
bot_name: str = Field("麦麦", description="Bot名称")
|
||||||
generated_at: str = Field(..., description="报告生成时间")
|
generated_at: str = Field(..., description="报告生成时间")
|
||||||
time_footprint: TimeFootprintData = Field(default_factory=TimeFootprintData)
|
time_footprint: TimeFootprintData = Field(default_factory=lambda: TimeFootprintData.model_construct())
|
||||||
social_network: SocialNetworkData = Field(default_factory=SocialNetworkData)
|
social_network: SocialNetworkData = Field(default_factory=lambda: SocialNetworkData.model_construct())
|
||||||
brain_power: BrainPowerData = Field(default_factory=BrainPowerData)
|
brain_power: BrainPowerData = Field(default_factory=lambda: BrainPowerData.model_construct())
|
||||||
expression_vibe: ExpressionVibeData = Field(default_factory=ExpressionVibeData)
|
expression_vibe: ExpressionVibeData = Field(default_factory=lambda: ExpressionVibeData.model_construct())
|
||||||
achievements: AchievementData = Field(default_factory=AchievementData)
|
achievements: AchievementData = Field(default_factory=lambda: AchievementData.model_construct())
|
||||||
|
|
||||||
|
|
||||||
# ==================== 辅助函数 ====================
|
# ==================== 辅助函数 ====================
|
||||||
|
|
@ -144,15 +146,18 @@ def get_year_datetime_range(year: int = 2025) -> tuple[datetime, datetime]:
|
||||||
|
|
||||||
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
||||||
"""获取时光足迹数据"""
|
"""获取时光足迹数据"""
|
||||||
data = TimeFootprintData()
|
data = TimeFootprintData.model_construct()
|
||||||
start_ts, end_ts = get_year_time_range(year)
|
start_ts, end_ts = get_year_time_range(year)
|
||||||
start_dt, end_dt = get_year_datetime_range(year)
|
start_dt, end_dt = get_year_datetime_range(year)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 年度在线时长
|
# 1. 年度在线时长
|
||||||
online_records = list(
|
with get_db_session() as session:
|
||||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_dt) | (OnlineTime.end_timestamp <= end_dt))
|
statement = select(OnlineTime).where(
|
||||||
|
col(OnlineTime.start_timestamp) >= start_dt,
|
||||||
|
col(OnlineTime.end_timestamp) <= end_dt,
|
||||||
)
|
)
|
||||||
|
online_records = session.exec(statement).all()
|
||||||
total_seconds = 0
|
total_seconds = 0
|
||||||
for record in online_records:
|
for record in online_records:
|
||||||
try:
|
try:
|
||||||
|
|
@ -165,50 +170,66 @@ async def get_time_footprint(year: int = 2025) -> TimeFootprintData:
|
||||||
data.total_online_hours = round(total_seconds / 3600, 2)
|
data.total_online_hours = round(total_seconds / 3600, 2)
|
||||||
|
|
||||||
# 2. 初次相遇 - 年度第一条消息
|
# 2. 初次相遇 - 年度第一条消息
|
||||||
first_msg = (
|
with get_db_session() as session:
|
||||||
Messages.select()
|
statement = (
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
select(Messages)
|
||||||
.order_by(Messages.time.asc())
|
.where(
|
||||||
.first()
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
|
.order_by(col(Messages.timestamp).asc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
first_msg = session.exec(statement).first()
|
||||||
if first_msg:
|
if first_msg:
|
||||||
data.first_message_time = datetime.fromtimestamp(first_msg.time).strftime("%Y-%m-%d %H:%M:%S")
|
data.first_message_time = first_msg.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
|
data.first_message_user = first_msg.user_nickname or first_msg.user_id or "未知用户"
|
||||||
content = first_msg.processed_plain_text or first_msg.display_message or ""
|
content = first_msg.processed_plain_text or first_msg.display_message or ""
|
||||||
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
|
data.first_message_content = content[:50] + "..." if len(content) > 50 else content
|
||||||
|
|
||||||
# 3. 最忙碌的一天
|
# 3. 最忙碌的一天
|
||||||
# 使用 SQLite 的 date 函数按日期分组
|
# 使用 SQLite 的 date 函数按日期分组
|
||||||
busiest_query = (
|
day_expr = func.date(col(Messages.timestamp))
|
||||||
Messages.select(
|
with get_db_session() as session:
|
||||||
fn.date(Messages.time, "unixepoch").alias("day"),
|
statement = (
|
||||||
fn.COUNT(Messages.id).alias("count"),
|
select(
|
||||||
|
day_expr.label("day"),
|
||||||
|
func.count().label("count"),
|
||||||
)
|
)
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
.where(
|
||||||
.group_by(fn.date(Messages.time, "unixepoch"))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
.order_by(fn.COUNT(Messages.id).desc())
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
)
|
||||||
|
.group_by(day_expr)
|
||||||
|
.order_by(func.count().desc())
|
||||||
.limit(1)
|
.limit(1)
|
||||||
)
|
)
|
||||||
busiest_result = list(busiest_query.dicts())
|
busiest_result = session.exec(statement).all()
|
||||||
if busiest_result:
|
if busiest_result:
|
||||||
data.busiest_day = busiest_result[0].get("day")
|
data.busiest_day = busiest_result[0][0]
|
||||||
data.busiest_day_count = busiest_result[0].get("count", 0)
|
data.busiest_day_count = busiest_result[0][1] or 0
|
||||||
|
|
||||||
# 4. 昼夜节律 - 24小时活跃分布
|
# 4. 昼夜节律 - 24小时活跃分布
|
||||||
hourly_query = (
|
hour_expr = func.strftime("%H", col(Messages.timestamp))
|
||||||
Messages.select(
|
with get_db_session() as session:
|
||||||
fn.strftime("%H", Messages.time, "unixepoch").alias("hour"),
|
statement = (
|
||||||
fn.COUNT(Messages.id).alias("count"),
|
select(
|
||||||
|
hour_expr.label("hour"),
|
||||||
|
func.count().label("count"),
|
||||||
)
|
)
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts))
|
.where(
|
||||||
.group_by(fn.strftime("%H", Messages.time, "unixepoch"))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
|
.group_by(hour_expr)
|
||||||
|
)
|
||||||
|
hourly_rows = session.exec(statement).all()
|
||||||
hourly_distribution = [0] * 24
|
hourly_distribution = [0] * 24
|
||||||
for row in hourly_query.dicts():
|
for row in hourly_rows:
|
||||||
try:
|
try:
|
||||||
hour = int(row.get("hour", 0))
|
hour = int(row[0] or 0)
|
||||||
if 0 <= hour < 24:
|
if 0 <= hour < 24:
|
||||||
hourly_distribution[hour] = row.get("count", 0)
|
hourly_distribution[hour] = row[1] or 0
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
continue
|
continue
|
||||||
data.hourly_distribution = hourly_distribution
|
data.hourly_distribution = hourly_distribution
|
||||||
|
|
@ -234,7 +255,7 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||||
"""获取社交网络数据"""
|
"""获取社交网络数据"""
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
data = SocialNetworkData()
|
data = SocialNetworkData.model_construct()
|
||||||
start_ts, end_ts = get_year_time_range(year)
|
start_ts, end_ts = get_year_time_range(year)
|
||||||
|
|
||||||
# 获取 bot 自身的 QQ 账号,用于过滤
|
# 获取 bot 自身的 QQ 账号,用于过滤
|
||||||
|
|
@ -242,91 +263,110 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 加入的群组总数
|
# 1. 加入的群组总数
|
||||||
data.total_groups = ChatStreams.select().where(ChatStreams.group_id.is_null(False)).count()
|
with get_db_session() as session:
|
||||||
|
statement = select(func.count(func.distinct(col(Messages.group_id)))).where(
|
||||||
|
col(Messages.group_id).is_not(None),
|
||||||
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
)
|
||||||
|
data.total_groups = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 2. 话痨群组 TOP3
|
# 2. 话痨群组 TOP3
|
||||||
top_groups_query = (
|
with get_db_session() as session:
|
||||||
Messages.select(
|
statement = (
|
||||||
Messages.chat_info_group_id,
|
select(
|
||||||
Messages.chat_info_group_name,
|
col(Messages.group_id),
|
||||||
fn.COUNT(Messages.id).alias("count"),
|
func.max(col(Messages.group_name)).label("group_name"),
|
||||||
|
func.count().label("count"),
|
||||||
)
|
)
|
||||||
.where(
|
.where(
|
||||||
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.chat_info_group_id.is_null(False))
|
col(Messages.group_id).is_not(None),
|
||||||
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
.group_by(Messages.chat_info_group_id)
|
.group_by(col(Messages.group_id))
|
||||||
.order_by(fn.COUNT(Messages.id).desc())
|
.order_by(func.count().desc())
|
||||||
.limit(5)
|
.limit(5)
|
||||||
)
|
)
|
||||||
|
top_groups_rows = session.exec(statement).all()
|
||||||
data.top_groups = [
|
data.top_groups = [
|
||||||
{
|
{
|
||||||
"group_id": row["chat_info_group_id"],
|
"group_id": row[0],
|
||||||
"group_name": row["chat_info_group_name"] or "未知群组",
|
"group_name": row[1] or "未知群组",
|
||||||
"message_count": row["count"],
|
"message_count": row[2] or 0,
|
||||||
"is_webui": str(row["chat_info_group_id"]).startswith("webui_"),
|
"is_webui": str(row[0]).startswith("webui_"),
|
||||||
}
|
}
|
||||||
for row in top_groups_query.dicts()
|
for row in top_groups_rows
|
||||||
]
|
]
|
||||||
|
|
||||||
# 3. 互动最多的用户 TOP5(过滤 bot 自身)
|
# 3. 互动最多的用户 TOP5(过滤 bot 自身)
|
||||||
top_users_query = (
|
with get_db_session() as session:
|
||||||
Messages.select(
|
statement = (
|
||||||
Messages.user_id,
|
select(
|
||||||
Messages.user_nickname,
|
col(Messages.user_id),
|
||||||
fn.COUNT(Messages.id).alias("count"),
|
func.max(col(Messages.user_nickname)).label("user_nickname"),
|
||||||
|
func.count().label("count"),
|
||||||
)
|
)
|
||||||
.where(
|
.where(
|
||||||
(Messages.time >= start_ts)
|
col(Messages.user_id).is_not(None),
|
||||||
& (Messages.time <= end_ts)
|
col(Messages.user_id) != bot_qq,
|
||||||
& (Messages.user_id.is_null(False))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
& (Messages.user_id != bot_qq) # 过滤 bot 自身
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
.group_by(Messages.user_id)
|
.group_by(col(Messages.user_id))
|
||||||
.order_by(fn.COUNT(Messages.id).desc())
|
.order_by(func.count().desc())
|
||||||
.limit(5)
|
.limit(5)
|
||||||
)
|
)
|
||||||
|
top_users_rows = session.exec(statement).all()
|
||||||
data.top_users = [
|
data.top_users = [
|
||||||
{
|
{
|
||||||
"user_id": row["user_id"],
|
"user_id": row[0],
|
||||||
"user_nickname": row["user_nickname"] or "未知用户",
|
"user_nickname": row[1] or "未知用户",
|
||||||
"message_count": row["count"],
|
"message_count": row[2] or 0,
|
||||||
"is_webui": str(row["user_id"]).startswith("webui_"),
|
"is_webui": str(row[0]).startswith("webui_"),
|
||||||
}
|
}
|
||||||
for row in top_users_query.dicts()
|
for row in top_users_rows
|
||||||
]
|
]
|
||||||
|
|
||||||
# 4. 被@次数
|
# 4. 被@次数
|
||||||
data.at_count = (
|
with get_db_session() as session:
|
||||||
Messages.select()
|
statement = select(func.count()).where(
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_at == True))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
.count()
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
col(Messages.is_at) == True,
|
||||||
)
|
)
|
||||||
|
data.at_count = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 5. 被提及次数
|
# 5. 被提及次数
|
||||||
data.mentioned_count = (
|
with get_db_session() as session:
|
||||||
Messages.select()
|
statement = select(func.count()).where(
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_mentioned == True))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
.count()
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
col(Messages.is_mentioned) == True,
|
||||||
)
|
)
|
||||||
|
data.mentioned_count = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 6. 最长情陪伴的用户(过滤 bot 自身)
|
# 6. 最长情陪伴的用户(过滤 bot 自身)
|
||||||
companion_query = (
|
with get_db_session() as session:
|
||||||
ChatStreams.select(
|
statement = select(PersonInfo).where(
|
||||||
ChatStreams.user_id,
|
col(PersonInfo.user_id) != bot_qq,
|
||||||
ChatStreams.user_nickname,
|
col(PersonInfo.first_known_time).is_not(None),
|
||||||
(ChatStreams.last_active_time - ChatStreams.create_time).alias("duration"),
|
col(PersonInfo.last_known_time).is_not(None),
|
||||||
)
|
)
|
||||||
.where(
|
persons = session.exec(statement).all()
|
||||||
(ChatStreams.user_id.is_null(False)) & (ChatStreams.user_id != bot_qq) # 过滤 bot 自身
|
if persons:
|
||||||
)
|
|
||||||
.order_by((ChatStreams.last_active_time - ChatStreams.create_time).desc())
|
def _companion_days(person: PersonInfo) -> float:
|
||||||
.limit(1)
|
if not person.first_known_time or not person.last_known_time:
|
||||||
)
|
return 0.0
|
||||||
companion_result = list(companion_query.dicts())
|
return (person.last_known_time - person.first_known_time).total_seconds()
|
||||||
if companion_result:
|
|
||||||
data.longest_companion_user = companion_result[0].get("user_nickname") or "未知用户"
|
longest = max(persons, key=_companion_days)
|
||||||
duration = companion_result[0].get("duration", 0) or 0
|
data.longest_companion_user = longest.person_name or longest.user_nickname or longest.user_id
|
||||||
data.longest_companion_days = int(duration / 86400) # 转换为天
|
data.longest_companion_days = int(_companion_days(longest) / 86400)
|
||||||
|
else:
|
||||||
|
data.longest_companion_user = None
|
||||||
|
data.longest_companion_days = 0
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取社交网络数据失败: {e}")
|
logger.error(f"获取社交网络数据失败: {e}")
|
||||||
|
|
@ -339,154 +379,139 @@ async def get_social_network(year: int = 2025) -> SocialNetworkData:
|
||||||
|
|
||||||
async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
||||||
"""获取最强大脑数据"""
|
"""获取最强大脑数据"""
|
||||||
data = BrainPowerData()
|
data = BrainPowerData.model_construct()
|
||||||
start_dt, end_dt = get_year_datetime_range(year)
|
start_dt, end_dt = get_year_datetime_range(year)
|
||||||
start_ts, end_ts = get_year_time_range(year)
|
start_ts, end_ts = get_year_time_range(year)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 年度消耗 Token 总量和总花费
|
# 1. 年度消耗 Token 总量和总花费
|
||||||
token_query = LLMUsage.select(
|
with get_db_session() as session:
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("total_tokens"),
|
statement = select(
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
|
||||||
).where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
func.sum(col(ModelUsage.cost)).label("total_cost"),
|
||||||
result = token_query.dicts().get()
|
).where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
|
||||||
data.total_tokens = int(result.get("total_tokens", 0) or 0)
|
result = session.exec(statement).first()
|
||||||
data.total_cost = round(float(result.get("total_cost", 0) or 0), 4)
|
if result:
|
||||||
|
data.total_tokens = int(result[0] or 0)
|
||||||
|
data.total_cost = round(float(result[1] or 0), 4)
|
||||||
|
|
||||||
# 2. 最爱用的模型
|
# 2. 最爱用的模型
|
||||||
model_query = (
|
with get_db_session() as session:
|
||||||
LLMUsage.select(
|
statement = (
|
||||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
|
select(ModelUsage)
|
||||||
fn.COUNT(LLMUsage.id).alias("count"),
|
.where(col(ModelUsage.timestamp) >= start_dt, col(ModelUsage.timestamp) <= end_dt)
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
|
.order_by(desc(col(ModelUsage.timestamp)))
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
|
||||||
)
|
)
|
||||||
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
records = session.exec(statement).all()
|
||||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
|
|
||||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
model_agg: dict[str, dict[str, float | int]] = {}
|
||||||
.limit(10)
|
for record in records:
|
||||||
)
|
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||||
model_results = list(model_query.dicts())
|
if model_name not in model_agg:
|
||||||
|
model_agg[model_name] = {"count": 0, "tokens": 0, "cost": 0.0}
|
||||||
|
bucket = model_agg[model_name]
|
||||||
|
bucket["count"] = int(bucket["count"]) + 1
|
||||||
|
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
|
||||||
|
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
|
||||||
|
|
||||||
|
model_results = sorted(
|
||||||
|
model_agg.items(),
|
||||||
|
key=lambda item: float(item[1]["count"]),
|
||||||
|
reverse=True,
|
||||||
|
)[:10]
|
||||||
if model_results:
|
if model_results:
|
||||||
data.favorite_model = model_results[0].get("model")
|
data.favorite_model = model_results[0][0]
|
||||||
data.favorite_model_count = model_results[0].get("count", 0)
|
data.favorite_model_count = int(model_results[0][1]["count"])
|
||||||
data.model_distribution = [
|
data.model_distribution = [
|
||||||
{
|
{
|
||||||
"model": row["model"],
|
"model": model_name,
|
||||||
"count": row["count"],
|
"count": int(bucket["count"]),
|
||||||
"tokens": row["tokens"],
|
"tokens": int(bucket["tokens"]),
|
||||||
"cost": round(row["cost"], 4),
|
"cost": round(float(bucket["cost"]), 4),
|
||||||
}
|
}
|
||||||
for row in model_results
|
for model_name, bucket in model_results
|
||||||
]
|
]
|
||||||
|
|
||||||
# 3. 最昂贵的一次思考
|
# 3. 最昂贵的一次思考
|
||||||
expensive_query = (
|
if records:
|
||||||
LLMUsage.select(LLMUsage.cost, LLMUsage.timestamp)
|
expensive_record = max(records, key=lambda record: record.cost or 0.0)
|
||||||
.where((LLMUsage.timestamp >= start_dt) & (LLMUsage.timestamp <= end_dt))
|
data.most_expensive_cost = round(expensive_record.cost or 0.0, 4)
|
||||||
.order_by(LLMUsage.cost.desc())
|
data.most_expensive_time = expensive_record.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
expensive_result = expensive_query.first()
|
|
||||||
if expensive_result:
|
|
||||||
data.most_expensive_cost = round(expensive_result.cost or 0, 4)
|
|
||||||
data.most_expensive_time = expensive_result.timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
|
# 4. 烧钱大户 TOP3 (按用户,过滤 system)
|
||||||
consumer_query = (
|
consumer_agg: dict[str, dict[str, float | int]] = {}
|
||||||
LLMUsage.select(
|
for record in records:
|
||||||
LLMUsage.user_id,
|
user_id = record.model_api_provider_name
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
if not user_id or user_id == "system":
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.total_tokens), 0).alias("tokens"),
|
continue
|
||||||
)
|
if user_id not in consumer_agg:
|
||||||
.where(
|
consumer_agg[user_id] = {"cost": 0.0, "tokens": 0}
|
||||||
(LLMUsage.timestamp >= start_dt)
|
bucket = consumer_agg[user_id]
|
||||||
& (LLMUsage.timestamp <= end_dt)
|
bucket["cost"] = float(bucket["cost"]) + float(record.cost or 0.0)
|
||||||
& (LLMUsage.user_id != "system") # 过滤 system 用户
|
bucket["tokens"] = int(bucket["tokens"]) + int(record.total_tokens or 0)
|
||||||
& (LLMUsage.user_id.is_null(False))
|
|
||||||
)
|
|
||||||
.group_by(LLMUsage.user_id)
|
|
||||||
.order_by(fn.SUM(LLMUsage.cost).desc())
|
|
||||||
.limit(3)
|
|
||||||
)
|
|
||||||
data.top_token_consumers = [
|
data.top_token_consumers = [
|
||||||
{
|
{
|
||||||
"user_id": row["user_id"],
|
"user_id": user_id,
|
||||||
"cost": round(row["cost"], 4),
|
"cost": round(float(bucket["cost"]), 4),
|
||||||
"tokens": row["tokens"],
|
"tokens": int(bucket["tokens"]),
|
||||||
}
|
}
|
||||||
for row in consumer_query.dicts()
|
for user_id, bucket in sorted(
|
||||||
|
consumer_agg.items(),
|
||||||
|
key=lambda item: float(item[1]["cost"]),
|
||||||
|
reverse=True,
|
||||||
|
)[:3]
|
||||||
]
|
]
|
||||||
|
|
||||||
# 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用)
|
# 5. 最喜欢的回复模型 TOP5(按模型的回复次数统计,只统计 replyer 调用)
|
||||||
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
|
# 假设 replyer 调用有特定的 model_assign_name 格式或可以通过某种方式识别
|
||||||
reply_model_query = (
|
reply_model_agg: dict[str, int] = {}
|
||||||
LLMUsage.select(
|
for record in records:
|
||||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name).alias("model"),
|
model_assign_name = record.model_assign_name or ""
|
||||||
fn.COUNT(LLMUsage.id).alias("count"),
|
if "replyer" not in model_assign_name and "回复" not in model_assign_name:
|
||||||
)
|
continue
|
||||||
.where(
|
model_name = model_assign_name or record.model_name or "unknown"
|
||||||
(LLMUsage.timestamp >= start_dt)
|
reply_model_agg[model_name] = reply_model_agg.get(model_name, 0) + 1
|
||||||
& (LLMUsage.timestamp <= end_dt)
|
data.top_reply_models = [
|
||||||
& (
|
{"model": model_name, "count": count}
|
||||||
LLMUsage.model_assign_name.contains("replyer")
|
for model_name, count in sorted(reply_model_agg.items(), key=lambda item: item[1], reverse=True)[:5]
|
||||||
| LLMUsage.model_assign_name.contains("回复")
|
]
|
||||||
| LLMUsage.model_assign_name.is_null(True) # 包含没有 assign_name 的情况
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name))
|
|
||||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
|
||||||
.limit(5)
|
|
||||||
)
|
|
||||||
data.top_reply_models = [{"model": row["model"], "count": row["count"]} for row in reply_model_query.dicts()]
|
|
||||||
|
|
||||||
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
|
# 6. 高冷指数 (沉默率) - 基于 ActionRecords
|
||||||
total_actions = (
|
with get_db_session() as session:
|
||||||
ActionRecords.select().where((ActionRecords.time >= start_ts) & (ActionRecords.time <= end_ts)).count()
|
statement = select(func.count()).where(
|
||||||
|
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
no_reply_count = (
|
total_actions = int(session.exec(statement).first() or 0)
|
||||||
ActionRecords.select()
|
with get_db_session() as session:
|
||||||
.where(
|
statement = select(func.count()).where(
|
||||||
(ActionRecords.time >= start_ts)
|
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
& (ActionRecords.time <= end_ts)
|
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
& (ActionRecords.action_name == "no_reply")
|
col(ActionRecord.action_name) == "no_reply",
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
)
|
||||||
|
no_reply_count = int(session.exec(statement).first() or 0)
|
||||||
data.total_actions = total_actions
|
data.total_actions = total_actions
|
||||||
data.no_reply_count = no_reply_count
|
data.no_reply_count = no_reply_count
|
||||||
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
|
data.silence_rate = round(no_reply_count / total_actions * 100, 2) if total_actions > 0 else 0
|
||||||
|
|
||||||
# 6. 情绪波动 (兴趣值)
|
# 6. 情绪波动 (兴趣值)
|
||||||
interest_query = Messages.select(
|
data.avg_interest_value = 0.0
|
||||||
fn.AVG(Messages.interest_value).alias("avg_interest"),
|
data.max_interest_value = 0.0
|
||||||
fn.MAX(Messages.interest_value).alias("max_interest"),
|
|
||||||
).where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.interest_value.is_null(False)))
|
|
||||||
interest_result = interest_query.dicts().get()
|
|
||||||
data.avg_interest_value = round(float(interest_result.get("avg_interest") or 0), 2)
|
|
||||||
data.max_interest_value = round(float(interest_result.get("max_interest") or 0), 2)
|
|
||||||
|
|
||||||
# 找到最高兴趣值的时间
|
# 找到最高兴趣值的时间
|
||||||
if data.max_interest_value > 0:
|
if data.max_interest_value > 0:
|
||||||
max_interest_msg = (
|
data.max_interest_time = None
|
||||||
Messages.select(Messages.time)
|
|
||||||
.where(
|
|
||||||
(Messages.time >= start_ts)
|
|
||||||
& (Messages.time <= end_ts)
|
|
||||||
& (Messages.interest_value == data.max_interest_value)
|
|
||||||
)
|
|
||||||
.first()
|
|
||||||
)
|
|
||||||
if max_interest_msg:
|
|
||||||
data.max_interest_time = datetime.fromtimestamp(max_interest_msg.time).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
|
|
||||||
# 7. 思考深度 (基于 action_reasoning 长度)
|
# 7. 思考深度 (基于 action_reasoning 长度)
|
||||||
reasoning_records = ActionRecords.select(ActionRecords.action_reasoning, ActionRecords.time).where(
|
with get_db_session() as session:
|
||||||
(ActionRecords.time >= start_ts)
|
statement = select(ActionRecord).where(
|
||||||
& (ActionRecords.time <= end_ts)
|
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
& (ActionRecords.action_reasoning.is_null(False))
|
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
& (ActionRecords.action_reasoning != "")
|
col(ActionRecord.action_reasoning).is_not(None),
|
||||||
|
col(ActionRecord.action_reasoning) != "",
|
||||||
)
|
)
|
||||||
|
reasoning_records = session.exec(statement).all()
|
||||||
reasoning_lengths = []
|
reasoning_lengths = []
|
||||||
max_len = 0
|
max_len = 0
|
||||||
max_len_time = None
|
max_len_time = None
|
||||||
|
|
@ -496,13 +521,13 @@ async def get_brain_power(year: int = 2025) -> BrainPowerData:
|
||||||
reasoning_lengths.append(length)
|
reasoning_lengths.append(length)
|
||||||
if length > max_len:
|
if length > max_len:
|
||||||
max_len = length
|
max_len = length
|
||||||
max_len_time = record.time
|
max_len_time = record.timestamp
|
||||||
|
|
||||||
if reasoning_lengths:
|
if reasoning_lengths:
|
||||||
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
|
data.avg_reasoning_length = round(sum(reasoning_lengths) / len(reasoning_lengths), 1)
|
||||||
data.max_reasoning_length = max_len
|
data.max_reasoning_length = max_len
|
||||||
if max_len_time:
|
if max_len_time:
|
||||||
data.max_reasoning_time = datetime.fromtimestamp(max_len_time).strftime("%Y-%m-%d %H:%M:%S")
|
data.max_reasoning_time = max_len_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取最强大脑数据失败: {e}")
|
logger.error(f"获取最强大脑数据失败: {e}")
|
||||||
|
|
@ -517,7 +542,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
"""获取个性与表达数据"""
|
"""获取个性与表达数据"""
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
data = ExpressionVibeData()
|
data = ExpressionVibeData.model_construct()
|
||||||
start_ts, end_ts = get_year_time_range(year)
|
start_ts, end_ts = get_year_time_range(year)
|
||||||
|
|
||||||
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
|
# 获取 bot 自身的 QQ 账号,用于筛选 bot 发送的消息
|
||||||
|
|
@ -525,75 +550,58 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 表情包之王 - 使用次数最多的表情包
|
# 1. 表情包之王 - 使用次数最多的表情包
|
||||||
top_emoji_query = (
|
with get_db_session() as session:
|
||||||
Emoji.select(Emoji.id, Emoji.full_path, Emoji.description, Emoji.usage_count, Emoji.emoji_hash)
|
statement = (
|
||||||
.where(Emoji.is_registered == True)
|
select(Images).where(col(Images.is_registered) == True).order_by(desc(col(Images.query_count))).limit(5)
|
||||||
.order_by(Emoji.usage_count.desc())
|
|
||||||
.limit(5)
|
|
||||||
)
|
)
|
||||||
top_emojis = list(top_emoji_query.dicts())
|
top_emojis = session.exec(statement).all()
|
||||||
if top_emojis:
|
if top_emojis:
|
||||||
data.top_emoji = {
|
data.top_emoji = {
|
||||||
"id": top_emojis[0].get("id"),
|
"id": top_emojis[0].id,
|
||||||
"path": top_emojis[0].get("full_path"),
|
"path": top_emojis[0].full_path,
|
||||||
"description": top_emojis[0].get("description"),
|
"description": top_emojis[0].description,
|
||||||
"usage_count": top_emojis[0].get("usage_count", 0),
|
"usage_count": top_emojis[0].query_count,
|
||||||
"hash": top_emojis[0].get("emoji_hash"),
|
"hash": top_emojis[0].image_hash,
|
||||||
}
|
}
|
||||||
data.top_emojis = [
|
data.top_emojis = [
|
||||||
{
|
{
|
||||||
"id": e.get("id"),
|
"id": e.id,
|
||||||
"path": e.get("full_path"),
|
"path": e.full_path,
|
||||||
"description": e.get("description"),
|
"description": e.description,
|
||||||
"usage_count": e.get("usage_count", 0),
|
"usage_count": e.query_count,
|
||||||
"hash": e.get("emoji_hash"),
|
"hash": e.image_hash,
|
||||||
}
|
}
|
||||||
for e in top_emojis
|
for e in top_emojis
|
||||||
]
|
]
|
||||||
|
|
||||||
# 2. 百变麦麦 - 最常用的表达风格
|
# 2. 百变麦麦 - 最常用的表达风格
|
||||||
expression_query = (
|
with get_db_session() as session:
|
||||||
Expression.select(
|
statement = (
|
||||||
Expression.style,
|
select(Expression.style, func.sum(col(Expression.count)).label("total_count"))
|
||||||
fn.SUM(Expression.count).alias("total_count"),
|
.where(
|
||||||
|
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
|
|
||||||
.group_by(Expression.style)
|
.group_by(Expression.style)
|
||||||
.order_by(fn.SUM(Expression.count).desc())
|
.order_by(func.sum(col(Expression.count)).desc())
|
||||||
.limit(5)
|
.limit(5)
|
||||||
)
|
)
|
||||||
data.top_expressions = [
|
expression_rows = session.exec(statement).all()
|
||||||
{"style": row["style"], "count": row["total_count"]} for row in expression_query.dicts()
|
data.top_expressions = [{"style": row[0], "count": row[1] or 0} for row in expression_rows]
|
||||||
]
|
|
||||||
|
|
||||||
# 3. 被拒绝的表达
|
# 3. 被拒绝的表达
|
||||||
data.rejected_expression_count = (
|
data.rejected_expression_count = 0
|
||||||
Expression.select()
|
|
||||||
.where(
|
|
||||||
(Expression.last_active_time >= start_ts)
|
|
||||||
& (Expression.last_active_time <= end_ts)
|
|
||||||
& (Expression.rejected == True)
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. 已检查的表达
|
# 4. 已检查的表达
|
||||||
data.checked_expression_count = (
|
data.checked_expression_count = 0
|
||||||
Expression.select()
|
|
||||||
.where(
|
|
||||||
(Expression.last_active_time >= start_ts)
|
|
||||||
& (Expression.last_active_time <= end_ts)
|
|
||||||
& (Expression.checked == True)
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. 表达总数
|
# 5. 表达总数
|
||||||
data.total_expressions = (
|
with get_db_session() as session:
|
||||||
Expression.select()
|
statement = select(func.count()).where(
|
||||||
.where((Expression.last_active_time >= start_ts) & (Expression.last_active_time <= end_ts))
|
col(Expression.last_active_time) >= datetime.fromtimestamp(start_ts),
|
||||||
.count()
|
col(Expression.last_active_time) <= datetime.fromtimestamp(end_ts),
|
||||||
)
|
)
|
||||||
|
data.total_expressions = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 6. 动作类型分布 (过滤无意义的动作)
|
# 6. 动作类型分布 (过滤无意义的动作)
|
||||||
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
|
# 过滤掉: no_reply_until_call, make_question, no_action, wait, complete_talk, listening, block_and_ignore
|
||||||
|
|
@ -608,28 +616,29 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
"listening",
|
"listening",
|
||||||
"block_and_ignore",
|
"block_and_ignore",
|
||||||
]
|
]
|
||||||
action_query = (
|
with get_db_session() as session:
|
||||||
ActionRecords.select(
|
statement = (
|
||||||
ActionRecords.action_name,
|
select(ActionRecord.action_name, func.count().label("count"))
|
||||||
fn.COUNT(ActionRecords.id).alias("count"),
|
|
||||||
)
|
|
||||||
.where(
|
.where(
|
||||||
(ActionRecords.time >= start_ts)
|
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
& (ActionRecords.time <= end_ts)
|
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
& (ActionRecords.action_name.not_in(excluded_actions))
|
col(ActionRecord.action_name).not_in(excluded_actions),
|
||||||
)
|
)
|
||||||
.group_by(ActionRecords.action_name)
|
.group_by(ActionRecord.action_name)
|
||||||
.order_by(fn.COUNT(ActionRecords.id).desc())
|
.order_by(func.count().desc())
|
||||||
.limit(10)
|
.limit(10)
|
||||||
)
|
)
|
||||||
data.action_types = [{"action": row["action_name"], "count": row["count"]} for row in action_query.dicts()]
|
action_rows = session.exec(statement).all()
|
||||||
|
data.action_types = [{"action": row[0], "count": row[1]} for row in action_rows]
|
||||||
|
|
||||||
# 7. 处理的图片数量
|
# 7. 处理的图片数量
|
||||||
data.image_processed_count = (
|
with get_db_session() as session:
|
||||||
Messages.select()
|
statement = select(func.count()).where(
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.is_picid == True))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
.count()
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
col(Messages.is_picture) == True,
|
||||||
)
|
)
|
||||||
|
data.image_processed_count = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
|
# 8. 深夜还在回复 (0-6点最晚的10条消息中随机抽取一条)
|
||||||
import random
|
import random
|
||||||
|
|
@ -648,21 +657,22 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
# 使用 user_id 判断是否是 bot 发送的消息
|
# 使用 user_id 判断是否是 bot 发送的消息
|
||||||
late_night_messages = list(
|
with get_db_session() as session:
|
||||||
Messages.select(
|
statement = (
|
||||||
Messages.time,
|
select(Messages)
|
||||||
Messages.processed_plain_text,
|
|
||||||
Messages.display_message,
|
|
||||||
)
|
|
||||||
.where(
|
.where(
|
||||||
(Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.user_id == bot_qq) # bot 发送的消息
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
col(Messages.user_id) == bot_qq,
|
||||||
)
|
)
|
||||||
.order_by(Messages.time.desc())
|
.order_by(desc(col(Messages.timestamp)))
|
||||||
|
.limit(200)
|
||||||
)
|
)
|
||||||
|
late_night_messages = session.exec(statement).all()
|
||||||
# 筛选出0-6点的消息
|
# 筛选出0-6点的消息
|
||||||
late_night_filtered = []
|
late_night_filtered = []
|
||||||
for msg in late_night_messages:
|
for msg in late_night_messages:
|
||||||
msg_dt = datetime.fromtimestamp(msg.time)
|
msg_dt = msg.timestamp
|
||||||
hour = msg_dt.hour
|
hour = msg_dt.hour
|
||||||
if 0 <= hour < 6: # 0点到6点
|
if 0 <= hour < 6: # 0点到6点
|
||||||
raw_content = msg.processed_plain_text or msg.display_message or ""
|
raw_content = msg.processed_plain_text or msg.display_message or ""
|
||||||
|
|
@ -671,7 +681,7 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
if cleaned_content and len(cleaned_content) > 2:
|
if cleaned_content and len(cleaned_content) > 2:
|
||||||
late_night_filtered.append(
|
late_night_filtered.append(
|
||||||
{
|
{
|
||||||
"time": msg.time,
|
"time": msg_dt.timestamp(),
|
||||||
"hour": hour,
|
"hour": hour,
|
||||||
"minute": msg_dt.minute,
|
"minute": msg_dt.minute,
|
||||||
"content": cleaned_content,
|
"content": cleaned_content,
|
||||||
|
|
@ -693,13 +703,15 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
import json as json_lib
|
import json as json_lib
|
||||||
|
|
||||||
reply_records = ActionRecords.select(ActionRecords.action_data).where(
|
with get_db_session() as session:
|
||||||
(ActionRecords.time >= start_ts)
|
statement = select(ActionRecord).where(
|
||||||
& (ActionRecords.time <= end_ts)
|
col(ActionRecord.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
& (ActionRecords.action_name == "reply")
|
col(ActionRecord.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
& (ActionRecords.action_data.is_null(False))
|
col(ActionRecord.action_name) == "reply",
|
||||||
& (ActionRecords.action_data != "")
|
col(ActionRecord.action_data).is_not(None),
|
||||||
|
col(ActionRecord.action_data) != "",
|
||||||
)
|
)
|
||||||
|
reply_records = session.exec(statement).all()
|
||||||
|
|
||||||
reply_contents = []
|
reply_contents = []
|
||||||
for record in reply_records:
|
for record in reply_records:
|
||||||
|
|
@ -762,21 +774,20 @@ async def get_expression_vibe(year: int = 2025) -> ExpressionVibeData:
|
||||||
|
|
||||||
async def get_achievements(year: int = 2025) -> AchievementData:
|
async def get_achievements(year: int = 2025) -> AchievementData:
|
||||||
"""获取趣味成就数据"""
|
"""获取趣味成就数据"""
|
||||||
data = AchievementData()
|
data = AchievementData.model_construct()
|
||||||
start_ts, end_ts = get_year_time_range(year)
|
start_ts, end_ts = get_year_time_range(year)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 新学到的黑话数量
|
# 1. 新学到的黑话数量
|
||||||
# Jargon 表没有时间字段,统计全部已确认的黑话
|
# Jargon 表没有时间字段,统计全部已确认的黑话
|
||||||
data.new_jargon_count = Jargon.select().where(Jargon.is_jargon == True).count()
|
with get_db_session() as session:
|
||||||
|
statement = select(func.count()).where(col(Jargon.is_jargon) == True)
|
||||||
|
data.new_jargon_count = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 2. 代表性黑话示例
|
# 2. 代表性黑话示例
|
||||||
jargon_samples = (
|
with get_db_session() as session:
|
||||||
Jargon.select(Jargon.content, Jargon.meaning, Jargon.count)
|
statement = select(Jargon).where(col(Jargon.is_jargon) == True).order_by(desc(col(Jargon.count))).limit(5)
|
||||||
.where(Jargon.is_jargon == True)
|
jargon_samples = session.exec(statement).all()
|
||||||
.order_by(Jargon.count.desc())
|
|
||||||
.limit(5)
|
|
||||||
)
|
|
||||||
data.sample_jargons = [
|
data.sample_jargons = [
|
||||||
{
|
{
|
||||||
"content": j.content,
|
"content": j.content,
|
||||||
|
|
@ -787,14 +798,21 @@ async def get_achievements(year: int = 2025) -> AchievementData:
|
||||||
]
|
]
|
||||||
|
|
||||||
# 3. 总消息数
|
# 3. 总消息数
|
||||||
data.total_messages = Messages.select().where((Messages.time >= start_ts) & (Messages.time <= end_ts)).count()
|
with get_db_session() as session:
|
||||||
|
statement = select(func.count()).where(
|
||||||
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
)
|
||||||
|
data.total_messages = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
# 4. 总回复数 (有 reply_to 的消息)
|
# 4. 总回复数 (有 reply_to 的消息)
|
||||||
data.total_replies = (
|
with get_db_session() as session:
|
||||||
Messages.select()
|
statement = select(func.count()).where(
|
||||||
.where((Messages.time >= start_ts) & (Messages.time <= end_ts) & (Messages.reply_to.is_null(False)))
|
col(Messages.timestamp) >= datetime.fromtimestamp(start_ts),
|
||||||
.count()
|
col(Messages.timestamp) <= datetime.fromtimestamp(end_ts),
|
||||||
|
col(Messages.reply_to).is_not(None),
|
||||||
)
|
)
|
||||||
|
data.total_replies = int(session.exec(statement).first() or 0)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取趣味成就数据失败: {e}")
|
logger.error(f"获取趣味成就数据失败: {e}")
|
||||||
|
|
|
||||||
|
|
@ -7,16 +7,19 @@
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, Optional, List
|
from typing import Any, Dict, List, Optional
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Query, Depends, Cookie, Header
|
|
||||||
from pydantic import BaseModel
|
from fastapi import APIRouter, Cookie, Depends, Header, Query, WebSocket, WebSocketDisconnect
|
||||||
from sqlalchemy import case, func as fn
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import case, desc, func
|
||||||
|
from sqlmodel import col, select, delete
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.common.database.database_model import Messages, PersonInfo
|
|
||||||
from src.config.config import global_config
|
|
||||||
from src.chat.message_receive.bot import chat_bot
|
from src.chat.message_receive.bot import chat_bot
|
||||||
from src.webui.core import verify_auth_token_from_cookie_or_header, get_token_manager
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Messages, PersonInfo
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.webui.core import get_token_manager, verify_auth_token_from_cookie_or_header
|
||||||
from src.webui.routers.websocket.auth import verify_ws_token
|
from src.webui.routers.websocket.auth import verify_ws_token
|
||||||
|
|
||||||
logger = get_logger("webui.chat")
|
logger = get_logger("webui.chat")
|
||||||
|
|
@ -97,7 +100,7 @@ class ChatHistoryManager:
|
||||||
"id": msg.message_id,
|
"id": msg.message_id,
|
||||||
"type": "bot" if is_bot else "user",
|
"type": "bot" if is_bot else "user",
|
||||||
"content": msg.processed_plain_text or msg.display_message or "",
|
"content": msg.processed_plain_text or msg.display_message or "",
|
||||||
"timestamp": msg.time,
|
"timestamp": msg.timestamp.timestamp(),
|
||||||
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
"sender_name": msg.user_nickname or (global_config.bot.nickname if is_bot else "未知用户"),
|
||||||
"sender_id": "bot" if is_bot else user_id,
|
"sender_id": "bot" if is_bot else user_id,
|
||||||
"is_bot": is_bot,
|
"is_bot": is_bot,
|
||||||
|
|
@ -113,12 +116,14 @@ class ChatHistoryManager:
|
||||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||||
try:
|
try:
|
||||||
# 查询指定群的消息,按时间排序
|
# 查询指定群的消息,按时间排序
|
||||||
messages = (
|
with get_db_session() as session:
|
||||||
Messages.select()
|
statement = (
|
||||||
.where(Messages.chat_info_group_id == target_group_id)
|
select(Messages)
|
||||||
.order_by(Messages.time.desc())
|
.where(col(Messages.group_id) == target_group_id)
|
||||||
|
.order_by(desc(col(Messages.timestamp)))
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
)
|
)
|
||||||
|
messages = session.exec(statement).all()
|
||||||
|
|
||||||
# 转换为列表并反转(使最旧的消息在前)
|
# 转换为列表并反转(使最旧的消息在前)
|
||||||
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
# 传递 group_id 以便正确判断虚拟群中的机器人消息
|
||||||
|
|
@ -139,7 +144,10 @@ class ChatHistoryManager:
|
||||||
"""
|
"""
|
||||||
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
target_group_id = group_id if group_id else WEBUI_CHAT_GROUP_ID
|
||||||
try:
|
try:
|
||||||
deleted = Messages.delete().where(Messages.chat_info_group_id == target_group_id).execute()
|
with get_db_session() as session:
|
||||||
|
statement = delete(Messages).where(col(Messages.group_id) == target_group_id)
|
||||||
|
result = session.exec(statement)
|
||||||
|
deleted = result.rowcount or 0
|
||||||
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
logger.info(f"已清空 {deleted} 条聊天记录 (group_id={target_group_id})")
|
||||||
return deleted
|
return deleted
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -172,14 +180,14 @@ class ChatConnectionManager:
|
||||||
del self.user_sessions[user_id]
|
del self.user_sessions[user_id]
|
||||||
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
logger.info(f"WebUI 聊天会话已断开: session={session_id}")
|
||||||
|
|
||||||
async def send_message(self, session_id: str, message: dict):
|
async def send_message(self, session_id: str, message: dict[str, Any]):
|
||||||
if session_id in self.active_connections:
|
if session_id in self.active_connections:
|
||||||
try:
|
try:
|
||||||
await self.active_connections[session_id].send_json(message)
|
await self.active_connections[session_id].send_json(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {e}")
|
logger.error(f"发送消息失败: {e}")
|
||||||
|
|
||||||
async def broadcast(self, message: dict):
|
async def broadcast(self, message: dict[str, Any]):
|
||||||
"""广播消息给所有连接"""
|
"""广播消息给所有连接"""
|
||||||
for session_id in list(self.active_connections.keys()):
|
for session_id in list(self.active_connections.keys()):
|
||||||
await self.send_message(session_id, message)
|
await self.send_message(session_id, message)
|
||||||
|
|
@ -292,16 +300,18 @@ async def get_available_platforms(_auth: bool = Depends(require_auth)):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 查询所有不同的平台
|
# 查询所有不同的平台
|
||||||
platforms = (
|
with get_db_session() as session:
|
||||||
PersonInfo.select(PersonInfo.platform, fn.COUNT(PersonInfo.id).alias("count"))
|
statement = (
|
||||||
|
select(PersonInfo.platform, func.count().label("count"))
|
||||||
.group_by(PersonInfo.platform)
|
.group_by(PersonInfo.platform)
|
||||||
.order_by(fn.COUNT(PersonInfo.id).desc())
|
.order_by(func.count().desc())
|
||||||
)
|
)
|
||||||
|
platforms = session.exec(statement).all()
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for p in platforms:
|
for platform, count in platforms:
|
||||||
if p.platform: # 排除空平台
|
if platform:
|
||||||
result.append({"platform": p.platform, "count": p.count})
|
result.append({"platform": platform, "count": count})
|
||||||
|
|
||||||
return {"success": True, "platforms": result}
|
return {"success": True, "platforms": result}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -325,31 +335,36 @@ async def get_persons_by_platform(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 构建查询
|
# 构建查询
|
||||||
query = PersonInfo.select().where(PersonInfo.platform == platform)
|
statement = select(PersonInfo).where(col(PersonInfo.platform) == platform)
|
||||||
|
|
||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
statement = statement.where(
|
||||||
(PersonInfo.person_name.contains(search))
|
(col(PersonInfo.person_name).contains(search))
|
||||||
| (PersonInfo.nickname.contains(search))
|
| (col(PersonInfo.user_nickname).contains(search))
|
||||||
| (PersonInfo.user_id.contains(search))
|
| (col(PersonInfo.user_id).contains(search))
|
||||||
)
|
)
|
||||||
|
|
||||||
# 按最后交互时间排序,优先显示活跃用户
|
# 按最后交互时间排序,优先显示活跃用户
|
||||||
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
|
statement = statement.order_by(
|
||||||
query = query.limit(limit)
|
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
|
||||||
|
col(PersonInfo.last_known_time).desc(),
|
||||||
|
).limit(limit)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
persons = session.exec(statement).all()
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for person in query:
|
for person in persons:
|
||||||
result.append(
|
result.append(
|
||||||
{
|
{
|
||||||
"person_id": person.person_id,
|
"person_id": person.person_id,
|
||||||
"user_id": person.user_id,
|
"user_id": person.user_id,
|
||||||
"person_name": person.person_name,
|
"person_name": person.person_name,
|
||||||
"nickname": person.nickname,
|
"nickname": person.user_nickname,
|
||||||
"is_known": person.is_known,
|
"is_known": person.is_known,
|
||||||
"platform": person.platform,
|
"platform": person.platform,
|
||||||
"display_name": person.person_name or person.nickname or person.user_id,
|
"display_name": person.person_name or person.user_nickname or person.user_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -448,7 +463,9 @@ async def websocket_chat(
|
||||||
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
# 如果 URL 参数中提供了虚拟身份信息,自动配置
|
||||||
if platform and person_id:
|
if platform and person_id:
|
||||||
try:
|
try:
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
if person:
|
if person:
|
||||||
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
# 使用前端传递的 group_id,如果没有则生成一个稳定的
|
||||||
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
virtual_group_id = group_id or f"{VIRTUAL_GROUP_ID_PREFIX}{platform}_{person.user_id}"
|
||||||
|
|
@ -457,7 +474,7 @@ async def websocket_chat(
|
||||||
platform=person.platform,
|
platform=person.platform,
|
||||||
person_id=person.person_id,
|
person_id=person.person_id,
|
||||||
user_id=person.user_id,
|
user_id=person.user_id,
|
||||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
user_nickname=person.person_name or person.user_nickname or person.user_id,
|
||||||
group_id=virtual_group_id,
|
group_id=virtual_group_id,
|
||||||
group_name=group_name or "WebUI虚拟群聊",
|
group_name=group_name or "WebUI虚拟群聊",
|
||||||
)
|
)
|
||||||
|
|
@ -471,7 +488,7 @@ async def websocket_chat(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 构建会话信息
|
# 构建会话信息
|
||||||
session_info_data = {
|
session_info_data: dict[str, Any] = {
|
||||||
"type": "session_info",
|
"type": "session_info",
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
|
@ -641,7 +658,13 @@ async def websocket_chat(
|
||||||
|
|
||||||
# 获取用户信息
|
# 获取用户信息
|
||||||
try:
|
try:
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == virtual_data.get("person_id"))
|
with get_db_session() as session:
|
||||||
|
statement = (
|
||||||
|
select(PersonInfo)
|
||||||
|
.where(col(PersonInfo.person_id) == virtual_data.get("person_id"))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
person = session.exec(statement).first()
|
||||||
if not person:
|
if not person:
|
||||||
await chat_manager.send_message(
|
await chat_manager.send_message(
|
||||||
session_id,
|
session_id,
|
||||||
|
|
@ -665,7 +688,7 @@ async def websocket_chat(
|
||||||
platform=person.platform,
|
platform=person.platform,
|
||||||
person_id=person.person_id,
|
person_id=person.person_id,
|
||||||
user_id=person.user_id,
|
user_id=person.user_id,
|
||||||
user_nickname=person.person_name or person.nickname or person.user_id,
|
user_nickname=person.person_name or person.user_nickname or person.user_id,
|
||||||
group_id=group_id,
|
group_id=group_id,
|
||||||
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
group_name=virtual_data.get("group_name", "WebUI虚拟群聊"),
|
||||||
)
|
)
|
||||||
|
|
@ -769,7 +792,7 @@ async def get_chat_info(_auth: bool = Depends(require_auth)):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_webui_chat_broadcaster() -> tuple:
|
def get_webui_chat_broadcaster() -> tuple[ChatConnectionManager, str]:
|
||||||
"""获取 WebUI 聊天广播器,供外部模块使用
|
"""获取 WebUI 聊天广播器,供外部模块使用
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,16 @@
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from sqlalchemy import case
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy import case, func
|
||||||
|
from sqlmodel import col, select, delete
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import Expression, ChatStreams
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Expression
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||||
import time
|
|
||||||
|
|
||||||
logger = get_logger("webui.expression")
|
logger = get_logger("webui.expression")
|
||||||
|
|
||||||
|
|
@ -98,30 +103,32 @@ def verify_auth_token(
|
||||||
|
|
||||||
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
def expression_to_response(expression: Expression) -> ExpressionResponse:
|
||||||
"""将 Expression 模型转换为响应对象"""
|
"""将 Expression 模型转换为响应对象"""
|
||||||
|
last_active_time = expression.last_active_time.timestamp() if expression.last_active_time else 0.0
|
||||||
|
create_date = expression.create_time.timestamp() if expression.create_time else None
|
||||||
return ExpressionResponse(
|
return ExpressionResponse(
|
||||||
id=expression.id,
|
id=expression.id if expression.id is not None else 0,
|
||||||
situation=expression.situation,
|
situation=expression.situation,
|
||||||
style=expression.style,
|
style=expression.style,
|
||||||
last_active_time=expression.last_active_time,
|
last_active_time=last_active_time,
|
||||||
chat_id=expression.chat_id,
|
chat_id=expression.session_id or "",
|
||||||
create_date=expression.create_date,
|
create_date=create_date,
|
||||||
checked=expression.checked,
|
checked=False,
|
||||||
rejected=expression.rejected,
|
rejected=False,
|
||||||
modified_by=expression.modified_by,
|
modified_by=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_chat_name(chat_id: str) -> str:
|
def get_chat_name(chat_id: str) -> str:
|
||||||
"""根据 chat_id 获取聊天名称"""
|
"""根据 chat_id 获取聊天名称"""
|
||||||
try:
|
try:
|
||||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||||
if chat_stream:
|
if not chat_stream:
|
||||||
# 优先使用群聊名称,否则使用用户昵称
|
return chat_id
|
||||||
if chat_stream.group_name:
|
if chat_stream.group_info and chat_stream.group_info.group_name:
|
||||||
return chat_stream.group_name
|
return chat_stream.group_info.group_name
|
||||||
elif chat_stream.user_nickname:
|
if chat_stream.user_info and chat_stream.user_info.user_nickname:
|
||||||
return chat_stream.user_nickname
|
return chat_stream.user_info.user_nickname
|
||||||
return chat_id # 找不到时返回原始ID
|
return chat_id
|
||||||
except Exception:
|
except Exception:
|
||||||
return chat_id
|
return chat_id
|
||||||
|
|
||||||
|
|
@ -130,12 +137,15 @@ def get_chat_names_batch(chat_ids: List[str]) -> Dict[str, str]:
|
||||||
"""批量获取聊天名称"""
|
"""批量获取聊天名称"""
|
||||||
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
result = {cid: cid for cid in chat_ids} # 默认值为原始ID
|
||||||
try:
|
try:
|
||||||
chat_streams = ChatStreams.select().where(ChatStreams.stream_id.in_(chat_ids))
|
chat_manager = get_chat_manager()
|
||||||
for cs in chat_streams:
|
for chat_id in chat_ids:
|
||||||
if cs.group_name:
|
chat_stream = chat_manager.get_stream(chat_id)
|
||||||
result[cs.stream_id] = cs.group_name
|
if not chat_stream:
|
||||||
elif cs.user_nickname:
|
continue
|
||||||
result[cs.stream_id] = cs.user_nickname
|
if chat_stream.group_info and chat_stream.group_info.group_name:
|
||||||
|
result[chat_id] = chat_stream.group_info.group_name
|
||||||
|
elif chat_stream.user_info and chat_stream.user_info.user_nickname:
|
||||||
|
result[chat_id] = chat_stream.user_info.user_nickname
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"批量获取聊天名称失败: {e}")
|
logger.warning(f"批量获取聊天名称失败: {e}")
|
||||||
return result
|
return result
|
||||||
|
|
@ -172,14 +182,17 @@ async def get_chat_list(maibot_session: Optional[str] = Cookie(None), authorizat
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
chat_list = []
|
chat_list = []
|
||||||
for cs in ChatStreams.select():
|
for stream_id, stream in get_chat_manager().streams.items():
|
||||||
chat_name = cs.group_name if cs.group_name else (cs.user_nickname if cs.user_nickname else cs.stream_id)
|
chat_name = stream.group_info.group_name if stream.group_info and stream.group_info.group_name else None
|
||||||
|
if not chat_name and stream.user_info and stream.user_info.user_nickname:
|
||||||
|
chat_name = stream.user_info.user_nickname
|
||||||
|
chat_name = chat_name or stream_id
|
||||||
chat_list.append(
|
chat_list.append(
|
||||||
ChatInfo(
|
ChatInfo(
|
||||||
chat_id=cs.stream_id,
|
chat_id=stream_id,
|
||||||
chat_name=chat_name,
|
chat_name=chat_name,
|
||||||
platform=cs.platform,
|
platform=stream.platform,
|
||||||
is_group=bool(cs.group_id),
|
is_group=bool(stream.group_info and stream.group_info.group_id),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -221,29 +234,39 @@ async def get_expression_list(
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
# 构建查询
|
# 构建查询
|
||||||
query = Expression.select()
|
statement = select(Expression)
|
||||||
|
|
||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
|
statement = statement.where(
|
||||||
|
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||||
|
)
|
||||||
|
|
||||||
# 聊天ID过滤
|
# 聊天ID过滤
|
||||||
if chat_id:
|
if chat_id:
|
||||||
query = query.where(Expression.chat_id == chat_id)
|
statement = statement.where(col(Expression.session_id) == chat_id)
|
||||||
|
|
||||||
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
# 排序:最后活跃时间倒序(NULL 值放在最后)
|
||||||
query = query.order_by(
|
statement = statement.order_by(
|
||||||
case((Expression.last_active_time.is_null(), 1), else_=0), Expression.last_active_time.desc()
|
case((col(Expression.last_active_time).is_(None), 1), else_=0),
|
||||||
|
col(Expression.last_active_time).desc(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取总数
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# 分页
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
expressions = query.offset(offset).limit(page_size)
|
statement = statement.offset(offset).limit(page_size)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
expressions = session.exec(statement).all()
|
||||||
|
|
||||||
|
count_statement = select(Expression.id)
|
||||||
|
if search:
|
||||||
|
count_statement = count_statement.where(
|
||||||
|
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||||
|
)
|
||||||
|
if chat_id:
|
||||||
|
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||||
|
total = len(session.exec(count_statement).all())
|
||||||
|
|
||||||
# 转换为响应对象
|
|
||||||
data = [expression_to_response(expr) for expr in expressions]
|
data = [expression_to_response(expr) for expr in expressions]
|
||||||
|
|
||||||
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
return ExpressionListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
|
|
@ -272,7 +295,9 @@ async def get_expression_detail(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||||
|
expression = session.exec(statement).first()
|
||||||
|
|
||||||
if not expression:
|
if not expression:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||||
|
|
@ -305,16 +330,22 @@ async def create_expression(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
current_time = time.time()
|
current_time = datetime.now()
|
||||||
|
|
||||||
# 创建表达方式
|
# 创建表达方式
|
||||||
expression = Expression.create(
|
with get_db_session() as session:
|
||||||
|
expression = Expression(
|
||||||
situation=request.situation,
|
situation=request.situation,
|
||||||
style=request.style,
|
style=request.style,
|
||||||
chat_id=request.chat_id,
|
context="",
|
||||||
|
up_content="",
|
||||||
|
content_list="[]",
|
||||||
|
count=0,
|
||||||
last_active_time=current_time,
|
last_active_time=current_time,
|
||||||
create_date=current_time,
|
create_time=current_time,
|
||||||
|
session_id=request.chat_id,
|
||||||
)
|
)
|
||||||
|
session.add(expression)
|
||||||
|
|
||||||
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
logger.info(f"表达方式已创建: ID={expression.id}, situation={request.situation}")
|
||||||
|
|
||||||
|
|
@ -350,16 +381,18 @@ async def update_expression(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||||
|
expression = session.exec(statement).first()
|
||||||
|
|
||||||
if not expression:
|
if not expression:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||||
|
|
||||||
# 冲突检测:如果要求未检查状态,但已经被检查了
|
# 冲突检测:如果要求未检查状态,但已经被检查了
|
||||||
if request.require_unchecked and expression.checked:
|
if request.require_unchecked and getattr(expression, "checked", False):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=409,
|
status_code=409,
|
||||||
detail=f"此表达方式已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查,请刷新列表",
|
detail=f"此表达方式已被{'AI自动' if getattr(expression, 'modified_by', None) == 'ai' else '人工'}检查,请刷新列表",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 只更新提供的字段
|
# 只更新提供的字段
|
||||||
|
|
@ -376,13 +409,18 @@ async def update_expression(
|
||||||
update_data["modified_by"] = "user"
|
update_data["modified_by"] = "user"
|
||||||
|
|
||||||
# 更新最后活跃时间
|
# 更新最后活跃时间
|
||||||
update_data["last_active_time"] = time.time()
|
update_data["last_active_time"] = datetime.now()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
|
with get_db_session() as session:
|
||||||
|
db_expression = session.exec(select(Expression).where(col(Expression.id) == expression_id).limit(1)).first()
|
||||||
|
if not db_expression:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(expression, field, value)
|
if hasattr(db_expression, field):
|
||||||
|
setattr(db_expression, field, value)
|
||||||
expression.save()
|
session.add(db_expression)
|
||||||
|
expression = db_expression
|
||||||
|
|
||||||
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"表达方式已更新: ID={expression_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
|
|
@ -414,7 +452,9 @@ async def delete_expression(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
expression = Expression.get_or_none(Expression.id == expression_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(Expression).where(col(Expression.id) == expression_id).limit(1)
|
||||||
|
expression = session.exec(statement).first()
|
||||||
|
|
||||||
if not expression:
|
if not expression:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {expression_id} 的表达方式")
|
||||||
|
|
@ -423,7 +463,8 @@ async def delete_expression(
|
||||||
situation = expression.situation
|
situation = expression.situation
|
||||||
|
|
||||||
# 执行删除
|
# 执行删除
|
||||||
expression.delete_instance()
|
with get_db_session() as session:
|
||||||
|
session.exec(delete(Expression).where(col(Expression.id) == expression_id))
|
||||||
|
|
||||||
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
logger.info(f"表达方式已删除: ID={expression_id}, situation={situation}")
|
||||||
|
|
||||||
|
|
@ -465,8 +506,9 @@ async def batch_delete_expressions(
|
||||||
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
raise HTTPException(status_code=400, detail="未提供要删除的表达方式ID")
|
||||||
|
|
||||||
# 查找所有要删除的表达方式
|
# 查找所有要删除的表达方式
|
||||||
expressions = Expression.select().where(Expression.id.in_(request.ids))
|
with get_db_session() as session:
|
||||||
found_ids = [expr.id for expr in expressions]
|
statements = select(Expression.id).where(col(Expression.id).in_(request.ids))
|
||||||
|
found_ids = [expr_id for expr_id in session.exec(statements).all()]
|
||||||
|
|
||||||
# 检查是否有未找到的ID
|
# 检查是否有未找到的ID
|
||||||
not_found_ids = set(request.ids) - set(found_ids)
|
not_found_ids = set(request.ids) - set(found_ids)
|
||||||
|
|
@ -474,7 +516,9 @@ async def batch_delete_expressions(
|
||||||
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
logger.warning(f"部分表达方式未找到: {not_found_ids}")
|
||||||
|
|
||||||
# 执行批量删除
|
# 执行批量删除
|
||||||
deleted_count = Expression.delete().where(Expression.id.in_(found_ids)).execute()
|
with get_db_session() as session:
|
||||||
|
result = session.exec(delete(Expression).where(col(Expression.id).in_(found_ids)))
|
||||||
|
deleted_count = result.rowcount or 0
|
||||||
|
|
||||||
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
logger.info(f"批量删除了 {deleted_count} 个表达方式")
|
||||||
|
|
||||||
|
|
@ -503,21 +547,21 @@ async def get_expression_stats(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
total = Expression.select().count()
|
with get_db_session() as session:
|
||||||
|
total = len(session.exec(select(Expression.id)).all())
|
||||||
|
|
||||||
# 按 chat_id 统计
|
|
||||||
chat_stats = {}
|
chat_stats = {}
|
||||||
for expr in Expression.select(Expression.chat_id):
|
for chat_id in session.exec(select(Expression.session_id)).all():
|
||||||
chat_id = expr.chat_id
|
if chat_id:
|
||||||
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
chat_stats[chat_id] = chat_stats.get(chat_id, 0) + 1
|
||||||
|
|
||||||
# 获取最近创建的记录数(7天内)
|
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||||
seven_days_ago = time.time() - (7 * 24 * 60 * 60)
|
recent_statement = (
|
||||||
recent = (
|
select(func.count())
|
||||||
Expression.select()
|
.select_from(Expression)
|
||||||
.where((Expression.create_date.is_null(False)) & (Expression.create_date >= seven_days_ago))
|
.where(col(Expression.create_time).is_not(None), col(Expression.create_time) >= seven_days_ago)
|
||||||
.count()
|
|
||||||
)
|
)
|
||||||
|
recent = session.exec(recent_statement).one()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|
@ -561,12 +605,13 @@ async def get_review_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
total = Expression.select().count()
|
with get_db_session() as session:
|
||||||
unchecked = Expression.select().where(Expression.checked == False).count()
|
total = len(session.exec(select(Expression.id)).all())
|
||||||
passed = Expression.select().where((Expression.checked == True) & (Expression.rejected == False)).count()
|
unchecked = 0
|
||||||
rejected = Expression.select().where((Expression.checked == True) & (Expression.rejected == True)).count()
|
passed = 0
|
||||||
ai_checked = Expression.select().where(Expression.modified_by == "ai").count()
|
rejected = 0
|
||||||
user_checked = Expression.select().where(Expression.modified_by == "user").count()
|
ai_checked = 0
|
||||||
|
user_checked = 0
|
||||||
|
|
||||||
return ReviewStatsResponse(
|
return ReviewStatsResponse(
|
||||||
total=total,
|
total=total,
|
||||||
|
|
@ -620,31 +665,44 @@ async def get_review_list(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
query = Expression.select()
|
statement = select(Expression)
|
||||||
|
|
||||||
# 根据筛选类型过滤
|
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||||
if filter_type == "unchecked":
|
statement = statement.where(col(Expression.id) == -1)
|
||||||
query = query.where(Expression.checked == False)
|
|
||||||
elif filter_type == "passed":
|
|
||||||
query = query.where((Expression.checked == True) & (Expression.rejected == False))
|
|
||||||
elif filter_type == "rejected":
|
|
||||||
query = query.where((Expression.checked == True) & (Expression.rejected == True))
|
|
||||||
# all 不需要额外过滤
|
# all 不需要额外过滤
|
||||||
|
|
||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where((Expression.situation.contains(search)) | (Expression.style.contains(search)))
|
statement = statement.where(
|
||||||
|
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||||
|
)
|
||||||
|
|
||||||
# 聊天ID过滤
|
# 聊天ID过滤
|
||||||
if chat_id:
|
if chat_id:
|
||||||
query = query.where(Expression.chat_id == chat_id)
|
statement = statement.where(col(Expression.session_id) == chat_id)
|
||||||
|
|
||||||
# 排序:创建时间倒序
|
# 排序:创建时间倒序
|
||||||
query = query.order_by(case((Expression.create_date.is_null(), 1), else_=0), Expression.create_date.desc())
|
statement = statement.order_by(
|
||||||
|
case((col(Expression.create_time).is_(None), 1), else_=0),
|
||||||
|
col(Expression.create_time).desc(),
|
||||||
|
)
|
||||||
|
|
||||||
total = query.count()
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
expressions = query.offset(offset).limit(page_size)
|
statement = statement.offset(offset).limit(page_size)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
expressions = session.exec(statement).all()
|
||||||
|
|
||||||
|
count_statement = select(Expression.id)
|
||||||
|
if filter_type in {"unchecked", "passed", "rejected"}:
|
||||||
|
count_statement = count_statement.where(col(Expression.id) == -1)
|
||||||
|
if search:
|
||||||
|
count_statement = count_statement.where(
|
||||||
|
(col(Expression.situation).contains(search)) | (col(Expression.style).contains(search))
|
||||||
|
)
|
||||||
|
if chat_id:
|
||||||
|
count_statement = count_statement.where(col(Expression.session_id) == chat_id)
|
||||||
|
total = len(session.exec(count_statement).all())
|
||||||
|
|
||||||
return ReviewListResponse(
|
return ReviewListResponse(
|
||||||
success=True,
|
success=True,
|
||||||
|
|
@ -720,7 +778,8 @@ async def batch_review_expressions(
|
||||||
|
|
||||||
for item in request.items:
|
for item in request.items:
|
||||||
try:
|
try:
|
||||||
expression = Expression.get_or_none(Expression.id == item.id)
|
with get_db_session() as session:
|
||||||
|
expression = session.exec(select(Expression).where(col(Expression.id) == item.id).limit(1)).first()
|
||||||
|
|
||||||
if not expression:
|
if not expression:
|
||||||
results.append(
|
results.append(
|
||||||
|
|
@ -730,23 +789,28 @@ async def batch_review_expressions(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 冲突检测
|
# 冲突检测
|
||||||
if item.require_unchecked and expression.checked:
|
if item.require_unchecked:
|
||||||
results.append(
|
results.append(
|
||||||
BatchReviewResultItem(
|
BatchReviewResultItem(id=item.id, success=False, message="当前模型不支持审核状态过滤")
|
||||||
id=item.id,
|
|
||||||
success=False,
|
|
||||||
message=f"已被{'AI自动' if expression.modified_by == 'ai' else '人工'}检查",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
failed += 1
|
failed += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 更新状态
|
# 更新状态
|
||||||
expression.checked = True
|
with get_db_session() as session:
|
||||||
expression.rejected = item.rejected
|
db_expression = session.exec(
|
||||||
expression.modified_by = "user"
|
select(Expression).where(col(Expression.id) == item.id).limit(1)
|
||||||
expression.last_active_time = time.time()
|
).first()
|
||||||
expression.save()
|
if not db_expression:
|
||||||
|
results.append(
|
||||||
|
BatchReviewResultItem(
|
||||||
|
id=item.id, success=False, message=f"未找到 ID 为 {item.id} 的表达方式"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
failed += 1
|
||||||
|
continue
|
||||||
|
db_expression.last_active_time = datetime.now()
|
||||||
|
session.add(db_expression)
|
||||||
|
|
||||||
results.append(
|
results.append(
|
||||||
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
|
BatchReviewResultItem(id=item.id, success=True, message="通过" if not item.rejected else "拒绝")
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,16 @@
|
||||||
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
from fastapi import APIRouter, HTTPException, Header, Query, Cookie
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import case
|
from sqlalchemy import case
|
||||||
|
from sqlmodel import col, select, delete
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
from src.common.database.database_model import PersonInfo
|
from src.common.database.database_model import PersonInfo
|
||||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
|
|
||||||
logger = get_logger("webui.person")
|
logger = get_logger("webui.person")
|
||||||
|
|
||||||
|
|
@ -29,7 +33,7 @@ class PersonInfoResponse(BaseModel):
|
||||||
nickname: Optional[str]
|
nickname: Optional[str]
|
||||||
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
|
group_nick_name: Optional[List[Dict[str, str]]] # 解析后的 JSON
|
||||||
memory_points: Optional[str]
|
memory_points: Optional[str]
|
||||||
know_times: Optional[float]
|
know_times: Optional[int]
|
||||||
know_since: Optional[float]
|
know_since: Optional[float]
|
||||||
last_know: Optional[float]
|
last_know: Optional[float]
|
||||||
|
|
||||||
|
|
@ -112,20 +116,22 @@ def parse_group_nick_name(group_nick_name_str: Optional[str]) -> Optional[List[D
|
||||||
|
|
||||||
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
|
def person_to_response(person: PersonInfo) -> PersonInfoResponse:
|
||||||
"""将 PersonInfo 模型转换为响应对象"""
|
"""将 PersonInfo 模型转换为响应对象"""
|
||||||
|
know_since = person.first_known_time.timestamp() if person.first_known_time else None
|
||||||
|
last_know = person.last_known_time.timestamp() if person.last_known_time else None
|
||||||
return PersonInfoResponse(
|
return PersonInfoResponse(
|
||||||
id=person.id,
|
id=person.id or 0,
|
||||||
is_known=person.is_known,
|
is_known=person.is_known,
|
||||||
person_id=person.person_id,
|
person_id=person.person_id,
|
||||||
person_name=person.person_name,
|
person_name=person.person_name,
|
||||||
name_reason=person.name_reason,
|
name_reason=person.name_reason,
|
||||||
platform=person.platform,
|
platform=person.platform,
|
||||||
user_id=person.user_id,
|
user_id=person.user_id,
|
||||||
nickname=person.nickname,
|
nickname=person.user_nickname,
|
||||||
group_nick_name=parse_group_nick_name(person.group_nick_name),
|
group_nick_name=parse_group_nick_name(person.group_nickname),
|
||||||
memory_points=person.memory_points,
|
memory_points=person.memory_points,
|
||||||
know_times=person.know_times,
|
know_times=person.know_counts,
|
||||||
know_since=person.know_since,
|
know_since=know_since,
|
||||||
last_know=person.last_know,
|
last_know=last_know,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -157,36 +163,50 @@ async def get_person_list(
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
# 构建查询
|
# 构建查询
|
||||||
query = PersonInfo.select()
|
statement = select(PersonInfo)
|
||||||
|
|
||||||
# 搜索过滤
|
# 搜索过滤
|
||||||
if search:
|
if search:
|
||||||
query = query.where(
|
statement = statement.where(
|
||||||
(PersonInfo.person_name.contains(search))
|
(col(PersonInfo.person_name).contains(search))
|
||||||
| (PersonInfo.nickname.contains(search))
|
| (col(PersonInfo.user_nickname).contains(search))
|
||||||
| (PersonInfo.user_id.contains(search))
|
| (col(PersonInfo.user_id).contains(search))
|
||||||
)
|
)
|
||||||
|
|
||||||
# 已认识状态过滤
|
# 已认识状态过滤
|
||||||
if is_known is not None:
|
if is_known is not None:
|
||||||
query = query.where(PersonInfo.is_known == is_known)
|
statement = statement.where(col(PersonInfo.is_known) == is_known)
|
||||||
|
|
||||||
# 平台过滤
|
# 平台过滤
|
||||||
if platform:
|
if platform:
|
||||||
query = query.where(PersonInfo.platform == platform)
|
statement = statement.where(col(PersonInfo.platform) == platform)
|
||||||
|
|
||||||
# 排序:最后更新时间倒序(NULL 值放在最后)
|
# 排序:最后更新时间倒序(NULL 值放在最后)
|
||||||
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
# Peewee 不支持 nulls_last,使用 CASE WHEN 来实现
|
||||||
query = query.order_by(case((PersonInfo.last_know.is_null(), 1), else_=0), PersonInfo.last_know.desc())
|
statement = statement.order_by(
|
||||||
|
case((col(PersonInfo.last_known_time).is_(None), 1), else_=0),
|
||||||
|
col(PersonInfo.last_known_time).desc(),
|
||||||
|
)
|
||||||
|
|
||||||
# 获取总数
|
|
||||||
total = query.count()
|
|
||||||
|
|
||||||
# 分页
|
|
||||||
offset = (page - 1) * page_size
|
offset = (page - 1) * page_size
|
||||||
persons = query.offset(offset).limit(page_size)
|
statement = statement.offset(offset).limit(page_size)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
persons = session.exec(statement).all()
|
||||||
|
|
||||||
|
count_statement = select(PersonInfo.id)
|
||||||
|
if search:
|
||||||
|
count_statement = count_statement.where(
|
||||||
|
(col(PersonInfo.person_name).contains(search))
|
||||||
|
| (col(PersonInfo.user_nickname).contains(search))
|
||||||
|
| (col(PersonInfo.user_id).contains(search))
|
||||||
|
)
|
||||||
|
if is_known is not None:
|
||||||
|
count_statement = count_statement.where(col(PersonInfo.is_known) == is_known)
|
||||||
|
if platform:
|
||||||
|
count_statement = count_statement.where(col(PersonInfo.platform) == platform)
|
||||||
|
total = len(session.exec(count_statement).all())
|
||||||
|
|
||||||
# 转换为响应对象
|
|
||||||
data = [person_to_response(person) for person in persons]
|
data = [person_to_response(person) for person in persons]
|
||||||
|
|
||||||
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
return PersonListResponse(success=True, total=total, page=page, page_size=page_size, data=data)
|
||||||
|
|
@ -215,7 +235,9 @@ async def get_person_detail(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
|
|
||||||
if not person:
|
if not person:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||||
|
|
@ -250,7 +272,9 @@ async def update_person(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
|
|
||||||
if not person:
|
if not person:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||||
|
|
@ -262,13 +286,18 @@ async def update_person(
|
||||||
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
raise HTTPException(status_code=400, detail="未提供任何需要更新的字段")
|
||||||
|
|
||||||
# 更新最后修改时间
|
# 更新最后修改时间
|
||||||
update_data["last_know"] = time.time()
|
update_data["last_known_time"] = datetime.now()
|
||||||
|
|
||||||
# 执行更新
|
# 执行更新
|
||||||
|
with get_db_session() as session:
|
||||||
|
db_person = session.exec(select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)).first()
|
||||||
|
if not db_person:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||||
for field, value in update_data.items():
|
for field, value in update_data.items():
|
||||||
setattr(person, field, value)
|
if hasattr(db_person, field):
|
||||||
|
setattr(db_person, field, value)
|
||||||
person.save()
|
session.add(db_person)
|
||||||
|
person = db_person
|
||||||
|
|
||||||
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
logger.info(f"人物信息已更新: {person_id}, 字段: {list(update_data.keys())}")
|
||||||
|
|
||||||
|
|
@ -300,16 +329,19 @@ async def delete_person(
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
statement = select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
person = session.exec(statement).first()
|
||||||
|
|
||||||
if not person:
|
if not person:
|
||||||
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
raise HTTPException(status_code=404, detail=f"未找到 ID 为 {person_id} 的人物信息")
|
||||||
|
|
||||||
# 记录删除信息
|
# 记录删除信息
|
||||||
person_name = person.person_name or person.nickname or person.user_id
|
person_name = person.person_name or person.user_nickname or person.user_id
|
||||||
|
|
||||||
# 执行删除
|
# 执行删除
|
||||||
person.delete_instance()
|
with get_db_session() as session:
|
||||||
|
session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id))
|
||||||
|
|
||||||
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
logger.info(f"人物信息已删除: {person_id} ({person_name})")
|
||||||
|
|
||||||
|
|
@ -336,14 +368,16 @@ async def get_person_stats(maibot_session: Optional[str] = Cookie(None), authori
|
||||||
try:
|
try:
|
||||||
verify_auth_token(maibot_session, authorization)
|
verify_auth_token(maibot_session, authorization)
|
||||||
|
|
||||||
total = PersonInfo.select().count()
|
with get_db_session() as session:
|
||||||
known = PersonInfo.select().where(PersonInfo.is_known).count()
|
total = len(session.exec(select(PersonInfo.id)).all())
|
||||||
|
known = len(session.exec(select(PersonInfo.id).where(col(PersonInfo.is_known) == True)).all())
|
||||||
unknown = total - known
|
unknown = total - known
|
||||||
|
|
||||||
# 按平台统计
|
# 按平台统计
|
||||||
platforms = {}
|
platforms = {}
|
||||||
for person in PersonInfo.select(PersonInfo.platform):
|
with get_db_session() as session:
|
||||||
platform = person.platform
|
for platform in session.exec(select(PersonInfo.platform)).all():
|
||||||
|
if platform:
|
||||||
platforms[platform] = platforms.get(platform, 0) + 1
|
platforms[platform] = platforms.get(platform, 0) + 1
|
||||||
|
|
||||||
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
return {"success": True, "data": {"total": total, "known": known, "unknown": unknown, "platforms": platforms}}
|
||||||
|
|
@ -383,9 +417,12 @@ async def batch_delete_persons(
|
||||||
|
|
||||||
for person_id in request.person_ids:
|
for person_id in request.person_ids:
|
||||||
try:
|
try:
|
||||||
person = PersonInfo.get_or_none(PersonInfo.person_id == person_id)
|
with get_db_session() as session:
|
||||||
|
person = session.exec(
|
||||||
|
select(PersonInfo).where(col(PersonInfo.person_id) == person_id).limit(1)
|
||||||
|
).first()
|
||||||
if person:
|
if person:
|
||||||
person.delete_instance()
|
session.exec(delete(PersonInfo).where(col(PersonInfo.person_id) == person_id))
|
||||||
deleted_count += 1
|
deleted_count += 1
|
||||||
logger.info(f"批量删除: {person_id}")
|
logger.info(f"批量删除: {person_id}")
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,16 @@
|
||||||
"""统计数据 API 路由"""
|
"""统计数据 API 路由"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, Depends, Cookie, Header
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import Dict, Any, List, Optional
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from sqlalchemy import func as fn
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Cookie, Depends, Header, HTTPException
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from sqlalchemy import desc, func, or_
|
||||||
|
from sqlmodel import col, select
|
||||||
|
|
||||||
|
from src.common.database.database import get_db_session
|
||||||
|
from src.common.database.database_model import Messages, ModelUsage, OnlineTime
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.database.database_model import LLMUsage, OnlineTime, Messages
|
|
||||||
from src.webui.core import verify_auth_token_from_cookie_or_header
|
from src.webui.core import verify_auth_token_from_cookie_or_header
|
||||||
|
|
||||||
logger = get_logger("webui.statistics")
|
logger = get_logger("webui.statistics")
|
||||||
|
|
@ -60,10 +63,10 @@ class DashboardData(BaseModel):
|
||||||
"""仪表盘数据"""
|
"""仪表盘数据"""
|
||||||
|
|
||||||
summary: StatisticsSummary
|
summary: StatisticsSummary
|
||||||
model_stats: List[ModelStatistics]
|
model_stats: list[ModelStatistics]
|
||||||
hourly_data: List[TimeSeriesData]
|
hourly_data: list[TimeSeriesData]
|
||||||
daily_data: List[TimeSeriesData]
|
daily_data: list[TimeSeriesData]
|
||||||
recent_activity: List[Dict[str, Any]]
|
recent_activity: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
@router.get("/dashboard", response_model=DashboardData)
|
@router.get("/dashboard", response_model=DashboardData)
|
||||||
|
|
@ -111,26 +114,44 @@ async def get_dashboard_data(hours: int = 24, _auth: bool = Depends(require_auth
|
||||||
|
|
||||||
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> StatisticsSummary:
|
||||||
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
"""获取摘要统计数据(优化:使用数据库聚合)"""
|
||||||
summary = StatisticsSummary()
|
summary = StatisticsSummary(
|
||||||
|
total_requests=0,
|
||||||
|
total_cost=0.0,
|
||||||
|
total_tokens=0,
|
||||||
|
online_time=0.0,
|
||||||
|
total_messages=0,
|
||||||
|
total_replies=0,
|
||||||
|
avg_response_time=0.0,
|
||||||
|
cost_per_hour=0.0,
|
||||||
|
tokens_per_hour=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
# 使用聚合查询替代全量加载
|
# 使用聚合查询替代全量加载
|
||||||
query = LLMUsage.select(
|
with get_db_session() as session:
|
||||||
fn.COUNT(LLMUsage.id).alias("total_requests"),
|
statement = select(
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
func.count().label("total_requests"),
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
func.sum(col(ModelUsage.cost)).label("total_cost"),
|
||||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
func.sum(col(ModelUsage.total_tokens)).label("total_tokens"),
|
||||||
).where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
func.avg(col(ModelUsage.time_cost)).label("avg_response_time"),
|
||||||
|
).where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||||
|
result = session.execute(statement).first()
|
||||||
|
|
||||||
result = query.dicts().get()
|
if result:
|
||||||
summary.total_requests = result["total_requests"]
|
total_requests, total_cost, total_tokens, avg_response_time = result
|
||||||
summary.total_cost = result["total_cost"]
|
summary.total_requests = total_requests or 0
|
||||||
summary.total_tokens = result["total_tokens"]
|
summary.total_cost = float(total_cost or 0.0)
|
||||||
summary.avg_response_time = result["avg_response_time"] or 0.0
|
summary.total_tokens = total_tokens or 0
|
||||||
|
summary.avg_response_time = float(avg_response_time or 0.0)
|
||||||
|
|
||||||
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
# 查询在线时间 - 这个数据量通常不大,保留原逻辑
|
||||||
online_records = list(
|
with get_db_session() as session:
|
||||||
OnlineTime.select().where((OnlineTime.start_timestamp >= start_time) | (OnlineTime.end_timestamp >= start_time))
|
statement = select(OnlineTime).where(
|
||||||
|
or_(
|
||||||
|
col(OnlineTime.start_timestamp) >= start_time,
|
||||||
|
col(OnlineTime.end_timestamp) >= start_time,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
online_records = session.execute(statement).scalars().all()
|
||||||
|
|
||||||
for record in online_records:
|
for record in online_records:
|
||||||
start = max(record.start_timestamp, start_time)
|
start = max(record.start_timestamp, start_time)
|
||||||
|
|
@ -139,18 +160,23 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
||||||
summary.online_time += (end - start).total_seconds()
|
summary.online_time += (end - start).total_seconds()
|
||||||
|
|
||||||
# 查询消息数量 - 使用聚合优化
|
# 查询消息数量 - 使用聚合优化
|
||||||
messages_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
with get_db_session() as session:
|
||||||
(Messages.time >= start_time.timestamp()) & (Messages.time <= end_time.timestamp())
|
statement = select(func.count()).where(
|
||||||
|
col(Messages.timestamp) >= start_time,
|
||||||
|
col(Messages.timestamp) <= end_time,
|
||||||
)
|
)
|
||||||
summary.total_messages = messages_query.scalar() or 0
|
total_messages = session.execute(statement).scalar()
|
||||||
|
summary.total_messages = int(total_messages or 0)
|
||||||
|
|
||||||
# 统计回复数量
|
# 统计回复数量
|
||||||
replies_query = Messages.select(fn.COUNT(Messages.id).alias("total")).where(
|
with get_db_session() as session:
|
||||||
(Messages.time >= start_time.timestamp())
|
statement = select(func.count()).where(
|
||||||
& (Messages.time <= end_time.timestamp())
|
col(Messages.timestamp) >= start_time,
|
||||||
& (Messages.reply_to.is_null(False))
|
col(Messages.timestamp) <= end_time,
|
||||||
|
col(Messages.reply_to).is_not(None),
|
||||||
)
|
)
|
||||||
summary.total_replies = replies_query.scalar() or 0
|
total_replies = session.execute(statement).scalar()
|
||||||
|
summary.total_replies = int(total_replies or 0)
|
||||||
|
|
||||||
# 计算派生指标
|
# 计算派生指标
|
||||||
if summary.online_time > 0:
|
if summary.online_time > 0:
|
||||||
|
|
@ -161,55 +187,80 @@ async def _get_summary_statistics(start_time: datetime, end_time: datetime) -> S
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
async def _get_model_statistics(start_time: datetime) -> List[ModelStatistics]:
|
async def _get_model_statistics(start_time: datetime) -> list[ModelStatistics]:
|
||||||
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
"""获取模型统计数据(优化:使用数据库聚合和分组)"""
|
||||||
# 使用GROUP BY聚合,避免全量加载
|
# 使用GROUP BY聚合,避免全量加载
|
||||||
query = (
|
statement = (
|
||||||
LLMUsage.select(
|
select(ModelUsage)
|
||||||
fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown").alias("model_name"),
|
.where(col(ModelUsage.timestamp) >= start_time)
|
||||||
fn.COUNT(LLMUsage.id).alias("request_count"),
|
.order_by(desc(col(ModelUsage.timestamp)))
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("total_cost"),
|
.limit(200)
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("total_tokens"),
|
|
||||||
fn.COALESCE(fn.AVG(LLMUsage.time_cost), 0).alias("avg_response_time"),
|
|
||||||
)
|
|
||||||
.where(LLMUsage.timestamp >= start_time)
|
|
||||||
.group_by(fn.COALESCE(LLMUsage.model_assign_name, LLMUsage.model_name, "unknown"))
|
|
||||||
.order_by(fn.COUNT(LLMUsage.id).desc())
|
|
||||||
.limit(10) # 只取前10个
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result = []
|
with get_db_session() as session:
|
||||||
for row in query.dicts():
|
rows = session.execute(statement).all()
|
||||||
|
|
||||||
|
aggregates: dict[str, dict[str, float | int]] = {}
|
||||||
|
for record in rows:
|
||||||
|
model_name = record.model_assign_name or record.model_name or "unknown"
|
||||||
|
if model_name not in aggregates:
|
||||||
|
aggregates[model_name] = {
|
||||||
|
"request_count": 0,
|
||||||
|
"total_cost": 0.0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
"total_time_cost": 0.0,
|
||||||
|
"time_cost_count": 0,
|
||||||
|
}
|
||||||
|
bucket = aggregates[model_name]
|
||||||
|
bucket["request_count"] = int(bucket["request_count"]) + 1
|
||||||
|
bucket["total_cost"] = float(bucket["total_cost"]) + float(record.cost or 0.0)
|
||||||
|
bucket["total_tokens"] = int(bucket["total_tokens"]) + int(record.total_tokens or 0)
|
||||||
|
if record.time_cost:
|
||||||
|
bucket["total_time_cost"] = float(bucket["total_time_cost"]) + float(record.time_cost)
|
||||||
|
bucket["time_cost_count"] = int(bucket["time_cost_count"]) + 1
|
||||||
|
|
||||||
|
result: list[ModelStatistics] = []
|
||||||
|
for model_name, bucket in sorted(
|
||||||
|
aggregates.items(),
|
||||||
|
key=lambda item: float(item[1]["request_count"]),
|
||||||
|
reverse=True,
|
||||||
|
)[:10]:
|
||||||
|
time_cost_count = int(bucket["time_cost_count"])
|
||||||
|
avg_time_cost = float(bucket["total_time_cost"]) / time_cost_count if time_cost_count > 0 else 0.0
|
||||||
result.append(
|
result.append(
|
||||||
ModelStatistics(
|
ModelStatistics(
|
||||||
model_name=row["model_name"],
|
model_name=model_name,
|
||||||
request_count=row["request_count"],
|
request_count=int(bucket["request_count"]),
|
||||||
total_cost=row["total_cost"],
|
total_cost=float(bucket["total_cost"]),
|
||||||
total_tokens=row["total_tokens"],
|
total_tokens=int(bucket["total_tokens"]),
|
||||||
avg_response_time=row["avg_response_time"] or 0.0,
|
avg_response_time=avg_time_cost,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
|
||||||
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
"""获取小时级统计数据(优化:使用数据库聚合)"""
|
||||||
# SQLite的日期时间函数进行小时分组
|
# SQLite的日期时间函数进行小时分组
|
||||||
# 使用strftime将timestamp格式化为小时级别
|
# 使用strftime将timestamp格式化为小时级别
|
||||||
query = (
|
hour_expr = func.strftime("%Y-%m-%dT%H:00:00", col(ModelUsage.timestamp))
|
||||||
LLMUsage.select(
|
statement = (
|
||||||
fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp).alias("hour"),
|
select(
|
||||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
hour_expr.label("hour"),
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
func.count().label("requests"),
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||||
|
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||||
)
|
)
|
||||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||||
.group_by(fn.strftime("%Y-%m-%dT%H:00:00", LLMUsage.timestamp))
|
.group_by(hour_expr)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
rows = session.execute(statement).all()
|
||||||
|
|
||||||
# 转换为字典以快速查找
|
# 转换为字典以快速查找
|
||||||
data_dict = {row["hour"]: row for row in query.dicts()}
|
data_dict = {row[0]: row for row in rows}
|
||||||
|
|
||||||
# 填充所有小时(包括没有数据的)
|
# 填充所有小时(包括没有数据的)
|
||||||
result = []
|
result = []
|
||||||
|
|
@ -219,7 +270,12 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
|
||||||
if hour_str in data_dict:
|
if hour_str in data_dict:
|
||||||
row = data_dict[hour_str]
|
row = data_dict[hour_str]
|
||||||
result.append(
|
result.append(
|
||||||
TimeSeriesData(timestamp=hour_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
TimeSeriesData(
|
||||||
|
timestamp=hour_str,
|
||||||
|
requests=row[1] or 0,
|
||||||
|
cost=float(row[2] or 0.0),
|
||||||
|
tokens=row[3] or 0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
result.append(TimeSeriesData(timestamp=hour_str, requests=0, cost=0.0, tokens=0))
|
||||||
|
|
@ -228,22 +284,26 @@ async def _get_hourly_statistics(start_time: datetime, end_time: datetime) -> Li
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> List[TimeSeriesData]:
|
async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> list[TimeSeriesData]:
|
||||||
"""获取日级统计数据(优化:使用数据库聚合)"""
|
"""获取日级统计数据(优化:使用数据库聚合)"""
|
||||||
# 使用strftime按日期分组
|
# 使用strftime按日期分组
|
||||||
query = (
|
day_expr = func.strftime("%Y-%m-%dT00:00:00", col(ModelUsage.timestamp))
|
||||||
LLMUsage.select(
|
statement = (
|
||||||
fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp).alias("day"),
|
select(
|
||||||
fn.COUNT(LLMUsage.id).alias("requests"),
|
day_expr.label("day"),
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.cost), 0).alias("cost"),
|
func.count().label("requests"),
|
||||||
fn.COALESCE(fn.SUM(LLMUsage.prompt_tokens + LLMUsage.completion_tokens), 0).alias("tokens"),
|
func.sum(col(ModelUsage.cost)).label("cost"),
|
||||||
|
func.sum(col(ModelUsage.total_tokens)).label("tokens"),
|
||||||
)
|
)
|
||||||
.where((LLMUsage.timestamp >= start_time) & (LLMUsage.timestamp <= end_time))
|
.where(col(ModelUsage.timestamp) >= start_time, col(ModelUsage.timestamp) <= end_time)
|
||||||
.group_by(fn.strftime("%Y-%m-%dT00:00:00", LLMUsage.timestamp))
|
.group_by(day_expr)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with get_db_session() as session:
|
||||||
|
rows = session.execute(statement).all()
|
||||||
|
|
||||||
# 转换为字典
|
# 转换为字典
|
||||||
data_dict = {row["day"]: row for row in query.dicts()}
|
data_dict = {row[0]: row for row in rows}
|
||||||
|
|
||||||
# 填充所有天
|
# 填充所有天
|
||||||
result = []
|
result = []
|
||||||
|
|
@ -253,7 +313,12 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
|
||||||
if day_str in data_dict:
|
if day_str in data_dict:
|
||||||
row = data_dict[day_str]
|
row = data_dict[day_str]
|
||||||
result.append(
|
result.append(
|
||||||
TimeSeriesData(timestamp=day_str, requests=row["requests"], cost=row["cost"], tokens=row["tokens"])
|
TimeSeriesData(
|
||||||
|
timestamp=day_str,
|
||||||
|
requests=row[1] or 0,
|
||||||
|
cost=float(row[2] or 0.0),
|
||||||
|
tokens=row[3] or 0,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
result.append(TimeSeriesData(timestamp=day_str, requests=0, cost=0.0, tokens=0))
|
||||||
|
|
@ -262,9 +327,11 @@ async def _get_daily_statistics(start_time: datetime, end_time: datetime) -> Lis
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
async def _get_recent_activity(limit: int = 10) -> list[dict[str, Any]]:
|
||||||
"""获取最近活动"""
|
"""获取最近活动"""
|
||||||
records = list(LLMUsage.select().order_by(LLMUsage.timestamp.desc()).limit(limit))
|
with get_db_session() as session:
|
||||||
|
statement = select(ModelUsage).order_by(desc(col(ModelUsage.timestamp))).limit(limit)
|
||||||
|
records = session.execute(statement).scalars().all()
|
||||||
|
|
||||||
activities = []
|
activities = []
|
||||||
for record in records:
|
for record in records:
|
||||||
|
|
@ -273,10 +340,10 @@ async def _get_recent_activity(limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
"timestamp": record.timestamp.isoformat(),
|
"timestamp": record.timestamp.isoformat(),
|
||||||
"model": record.model_assign_name or record.model_name,
|
"model": record.model_assign_name or record.model_name,
|
||||||
"request_type": record.request_type,
|
"request_type": record.request_type,
|
||||||
"tokens": (record.prompt_tokens or 0) + (record.completion_tokens or 0),
|
"tokens": record.total_tokens or 0,
|
||||||
"cost": record.cost or 0.0,
|
"cost": record.cost or 0.0,
|
||||||
"time_cost": record.time_cost or 0.0,
|
"time_cost": record.time_cost or 0.0,
|
||||||
"status": record.status,
|
"status": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue