mirror of https://github.com/Mai-with-u/MaiBot.git
commit
84aa4fc172
|
|
@ -146,7 +146,7 @@ class ChattingObservation(Observation):
|
||||||
"platform": find_msg.get("user_platform", ""),
|
"platform": find_msg.get("user_platform", ""),
|
||||||
"user_id": find_msg.get("user_id", ""),
|
"user_id": find_msg.get("user_id", ""),
|
||||||
"user_nickname": find_msg.get("user_nickname", ""),
|
"user_nickname": find_msg.get("user_nickname", ""),
|
||||||
"user_cardname": find_msg.get("user_cardname", "")
|
"user_cardname": find_msg.get("user_cardname", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建所需的group_info字段,如果是群聊的话
|
# 创建所需的group_info字段,如果是群聊的话
|
||||||
|
|
@ -155,7 +155,7 @@ class ChattingObservation(Observation):
|
||||||
group_info = {
|
group_info = {
|
||||||
"platform": find_msg.get("chat_info_group_platform", ""),
|
"platform": find_msg.get("chat_info_group_platform", ""),
|
||||||
"group_id": find_msg.get("chat_info_group_id", ""),
|
"group_id": find_msg.get("chat_info_group_id", ""),
|
||||||
"group_name": find_msg.get("chat_info_group_name", "")
|
"group_name": find_msg.get("chat_info_group_name", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
content_format = ""
|
content_format = ""
|
||||||
|
|
|
||||||
|
|
@ -861,9 +861,7 @@ class EntorhinalCortex:
|
||||||
# 确保在更新前获取最新的 memorized_times
|
# 确保在更新前获取最新的 memorized_times
|
||||||
current_memorized_times = message.get("memorized_times", 0)
|
current_memorized_times = message.get("memorized_times", 0)
|
||||||
# 使用 Peewee 更新记录
|
# 使用 Peewee 更新记录
|
||||||
Messages.update(
|
Messages.update(memorized_times=current_memorized_times + 1).where(
|
||||||
memorized_times=current_memorized_times + 1
|
|
||||||
).where(
|
|
||||||
Messages.message_id == message["message_id"]
|
Messages.message_id == message["message_id"]
|
||||||
).execute()
|
).execute()
|
||||||
return messages # 直接返回原始的消息列表
|
return messages # 直接返回原始的消息列表
|
||||||
|
|
@ -983,9 +981,7 @@ class EntorhinalCortex:
|
||||||
if not node.last_modified:
|
if not node.last_modified:
|
||||||
update_data["last_modified"] = current_time
|
update_data["last_modified"] = current_time
|
||||||
|
|
||||||
GraphNodes.update(
|
GraphNodes.update(**update_data).where(GraphNodes.concept == concept).execute()
|
||||||
**update_data
|
|
||||||
).where(GraphNodes.concept == concept).execute()
|
|
||||||
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
logger.info(f"[时间更新] 节点 {concept} 添加缺失的时间字段")
|
||||||
|
|
||||||
# 获取时间信息(如果不存在则使用当前时间)
|
# 获取时间信息(如果不存在则使用当前时间)
|
||||||
|
|
@ -1014,9 +1010,7 @@ class EntorhinalCortex:
|
||||||
if not edge.last_modified:
|
if not edge.last_modified:
|
||||||
update_data["last_modified"] = current_time
|
update_data["last_modified"] = current_time
|
||||||
|
|
||||||
GraphEdges.update(
|
GraphEdges.update(**update_data).where(
|
||||||
**update_data
|
|
||||||
).where(
|
|
||||||
(GraphEdges.source == source) & (GraphEdges.target == target)
|
(GraphEdges.source == source) & (GraphEdges.target == target)
|
||||||
).execute()
|
).execute()
|
||||||
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
logger.info(f"[时间更新] 边 {source} - {target} 添加缺失的时间字段")
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,10 @@ class ChatBot:
|
||||||
|
|
||||||
async def _create_pfc_chat(self, message: MessageRecv):
|
async def _create_pfc_chat(self, message: MessageRecv):
|
||||||
try:
|
try:
|
||||||
chat_id = str(message.chat_stream.stream_id)
|
if global_config.experimental.pfc_chatting:
|
||||||
private_name = str(message.message_info.user_info.user_nickname)
|
chat_id = str(message.chat_stream.stream_id)
|
||||||
|
private_name = str(message.message_info.user_info.user_nickname)
|
||||||
|
|
||||||
if global_config.experimental.enable_pfc_chatting:
|
|
||||||
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
await self.pfc_manager.get_or_create_conversation(chat_id, private_name)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -75,27 +75,27 @@ class ChatBot:
|
||||||
# print(message_data)
|
# print(message_data)
|
||||||
logger.trace(f"处理消息:{str(message_data)[:120]}...")
|
logger.trace(f"处理消息:{str(message_data)[:120]}...")
|
||||||
message = MessageRecv(message_data)
|
message = MessageRecv(message_data)
|
||||||
groupinfo = message.message_info.group_info
|
group_info = message.message_info.group_info
|
||||||
userinfo = message.message_info.user_info
|
user_info = message.message_info.user_info
|
||||||
|
|
||||||
# 用户黑名单拦截
|
# 用户黑名单拦截
|
||||||
if userinfo.user_id in global_config.chat_target.ban_user_id:
|
# if userinfo.user_id in global_config.chat_target.ban_user_id:
|
||||||
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
# logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
||||||
return
|
# return
|
||||||
|
|
||||||
if groupinfo is None:
|
# if groupinfo is None:
|
||||||
logger.trace("检测到私聊消息,检查")
|
# logger.trace("检测到私聊消息,检查")
|
||||||
# 好友黑名单拦截
|
# # 好友黑名单拦截
|
||||||
if userinfo.user_id not in global_config.experimental.talk_allowed_private:
|
# if userinfo.user_id not in global_config.experimental.talk_allowed_private:
|
||||||
# logger.debug(f"用户{userinfo.user_id}没有私聊权限")
|
# # logger.debug(f"用户{userinfo.user_id}没有私聊权限")
|
||||||
return
|
# return
|
||||||
|
|
||||||
# 群聊黑名单拦截
|
# 群聊黑名单拦截
|
||||||
# print(groupinfo.group_id)
|
# print(groupinfo.group_id)
|
||||||
# print(global_config.chat_target.talk_allowed_groups)
|
# print(global_config.chat_target.talk_allowed_groups)
|
||||||
if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
|
# if groupinfo is not None and groupinfo.group_id not in global_config.chat_target.talk_allowed_groups:
|
||||||
logger.debug(f"群{groupinfo.group_id}被禁止回复")
|
# logger.debug(f"群{groupinfo.group_id}被禁止回复")
|
||||||
return
|
# return
|
||||||
|
|
||||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||||
|
|
@ -112,33 +112,49 @@ class ChatBot:
|
||||||
async def preprocess():
|
async def preprocess():
|
||||||
logger.trace("开始预处理消息...")
|
logger.trace("开始预处理消息...")
|
||||||
# 如果在私聊中
|
# 如果在私聊中
|
||||||
if groupinfo is None:
|
if group_info is None:
|
||||||
logger.trace("检测到私聊消息")
|
logger.trace("检测到私聊消息")
|
||||||
# 是否在配置信息中开启私聊模式
|
# 是否在配置信息中开启私聊模式
|
||||||
if global_config.experimental.enable_friend_chat:
|
# if global_config.experimental.enable_friend_chat:
|
||||||
logger.trace("私聊模式已启用")
|
# logger.trace("私聊模式已启用")
|
||||||
# 是否进入PFC
|
# # 是否进入PFC
|
||||||
if global_config.enable_pfc_chatting:
|
# if global_config.enable_pfc_chatting:
|
||||||
logger.trace("进入PFC私聊处理流程")
|
# logger.trace("进入PFC私聊处理流程")
|
||||||
userinfo = message.message_info.user_info
|
# userinfo = message.message_info.user_info
|
||||||
messageinfo = message.message_info
|
# messageinfo = message.message_info
|
||||||
# 创建聊天流
|
# # 创建聊天流
|
||||||
logger.trace(f"为{userinfo.user_id}创建/获取聊天流")
|
# logger.trace(f"为{userinfo.user_id}创建/获取聊天流")
|
||||||
chat = await chat_manager.get_or_create_stream(
|
# chat = await chat_manager.get_or_create_stream(
|
||||||
platform=messageinfo.platform,
|
# platform=messageinfo.platform,
|
||||||
user_info=userinfo,
|
# user_info=userinfo,
|
||||||
group_info=groupinfo,
|
# group_info=groupinfo,
|
||||||
)
|
# )
|
||||||
message.update_chat_stream(chat)
|
# message.update_chat_stream(chat)
|
||||||
await self.only_process_chat.process_message(message)
|
# await self.only_process_chat.process_message(message)
|
||||||
await self._create_pfc_chat(message)
|
# await self._create_pfc_chat(message)
|
||||||
# 禁止PFC,进入普通的心流消息处理逻辑
|
# # 禁止PFC,进入普通的心流消息处理逻辑
|
||||||
else:
|
# else:
|
||||||
logger.trace("进入普通心流私聊处理")
|
# logger.trace("进入普通心流私聊处理")
|
||||||
await self.heartflow_processor.process_message(message_data)
|
# await self.heartflow_processor.process_message(message_data)
|
||||||
|
if global_config.experimental.pfc_chatting:
|
||||||
|
logger.trace("进入PFC私聊处理流程")
|
||||||
|
# 创建聊天流
|
||||||
|
logger.trace(f"为{user_info.user_id}创建/获取聊天流")
|
||||||
|
chat = await chat_manager.get_or_create_stream(
|
||||||
|
platform=message.message_info.platform,
|
||||||
|
user_info=user_info,
|
||||||
|
group_info=group_info,
|
||||||
|
)
|
||||||
|
message.update_chat_stream(chat)
|
||||||
|
await self.only_process_chat.process_message(message)
|
||||||
|
await self._create_pfc_chat(message)
|
||||||
|
# 禁止PFC,进入普通的心流消息处理逻辑
|
||||||
|
else:
|
||||||
|
logger.trace("进入普通心流私聊处理")
|
||||||
|
await self.heartflow_processor.process_message(message_data)
|
||||||
# 群聊默认进入心流消息处理逻辑
|
# 群聊默认进入心流消息处理逻辑
|
||||||
else:
|
else:
|
||||||
logger.trace(f"检测到群聊消息,群ID: {groupinfo.group_id}")
|
logger.trace(f"检测到群聊消息,群ID: {group_info.group_id}")
|
||||||
await self.heartflow_processor.process_message(message_data)
|
await self.heartflow_processor.process_message(message_data)
|
||||||
|
|
||||||
if template_group_name:
|
if template_group_name:
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class ChatStream:
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
def to_dict(self) -> dict:
|
||||||
"""转换为字典格式"""
|
"""转换为字典格式"""
|
||||||
result = {
|
return {
|
||||||
"stream_id": self.stream_id,
|
"stream_id": self.stream_id,
|
||||||
"platform": self.platform,
|
"platform": self.platform,
|
||||||
"user_info": self.user_info.to_dict() if self.user_info else None,
|
"user_info": self.user_info.to_dict() if self.user_info else None,
|
||||||
|
|
@ -47,7 +47,6 @@ class ChatStream:
|
||||||
"create_time": self.create_time,
|
"create_time": self.create_time,
|
||||||
"last_active_time": self.last_active_time,
|
"last_active_time": self.last_active_time,
|
||||||
}
|
}
|
||||||
return result
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: dict) -> "ChatStream":
|
def from_dict(cls, data: dict) -> "ChatStream":
|
||||||
|
|
@ -235,33 +234,34 @@ class ChatManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def _save_stream(stream: ChatStream):
|
async def _save_stream(stream: ChatStream):
|
||||||
"""保存聊天流到数据库"""
|
"""保存聊天流到数据库"""
|
||||||
if not stream.saved:
|
if stream.saved:
|
||||||
stream_data_dict = stream.to_dict()
|
return
|
||||||
|
stream_data_dict = stream.to_dict()
|
||||||
|
|
||||||
def _db_save_stream_sync(s_data_dict: dict):
|
def _db_save_stream_sync(s_data_dict: dict):
|
||||||
user_info_d = s_data_dict.get("user_info")
|
user_info_d = s_data_dict.get("user_info")
|
||||||
group_info_d = s_data_dict.get("group_info")
|
group_info_d = s_data_dict.get("group_info")
|
||||||
|
|
||||||
fields_to_save = {
|
fields_to_save = {
|
||||||
"platform": s_data_dict["platform"],
|
"platform": s_data_dict["platform"],
|
||||||
"create_time": s_data_dict["create_time"],
|
"create_time": s_data_dict["create_time"],
|
||||||
"last_active_time": s_data_dict["last_active_time"],
|
"last_active_time": s_data_dict["last_active_time"],
|
||||||
"user_platform": user_info_d["platform"] if user_info_d else "",
|
"user_platform": user_info_d["platform"] if user_info_d else "",
|
||||||
"user_id": user_info_d["user_id"] if user_info_d else "",
|
"user_id": user_info_d["user_id"] if user_info_d else "",
|
||||||
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
"user_nickname": user_info_d["user_nickname"] if user_info_d else "",
|
||||||
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
"user_cardname": user_info_d.get("user_cardname", "") if user_info_d else None,
|
||||||
"group_platform": group_info_d["platform"] if group_info_d else "",
|
"group_platform": group_info_d["platform"] if group_info_d else "",
|
||||||
"group_id": group_info_d["group_id"] if group_info_d else "",
|
"group_id": group_info_d["group_id"] if group_info_d else "",
|
||||||
"group_name": group_info_d["group_name"] if group_info_d else "",
|
"group_name": group_info_d["group_name"] if group_info_d else "",
|
||||||
}
|
}
|
||||||
|
|
||||||
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
ChatStreams.replace(stream_id=s_data_dict["stream_id"], **fields_to_save).execute()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
await asyncio.to_thread(_db_save_stream_sync, stream_data_dict)
|
||||||
stream.saved = True
|
stream.saved = True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
logger.error(f"保存聊天流 {stream.stream_id} 到数据库失败 (Peewee): {e}", exc_info=True)
|
||||||
|
|
||||||
async def _save_all_streams(self):
|
async def _save_all_streams(self):
|
||||||
"""保存所有聊天流"""
|
"""保存所有聊天流"""
|
||||||
|
|
|
||||||
|
|
@ -175,13 +175,13 @@ async def _build_readable_messages_internal(
|
||||||
# 1 & 2: 获取发送者信息并提取消息组件
|
# 1 & 2: 获取发送者信息并提取消息组件
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
# 检查并修复缺少的user_info字段
|
# 检查并修复缺少的user_info字段
|
||||||
if 'user_info' not in msg:
|
if "user_info" not in msg:
|
||||||
# 创建user_info字段
|
# 创建user_info字段
|
||||||
msg['user_info'] = {
|
msg["user_info"] = {
|
||||||
'platform': msg.get('user_platform', ''),
|
"platform": msg.get("user_platform", ""),
|
||||||
'user_id': msg.get('user_id', ''),
|
"user_id": msg.get("user_id", ""),
|
||||||
'user_nickname': msg.get('user_nickname', ''),
|
"user_nickname": msg.get("user_nickname", ""),
|
||||||
'user_cardname': msg.get('user_cardname', '')
|
"user_cardname": msg.get("user_cardname", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
user_info = msg.get("user_info", {})
|
user_info = msg.get("user_info", {})
|
||||||
|
|
|
||||||
|
|
@ -279,6 +279,7 @@ class GraphNodes(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储记忆图节点的模型
|
用于存储记忆图节点的模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
concept = TextField(unique=True, index=True) # 节点概念
|
concept = TextField(unique=True, index=True) # 节点概念
|
||||||
memory_items = TextField() # JSON格式存储的记忆列表
|
memory_items = TextField() # JSON格式存储的记忆列表
|
||||||
hash = TextField() # 节点哈希值
|
hash = TextField() # 节点哈希值
|
||||||
|
|
@ -293,6 +294,7 @@ class GraphEdges(BaseModel):
|
||||||
"""
|
"""
|
||||||
用于存储记忆图边的模型
|
用于存储记忆图边的模型
|
||||||
"""
|
"""
|
||||||
|
|
||||||
source = TextField(index=True) # 源节点
|
source = TextField(index=True) # 源节点
|
||||||
target = TextField(index=True) # 目标节点
|
target = TextField(index=True) # 目标节点
|
||||||
strength = IntegerField() # 连接强度
|
strength = IntegerField() # 连接强度
|
||||||
|
|
|
||||||
|
|
@ -340,11 +340,11 @@ class TelemetryConfig(ConfigBase):
|
||||||
class ExperimentalConfig(ConfigBase):
|
class ExperimentalConfig(ConfigBase):
|
||||||
"""实验功能配置类"""
|
"""实验功能配置类"""
|
||||||
|
|
||||||
enable_friend_chat: bool = False
|
# enable_friend_chat: bool = False
|
||||||
"""是否启用好友聊天"""
|
# """是否启用好友聊天"""
|
||||||
|
|
||||||
talk_allowed_private: set[str] = field(default_factory=lambda: set())
|
# talk_allowed_private: set[str] = field(default_factory=lambda: set())
|
||||||
"""允许聊天的私聊列表"""
|
# """允许聊天的私聊列表"""
|
||||||
|
|
||||||
pfc_chatting: bool = False
|
pfc_chatting: bool = False
|
||||||
"""是否启用PFC"""
|
"""是否启用PFC"""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
[inner]
|
[inner]
|
||||||
version = "2.0.0"
|
version = "2.1.0"
|
||||||
|
|
||||||
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
|
||||||
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
#如果你想要修改配置文件,请在修改后将version的值进行变更
|
||||||
|
|
@ -18,12 +18,7 @@ nickname = "麦麦"
|
||||||
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
|
alias_names = ["麦叠", "牢麦"] #该选项还在调试中,暂时未生效
|
||||||
|
|
||||||
[chat_target]
|
[chat_target]
|
||||||
talk_allowed_groups = [
|
|
||||||
123,
|
|
||||||
123,
|
|
||||||
] #可以回复消息的群号码
|
|
||||||
talk_frequency_down_groups = [] #降低回复频率的群号码
|
talk_frequency_down_groups = [] #降低回复频率的群号码
|
||||||
ban_user_id = [] #禁止回复和读取消息的QQ号
|
|
||||||
|
|
||||||
[personality] #未完善
|
[personality] #未完善
|
||||||
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
personality_core = "用一句话或几句话描述人格的核心特点" # 建议20字以内,谁再写3000字小作文敲谁脑袋
|
||||||
|
|
@ -171,8 +166,6 @@ enable_kaomoji_protection = false # 是否启用颜文字保护
|
||||||
enable = true
|
enable = true
|
||||||
|
|
||||||
[experimental] #实验性功能
|
[experimental] #实验性功能
|
||||||
enable_friend_chat = false # 是否启用好友聊天
|
|
||||||
talk_allowed_private = [] # 可以回复消息的QQ号
|
|
||||||
pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立
|
pfc_chatting = false # 是否启用PFC聊天,该功能仅作用于私聊,与回复模式独立
|
||||||
|
|
||||||
#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写
|
#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env自定义的宏,使用自定义模型则选择定位相似的模型自己填写
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
# 添加项目根目录到Python路径
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
|
||||||
|
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
from src.common.database.database_model import Messages, BaseModel
|
from src.common.database.database_model import Messages, BaseModel
|
||||||
|
|
@ -15,7 +15,7 @@ from src.common.message_repository import find_messages
|
||||||
class TestMessageRepository(unittest.TestCase):
|
class TestMessageRepository(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# 创建内存中的SQLite数据库用于测试
|
# 创建内存中的SQLite数据库用于测试
|
||||||
self.test_db = SqliteDatabase(':memory:')
|
self.test_db = SqliteDatabase(":memory:")
|
||||||
|
|
||||||
# 覆盖原有数据库连接
|
# 覆盖原有数据库连接
|
||||||
BaseModel._meta.database = self.test_db
|
BaseModel._meta.database = self.test_db
|
||||||
|
|
@ -28,74 +28,74 @@ class TestMessageRepository(unittest.TestCase):
|
||||||
current_time = datetime.datetime.now().timestamp()
|
current_time = datetime.datetime.now().timestamp()
|
||||||
self.test_messages = [
|
self.test_messages = [
|
||||||
{
|
{
|
||||||
'message_id': 'msg1',
|
"message_id": "msg1",
|
||||||
'time': current_time - 3600, # 1小时前
|
"time": current_time - 3600, # 1小时前
|
||||||
'chat_id': '5ed68437e28644da51f314f37df68d18',
|
"chat_id": "5ed68437e28644da51f314f37df68d18",
|
||||||
'chat_info_stream_id': 'stream1',
|
"chat_info_stream_id": "stream1",
|
||||||
'chat_info_platform': 'qq',
|
"chat_info_platform": "qq",
|
||||||
'chat_info_user_platform': 'qq',
|
"chat_info_user_platform": "qq",
|
||||||
'chat_info_user_id': 'user1',
|
"chat_info_user_id": "user1",
|
||||||
'chat_info_user_nickname': '用户1',
|
"chat_info_user_nickname": "用户1",
|
||||||
'chat_info_user_cardname': '卡片名1',
|
"chat_info_user_cardname": "卡片名1",
|
||||||
'chat_info_group_platform': 'qq',
|
"chat_info_group_platform": "qq",
|
||||||
'chat_info_group_id': 'group1',
|
"chat_info_group_id": "group1",
|
||||||
'chat_info_group_name': '群组1',
|
"chat_info_group_name": "群组1",
|
||||||
'chat_info_create_time': current_time - 7200, # 2小时前
|
"chat_info_create_time": current_time - 7200, # 2小时前
|
||||||
'chat_info_last_active_time': current_time - 1800, # 30分钟前
|
"chat_info_last_active_time": current_time - 1800, # 30分钟前
|
||||||
'user_platform': 'qq',
|
"user_platform": "qq",
|
||||||
'user_id': 'user1',
|
"user_id": "user1",
|
||||||
'user_nickname': '用户1',
|
"user_nickname": "用户1",
|
||||||
'user_cardname': '卡片名1',
|
"user_cardname": "卡片名1",
|
||||||
'processed_plain_text': '你好',
|
"processed_plain_text": "你好",
|
||||||
'detailed_plain_text': '你好',
|
"detailed_plain_text": "你好",
|
||||||
'memorized_times': 1
|
"memorized_times": 1,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'message_id': 'msg2',
|
"message_id": "msg2",
|
||||||
'time': current_time - 1800, # 30分钟前
|
"time": current_time - 1800, # 30分钟前
|
||||||
'chat_id': 'chat1',
|
"chat_id": "chat1",
|
||||||
'chat_info_stream_id': 'stream1',
|
"chat_info_stream_id": "stream1",
|
||||||
'chat_info_platform': 'qq',
|
"chat_info_platform": "qq",
|
||||||
'chat_info_user_platform': 'qq',
|
"chat_info_user_platform": "qq",
|
||||||
'chat_info_user_id': 'user1',
|
"chat_info_user_id": "user1",
|
||||||
'chat_info_user_nickname': '用户1',
|
"chat_info_user_nickname": "用户1",
|
||||||
'chat_info_user_cardname': '卡片名1',
|
"chat_info_user_cardname": "卡片名1",
|
||||||
'chat_info_group_platform': 'qq',
|
"chat_info_group_platform": "qq",
|
||||||
'chat_info_group_id': 'group1',
|
"chat_info_group_id": "group1",
|
||||||
'chat_info_group_name': '群组1',
|
"chat_info_group_name": "群组1",
|
||||||
'chat_info_create_time': current_time - 7200,
|
"chat_info_create_time": current_time - 7200,
|
||||||
'chat_info_last_active_time': current_time - 900, # 15分钟前
|
"chat_info_last_active_time": current_time - 900, # 15分钟前
|
||||||
'user_platform': 'qq',
|
"user_platform": "qq",
|
||||||
'user_id': 'user1',
|
"user_id": "user1",
|
||||||
'user_nickname': '用户1',
|
"user_nickname": "用户1",
|
||||||
'user_cardname': '卡片名1',
|
"user_cardname": "卡片名1",
|
||||||
'processed_plain_text': '世界',
|
"processed_plain_text": "世界",
|
||||||
'detailed_plain_text': '世界',
|
"detailed_plain_text": "世界",
|
||||||
'memorized_times': 2
|
"memorized_times": 2,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
'message_id': 'msg3',
|
"message_id": "msg3",
|
||||||
'time': current_time - 900, # 15分钟前
|
"time": current_time - 900, # 15分钟前
|
||||||
'chat_id': 'chat2',
|
"chat_id": "chat2",
|
||||||
'chat_info_stream_id': 'stream2',
|
"chat_info_stream_id": "stream2",
|
||||||
'chat_info_platform': 'wechat',
|
"chat_info_platform": "wechat",
|
||||||
'chat_info_user_platform': 'wechat',
|
"chat_info_user_platform": "wechat",
|
||||||
'chat_info_user_id': 'user2',
|
"chat_info_user_id": "user2",
|
||||||
'chat_info_user_nickname': '用户2',
|
"chat_info_user_nickname": "用户2",
|
||||||
'chat_info_user_cardname': '卡片名2',
|
"chat_info_user_cardname": "卡片名2",
|
||||||
'chat_info_group_platform': 'wechat',
|
"chat_info_group_platform": "wechat",
|
||||||
'chat_info_group_id': 'group2',
|
"chat_info_group_id": "group2",
|
||||||
'chat_info_group_name': '群组2',
|
"chat_info_group_name": "群组2",
|
||||||
'chat_info_create_time': current_time - 3600,
|
"chat_info_create_time": current_time - 3600,
|
||||||
'chat_info_last_active_time': current_time - 600, # 10分钟前
|
"chat_info_last_active_time": current_time - 600, # 10分钟前
|
||||||
'user_platform': 'wechat',
|
"user_platform": "wechat",
|
||||||
'user_id': 'user2',
|
"user_id": "user2",
|
||||||
'user_nickname': '用户2',
|
"user_nickname": "用户2",
|
||||||
'user_cardname': '卡片名2',
|
"user_cardname": "卡片名2",
|
||||||
'processed_plain_text': '测试',
|
"processed_plain_text": "测试",
|
||||||
'detailed_plain_text': '测试',
|
"detailed_plain_text": "测试",
|
||||||
'memorized_times': 0
|
"memorized_times": 0,
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
for msg_data in self.test_messages:
|
for msg_data in self.test_messages:
|
||||||
|
|
@ -110,65 +110,63 @@ class TestMessageRepository(unittest.TestCase):
|
||||||
results = find_messages({})
|
results = find_messages({})
|
||||||
self.assertEqual(len(results), 3)
|
self.assertEqual(len(results), 3)
|
||||||
# 验证结果是否按时间升序排列
|
# 验证结果是否按时间升序排列
|
||||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
self.assertEqual(results[0]["message_id"], "msg1")
|
||||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
self.assertEqual(results[1]["message_id"], "msg2")
|
||||||
self.assertEqual(results[2]['message_id'], 'msg3')
|
self.assertEqual(results[2]["message_id"], "msg3")
|
||||||
|
|
||||||
def test_find_messages_with_filter(self):
|
def test_find_messages_with_filter(self):
|
||||||
"""测试带过滤器的查询"""
|
"""测试带过滤器的查询"""
|
||||||
results = find_messages({'chat_id': 'chat1'})
|
results = find_messages({"chat_id": "chat1"})
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
self.assertEqual(results[0]["message_id"], "msg1")
|
||||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
self.assertEqual(results[1]["message_id"], "msg2")
|
||||||
|
|
||||||
results = find_messages({'user_id': 'user2'})
|
results = find_messages({"user_id": "user2"})
|
||||||
self.assertEqual(len(results), 1)
|
self.assertEqual(len(results), 1)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg3')
|
self.assertEqual(results[0]["message_id"], "msg3")
|
||||||
|
|
||||||
def test_find_messages_with_operators(self):
|
def test_find_messages_with_operators(self):
|
||||||
"""测试带操作符的查询"""
|
"""测试带操作符的查询"""
|
||||||
results = find_messages({'memorized_times': {'$gt': 0}})
|
results = find_messages({"memorized_times": {"$gt": 0}})
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
self.assertEqual(results[0]["message_id"], "msg1")
|
||||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
self.assertEqual(results[1]["message_id"], "msg2")
|
||||||
|
|
||||||
results = find_messages({'memorized_times': {'$gte': 2}})
|
results = find_messages({"memorized_times": {"$gte": 2}})
|
||||||
self.assertEqual(len(results), 1)
|
self.assertEqual(len(results), 1)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg2')
|
self.assertEqual(results[0]["message_id"], "msg2")
|
||||||
|
|
||||||
def test_find_messages_with_sort(self):
|
def test_find_messages_with_sort(self):
|
||||||
"""测试带排序的查询"""
|
"""测试带排序的查询"""
|
||||||
results = find_messages({}, sort=[('memorized_times', -1)])
|
results = find_messages({}, sort=[("memorized_times", -1)])
|
||||||
self.assertEqual(len(results), 3)
|
self.assertEqual(len(results), 3)
|
||||||
# 验证结果是否按memorized_times降序排列
|
# 验证结果是否按memorized_times降序排列
|
||||||
self.assertEqual(results[0]['message_id'], 'msg2') # memorized_times = 2
|
self.assertEqual(results[0]["message_id"], "msg2") # memorized_times = 2
|
||||||
self.assertEqual(results[1]['message_id'], 'msg1') # memorized_times = 1
|
self.assertEqual(results[1]["message_id"], "msg1") # memorized_times = 1
|
||||||
self.assertEqual(results[2]['message_id'], 'msg3') # memorized_times = 0
|
self.assertEqual(results[2]["message_id"], "msg3") # memorized_times = 0
|
||||||
|
|
||||||
def test_find_messages_with_limit(self):
|
def test_find_messages_with_limit(self):
|
||||||
"""测试带限制的查询"""
|
"""测试带限制的查询"""
|
||||||
# 默认limit_mode为latest,应返回最新的2条记录
|
# 默认limit_mode为latest,应返回最新的2条记录
|
||||||
results = find_messages({}, limit=2)
|
results = find_messages({}, limit=2)
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg2')
|
self.assertEqual(results[0]["message_id"], "msg2")
|
||||||
self.assertEqual(results[1]['message_id'], 'msg3')
|
self.assertEqual(results[1]["message_id"], "msg3")
|
||||||
|
|
||||||
# 使用earliest模式,应返回最早的2条记录
|
# 使用earliest模式,应返回最早的2条记录
|
||||||
results = find_messages({}, limit=2, limit_mode='earliest')
|
results = find_messages({}, limit=2, limit_mode="earliest")
|
||||||
self.assertEqual(len(results), 2)
|
self.assertEqual(len(results), 2)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg1')
|
self.assertEqual(results[0]["message_id"], "msg1")
|
||||||
self.assertEqual(results[1]['message_id'], 'msg2')
|
self.assertEqual(results[1]["message_id"], "msg2")
|
||||||
|
|
||||||
def test_find_messages_with_combined_criteria(self):
|
def test_find_messages_with_combined_criteria(self):
|
||||||
"""测试组合查询条件"""
|
"""测试组合查询条件"""
|
||||||
results = find_messages(
|
results = find_messages(
|
||||||
{'chat_info_platform': 'qq', 'memorized_times': {'$gt': 0}},
|
{"chat_info_platform": "qq", "memorized_times": {"$gt": 0}}, sort=[("time", 1)], limit=1
|
||||||
sort=[('time', 1)],
|
|
||||||
limit=1
|
|
||||||
)
|
)
|
||||||
self.assertEqual(len(results), 1)
|
self.assertEqual(len(results), 1)
|
||||||
self.assertEqual(results[0]['message_id'], 'msg2')
|
self.assertEqual(results[0]["message_id"], "msg2")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
@ -9,7 +9,7 @@ import json
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
# 添加项目根目录到Python路径
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat, build_readable_messages
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
@ -17,10 +17,11 @@ from src.common.logger import get_module_logger
|
||||||
# 创建测试日志记录器
|
# 创建测试日志记录器
|
||||||
logger = get_module_logger("test_readable_msg")
|
logger = get_module_logger("test_readable_msg")
|
||||||
|
|
||||||
|
|
||||||
class TestBuildReadableMessages(unittest.TestCase):
|
class TestBuildReadableMessages(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# 准备测试数据:从真实数据库获取消息
|
# 准备测试数据:从真实数据库获取消息
|
||||||
self.chat_id = '5ed68437e28644da51f314f37df68d18'
|
self.chat_id = "5ed68437e28644da51f314f37df68d18"
|
||||||
self.current_time = time.time()
|
self.current_time = time.time()
|
||||||
self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
|
self.thirty_days_ago = self.current_time - (30 * 24 * 60 * 60) # 30天前的时间戳
|
||||||
|
|
||||||
|
|
@ -31,7 +32,7 @@ class TestBuildReadableMessages(unittest.TestCase):
|
||||||
timestamp_start=self.thirty_days_ago,
|
timestamp_start=self.thirty_days_ago,
|
||||||
timestamp_end=self.current_time,
|
timestamp_end=self.current_time,
|
||||||
limit=10,
|
limit=10,
|
||||||
limit_mode="latest"
|
limit_mode="latest",
|
||||||
)
|
)
|
||||||
logger.info(f"已获取 {len(self.messages)} 条测试消息")
|
logger.info(f"已获取 {len(self.messages)} 条测试消息")
|
||||||
|
|
||||||
|
|
@ -61,14 +62,14 @@ class TestBuildReadableMessages(unittest.TestCase):
|
||||||
fixed_msg = copy.deepcopy(msg)
|
fixed_msg = copy.deepcopy(msg)
|
||||||
|
|
||||||
# 构建 user_info 对象
|
# 构建 user_info 对象
|
||||||
if 'user_info' not in fixed_msg:
|
if "user_info" not in fixed_msg:
|
||||||
user_info = {
|
user_info = {
|
||||||
'platform': fixed_msg.get('user_platform', 'qq'),
|
"platform": fixed_msg.get("user_platform", "qq"),
|
||||||
'user_id': fixed_msg.get('user_id', '10000'),
|
"user_id": fixed_msg.get("user_id", "10000"),
|
||||||
'user_nickname': fixed_msg.get('user_nickname', '测试用户'),
|
"user_nickname": fixed_msg.get("user_nickname", "测试用户"),
|
||||||
'user_cardname': fixed_msg.get('user_cardname', '')
|
"user_cardname": fixed_msg.get("user_cardname", ""),
|
||||||
}
|
}
|
||||||
fixed_msg['user_info'] = user_info
|
fixed_msg["user_info"] = user_info
|
||||||
logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info")
|
logger.info(f"为消息 {fixed_msg.get('message_id')} 添加了 user_info")
|
||||||
|
|
||||||
fixed_messages.append(fixed_msg)
|
fixed_messages.append(fixed_msg)
|
||||||
|
|
@ -77,14 +78,16 @@ class TestBuildReadableMessages(unittest.TestCase):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 使用修复后的消息尝试格式化
|
# 使用修复后的消息尝试格式化
|
||||||
formatted_text = asyncio.run(build_readable_messages(
|
formatted_text = asyncio.run(
|
||||||
messages=fixed_messages,
|
build_readable_messages(
|
||||||
replace_bot_name=True,
|
messages=fixed_messages,
|
||||||
merge_messages=False,
|
replace_bot_name=True,
|
||||||
timestamp_mode="absolute",
|
merge_messages=False,
|
||||||
read_mark=0.0,
|
timestamp_mode="absolute",
|
||||||
truncate=False
|
read_mark=0.0,
|
||||||
))
|
truncate=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("使用修复后的消息格式化完成")
|
logger.info("使用修复后的消息格式化完成")
|
||||||
logger.info(f"格式化结果长度: {len(formatted_text)}")
|
logger.info(f"格式化结果长度: {len(formatted_text)}")
|
||||||
|
|
@ -120,26 +123,24 @@ class TestBuildReadableMessages(unittest.TestCase):
|
||||||
logger.info(f"user_info存在: {'user_info' in test_msg}")
|
logger.info(f"user_info存在: {'user_info' in test_msg}")
|
||||||
|
|
||||||
# 修复缺少的user_info字段
|
# 修复缺少的user_info字段
|
||||||
if 'user_info' not in test_msg:
|
if "user_info" not in test_msg:
|
||||||
logger.warning("消息中缺少user_info字段,添加模拟数据")
|
logger.warning("消息中缺少user_info字段,添加模拟数据")
|
||||||
test_msg['user_info'] = {
|
test_msg["user_info"] = {
|
||||||
'platform': test_msg.get('user_platform', 'qq'),
|
"platform": test_msg.get("user_platform", "qq"),
|
||||||
'user_id': test_msg.get('user_id', '10000'),
|
"user_id": test_msg.get("user_id", "10000"),
|
||||||
'user_nickname': test_msg.get('user_nickname', '测试用户'),
|
"user_nickname": test_msg.get("user_nickname", "测试用户"),
|
||||||
'user_cardname': test_msg.get('user_cardname', '')
|
"user_cardname": test_msg.get("user_cardname", ""),
|
||||||
}
|
}
|
||||||
logger.info(f"添加的user_info: {test_msg['user_info']}")
|
logger.info(f"添加的user_info: {test_msg['user_info']}")
|
||||||
|
|
||||||
simple_msgs = [test_msg]
|
simple_msgs = [test_msg]
|
||||||
|
|
||||||
# 运行内部函数
|
# 运行内部函数
|
||||||
result_text, result_details = asyncio.run(_build_readable_messages_internal(
|
result_text, result_details = asyncio.run(
|
||||||
simple_msgs,
|
_build_readable_messages_internal(
|
||||||
replace_bot_name=True,
|
simple_msgs, replace_bot_name=True, merge_messages=False, timestamp_mode="absolute", truncate=False
|
||||||
merge_messages=False,
|
)
|
||||||
timestamp_mode="absolute",
|
)
|
||||||
truncate=False
|
|
||||||
))
|
|
||||||
|
|
||||||
logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}")
|
logger.info(f"内部函数返回结果: {result_text[:200] if result_text else '空'}")
|
||||||
logger.info(f"详情列表长度: {len(result_details)}")
|
logger.info(f"详情列表长度: {len(result_details)}")
|
||||||
|
|
@ -167,5 +168,6 @@ class TestBuildReadableMessages(unittest.TestCase):
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
@ -5,13 +5,14 @@ import datetime
|
||||||
import time
|
import time
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
# 添加项目根目录到Python路径
|
||||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
from src.common.message_repository import find_messages
|
from src.common.message_repository import find_messages
|
||||||
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import get_raw_msg_by_timestamp_with_chat
|
||||||
from peewee import SqliteDatabase
|
from peewee import SqliteDatabase
|
||||||
from src.common.database.database import db # 导入实际的数据库连接
|
from src.common.database.database import db # 导入实际的数据库连接
|
||||||
|
|
||||||
|
|
||||||
class TestExtractMessages(unittest.TestCase):
|
class TestExtractMessages(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
# 这个测试使用真实的数据库,所以不需要创建测试数据
|
# 这个测试使用真实的数据库,所以不需要创建测试数据
|
||||||
|
|
@ -19,13 +20,10 @@ class TestExtractMessages(unittest.TestCase):
|
||||||
|
|
||||||
def test_extract_latest_messages_direct(self):
|
def test_extract_latest_messages_direct(self):
|
||||||
"""测试直接使用message_repository.find_messages函数"""
|
"""测试直接使用message_repository.find_messages函数"""
|
||||||
chat_id = '5ed68437e28644da51f314f37df68d18'
|
chat_id = "5ed68437e28644da51f314f37df68d18"
|
||||||
|
|
||||||
# 提取最新的10条消息
|
# 提取最新的10条消息
|
||||||
results = find_messages(
|
results = find_messages({"chat_id": chat_id}, limit=10)
|
||||||
{'chat_id': chat_id},
|
|
||||||
limit=10
|
|
||||||
)
|
|
||||||
|
|
||||||
# 打印结果数量
|
# 打印结果数量
|
||||||
print(f"\n直接使用find_messages,找到 {len(results)} 条消息")
|
print(f"\n直接使用find_messages,找到 {len(results)} 条消息")
|
||||||
|
|
@ -34,12 +32,12 @@ class TestExtractMessages(unittest.TestCase):
|
||||||
if results:
|
if results:
|
||||||
print("\n消息时间顺序:")
|
print("\n消息时间顺序:")
|
||||||
for idx, msg in enumerate(results):
|
for idx, msg in enumerate(results):
|
||||||
msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S')
|
msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}")
|
print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}")
|
||||||
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
||||||
|
|
||||||
# 验证结果按时间排序
|
# 验证结果按时间排序
|
||||||
times = [msg['time'] for msg in results]
|
times = [msg["time"] for msg in results]
|
||||||
self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
|
self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
|
||||||
else:
|
else:
|
||||||
print(f"未找到chat_id为 {chat_id} 的消息")
|
print(f"未找到chat_id为 {chat_id} 的消息")
|
||||||
|
|
@ -49,7 +47,7 @@ class TestExtractMessages(unittest.TestCase):
|
||||||
|
|
||||||
def test_extract_latest_messages_via_builder(self):
|
def test_extract_latest_messages_via_builder(self):
|
||||||
"""使用chat_message_builder中的函数测试从真实数据库提取消息"""
|
"""使用chat_message_builder中的函数测试从真实数据库提取消息"""
|
||||||
chat_id = '5ed68437e28644da51f314f37df68d18'
|
chat_id = "5ed68437e28644da51f314f37df68d18"
|
||||||
|
|
||||||
# 设置时间范围为过去30天到现在
|
# 设置时间范围为过去30天到现在
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
|
@ -57,11 +55,7 @@ class TestExtractMessages(unittest.TestCase):
|
||||||
|
|
||||||
# 使用chat_message_builder中的函数
|
# 使用chat_message_builder中的函数
|
||||||
results = get_raw_msg_by_timestamp_with_chat(
|
results = get_raw_msg_by_timestamp_with_chat(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id, timestamp_start=thirty_days_ago, timestamp_end=current_time, limit=10, limit_mode="latest"
|
||||||
timestamp_start=thirty_days_ago,
|
|
||||||
timestamp_end=current_time,
|
|
||||||
limit=10,
|
|
||||||
limit_mode="latest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 打印结果数量
|
# 打印结果数量
|
||||||
|
|
@ -71,12 +65,12 @@ class TestExtractMessages(unittest.TestCase):
|
||||||
if results:
|
if results:
|
||||||
print("\n消息时间顺序:")
|
print("\n消息时间顺序:")
|
||||||
for idx, msg in enumerate(results):
|
for idx, msg in enumerate(results):
|
||||||
msg_time = datetime.datetime.fromtimestamp(msg['time']).strftime('%Y-%m-%d %H:%M:%S')
|
msg_time = datetime.datetime.fromtimestamp(msg["time"]).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
print(f"{idx+1}. ID: {msg['message_id']}, 时间: {msg_time}")
|
print(f"{idx + 1}. ID: {msg['message_id']}, 时间: {msg_time}")
|
||||||
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
print(f" 文本: {msg.get('processed_plain_text', '无文本内容')[:50]}...")
|
||||||
|
|
||||||
# 验证结果按时间排序
|
# 验证结果按时间排序
|
||||||
times = [msg['time'] for msg in results]
|
times = [msg["time"] for msg in results]
|
||||||
self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
|
self.assertEqual(times, sorted(times), "消息应该按时间升序排列")
|
||||||
else:
|
else:
|
||||||
print(f"未找到chat_id为 {chat_id} 的消息")
|
print(f"未找到chat_id为 {chat_id} 的消息")
|
||||||
|
|
@ -84,5 +78,6 @@ class TestExtractMessages(unittest.TestCase):
|
||||||
# 最基本的断言,确保测试有效
|
# 最基本的断言,确保测试有效
|
||||||
self.assertIsInstance(results, list, "结果应该是一个列表")
|
self.assertIsInstance(results, list, "结果应该是一个列表")
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
Loading…
Reference in New Issue