mirror of https://github.com/Mai-with-u/MaiBot.git
分离嵌入向量方法
parent
080c862fcd
commit
28eb827c5f
|
|
@ -6,7 +6,7 @@ from maim_message import UserInfo # UserInfo 来自 maim_message 包 # 从 maim_
|
||||||
from src.plugins.chat.message import MessageRecv # MessageRecv 来自message.py
|
from src.plugins.chat.message import MessageRecv # MessageRecv 来自message.py
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from ..chat.chat_stream import chat_manager
|
from ..chat.chat_stream import ChatStream, chat_manager
|
||||||
from src.plugins.chat.utils import get_embedding
|
from src.plugins.chat.utils import get_embedding
|
||||||
from src.common.database import db
|
from src.common.database import db
|
||||||
from .pfc_manager import PFCManager
|
from .pfc_manager import PFCManager
|
||||||
|
|
@ -24,7 +24,7 @@ async def _handle_error(error: Exception, context: str, message: MessageRecv | N
|
||||||
if message and hasattr(message, 'message_info') and hasattr(message.message_info, 'raw_message'): # MessageRecv 结构可能没有直接的 raw_message
|
if message and hasattr(message, 'message_info') and hasattr(message.message_info, 'raw_message'): # MessageRecv 结构可能没有直接的 raw_message
|
||||||
raw_msg_content = getattr(message.message_info, 'raw_message', None) # 安全获取
|
raw_msg_content = getattr(message.message_info, 'raw_message', None) # 安全获取
|
||||||
if raw_msg_content:
|
if raw_msg_content:
|
||||||
logger.error(f"相关消息原始内容: {raw_msg_content}")
|
logger.error(f"相关消息原始内容: {raw_msg_content}")
|
||||||
elif message and hasattr(message, 'raw_message'): # 如果 MessageRecv 直接有 raw_message
|
elif message and hasattr(message, 'raw_message'): # 如果 MessageRecv 直接有 raw_message
|
||||||
logger.error(f"相关消息原始内容: {message.raw_message}")
|
logger.error(f"相关消息原始内容: {message.raw_message}")
|
||||||
|
|
||||||
|
|
@ -47,21 +47,10 @@ class PFCProcessor:
|
||||||
try:
|
try:
|
||||||
# 1. 消息解析与初始化
|
# 1. 消息解析与初始化
|
||||||
message_obj = MessageRecv(message_data) # 使用你提供的 message.py 中的 MessageRecv
|
message_obj = MessageRecv(message_data) # 使用你提供的 message.py 中的 MessageRecv
|
||||||
# 确保 message_obj.message_info 存在
|
|
||||||
if not hasattr(message_obj, 'message_info'):
|
|
||||||
logger.error("MessageRecv 对象缺少 message_info 属性。跳过处理。")
|
|
||||||
return
|
|
||||||
|
|
||||||
groupinfo = getattr(message_obj.message_info, 'group_info', None)
|
groupinfo = getattr(message_obj.message_info, 'group_info', None)
|
||||||
userinfo = getattr(message_obj.message_info, 'user_info', None)
|
userinfo = getattr(message_obj.message_info, 'user_info', None)
|
||||||
|
|
||||||
if userinfo is None: # 确保 userinfo 存在
|
|
||||||
logger.error("message_obj.message_info 中缺少 user_info。跳过处理。")
|
|
||||||
return
|
|
||||||
if not hasattr(userinfo, 'user_id'): # 确保 user_id 存在
|
|
||||||
logger.error("userinfo 对象中缺少 user_id。跳过处理。")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.trace(f"准备为{userinfo.user_id}创建/获取聊天流")
|
logger.trace(f"准备为{userinfo.user_id}创建/获取聊天流")
|
||||||
chat = await chat_manager.get_or_create_stream(
|
chat = await chat_manager.get_or_create_stream(
|
||||||
platform=message_obj.message_info.platform,
|
platform=message_obj.message_info.platform,
|
||||||
|
|
@ -73,7 +62,7 @@ class PFCProcessor:
|
||||||
# 2. 过滤检查
|
# 2. 过滤检查
|
||||||
await message_obj.process() # 调用 MessageRecv 的异步 process 方法
|
await message_obj.process() # 调用 MessageRecv 的异步 process 方法
|
||||||
if self._check_ban_words(message_obj.processed_plain_text, userinfo) or \
|
if self._check_ban_words(message_obj.processed_plain_text, userinfo) or \
|
||||||
self._check_ban_regex(message_obj.raw_message, userinfo): # MessageRecv 有 raw_message 属性
|
self._check_ban_regex(message_obj.raw_message, userinfo): # MessageRecv 有 raw_message 属性
|
||||||
return
|
return
|
||||||
|
|
||||||
# 3. 消息存储 (保持原有调用)
|
# 3. 消息存储 (保持原有调用)
|
||||||
|
|
@ -82,49 +71,10 @@ class PFCProcessor:
|
||||||
await self.storage.store_message(message_obj, chat)
|
await self.storage.store_message(message_obj, chat)
|
||||||
logger.trace(f"存储成功 (初步): {message_obj.processed_plain_text}")
|
logger.trace(f"存储成功 (初步): {message_obj.processed_plain_text}")
|
||||||
|
|
||||||
# === 新增:为已存储的消息生成嵌入并更新数据库文档 ===
|
await self._update_embedding_vector(message_obj) # 明确传递 message_obj
|
||||||
embedding_vector = None
|
|
||||||
text_for_embedding = message_obj.processed_plain_text # 使用处理后的纯文本
|
|
||||||
|
|
||||||
# 在 storage.py 中,会对 processed_plain_text 进行一次过滤
|
|
||||||
# 为了保持一致,我们也在这里应用相同的过滤逻辑
|
|
||||||
# 当然,更优的做法是 store_message 返回过滤后的文本,或在 message_obj 中增加一个 filtered_processed_plain_text 属性
|
|
||||||
# 这里为了简单,我们先重复一次过滤逻辑
|
|
||||||
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
|
||||||
if text_for_embedding:
|
|
||||||
filtered_text_for_embedding = re.sub(pattern, "", text_for_embedding, flags=re.DOTALL)
|
|
||||||
else:
|
|
||||||
filtered_text_for_embedding = ""
|
|
||||||
|
|
||||||
if filtered_text_for_embedding and filtered_text_for_embedding.strip():
|
|
||||||
try:
|
|
||||||
# request_type 参数根据你的 get_embedding 函数实际需求来定
|
|
||||||
embedding_vector = await get_embedding(filtered_text_for_embedding, request_type="pfc_private_memory")
|
|
||||||
if embedding_vector:
|
|
||||||
logger.debug(f"成功为消息 ID '{message_obj.message_info.message_id}' 生成嵌入向量。")
|
|
||||||
|
|
||||||
# 更新数据库中的对应文档
|
|
||||||
# 确保你有权限访问和操作 db 对象
|
|
||||||
update_result = db.messages.update_one(
|
|
||||||
{"message_id": message_obj.message_info.message_id, "chat_id": chat.stream_id},
|
|
||||||
{"$set": {"embedding_vector": embedding_vector}}
|
|
||||||
)
|
|
||||||
if update_result.modified_count > 0:
|
|
||||||
logger.info(f"成功为消息 ID '{message_obj.message_info.message_id}' 更新嵌入向量到数据库。")
|
|
||||||
elif update_result.matched_count > 0:
|
|
||||||
logger.warning(f"消息 ID '{message_obj.message_info.message_id}' 已存在嵌入向量或未作修改。")
|
|
||||||
else:
|
|
||||||
logger.error(f"未能找到消息 ID '{message_obj.message_info.message_id}' (chat_id: {chat.stream_id}) 来更新嵌入向量。可能是存储和更新之间存在延迟或问题。")
|
|
||||||
else:
|
|
||||||
logger.warning(f"未能为消息 ID '{message_obj.message_info.message_id}' 的文本 '{filtered_text_for_embedding[:30]}...' 生成嵌入向量。")
|
|
||||||
except Exception as e_embed_update:
|
|
||||||
logger.error(f"为消息 ID '{message_obj.message_info.message_id}' 生成嵌入或更新数据库时发生异常: {e_embed_update}", exc_info=True)
|
|
||||||
else:
|
|
||||||
logger.debug(f"消息 ID '{message_obj.message_info.message_id}' 的过滤后纯文本为空,不生成或更新嵌入。")
|
|
||||||
# === 新增结束 ===
|
|
||||||
|
|
||||||
# 4. 创建 PFC 聊天流
|
# 4. 创建 PFC 聊天流
|
||||||
await self._create_pfc_chat(message_obj)
|
await self._create_pfc_chat(message_obj, chat)
|
||||||
|
|
||||||
# 5. 日志记录
|
# 5. 日志记录
|
||||||
# 确保 message_obj.message_info.time 是 float 类型的时间戳
|
# 确保 message_obj.message_info.time 是 float 类型的时间戳
|
||||||
|
|
@ -169,4 +119,47 @@ class PFCProcessor:
|
||||||
logger.info(f"[私聊]{userinfo.user_nickname}:{text}") # _nickname
|
logger.info(f"[私聊]{userinfo.user_nickname}:{text}") # _nickname
|
||||||
logger.info(f"[正则表达式过滤]消息匹配到{pattern.pattern},filtered") # .pattern 获取原始表达式字符串
|
logger.info(f"[正则表达式过滤]消息匹配到{pattern.pattern},filtered") # .pattern 获取原始表达式字符串
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _update_embedding_vector(self, message_obj: MessageRecv, chat: ChatStream) -> None:
|
||||||
|
"""更新消息的嵌入向量"""
|
||||||
|
# === 新增:为已存储的消息生成嵌入并更新数据库文档 ===
|
||||||
|
embedding_vector = None
|
||||||
|
text_for_embedding = message_obj.processed_plain_text # 使用处理后的纯文本
|
||||||
|
|
||||||
|
# 在 storage.py 中,会对 processed_plain_text 进行一次过滤
|
||||||
|
# 为了保持一致,我们也在这里应用相同的过滤逻辑
|
||||||
|
# 当然,更优的做法是 store_message 返回过滤后的文本,或在 message_obj 中增加一个 filtered_processed_plain_text 属性
|
||||||
|
# 这里为了简单,我们先重复一次过滤逻辑
|
||||||
|
pattern = r"<MainRule>.*?</MainRule>|<schedule>.*?</schedule>|<UserMessage>.*?</UserMessage>"
|
||||||
|
if text_for_embedding:
|
||||||
|
filtered_text_for_embedding = re.sub(pattern, "", text_for_embedding, flags=re.DOTALL)
|
||||||
|
else:
|
||||||
|
filtered_text_for_embedding = ""
|
||||||
|
|
||||||
|
if filtered_text_for_embedding and filtered_text_for_embedding.strip():
|
||||||
|
try:
|
||||||
|
# request_type 参数根据你的 get_embedding 函数实际需求来定
|
||||||
|
embedding_vector = await get_embedding(filtered_text_for_embedding, request_type="pfc_private_memory")
|
||||||
|
if embedding_vector:
|
||||||
|
logger.debug(f"成功为消息 ID '{message_obj.message_info.message_id}' 生成嵌入向量。")
|
||||||
|
|
||||||
|
# 更新数据库中的对应文档
|
||||||
|
# 确保你有权限访问和操作 db 对象
|
||||||
|
update_result = db.messages.update_one(
|
||||||
|
{"message_id": message_obj.message_info.message_id, "chat_id": chat.stream_id},
|
||||||
|
{"$set": {"embedding_vector": embedding_vector}}
|
||||||
|
)
|
||||||
|
if update_result.modified_count > 0:
|
||||||
|
logger.info(f"成功为消息 ID '{message_obj.message_info.message_id}' 更新嵌入向量到数据库。")
|
||||||
|
elif update_result.matched_count > 0:
|
||||||
|
logger.warning(f"消息 ID '{message_obj.message_info.message_id}' 已存在嵌入向量或未作修改。")
|
||||||
|
else:
|
||||||
|
logger.error(f"未能找到消息 ID '{message_obj.message_info.message_id}' (chat_id: {chat.stream_id}) 来更新嵌入向量。可能是存储和更新之间存在延迟或问题。")
|
||||||
|
else:
|
||||||
|
logger.warning(f"未能为消息 ID '{message_obj.message_info.message_id}' 的文本 '{filtered_text_for_embedding[:30]}...' 生成嵌入向量。")
|
||||||
|
except Exception as e_embed_update:
|
||||||
|
logger.error(f"为消息 ID '{message_obj.message_info.message_id}' 生成嵌入或更新数据库时发生异常: {e_embed_update}", exc_info=True)
|
||||||
|
else:
|
||||||
|
logger.debug(f"消息 ID '{message_obj.message_info.message_id}' 的过滤后纯文本为空,不生成或更新嵌入。")
|
||||||
|
# === 新增结束 ===
|
||||||
Loading…
Reference in New Issue