refactor: 日志打印优化(终于改完了,爽了

pull/155/head
AL76 2025-03-10 02:25:03 +08:00
parent 8d99592b32
commit 052e67b576
15 changed files with 431 additions and 412 deletions

2
bot.py
View File

@ -100,7 +100,7 @@ def load_logger():
"#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg " "#777777>|</> <cyan>{name:.<8}</cyan>:<cyan>{function:.<8}</cyan>:<cyan>{line: >4}</cyan> <fg "
"#777777>-</> <level>{message}</level>", "#777777>-</> <level>{message}</level>",
colorize=True, colorize=True,
level=os.getenv("LOG_LEVEL", "INFO") # 根据环境设置日志级别默认为INFO level=os.getenv("LOG_LEVEL", "DEBUG") # 根据环境设置日志级别默认为INFO
) )

View File

@ -71,8 +71,8 @@ class ChatBot:
for word in global_config.ban_words: for word in global_config.ban_words:
if word in message.detailed_plain_text: if word in message.detailed_plain_text:
logger.info( logger.info(
f"\033[1;32m[{message.group_name}]{message.user_nickname}:\033[0m {message.processed_plain_text}") f"[{message.group_name}]{message.user_nickname}:{message.processed_plain_text}")
logger.info(f"\033[1;32m[过滤词识别]\033[0m 消息中含有{word}filtered") logger.info(f"[过滤词识别]消息中含有{word}filtered")
return return
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
@ -81,8 +81,8 @@ class ChatBot:
topic = '' topic = ''
interested_rate = 0 interested_rate = 0
interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100 interested_rate = await hippocampus.memory_activate_value(message.processed_plain_text) / 100
logger.debug(f"\033[1;32m[记忆激活]\033[0m {message.processed_plain_text}" logger.debug(f"{message.processed_plain_text}"
"的激活度:---------------------------------------{interested_rate}\n") f"的激活度:{interested_rate}")
# logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") # logger.info(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
await self.storage.store_message(message, topic[0] if topic else None) await self.storage.store_message(message, topic[0] if topic else None)
@ -99,10 +99,9 @@ class ChatBot:
) )
current_willing = willing_manager.get_willing(event.group_id) current_willing = willing_manager.get_willing(event.group_id)
logger.debug( logger.info(
f"\033[1;32m[{current_time}][{message.group_name}]{message.user_nickname}:\033[0m " f"[{current_time}][{message.group_name}]{message.user_nickname}:"
"{message.processed_plain_text}\033[1;36m[回复意愿:{current_willing:.2f}][概率:{reply_probability * " f"{message.processed_plain_text}[回复意愿:{current_willing:.2f}][概率:{reply_probability * 100:.1f}%]")
"100:.1f}%]\033[0m")
response = "" response = ""
@ -130,7 +129,7 @@ class ChatBot:
# 如果找不到思考消息,直接返回 # 如果找不到思考消息,直接返回
if not thinking_message: if not thinking_message:
print(f"\033[1;33m[警告]\033[0m 未找到对应的思考消息,可能已超时被移除") logger.warning(f"未找到对应的思考消息,可能已超时被移除")
return return
# 记录开始思考的时间,避免从思考到回复的时间太久 # 记录开始思考的时间,避免从思考到回复的时间太久

View File

@ -4,6 +4,7 @@ import os
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Dict, Optional
from loguru import logger
import requests import requests
@ -151,11 +152,11 @@ class CQCode:
except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e: except (requests.exceptions.SSLError, requests.exceptions.HTTPError) as e:
if retry == max_retries - 1: if retry == max_retries - 1:
print(f"\033[1;31m[致命错误]\033[0m 最终请求失败: {str(e)}") logger.error(f"最终请求失败: {str(e)}")
time.sleep(1.5 ** retry) # 指数退避 time.sleep(1.5 ** retry) # 指数退避
except Exception as e: except Exception as e:
print(f"\033[1;33m[未知错误]\033[0m {str(e)}") logger.exception(f"[未知错误]")
return None return None
return None return None
@ -194,7 +195,7 @@ class CQCode:
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]" return f"[表情包:{description}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") logger.exception(f"AI接口调用失败: {str(e)}")
return "[表情包]" return "[表情包]"
async def get_image_description(self, image_base64: str) -> str: async def get_image_description(self, image_base64: str) -> str:
@ -205,7 +206,7 @@ class CQCode:
description, _ = await self._llm.generate_response_for_image(prompt, image_base64) description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]" return f"[图片:{description}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") logger.exception(f"AI接口调用失败: {str(e)}")
return "[图片]" return "[图片]"
async def translate_forward(self) -> str: async def translate_forward(self) -> str:
@ -222,7 +223,7 @@ class CQCode:
try: try:
messages = ast.literal_eval(content) messages = ast.literal_eval(content)
except ValueError as e: except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") logger.error(f"解析转发消息内容失败: {str(e)}")
return '[转发消息]' return '[转发消息]'
# 处理每条消息 # 处理每条消息
@ -277,11 +278,11 @@ class CQCode:
# 合并所有消息 # 合并所有消息
combined_messages = '\n'.join(formatted_messages) combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") logger.debug(f"合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]" return f"[转发消息:\n{combined_messages}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") logger.exception("处理转发消息失败")
return '[转发消息]' return '[转发消息]'
async def translate_reply(self) -> str: async def translate_reply(self) -> str:
@ -307,7 +308,7 @@ class CQCode:
return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]" return f"[回复 {self.reply_message.sender.nickname} 的消息: {message_obj.processed_plain_text}]"
else: else:
print("\033[1;31m[错误]\033[0m 回复消息的sender.user_id为空") logger.error("回复消息的sender.user_id为空")
return '[回复某人消息]' return '[回复某人消息]'
@staticmethod @staticmethod

View File

@ -21,24 +21,25 @@ config = driver.config
class EmojiManager: class EmojiManager:
_instance = None _instance = None
EMOJI_DIR = "data/emoji" # 表情包存储目录 EMOJI_DIR = "data/emoji" # 表情包存储目录
def __new__(cls): def __new__(cls):
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance.db = None cls._instance.db = None
cls._instance._initialized = False cls._instance._initialized = False
return cls._instance return cls._instance
def __init__(self): def __init__(self):
self.db = Database.get_instance() self.db = Database.get_instance()
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,temperature=0.8) #更高的温度更少的token后续可以根据情绪来调整温度 self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,
temperature=0.8) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):
"""确保表情存储目录存在""" """确保表情存储目录存在"""
os.makedirs(self.EMOJI_DIR, exist_ok=True) os.makedirs(self.EMOJI_DIR, exist_ok=True)
def initialize(self): def initialize(self):
"""初始化数据库连接和表情目录""" """初始化数据库连接和表情目录"""
if not self._initialized: if not self._initialized:
@ -50,15 +51,15 @@ class EmojiManager:
# 启动时执行一次完整性检查 # 启动时执行一次完整性检查
self.check_emoji_file_integrity() self.check_emoji_file_integrity()
except Exception as e: except Exception as e:
logger.error(f"初始化表情管理器失败: {str(e)}") logger.exception(f"初始化表情管理器失败")
def _ensure_db(self): def _ensure_db(self):
"""确保数据库已初始化""" """确保数据库已初始化"""
if not self._initialized: if not self._initialized:
self.initialize() self.initialize()
if not self._initialized: if not self._initialized:
raise RuntimeError("EmojiManager not initialized") raise RuntimeError("EmojiManager not initialized")
def _ensure_emoji_collection(self): def _ensure_emoji_collection(self):
"""确保emoji集合存在并创建索引 """确保emoji集合存在并创建索引
@ -76,7 +77,7 @@ class EmojiManager:
self.db.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.db.emoji.create_index([('tags', 1)]) self.db.db.emoji.create_index([('tags', 1)])
self.db.db.emoji.create_index([('filename', 1)], unique=True) self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
@ -86,8 +87,8 @@ class EmojiManager:
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
except Exception as e: except Exception as e:
logger.error(f"记录表情使用失败: {str(e)}") logger.exception(f"记录表情使用失败")
async def get_emoji_for_text(self, text: str) -> Optional[str]: async def get_emoji_for_text(self, text: str) -> Optional[str]:
"""根据文本内容获取相关表情包 """根据文本内容获取相关表情包
Args: Args:
@ -102,9 +103,9 @@ class EmojiManager:
""" """
try: try:
self._ensure_db() self._ensure_db()
# 获取文本的embedding # 获取文本的embedding
text_for_search= await self._get_kimoji_for_text(text) text_for_search = await self._get_kimoji_for_text(text)
if not text_for_search: if not text_for_search:
logger.error("无法获取文本的情绪") logger.error("无法获取文本的情绪")
return None return None
@ -112,15 +113,15 @@ class EmojiManager:
if not text_embedding: if not text_embedding:
logger.error("无法获取文本的embedding") logger.error("无法获取文本的embedding")
return None return None
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1})) all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'discription': 1}))
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")
return None return None
# 计算余弦相似度并排序 # 计算余弦相似度并排序
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
if not v1 or not v2: if not v1 or not v2:
@ -131,42 +132,43 @@ class EmojiManager:
if norm_v1 == 0 or norm_v2 == 0: if norm_v1 == 0 or norm_v2 == 0:
return 0 return 0
return dot_product / (norm_v1 * norm_v2) return dot_product / (norm_v1 * norm_v2)
# 计算所有表情包与输入文本的相似度 # 计算所有表情包与输入文本的相似度
emoji_similarities = [ emoji_similarities = [
(emoji, cosine_similarity(text_embedding, emoji.get('embedding', []))) (emoji, cosine_similarity(text_embedding, emoji.get('embedding', [])))
for emoji in all_emojis for emoji in all_emojis
] ]
# 按相似度降序排序 # 按相似度降序排序
emoji_similarities.sort(key=lambda x: x[1], reverse=True) emoji_similarities.sort(key=lambda x: x[1], reverse=True)
# 获取前3个最相似的表情包 # 获取前3个最相似的表情包
top_3_emojis = emoji_similarities[:3] top_3_emojis = emoji_similarities[:3]
if not top_3_emojis: if not top_3_emojis:
logger.warning("未找到匹配的表情包") logger.warning("未找到匹配的表情包")
return None return None
# 从前3个中随机选择一个 # 从前3个中随机选择一个
selected_emoji, similarity = random.choice(top_3_emojis) selected_emoji, similarity = random.choice(top_3_emojis)
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
self.db.db.emoji.update_one( self.db.db.emoji.update_one(
{'_id': selected_emoji['_id']}, {'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})") logger.success(
f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了 # 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
return selected_emoji['path'],"[ %s ]" % selected_emoji.get('discription', '无描述') return selected_emoji['path'], "[ %s ]" % selected_emoji.get('discription', '无描述')
except Exception as search_error: except Exception as search_error:
logger.error(f"搜索表情包失败: {str(search_error)}") logger.error(f"搜索表情包失败: {str(search_error)}")
return None return None
return None return None
except Exception as e: except Exception as e:
logger.error(f"获取表情包失败: {str(e)}") logger.error(f"获取表情包失败: {str(e)}")
return None return None
@ -175,39 +177,39 @@ class EmojiManager:
"""获取表情包的标签""" """获取表情包的标签"""
try: try:
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感' prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}") logger.debug(f"输出描述: {content}")
return content return content
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def _check_emoji(self, image_base64: str) -> str: async def _check_emoji(self, image_base64: str) -> str:
try: try:
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容' prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64) content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
logger.debug(f"输出描述: {content}") logger.debug(f"输出描述: {content}")
return content return content
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def _get_kimoji_for_text(self, text:str): async def _get_kimoji_for_text(self, text: str):
try: try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
content, _ = await self.llm_emotion_judge.generate_response_async(prompt) content, _ = await self.llm_emotion_judge.generate_response_async(prompt)
logger.info(f"输出描述: {content}") logger.info(f"输出描述: {content}")
return content return content
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def scan_new_emojis(self): async def scan_new_emojis(self):
"""扫描新的表情包""" """扫描新的表情包"""
try: try:
@ -215,22 +217,23 @@ class EmojiManager:
os.makedirs(emoji_dir, exist_ok=True) os.makedirs(emoji_dir, exist_ok=True)
# 获取所有支持的图片文件 # 获取所有支持的图片文件
files_to_process = [f for f in os.listdir(emoji_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))] files_to_process = [f for f in os.listdir(emoji_dir) if
f.lower().endswith(('.jpg', '.jpeg', '.png', '.gif'))]
for filename in files_to_process: for filename in files_to_process:
image_path = os.path.join(emoji_dir, filename) image_path = os.path.join(emoji_dir, filename)
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
if existing_emoji: if existing_emoji:
continue continue
# 压缩图片并获取base64编码 # 压缩图片并获取base64编码
image_base64 = image_path_to_base64(image_path) image_base64 = image_path_to_base64(image_path)
if image_base64 is None: if image_base64 is None:
os.remove(image_path) os.remove(image_path)
continue continue
# 获取表情包的描述 # 获取表情包的描述
discription = await self._get_emoji_discription(image_base64) discription = await self._get_emoji_discription(image_base64)
if global_config.EMOJI_CHECK: if global_config.EMOJI_CHECK:
@ -247,30 +250,28 @@ class EmojiManager:
emoji_record = { emoji_record = {
'filename': filename, 'filename': filename,
'path': image_path, 'path': image_path,
'embedding':embedding, 'embedding': embedding,
'discription': discription, 'discription': discription,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }
# 保存到数据库 # 保存到数据库
self.db.db['emoji'].insert_one(emoji_record) self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {discription}") logger.info(f"描述: {discription}")
else: else:
logger.warning(f"跳过表情包: {filename}") logger.warning(f"跳过表情包: {filename}")
except Exception as e: except Exception as e:
logger.error(f"扫描表情包失败: {str(e)}") logger.exception(f"扫描表情包失败")
logger.error(traceback.format_exc())
async def _periodic_scan(self, interval_MINS: int = 10): async def _periodic_scan(self, interval_MINS: int = 10):
"""定期扫描新表情包""" """定期扫描新表情包"""
while True: while True:
print("\033[1;36m[表情包]\033[0m 开始扫描新表情包...") logger.info("开始扫描新表情包...")
await self.scan_new_emojis() await self.scan_new_emojis()
await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次 await asyncio.sleep(interval_MINS * 60) # 每600秒扫描一次
def check_emoji_file_integrity(self): def check_emoji_file_integrity(self):
"""检查表情包文件完整性 """检查表情包文件完整性
如果文件已被删除则从数据库中移除对应记录 如果文件已被删除则从数据库中移除对应记录
@ -281,7 +282,7 @@ class EmojiManager:
all_emojis = list(self.db.db.emoji.find()) all_emojis = list(self.db.db.emoji.find())
removed_count = 0 removed_count = 0
total_count = len(all_emojis) total_count = len(all_emojis)
for emoji in all_emojis: for emoji in all_emojis:
try: try:
if 'path' not in emoji: if 'path' not in emoji:
@ -289,13 +290,13 @@ class EmojiManager:
self.db.db.emoji.delete_one({'_id': emoji['_id']}) self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
if 'embedding' not in emoji: if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.db.emoji.delete_one({'_id': emoji['_id']}) self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
# 检查文件是否存在 # 检查文件是否存在
if not os.path.exists(emoji['path']): if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}") logger.warning(f"表情包文件已被删除: {emoji['path']}")
@ -309,7 +310,7 @@ class EmojiManager:
except Exception as item_error: except Exception as item_error:
logger.error(f"处理表情包记录时出错: {str(item_error)}") logger.error(f"处理表情包记录时出错: {str(item_error)}")
continue continue
# 验证清理结果 # 验证清理结果
remaining_count = self.db.db.emoji.count_documents({}) remaining_count = self.db.db.emoji.count_documents({})
if removed_count > 0: if removed_count > 0:
@ -317,7 +318,7 @@ class EmojiManager:
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")
else: else:
logger.info(f"已检查 {total_count} 个表情包记录") logger.info(f"已检查 {total_count} 个表情包记录")
except Exception as e: except Exception as e:
logger.error(f"检查表情包完整性失败: {str(e)}") logger.error(f"检查表情包完整性失败: {str(e)}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@ -328,6 +329,5 @@ class EmojiManager:
await asyncio.sleep(interval_MINS * 60) await asyncio.sleep(interval_MINS * 60)
# 创建全局单例 # 创建全局单例
emoji_manager = EmojiManager() emoji_manager = EmojiManager()

View File

@ -3,6 +3,7 @@ import time
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from nonebot import get_driver from nonebot import get_driver
from loguru import logger
from ...common.database import Database from ...common.database import Database
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
@ -39,13 +40,13 @@ class ResponseGenerator:
self.current_model_type = 'r1_distill' self.current_model_type = 'r1_distill'
current_model = self.model_r1_distill current_model = self.model_r1_distill
print(f"+++++++++++++++++{global_config.BOT_NICKNAME}{self.current_model_type}思考中+++++++++++++++++") logger.info(f"{global_config.BOT_NICKNAME}{self.current_model_type}思考中")
model_response = await self._generate_response_with_model(message, current_model) model_response = await self._generate_response_with_model(message, current_model)
raw_content=model_response raw_content=model_response
if model_response: if model_response:
print(f'{global_config.BOT_NICKNAME}的回复是:{model_response}') logger.info(f'{global_config.BOT_NICKNAME}的回复是:{model_response}')
model_response = await self._process_response(model_response) model_response = await self._process_response(model_response)
if model_response: if model_response:
@ -93,7 +94,7 @@ class ResponseGenerator:
try: try:
content, reasoning_content = await model.generate_response(prompt) content, reasoning_content = await model.generate_response(prompt)
except Exception as e: except Exception as e:
print(f"生成回复时出错: {e}") logger.exception(f"生成回复时出错: {e}")
return None return None
# 保存到数据库 # 保存到数据库
@ -145,7 +146,7 @@ class ResponseGenerator:
return ["neutral"] return ["neutral"]
except Exception as e: except Exception as e:
print(f"获取情感标签时出错: {e}") logger.exception(f"获取情感标签时出错: {e}")
return ["neutral"] return ["neutral"]
async def _process_response(self, content: str) -> Tuple[List[str], List[str]]: async def _process_response(self, content: str) -> Tuple[List[str], List[str]]:
@ -172,7 +173,7 @@ class InitiativeMessageGenerate:
prompt_builder._build_initiative_prompt_select(message.group_id) prompt_builder._build_initiative_prompt_select(message.group_id)
) )
content_select, reasoning = self.model_v3.generate_response(topic_select_prompt) content_select, reasoning = self.model_v3.generate_response(topic_select_prompt)
print(f"[DEBUG] {content_select} {reasoning}") logger.debug(f"{content_select} {reasoning}")
topics_list = [dot[0] for dot in dots_for_select] topics_list = [dot[0] for dot in dots_for_select]
if content_select: if content_select:
if content_select in topics_list: if content_select in topics_list:
@ -185,12 +186,12 @@ class InitiativeMessageGenerate:
select_dot[1], prompt_template select_dot[1], prompt_template
) )
content_check, reasoning_check = self.model_v3.generate_response(prompt_check) content_check, reasoning_check = self.model_v3.generate_response(prompt_check)
print(f"[DEBUG] {content_check} {reasoning_check}") logger.info(f"{content_check} {reasoning_check}")
if "yes" not in content_check.lower(): if "yes" not in content_check.lower():
return None return None
prompt = prompt_builder._build_initiative_prompt( prompt = prompt_builder._build_initiative_prompt(
select_dot, prompt_template, memory select_dot, prompt_template, memory
) )
content, reasoning = self.model_r1.generate_response_async(prompt) content, reasoning = self.model_r1.generate_response_async(prompt)
print(f"[DEBUG] {content} {reasoning}") logger.debug(f"[DEBUG] {content} {reasoning}")
return content return content

View File

@ -2,6 +2,7 @@ import asyncio
import time import time
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from loguru import logger
from nonebot.adapters.onebot.v11 import Bot from nonebot.adapters.onebot.v11 import Bot
from .cq_code import cq_code_tool from .cq_code import cq_code_tool
@ -13,45 +14,45 @@ from .config import global_config
class Message_Sender: class Message_Sender:
"""发送器""" """发送器"""
def __init__(self): def __init__(self):
self.message_interval = (0.5, 1) # 消息间隔时间范围(秒) self.message_interval = (0.5, 1) # 消息间隔时间范围(秒)
self.last_send_time = 0 self.last_send_time = 0
self._current_bot = None self._current_bot = None
def set_bot(self, bot: Bot): def set_bot(self, bot: Bot):
"""设置当前bot实例""" """设置当前bot实例"""
self._current_bot = bot self._current_bot = bot
async def send_group_message( async def send_group_message(
self, self,
group_id: int, group_id: int,
send_text: str, send_text: str,
auto_escape: bool = False, auto_escape: bool = False,
reply_message_id: int = None, reply_message_id: int = None,
at_user_id: int = None at_user_id: int = None
) -> None: ) -> None:
if not self._current_bot: if not self._current_bot:
raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例") raise RuntimeError("Bot未设置请先调用set_bot方法设置bot实例")
message = send_text message = send_text
# 如果需要回复 # 如果需要回复
if reply_message_id: if reply_message_id:
reply_cq = cq_code_tool.create_reply_cq(reply_message_id) reply_cq = cq_code_tool.create_reply_cq(reply_message_id)
message = reply_cq + message message = reply_cq + message
# 如果需要at # 如果需要at
# if at_user_id: # if at_user_id:
# at_cq = cq_code_tool.create_at_cq(at_user_id) # at_cq = cq_code_tool.create_at_cq(at_user_id)
# message = at_cq + " " + message # message = at_cq + " " + message
typing_time = calculate_typing_time(message) typing_time = calculate_typing_time(message)
if typing_time > 10: if typing_time > 10:
typing_time = 10 typing_time = 10
await asyncio.sleep(typing_time) await asyncio.sleep(typing_time)
# 发送消息 # 发送消息
try: try:
await self._current_bot.send_group_msg( await self._current_bot.send_group_msg(
@ -59,49 +60,49 @@ class Message_Sender:
message=message, message=message,
auto_escape=auto_escape auto_escape=auto_escape
) )
print(f"\033[1;34m[调试]\033[0m 发送消息{message}成功") logger.debug(f"发送消息{message}成功")
except Exception as e: except Exception as e:
print(f"发生错误 {e}") logger.exception(f"发送消息{message}失败")
print(f"\033[1;34m[调试]\033[0m 发送消息{message}失败")
class MessageContainer: class MessageContainer:
"""单个群的发送/思考消息容器""" """单个群的发送/思考消息容器"""
def __init__(self, group_id: int, max_size: int = 100): def __init__(self, group_id: int, max_size: int = 100):
self.group_id = group_id self.group_id = group_id
self.max_size = max_size self.max_size = max_size
self.messages = [] self.messages = []
self.last_send_time = 0 self.last_send_time = 0
self.thinking_timeout = 20 # 思考超时时间(秒) self.thinking_timeout = 20 # 思考超时时间(秒)
def get_timeout_messages(self) -> List[Message_Sending]: def get_timeout_messages(self) -> List[Message_Sending]:
"""获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序""" """获取所有超时的Message_Sending对象思考时间超过30秒按thinking_start_time排序"""
current_time = time.time() current_time = time.time()
timeout_messages = [] timeout_messages = []
for msg in self.messages: for msg in self.messages:
if isinstance(msg, Message_Sending): if isinstance(msg, Message_Sending):
if current_time - msg.thinking_start_time > self.thinking_timeout: if current_time - msg.thinking_start_time > self.thinking_timeout:
timeout_messages.append(msg) timeout_messages.append(msg)
# 按thinking_start_time排序时间早的在前面 # 按thinking_start_time排序时间早的在前面
timeout_messages.sort(key=lambda x: x.thinking_start_time) timeout_messages.sort(key=lambda x: x.thinking_start_time)
return timeout_messages return timeout_messages
def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]: def get_earliest_message(self) -> Optional[Union[Message_Thinking, Message_Sending]]:
"""获取thinking_start_time最早的消息对象""" """获取thinking_start_time最早的消息对象"""
if not self.messages: if not self.messages:
return None return None
earliest_time = float('inf') earliest_time = float('inf')
earliest_message = None earliest_message = None
for msg in self.messages: for msg in self.messages:
msg_time = msg.thinking_start_time msg_time = msg.thinking_start_time
if msg_time < earliest_time: if msg_time < earliest_time:
earliest_time = msg_time earliest_time = msg_time
earliest_message = msg earliest_message = msg
return earliest_message return earliest_message
def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None: def add_message(self, message: Union[Message_Thinking, Message_Sending]) -> None:
"""添加消息到队列""" """添加消息到队列"""
# print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群") # print(f"\033[1;32m[添加消息]\033[0m 添加消息到对应群")
@ -110,7 +111,7 @@ class MessageContainer:
self.messages.append(single_message) self.messages.append(single_message)
else: else:
self.messages.append(message) self.messages.append(message)
def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool: def remove_message(self, message: Union[Message_Thinking, Message_Sending]) -> bool:
"""移除消息如果消息存在则返回True否则返回False""" """移除消息如果消息存在则返回True否则返回False"""
try: try:
@ -119,97 +120,103 @@ class MessageContainer:
return True return True
return False return False
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 移除消息时发生错误: {e}") logger.exception(f"移除消息时发生错误: {e}")
return False return False
def has_messages(self) -> bool: def has_messages(self) -> bool:
"""检查是否有待发送的消息""" """检查是否有待发送的消息"""
return bool(self.messages) return bool(self.messages)
def get_all_messages(self) -> List[Union[Message, Message_Thinking]]: def get_all_messages(self) -> List[Union[Message, Message_Thinking]]:
"""获取所有消息""" """获取所有消息"""
return list(self.messages) return list(self.messages)
class MessageManager: class MessageManager:
"""管理所有群的消息容器""" """管理所有群的消息容器"""
def __init__(self): def __init__(self):
self.containers: Dict[int, MessageContainer] = {} self.containers: Dict[int, MessageContainer] = {}
self.storage = MessageStorage() self.storage = MessageStorage()
self._running = True self._running = True
def get_container(self, group_id: int) -> MessageContainer: def get_container(self, group_id: int) -> MessageContainer:
"""获取或创建群的消息容器""" """获取或创建群的消息容器"""
if group_id not in self.containers: if group_id not in self.containers:
self.containers[group_id] = MessageContainer(group_id) self.containers[group_id] = MessageContainer(group_id)
return self.containers[group_id] return self.containers[group_id]
def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None: def add_message(self, message: Union[Message_Thinking, Message_Sending, MessageSet]) -> None:
container = self.get_container(message.group_id) container = self.get_container(message.group_id)
container.add_message(message) container.add_message(message)
async def process_group_messages(self, group_id: int): async def process_group_messages(self, group_id: int):
"""处理群消息""" """处理群消息"""
# if int(time.time() / 3) == time.time() / 3: # if int(time.time() / 3) == time.time() / 3:
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息") # print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
container = self.get_container(group_id) container = self.get_container(group_id)
if container.has_messages(): if container.has_messages():
#最早的对象,可能是思考消息,也可能是发送消息 # 最早的对象,可能是思考消息,也可能是发送消息
message_earliest = container.get_earliest_message() #一个message_thinking or message_sending message_earliest = container.get_earliest_message() # 一个message_thinking or message_sending
#如果是思考消息 # 如果是思考消息
if isinstance(message_earliest, Message_Thinking): if isinstance(message_earliest, Message_Thinking):
#优先等待这条消息 # 优先等待这条消息
message_earliest.update_thinking_time() message_earliest.update_thinking_time()
thinking_time = message_earliest.thinking_time thinking_time = message_earliest.thinking_time
print(f"\033[1;34m[调试]\033[0m 消息正在思考中,已思考{int(thinking_time)}\033[K\r", end='', flush=True) print(f"消息正在思考中,已思考{int(thinking_time)}\r", end='', flush=True)
# 检查是否超时 # 检查是否超时
if thinking_time > global_config.thinking_timeout: if thinking_time > global_config.thinking_timeout:
print(f"\033[1;33m[警告]\033[0m 消息思考超时({thinking_time}秒),移除该消息") logger.warning(f"消息思考超时({thinking_time}秒),移除该消息")
container.remove_message(message_earliest) container.remove_message(message_earliest)
else:# 如果不是message_thinking就只能是message_sending else: # 如果不是message_thinking就只能是message_sending
print(f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中") logger.debug(f"消息'{message_earliest.processed_plain_text}'正在发送中")
#直接发,等什么呢 # 直接发,等什么呢
if message_earliest.is_head and message_earliest.update_thinking_time() >30: if message_earliest.is_head and message_earliest.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False, reply_message_id=message_earliest.reply_message_id) await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
auto_escape=False,
reply_message_id=message_earliest.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, message_earliest.processed_plain_text, auto_escape=False) await message_sender.send_group_message(group_id, message_earliest.processed_plain_text,
#移除消息 auto_escape=False)
# 移除消息
if message_earliest.is_emoji: if message_earliest.is_emoji:
message_earliest.processed_plain_text = "[表情包]" message_earliest.processed_plain_text = "[表情包]"
await self.storage.store_message(message_earliest, None) await self.storage.store_message(message_earliest, None)
container.remove_message(message_earliest) container.remove_message(message_earliest)
#获取并处理超时消息 # 获取并处理超时消息
message_timeout = container.get_timeout_messages() #也许是一堆message_sending message_timeout = container.get_timeout_messages() # 也许是一堆message_sending
if message_timeout: if message_timeout:
print(f"\033[1;34m[调试]\033[0m 发现{len(message_timeout)}条超时消息") logger.warning(f"发现{len(message_timeout)}条超时消息")
for msg in message_timeout: for msg in message_timeout:
if msg == message_earliest: if msg == message_earliest:
continue # 跳过已经处理过的消息 continue # 跳过已经处理过的消息
try: try:
#发送 # 发送
if msg.is_head and msg.update_thinking_time() >30: if msg.is_head and msg.update_thinking_time() > 30:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False, reply_message_id=msg.reply_message_id) await message_sender.send_group_message(group_id, msg.processed_plain_text,
auto_escape=False,
reply_message_id=msg.reply_message_id)
else: else:
await message_sender.send_group_message(group_id, msg.processed_plain_text, auto_escape=False) await message_sender.send_group_message(group_id, msg.processed_plain_text,
auto_escape=False)
#如果是表情包,则替换为"[表情包]" # 如果是表情包,则替换为"[表情包]"
if msg.is_emoji: if msg.is_emoji:
msg.processed_plain_text = "[表情包]" msg.processed_plain_text = "[表情包]"
await self.storage.store_message(msg, None) await self.storage.store_message(msg, None)
# 安全地移除消息 # 安全地移除消息
if not container.remove_message(msg): if not container.remove_message(msg):
print("\033[1;33m[警告]\033[0m 尝试删除不存在的消息") logger.warning("尝试删除不存在的消息")
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理超时消息时发生错误: {e}") logger.exception(f"处理超时消息时发生错误: {e}")
continue continue
async def start_processor(self): async def start_processor(self):
"""启动消息处理器""" """启动消息处理器"""
while self._running: while self._running:
@ -217,9 +224,10 @@ class MessageManager:
tasks = [] tasks = []
for group_id in self.containers.keys(): for group_id in self.containers.keys():
tasks.append(self.process_group_messages(group_id)) tasks.append(self.process_group_messages(group_id))
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# 创建全局消息管理器实例 # 创建全局消息管理器实例
message_manager = MessageManager() message_manager = MessageManager()
# 创建全局发送器实例 # 创建全局发送器实例

View File

@ -1,6 +1,7 @@
import random import random
import time import time
from typing import Optional from typing import Optional
from loguru import logger
from ...common.database import Database from ...common.database import Database
from ..memory_system.memory import hippocampus, memory_graph from ..memory_system.memory import hippocampus, memory_graph
@ -16,13 +17,11 @@ class PromptBuilder:
self.activate_messages = '' self.activate_messages = ''
self.db = Database.get_instance() self.db = Database.get_instance()
async def _build_prompt(self,
message_txt: str,
async def _build_prompt(self, sender_name: str = "某人",
message_txt: str, relationship_value: float = 0.0,
sender_name: str = "某人", group_id: Optional[int] = None) -> tuple[str, str]:
relationship_value: float = 0.0,
group_id: Optional[int] = None) -> tuple[str, str]:
"""构建prompt """构建prompt
Args: Args:
@ -33,57 +32,56 @@ class PromptBuilder:
Returns: Returns:
str: 构建好的prompt str: 构建好的prompt
""" """
#先禁用关系 # 先禁用关系
if 0 > 30: if 0 > 30:
relation_prompt = "关系特别特别好,你很喜欢喜欢他" relation_prompt = "关系特别特别好,你很喜欢喜欢他"
relation_prompt_2 = "热情发言或者回复" relation_prompt_2 = "热情发言或者回复"
elif 0 <-20: elif 0 < -20:
relation_prompt = "关系很差,你很讨厌他" relation_prompt = "关系很差,你很讨厌他"
relation_prompt_2 = "骂他" relation_prompt_2 = "骂他"
else: else:
relation_prompt = "关系一般" relation_prompt = "关系一般"
relation_prompt_2 = "发言或者回复" relation_prompt_2 = "发言或者回复"
#开始构建prompt # 开始构建prompt
# 心情
#心情
mood_manager = MoodManager.get_instance() mood_manager = MoodManager.get_instance()
mood_prompt = mood_manager.get_prompt() mood_prompt = mood_manager.get_prompt()
# 日程构建
#日程构建
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
#知识构建 # 知识构建
start_time = time.time() start_time = time.time()
prompt_info = '' prompt_info = ''
promt_info_prompt = '' promt_info_prompt = ''
prompt_info = await self.get_prompt_info(message_txt,threshold=0.5) prompt_info = await self.get_prompt_info(message_txt, threshold=0.5)
if prompt_info: if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' prompt_info = f'''你有以下这些[知识]{prompt_info}请你记住上面的[
知识]之后可能会用到-'''
end_time = time.time() end_time = time.time()
print(f"\033[1;32m[知识检索]\033[0m 耗时: {(end_time - start_time):.3f}") logger.debug(f"知识检索耗时: {(end_time - start_time):.3f}")
# 获取聊天上下文 # 获取聊天上下文
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# 使用新的记忆获取方法 # 使用新的记忆获取方法
memory_prompt = '' memory_prompt = ''
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 调用 hippocampus 的 get_relevant_memories 方法
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt, text=message_txt,
@ -91,30 +89,28 @@ class PromptBuilder:
similarity_threshold=0.4, similarity_threshold=0.4,
max_memory_num=5 max_memory_num=5
) )
if relevant_memories: if relevant_memories:
# 格式化记忆内容 # 格式化记忆内容
memory_items = [] memory_items = []
for memory in relevant_memories: for memory in relevant_memories:
memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}") memory_items.append(f"关于「{memory['topic']}」的记忆:{memory['content']}")
memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n" memory_prompt = "看到这些聊天,你想起来:\n" + "\n".join(memory_items) + "\n"
# 打印调试信息 # 打印调试信息
print("\n\033[1;32m[记忆检索]\033[0m 找到以下相关记忆:") logger.debug("[记忆检索]找到以下相关记忆:")
for memory in relevant_memories: for memory in relevant_memories:
print(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}") logger.debug(f"- 主题「{memory['topic']}」[相似度: {memory['similarity']:.2f}]: {memory['content']}")
end_time = time.time() end_time = time.time()
print(f"\033[1;32m[回忆耗时]\033[0m 耗时: {(end_time - start_time):.3f}") logger.info(f"回忆耗时: {(end_time - start_time):.3f}")
# 激活prompt构建
#激活prompt构建
activate_prompt = '' activate_prompt = ''
activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}" activate_prompt = f"以上是群里正在进行的聊天,{memory_prompt} 现在昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},{mood_prompt},你想要{relation_prompt_2}"
#检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中 # 检测机器人相关词汇,改为关键词检测与反应功能了,提取到全局配置中
# bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人'] # bot_keywords = ['人机', 'bot', '机器', '入机', 'robot', '机器人']
# is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords) # is_bot = any(keyword in message_txt.lower() for keyword in bot_keywords)
# if is_bot: # if is_bot:
@ -127,12 +123,11 @@ class PromptBuilder:
for rule in global_config.keywords_reaction_rules: for rule in global_config.keywords_reaction_rules:
if rule.get("enable", False): if rule.get("enable", False):
if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])): if any(keyword in message_txt.lower() for keyword in rule.get("keywords", [])):
print(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}") logger.info(f"检测到以下关键词之一:{rule.get('keywords', [])},触发反应:{rule.get('reaction', '')}")
keywords_reaction_prompt += rule.get("reaction", "") + '' keywords_reaction_prompt += rule.get("reaction", "") + ''
# 人格选择
#人格选择 personality = global_config.PROMPT_PERSONALITY
personality=global_config.PROMPT_PERSONALITY
probability_1 = global_config.PERSONALITY_1 probability_1 = global_config.PERSONALITY_1
probability_2 = global_config.PERSONALITY_2 probability_2 = global_config.PERSONALITY_2
probability_3 = global_config.PERSONALITY_3 probability_3 = global_config.PERSONALITY_3
@ -150,8 +145,8 @@ class PromptBuilder:
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt}, prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt},
现在请你给出日常且口语化的回复请表现你自己的见解不要一昧迎合尽量简短一些{keywords_reaction_prompt} 现在请你给出日常且口语化的回复请表现你自己的见解不要一昧迎合尽量简短一些{keywords_reaction_prompt}
请你表达自己的见解和观点可以有个性''' 请你表达自己的见解和观点可以有个性'''
#中文高手(新加的好玩功能) # 中文高手(新加的好玩功能)
prompt_ger = '' prompt_ger = ''
if random.random() < 0.04: if random.random() < 0.04:
prompt_ger += '你喜欢用倒装句' prompt_ger += '你喜欢用倒装句'
@ -159,23 +154,23 @@ class PromptBuilder:
prompt_ger += '你喜欢用反问句' prompt_ger += '你喜欢用反问句'
if random.random() < 0.01: if random.random() < 0.01:
prompt_ger += '你喜欢用文言文' prompt_ger += '你喜欢用文言文'
#额外信息要求 # 额外信息要求
extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容''' extra_info = '''但是记得回复平淡一些,简短一些,尤其注意在没明确提到时不要过多提及自身的背景, 不要直接回复别人发的表情包,记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等),只需要输出回复内容就好,不要输出其他任何内容'''
#合并prompt # 合并prompt
prompt = "" prompt = ""
prompt += f"{prompt_info}\n" prompt += f"{prompt_info}\n"
prompt += f"{prompt_date}\n" prompt += f"{prompt_date}\n"
prompt += f"{chat_talking_prompt}\n" prompt += f"{chat_talking_prompt}\n"
prompt += f"{prompt_personality}\n" prompt += f"{prompt_personality}\n"
prompt += f"{prompt_ger}\n" prompt += f"{prompt_ger}\n"
prompt += f"{extra_info}\n" prompt += f"{extra_info}\n"
'''读空气prompt处理''' '''读空气prompt处理'''
activate_prompt_check=f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。" activate_prompt_check = f"以上是群里正在进行的聊天,昵称为 '{sender_name}' 的用户说的:{message_txt}。引起了你的注意,你和他{relation_prompt},你想要{relation_prompt_2},但是这不一定是合适的时机,请你决定是否要回应这条消息。"
prompt_personality_check = '' prompt_personality_check = ''
extra_check_info=f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。" extra_check_info = f"请注意把握群里的聊天内容的基础上,综合群内的氛围,例如,和{global_config.BOT_NICKNAME}相关的话题要积极回复,如果是at自己的消息一定要回复如果自己正在和别人聊天一定要回复其他话题如果合适搭话也可以回复如果认为应该回复请输出yes否则输出no请注意是决定是否需要回复而不是编写回复内容除了yes和no不要输出任何回复内容。"
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[0]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
elif personality_choice < probability_1 + probability_2: # 第二种人格 elif personality_choice < probability_1 + probability_2: # 第二种人格
@ -183,34 +178,36 @@ class PromptBuilder:
else: # 第三种人格 else: # 第三种人格
prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}''' prompt_personality_check = f'''你的网名叫{global_config.BOT_NICKNAME}{personality[2]}, 你正在浏览qq群{promt_info_prompt} {activate_prompt_check} {extra_check_info}'''
prompt_check_if_response=f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}" prompt_check_if_response = f"{prompt_info}\n{prompt_date}\n{chat_talking_prompt}\n{prompt_personality_check}"
return prompt,prompt_check_if_response return prompt, prompt_check_if_response
def _build_initiative_prompt_select(self,group_id): def _build_initiative_prompt_select(self, group_id, probability_1=0.8, probability_2=0.1):
current_date = time.strftime("%Y-%m-%d", time.localtime()) current_date = time.strftime("%Y-%m-%d", time.localtime())
current_time = time.strftime("%H:%M:%S", time.localtime()) current_time = time.strftime("%H:%M:%S", time.localtime())
bot_schedule_now_time,bot_schedule_now_activity = bot_schedule.get_current_task() bot_schedule_now_time, bot_schedule_now_activity = bot_schedule.get_current_task()
prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n''' prompt_date = f'''今天是{current_date},现在是{current_time},你今天的日程是:\n{bot_schedule.today_schedule}\n你现在正在{bot_schedule_now_activity}\n'''
chat_talking_prompt = '' chat_talking_prompt = ''
if group_id: if group_id:
chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id, limit=global_config.MAX_CONTEXT_SIZE,combine = True) chat_talking_prompt = get_recent_group_detailed_plain_text(self.db, group_id,
limit=global_config.MAX_CONTEXT_SIZE,
combine=True)
chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}" chat_talking_prompt = f"以下是群里正在聊天的内容:\n{chat_talking_prompt}"
# print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}") # print(f"\033[1;34m[调试]\033[0m 已从数据库获取群 {group_id} 的消息记录:{chat_talking_prompt}")
# 获取主动发言的话题 # 获取主动发言的话题
all_nodes=memory_graph.dots all_nodes = memory_graph.dots
all_nodes=filter(lambda dot:len(dot[1]['memory_items'])>3,all_nodes) all_nodes = filter(lambda dot: len(dot[1]['memory_items']) > 3, all_nodes)
nodes_for_select=random.sample(all_nodes,5) nodes_for_select = random.sample(all_nodes, 5)
topics=[info[0] for info in nodes_for_select] topics = [info[0] for info in nodes_for_select]
infos=[info[1] for info in nodes_for_select] infos = [info[1] for info in nodes_for_select]
#激活prompt构建 # 激活prompt构建
activate_prompt = '' activate_prompt = ''
activate_prompt = "以上是群里正在进行的聊天。" activate_prompt = "以上是群里正在进行的聊天。"
personality=global_config.PROMPT_PERSONALITY personality = global_config.PROMPT_PERSONALITY
prompt_personality = '' prompt_personality = ''
personality_choice = random.random() personality_choice = random.random()
if personality_choice < probability_1: # 第一种人格 if personality_choice < probability_1: # 第一种人格
@ -219,32 +216,31 @@ class PromptBuilder:
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}''' prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[1]}'''
else: # 第三种人格 else: # 第三种人格
prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}''' prompt_personality = f'''{activate_prompt}你的网名叫{global_config.BOT_NICKNAME}{personality[2]}'''
topics_str=','.join(f"\"{topics}\"")
prompt_for_select=f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
prompt_initiative_select=f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
prompt_regular=f"{prompt_date}\n{prompt_personality}"
return prompt_initiative_select,nodes_for_select,prompt_regular topics_str = ','.join(f"\"{topics}\"")
prompt_for_select = f"你现在想在群里发言,回忆了一下,想到几个话题,分别是{topics_str},综合当前状态以及群内气氛,请你在其中选择一个合适的话题,注意只需要输出话题,除了话题什么也不要输出(双引号也不要输出)"
def _build_initiative_prompt_check(self,selected_node,prompt_regular):
memory=random.sample(selected_node['memory_items'],3) prompt_initiative_select = f"{prompt_date}\n{prompt_personality}\n{prompt_for_select}"
memory='\n'.join(memory) prompt_regular = f"{prompt_date}\n{prompt_personality}"
prompt_for_check=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check,memory return prompt_initiative_select, nodes_for_select, prompt_regular
def _build_initiative_prompt(self,selected_node,prompt_regular,memory): def _build_initiative_prompt_check(self, selected_node, prompt_regular):
prompt_for_initiative=f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)" memory = random.sample(selected_node['memory_items'], 3)
memory = '\n'.join(memory)
prompt_for_check = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n以这个作为主题发言合适吗请在把握群里的聊天内容的基础上综合群内的氛围如果认为应该发言请输出yes否则输出no请注意是决定是否需要发言而不是编写回复内容除了yes和no不要输出任何回复内容。"
return prompt_for_check, memory
def _build_initiative_prompt(self, selected_node, prompt_regular, memory):
prompt_for_initiative = f"{prompt_regular}你现在想在群里发言,回忆了一下,想到一个话题,是{selected_node['concept']},关于这个话题的记忆有\n{memory}\n,请在把握群里的聊天内容的基础上,综合群内的氛围,以日常且口语化的口吻,简短且随意一点进行发言,不要说的太有条理,可以有个性。记住不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
return prompt_for_initiative return prompt_for_initiative
async def get_prompt_info(self,message:str,threshold:float): async def get_prompt_info(self, message: str, threshold: float):
related_info = '' related_info = ''
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = await get_embedding(message) embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold) related_info += self.get_info_from_db(embedding, threshold=threshold)
return related_info return related_info
def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str: def get_info_from_db(self, query_embedding: list, limit: int = 1, threshold: float = 0.5) -> str:
@ -305,14 +301,15 @@ class PromptBuilder:
{"$limit": limit}, {"$limit": limit},
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}}
] ]
results = list(self.db.db.knowledges.aggregate(pipeline)) results = list(self.db.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:
return '' return ''
# 返回所有找到的内容,用换行分隔 # 返回所有找到的内容,用换行分隔
return '\n'.join(str(result['content']) for result in results) return '\n'.join(str(result['content']) for result in results)
prompt_builder = PromptBuilder()
prompt_builder = PromptBuilder()

View File

@ -4,9 +4,11 @@ from nonebot import get_driver
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from .config import global_config from .config import global_config
from loguru import logger
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
class TopicIdentifier: class TopicIdentifier:
def __init__(self): def __init__(self):
@ -23,19 +25,20 @@ class TopicIdentifier:
# 使用 LLM_request 类进行请求 # 使用 LLM_request 类进行请求
topic, _ = await self.llm_topic_judge.generate_response(prompt) topic, _ = await self.llm_topic_judge.generate_response(prompt)
if not topic: if not topic:
print("\033[1;31m[错误]\033[0m LLM API 返回为空") logger.error("LLM API 返回为空")
return None return None
# 直接在这里处理主题解析 # 直接在这里处理主题解析
if not topic or topic == "无主题": if not topic or topic == "无主题":
return None return None
# 解析主题字符串为列表 # 解析主题字符串为列表
topic_list = [t.strip() for t in topic.split(",") if t.strip()] topic_list = [t.strip() for t in topic.split(",") if t.strip()]
print(f"\033[1;32m[主题识别]\033[0m 主题: {topic_list}") logger.info(f"主题: {topic_list}")
return topic_list if topic_list else None return topic_list if topic_list else None
topic_identifier = TopicIdentifier()
topic_identifier = TopicIdentifier()

View File

@ -7,6 +7,7 @@ from typing import Dict, List
import jieba import jieba
import numpy as np import numpy as np
from nonebot import get_driver from nonebot import get_driver
from loguru import logger
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..utils.typo_generator import ChineseTypoGenerator from ..utils.typo_generator import ChineseTypoGenerator
@ -39,16 +40,16 @@ def combine_messages(messages: List[Message]) -> str:
def db_message_to_str(message_dict: Dict) -> str: def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}") logger.debug(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try: try:
name = "[(%s)%s]%s" % ( name = "[(%s)%s]%s" % (
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", "")) message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
except: except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "") content = message_dict.get("processed_plain_text", "")
result = f"[{time_str}] {name}: {content}\n" result = f"[{time_str}] {name}: {content}\n"
print(f"result: {result}") logger.debug(f"result: {result}")
return result return result
@ -176,7 +177,7 @@ async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
await msg.initialize() await msg.initialize()
message_objects.append(msg) message_objects.append(msg)
except KeyError: except KeyError:
print("[WARNING] 数据库中存在无效的消息") logger.warning("数据库中存在无效的消息")
continue continue
# 按时间正序排列 # 按时间正序排列
@ -292,11 +293,10 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
sentence = sentence.replace('', ' ').replace(',', ' ') sentence = sentence.replace('', ' ').replace(',', ' ')
sentences_done.append(sentence) sentences_done.append(sentence)
print(f"处理后的句子: {sentences_done}") logger.info(f"处理后的句子: {sentences_done}")
return sentences_done return sentences_done
def random_remove_punctuation(text: str) -> str: def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯 """随机处理标点符号,模拟人类打字习惯
@ -324,11 +324,10 @@ def random_remove_punctuation(text: str) -> str:
return result return result
def process_llm_response(text: str) -> List[str]: def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content) # processed_response = process_text_with_typos(content)
if len(text) > 200: if len(text) > 200:
print(f"回复过长 ({len(text)} 字符),返回默认回复") logger.warning(f"回复过长 ({len(text)} 字符),返回默认回复")
return ['懒得说'] return ['懒得说']
# 处理长消息 # 处理长消息
typo_generator = ChineseTypoGenerator( typo_generator = ChineseTypoGenerator(
@ -348,9 +347,9 @@ def process_llm_response(text: str) -> List[str]:
else: else:
sentences.append(sentence) sentences.append(sentence)
# 检查分割后的消息数量是否过多超过3条 # 检查分割后的消息数量是否过多超过3条
if len(sentences) > 5: if len(sentences) > 5:
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") logger.warning(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f'{global_config.BOT_NICKNAME}不知道哦'] return [f'{global_config.BOT_NICKNAME}不知道哦']
return sentences return sentences
@ -372,15 +371,15 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.4, english_
mood_arousal = mood_manager.current_mood.arousal mood_arousal = mood_manager.current_mood.arousal
# 映射到0.5到2倍的速度系数 # 映射到0.5到2倍的速度系数
typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半 typing_speed_multiplier = 1.5 ** mood_arousal # 唤醒度为1时速度翻倍,为-1时速度减半
chinese_time *= 1/typing_speed_multiplier chinese_time *= 1 / typing_speed_multiplier
english_time *= 1/typing_speed_multiplier english_time *= 1 / typing_speed_multiplier
# 计算中文字符数 # 计算中文字符数
chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff') chinese_chars = sum(1 for char in input_string if '\u4e00' <= char <= '\u9fff')
# 如果只有一个中文字符使用3倍时间 # 如果只有一个中文字符使用3倍时间
if chinese_chars == 1 and len(input_string.strip()) == 1: if chinese_chars == 1 and len(input_string.strip()) == 1:
return chinese_time * 3 + 0.3 # 加上回车时间 return chinese_time * 3 + 0.3 # 加上回车时间
# 正常计算所有字符的输入时间 # 正常计算所有字符的输入时间
total_time = 0.0 total_time = 0.0
for char in input_string: for char in input_string:

View File

@ -1,5 +1,6 @@
import asyncio import asyncio
from .config import global_config from .config import global_config
from loguru import logger
class WillingManager: class WillingManager:
@ -30,16 +31,16 @@ class WillingManager:
# print(f"初始意愿: {current_willing}") # print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0: if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9 current_willing += 0.9
print(f"被提及, 当前意愿: {current_willing}") logger.info(f"被提及, 当前意愿: {current_willing}")
elif is_mentioned_bot: elif is_mentioned_bot:
current_willing += 0.05 current_willing += 0.05
print(f"被重复提及, 当前意愿: {current_willing}") logger.info(f"被重复提及, 当前意愿: {current_willing}")
if is_emoji: if is_emoji:
current_willing *= 0.1 current_willing *= 0.1
print(f"表情包, 当前意愿: {current_willing}") logger.info(f"表情包, 当前意愿: {current_willing}")
print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}") logger.debug(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度 interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
if interested_rate > 0.4: if interested_rate > 0.4:
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") # print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")

View File

@ -224,7 +224,7 @@ class Hippocampus:
for msg in messages: for msg in messages:
input_text += f"{msg['text']}\n" input_text += f"{msg['text']}\n"
print(input_text) logger.debug(input_text)
topic_num = self.calculate_topic_num(input_text, compress_rate) topic_num = self.calculate_topic_num(input_text, compress_rate)
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num)) topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
@ -235,7 +235,7 @@ class Hippocampus:
topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()] topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)] filtered_topics = [topic for topic in topics if not any(keyword in topic for keyword in filter_keywords)]
print(f"过滤后话题: {filtered_topics}") logger.info(f"过滤后话题: {filtered_topics}")
# 创建所有话题的请求任务 # 创建所有话题的请求任务
tasks = [] tasks = []
@ -259,8 +259,9 @@ class Hippocampus:
topic_by_length = text.count('\n') * compress_rate topic_by_length = text.count('\n') * compress_rate
topic_by_information_content = max(1, min(5, int((information_content - 3) * 2))) topic_by_information_content = max(1, min(5, int((information_content - 3) * 2)))
topic_num = int((topic_by_length + topic_by_information_content) / 2) topic_num = int((topic_by_length + topic_by_information_content) / 2)
print( logger.debug(
f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, topic_num: {topic_num}") f"topic_by_length: {topic_by_length}, topic_by_information_content: {topic_by_information_content}, "
f"topic_num: {topic_num}")
return topic_num return topic_num
async def operation_build_memory(self, chat_size=20): async def operation_build_memory(self, chat_size=20):
@ -275,22 +276,22 @@ class Hippocampus:
bar_length = 30 bar_length = 30
filled_length = int(bar_length * i // len(memory_sample)) filled_length = int(bar_length * i // len(memory_sample))
bar = '' * filled_length + '-' * (bar_length - filled_length) bar = '' * filled_length + '-' * (bar_length - filled_length)
print(f"\n进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})") logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_sample)})")
# 生成压缩后记忆 ,表现为 (话题,记忆) 的元组 # 生成压缩后记忆 ,表现为 (话题,记忆) 的元组
compressed_memory = set() compressed_memory = set()
compress_rate = 0.1 compress_rate = 0.1
compressed_memory = await self.memory_compress(input_text, compress_rate) compressed_memory = await self.memory_compress(input_text, compress_rate)
print(f"\033[1;33m压缩后记忆数量\033[0m: {len(compressed_memory)}") logger.info(f"压缩后记忆数量: {len(compressed_memory)}")
# 将记忆加入到图谱中 # 将记忆加入到图谱中
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
print(f"\033[1;32m添加节点\033[0m: {topic}") logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory)
all_topics.append(topic) # 收集所有话题 all_topics.append(topic) # 收集所有话题
for i in range(len(all_topics)): for i in range(len(all_topics)):
for j in range(i + 1, len(all_topics)): for j in range(i + 1, len(all_topics)):
print(f"\033[1;32m连接节点\033[0m: {all_topics[i]}{all_topics[j]}") logger.info(f"连接节点: {all_topics[i]}{all_topics[j]}")
self.memory_graph.connect_dot(all_topics[i], all_topics[j]) self.memory_graph.connect_dot(all_topics[i], all_topics[j])
self.sync_memory_to_db() self.sync_memory_to_db()
@ -451,14 +452,14 @@ class Hippocampus:
removed_item = self.memory_graph.forget_topic(node) removed_item = self.memory_graph.forget_topic(node)
if removed_item: if removed_item:
forgotten_nodes.append((node, removed_item)) forgotten_nodes.append((node, removed_item))
print(f"遗忘节点 {node} 的记忆: {removed_item}") logger.debug(f"遗忘节点 {node} 的记忆: {removed_item}")
# 同步到数据库 # 同步到数据库
if forgotten_nodes: if forgotten_nodes:
self.sync_memory_to_db() self.sync_memory_to_db()
print(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆") logger.debug(f"完成遗忘操作,共遗忘 {len(forgotten_nodes)} 个节点的记忆")
else: else:
print("本次检查没有节点满足遗忘条件") logger.debug("本次检查没有节点满足遗忘条件")
async def merge_memory(self, topic): async def merge_memory(self, topic):
""" """
@ -481,8 +482,8 @@ class Hippocampus:
# 拼接成文本 # 拼接成文本
merged_text = "\n".join(selected_memories) merged_text = "\n".join(selected_memories)
print(f"\n[合并记忆] 话题: {topic}") logger.debug(f"\n[合并记忆] 话题: {topic}")
print(f"选择的记忆:\n{merged_text}") logger.debug(f"选择的记忆:\n{merged_text}")
# 使用memory_compress生成新的压缩记忆 # 使用memory_compress生成新的压缩记忆
compressed_memories = await self.memory_compress(selected_memories, 0.1) compressed_memories = await self.memory_compress(selected_memories, 0.1)
@ -494,11 +495,11 @@ class Hippocampus:
# 添加新的压缩记忆 # 添加新的压缩记忆
for _, compressed_memory in compressed_memories: for _, compressed_memory in compressed_memories:
memory_items.append(compressed_memory) memory_items.append(compressed_memory)
print(f"添加压缩记忆: {compressed_memory}") logger.info(f"添加压缩记忆: {compressed_memory}")
# 更新节点的记忆项 # 更新节点的记忆项
self.memory_graph.G.nodes[topic]['memory_items'] = memory_items self.memory_graph.G.nodes[topic]['memory_items'] = memory_items
print(f"完成记忆合并,当前记忆数量: {len(memory_items)}") logger.debug(f"完成记忆合并,当前记忆数量: {len(memory_items)}")
async def operation_merge_memory(self, percentage=0.1): async def operation_merge_memory(self, percentage=0.1):
""" """
@ -524,16 +525,16 @@ class Hippocampus:
# 如果内容数量超过100进行合并 # 如果内容数量超过100进行合并
if content_count > 100: if content_count > 100:
print(f"\n检查节点: {node}, 当前记忆数量: {content_count}") logger.debug(f"检查节点: {node}, 当前记忆数量: {content_count}")
await self.merge_memory(node) await self.merge_memory(node)
merged_nodes.append(node) merged_nodes.append(node)
# 同步到数据库 # 同步到数据库
if merged_nodes: if merged_nodes:
self.sync_memory_to_db() self.sync_memory_to_db()
print(f"\n完成记忆合并操作,共处理 {len(merged_nodes)} 个节点") logger.debug(f"完成记忆合并操作,共处理 {len(merged_nodes)} 个节点")
else: else:
print("\n本次检查没有需要合并的节点") logger.debug("本次检查没有需要合并的节点")
def find_topic_llm(self, text, topic_num): def find_topic_llm(self, text, topic_num):
prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。' prompt = f'这是一段文字:{text}。请你从这段话中总结出{topic_num}个关键的概念,可以是名词,动词,或者特定人物,帮我列出来,用逗号,隔开,尽可能精简。只需要列举{topic_num}个话题就好,不要有序号,不要告诉我其他内容。'
@ -628,7 +629,7 @@ class Hippocampus:
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int: async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度""" """计算输入文本对记忆的激活程度"""
print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}") logger.info(f"识别主题: {await self._identify_topics(text)}")
# 识别主题 # 识别主题
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
@ -659,8 +660,8 @@ class Hippocampus:
penalty = 1.0 / (1 + math.log(content_count + 1)) penalty = 1.0 / (1 + math.log(content_count + 1))
activation = int(score * 50 * penalty) activation = int(score * 50 * penalty)
print( logger.info(
f"\033[1;32m[记忆激活]\033[0m 单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}") f"[记忆激活]单主题「{topic}」- 相似度: {score:.3f}, 内容数: {content_count}, 激活值: {activation}")
return activation return activation
# 计算关键词匹配率,同时考虑内容数量 # 计算关键词匹配率,同时考虑内容数量
@ -687,8 +688,8 @@ class Hippocampus:
matched_topics.add(input_topic) matched_topics.add(input_topic)
adjusted_sim = sim * penalty adjusted_sim = sim * penalty
topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim) topic_similarities[input_topic] = max(topic_similarities.get(input_topic, 0), adjusted_sim)
print( logger.info(
f"\033[1;32m[记忆激活]\033[0m 主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})") f"[记忆激活]主题「{input_topic}」-> 「{memory_topic}」(内容数: {content_count}, 相似度: {adjusted_sim:.3f})")
# 计算主题匹配率和平均相似度 # 计算主题匹配率和平均相似度
topic_match = len(matched_topics) / len(identified_topics) topic_match = len(matched_topics) / len(identified_topics)
@ -696,8 +697,8 @@ class Hippocampus:
# 计算最终激活值 # 计算最终激活值
activation = int((topic_match + average_similarities) / 2 * 100) activation = int((topic_match + average_similarities) / 2 * 100)
print( logger.info(
f"\033[1;32m[记忆激活]\033[0m 匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}") f"[记忆激活]匹配率: {topic_match:.3f}, 平均相似度: {average_similarities:.3f}, 激活值: {activation}")
return activation return activation

View File

@ -743,7 +743,7 @@ class Hippocampus:
async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int: async def memory_activate_value(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.3) -> int:
"""计算输入文本对记忆的激活程度""" """计算输入文本对记忆的激活程度"""
print(f"\033[1;32m[记忆激活]\033[0m 识别主题: {await self._identify_topics(text)}") logger.info(f"[记忆激活]识别主题: {await self._identify_topics(text)}")
identified_topics = await self._identify_topics(text) identified_topics = await self._identify_topics(text)
if not identified_topics: if not identified_topics:

View File

@ -28,10 +28,10 @@ class LLM_request:
raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e raise ValueError(f"配置错误:找不到对应的配置项 - {str(e)}") from e
self.model_name = model["name"] self.model_name = model["name"]
self.params = kwargs self.params = kwargs
self.pri_in = model.get("pri_in", 0) self.pri_in = model.get("pri_in", 0)
self.pri_out = model.get("pri_out", 0) self.pri_out = model.get("pri_out", 0)
# 获取数据库实例 # 获取数据库实例
self.db = Database.get_instance() self.db = Database.get_instance()
self._init_database() self._init_database()
@ -47,9 +47,9 @@ class LLM_request:
except Exception as e: except Exception as e:
logger.error(f"创建数据库索引失败: {e}") logger.error(f"创建数据库索引失败: {e}")
def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int, def _record_usage(self, prompt_tokens: int, completion_tokens: int, total_tokens: int,
user_id: str = "system", request_type: str = "chat", user_id: str = "system", request_type: str = "chat",
endpoint: str = "/chat/completions"): endpoint: str = "/chat/completions"):
"""记录模型使用情况到数据库 """记录模型使用情况到数据库
Args: Args:
prompt_tokens: 输入token数 prompt_tokens: 输入token数
@ -140,12 +140,12 @@ class LLM_request:
} }
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
#判断是否为流式 # 判断是否为流式
stream_mode = self.params.get("stream", False) stream_mode = self.params.get("stream", False)
if self.params.get("stream", False) is True: if self.params.get("stream", False) is True:
logger.info(f"进入流式输出模式发送请求到URL: {api_url}") logger.debug(f"进入流式输出模式发送请求到URL: {api_url}")
else: else:
logger.info(f"发送请求到URL: {api_url}") logger.debug(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}") logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
@ -158,7 +158,7 @@ class LLM_request:
try: try:
# 使用上下文管理器处理会话 # 使用上下文管理器处理会话
headers = await self._build_headers() headers = await self._build_headers()
#似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响 # 似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
if stream_mode: if stream_mode:
headers["Accept"] = "text/event-stream" headers["Accept"] = "text/event-stream"
@ -184,29 +184,31 @@ class LLM_request:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
if response.status == 403: if response.status == 403:
# 尝试降级Pro模型 # 尝试降级Pro模型
if self.model_name.startswith("Pro/") and self.base_url == "https://api.siliconflow.cn/v1/": if self.model_name.startswith(
"Pro/") and self.base_url == "https://api.siliconflow.cn/v1/":
old_model_name = self.model_name old_model_name = self.model_name
self.model_name = self.model_name[4:] # 移除"Pro/"前缀 self.model_name = self.model_name[4:] # 移除"Pro/"前缀
logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}") logger.warning(f"检测到403错误模型从 {old_model_name} 降级为 {self.model_name}")
# 对全局配置进行更新 # 对全局配置进行更新
if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get('name') == old_model_name: if hasattr(global_config, 'llm_normal') and global_config.llm_normal.get(
'name') == old_model_name:
global_config.llm_normal['name'] = self.model_name global_config.llm_normal['name'] = self.model_name
logger.warning(f"已将全局配置中的 llm_normal 模型降级") logger.warning(f"已将全局配置中的 llm_normal 模型降级")
# 更新payload中的模型名 # 更新payload中的模型名
if payload and 'model' in payload: if payload and 'model' in payload:
payload['model'] = self.model_name payload['model'] = self.model_name
# 重新尝试请求 # 重新尝试请求
retry -= 1 # 不计入重试次数 retry -= 1 # 不计入重试次数
continue continue
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
response.raise_for_status() response.raise_for_status()
#将流式输出转化为非流式输出 # 将流式输出转化为非流式输出
if stream_mode: if stream_mode:
accumulated_content = "" accumulated_content = ""
async for line_bytes in response.content: async for line_bytes in response.content:
@ -233,12 +235,15 @@ class LLM_request:
reasoning_content = think_match.group(1).strip() reasoning_content = think_match.group(1).strip()
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip() content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
# 构造一个伪result以便调用自定义响应处理器或默认处理器 # 构造一个伪result以便调用自定义响应处理器或默认处理器
result = {"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]} result = {
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) "choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
else: else:
result = await response.json() result = await response.json()
# 使用自定义处理器或默认处理 # 使用自定义处理器或默认处理
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint) return response_handler(result) if response_handler else self._default_response_handler(
result, user_id, request_type, endpoint)
except Exception as e: except Exception as e:
if retry < policy["max_retries"] - 1: if retry < policy["max_retries"] - 1:
@ -252,8 +257,8 @@ class LLM_request:
logger.error("达到最大重试次数,请求仍然失败") logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败") raise RuntimeError("达到最大重试次数API请求仍然失败")
async def _transform_parameters(self, params: dict) ->dict: async def _transform_parameters(self, params: dict) -> dict:
""" """
根据模型名称转换参数 根据模型名称转换参数
- 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"删除 'temprature' 参数 - 对于需要转换的OpenAI CoT系列模型例如 "o3-mini"删除 'temprature' 参数
@ -262,7 +267,8 @@ class LLM_request:
# 复制一份参数,避免直接修改原始数据 # 复制一份参数,避免直接修改原始数据
new_params = dict(params) new_params = dict(params)
# 定义需要转换的模型列表 # 定义需要转换的模型列表
models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] models_needing_transformation = ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
"o3-mini-2025-01-31", "o1-mini-2024-09-12"]
if self.model_name.lower() in models_needing_transformation: if self.model_name.lower() in models_needing_transformation:
# 删除 'temprature' 参数(如果存在) # 删除 'temprature' 参数(如果存在)
new_params.pop("temperature", None) new_params.pop("temperature", None)
@ -298,13 +304,13 @@ class LLM_request:
**params_copy **params_copy
} }
# 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查 # 如果 payload 中依然存在 max_tokens 且需要转换,在这里进行再次检查
if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12", "o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload: if self.model_name.lower() in ["o3-mini", "o1-mini", "o1-preview", "o1-2024-12-17", "o1-preview-2024-09-12",
"o3-mini-2025-01-31", "o1-mini-2024-09-12"] and "max_tokens" in payload:
payload["max_completion_tokens"] = payload.pop("max_tokens") payload["max_completion_tokens"] = payload.pop("max_tokens")
return payload return payload
def _default_response_handler(self, result: dict, user_id: str = "system", def _default_response_handler(self, result: dict, user_id: str = "system",
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple: request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
"""默认响应解析""" """默认响应解析"""
if "choices" in result and result["choices"]: if "choices" in result and result["choices"]:
message = result["choices"][0]["message"] message = result["choices"][0]["message"]
@ -356,8 +362,8 @@ class LLM_request:
return { return {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json" "Content-Type": "application/json"
} }
# 防止小朋友们截图自己的key # 防止小朋友们截图自己的key
async def generate_response(self, prompt: str) -> Tuple[str, str]: async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的异步响应""" """根据输入的提示生成模型的异步响应"""
@ -404,6 +410,7 @@ class LLM_request:
Returns: Returns:
list: embedding向量如果失败则返回None list: embedding向量如果失败则返回None
""" """
def embedding_handler(result): def embedding_handler(result):
"""处理响应""" """处理响应"""
if "data" in result and len(result["data"]) > 0: if "data" in result and len(result["data"]) > 0:
@ -425,4 +432,3 @@ class LLM_request:
response_handler=embedding_handler response_handler=embedding_handler
) )
return embedding return embedding

View File

@ -4,7 +4,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from ..chat.config import global_config from ..chat.config import global_config
from loguru import logger
@dataclass @dataclass
class MoodState: class MoodState:
@ -210,7 +210,7 @@ class MoodManager:
def print_mood_status(self) -> None: def print_mood_status(self) -> None:
"""打印当前情绪状态""" """打印当前情绪状态"""
print(f"\033[1;35m[情绪状态]\033[0m 愉悦度: {self.current_mood.valence:.2f}, " logger.info(f"[情绪状态]愉悦度: {self.current_mood.valence:.2f}, "
f"唤醒度: {self.current_mood.arousal:.2f}, " f"唤醒度: {self.current_mood.arousal:.2f}, "
f"心情: {self.current_mood.text}") f"心情: {self.current_mood.text}")

View File

@ -11,12 +11,14 @@ from pathlib import Path
import random import random
import math import math
import time import time
from loguru import logger
class ChineseTypoGenerator: class ChineseTypoGenerator:
def __init__(self, def __init__(self,
error_rate=0.3, error_rate=0.3,
min_freq=5, min_freq=5,
tone_error_rate=0.2, tone_error_rate=0.2,
word_replace_rate=0.3, word_replace_rate=0.3,
max_freq_diff=200): max_freq_diff=200):
""" """
@ -34,27 +36,27 @@ class ChineseTypoGenerator:
self.tone_error_rate = tone_error_rate self.tone_error_rate = tone_error_rate
self.word_replace_rate = word_replace_rate self.word_replace_rate = word_replace_rate
self.max_freq_diff = max_freq_diff self.max_freq_diff = max_freq_diff
# 加载数据 # 加载数据
print("正在加载汉字数据库,请稍候...") logger.debug("正在加载汉字数据库,请稍候...")
self.pinyin_dict = self._create_pinyin_dict() self.pinyin_dict = self._create_pinyin_dict()
self.char_frequency = self._load_or_create_char_frequency() self.char_frequency = self._load_or_create_char_frequency()
def _load_or_create_char_frequency(self): def _load_or_create_char_frequency(self):
""" """
加载或创建汉字频率字典 加载或创建汉字频率字典
""" """
cache_file = Path("char_frequency.json") cache_file = Path("char_frequency.json")
# 如果缓存文件存在,直接加载 # 如果缓存文件存在,直接加载
if cache_file.exists(): if cache_file.exists():
with open(cache_file, 'r', encoding='utf-8') as f: with open(cache_file, 'r', encoding='utf-8') as f:
return json.load(f) return json.load(f)
# 使用内置的词频文件 # 使用内置的词频文件
char_freq = defaultdict(int) char_freq = defaultdict(int)
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
# 读取jieba的词典文件 # 读取jieba的词典文件
with open(dict_path, 'r', encoding='utf-8') as f: with open(dict_path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
@ -63,15 +65,15 @@ class ChineseTypoGenerator:
for char in word: for char in word:
if self._is_chinese_char(char): if self._is_chinese_char(char):
char_freq[char] += int(freq) char_freq[char] += int(freq)
# 归一化频率值 # 归一化频率值
max_freq = max(char_freq.values()) max_freq = max(char_freq.values())
normalized_freq = {char: freq/max_freq * 1000 for char, freq in char_freq.items()} normalized_freq = {char: freq / max_freq * 1000 for char, freq in char_freq.items()}
# 保存到缓存文件 # 保存到缓存文件
with open(cache_file, 'w', encoding='utf-8') as f: with open(cache_file, 'w', encoding='utf-8') as f:
json.dump(normalized_freq, f, ensure_ascii=False, indent=2) json.dump(normalized_freq, f, ensure_ascii=False, indent=2)
return normalized_freq return normalized_freq
def _create_pinyin_dict(self): def _create_pinyin_dict(self):
@ -81,7 +83,7 @@ class ChineseTypoGenerator:
# 常用汉字范围 # 常用汉字范围
chars = [chr(i) for i in range(0x4e00, 0x9fff)] chars = [chr(i) for i in range(0x4e00, 0x9fff)]
pinyin_dict = defaultdict(list) pinyin_dict = defaultdict(list)
# 为每个汉字建立拼音映射 # 为每个汉字建立拼音映射
for char in chars: for char in chars:
try: try:
@ -89,7 +91,7 @@ class ChineseTypoGenerator:
pinyin_dict[py].append(char) pinyin_dict[py].append(char)
except Exception: except Exception:
continue continue
return pinyin_dict return pinyin_dict
def _is_chinese_char(self, char): def _is_chinese_char(self, char):
@ -107,7 +109,7 @@ class ChineseTypoGenerator:
""" """
# 将句子拆分成单个字符 # 将句子拆分成单个字符
characters = list(sentence) characters = list(sentence)
# 获取每个字符的拼音 # 获取每个字符的拼音
result = [] result = []
for char in characters: for char in characters:
@ -117,7 +119,7 @@ class ChineseTypoGenerator:
# 获取拼音(数字声调) # 获取拼音(数字声调)
py = pinyin(char, style=Style.TONE3)[0][0] py = pinyin(char, style=Style.TONE3)[0][0]
result.append((char, py)) result.append((char, py))
return result return result
def _get_similar_tone_pinyin(self, py): def _get_similar_tone_pinyin(self, py):
@ -127,19 +129,19 @@ class ChineseTypoGenerator:
# 检查拼音是否为空或无效 # 检查拼音是否为空或无效
if not py or len(py) < 1: if not py or len(py) < 1:
return py return py
# 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况 # 如果最后一个字符不是数字,说明可能是轻声或其他特殊情况
if not py[-1].isdigit(): if not py[-1].isdigit():
# 为非数字结尾的拼音添加数字声调1 # 为非数字结尾的拼音添加数字声调1
return py + '1' return py + '1'
base = py[:-1] # 去掉声调 base = py[:-1] # 去掉声调
tone = int(py[-1]) # 获取声调 tone = int(py[-1]) # 获取声调
# 处理轻声通常用5表示或无效声调 # 处理轻声通常用5表示或无效声调
if tone not in [1, 2, 3, 4]: if tone not in [1, 2, 3, 4]:
return base + str(random.choice([1, 2, 3, 4])) return base + str(random.choice([1, 2, 3, 4]))
# 正常处理声调 # 正常处理声调
possible_tones = [1, 2, 3, 4] possible_tones = [1, 2, 3, 4]
possible_tones.remove(tone) # 移除原声调 possible_tones.remove(tone) # 移除原声调
@ -152,11 +154,11 @@ class ChineseTypoGenerator:
""" """
if target_freq > orig_freq: if target_freq > orig_freq:
return 1.0 # 如果替换字频率更高,保持原有概率 return 1.0 # 如果替换字频率更高,保持原有概率
freq_diff = orig_freq - target_freq freq_diff = orig_freq - target_freq
if freq_diff > self.max_freq_diff: if freq_diff > self.max_freq_diff:
return 0.0 # 频率差太大,不替换 return 0.0 # 频率差太大,不替换
# 使用指数衰减函数计算概率 # 使用指数衰减函数计算概率
# 频率差为0时概率为1频率差为max_freq_diff时概率接近0 # 频率差为0时概率为1频率差为max_freq_diff时概率接近0
return math.exp(-3 * freq_diff / self.max_freq_diff) return math.exp(-3 * freq_diff / self.max_freq_diff)
@ -166,42 +168,42 @@ class ChineseTypoGenerator:
获取与给定字频率相近的同音字可能包含声调错误 获取与给定字频率相近的同音字可能包含声调错误
""" """
homophones = [] homophones = []
# 有一定概率使用错误声调 # 有一定概率使用错误声调
if random.random() < self.tone_error_rate: if random.random() < self.tone_error_rate:
wrong_tone_py = self._get_similar_tone_pinyin(py) wrong_tone_py = self._get_similar_tone_pinyin(py)
homophones.extend(self.pinyin_dict[wrong_tone_py]) homophones.extend(self.pinyin_dict[wrong_tone_py])
# 添加正确声调的同音字 # 添加正确声调的同音字
homophones.extend(self.pinyin_dict[py]) homophones.extend(self.pinyin_dict[py])
if not homophones: if not homophones:
return None return None
# 获取原字的频率 # 获取原字的频率
orig_freq = self.char_frequency.get(char, 0) orig_freq = self.char_frequency.get(char, 0)
# 计算所有同音字与原字的频率差,并过滤掉低频字 # 计算所有同音字与原字的频率差,并过滤掉低频字
freq_diff = [(h, self.char_frequency.get(h, 0)) freq_diff = [(h, self.char_frequency.get(h, 0))
for h in homophones for h in homophones
if h != char and self.char_frequency.get(h, 0) >= self.min_freq] if h != char and self.char_frequency.get(h, 0) >= self.min_freq]
if not freq_diff: if not freq_diff:
return None return None
# 计算每个候选字的替换概率 # 计算每个候选字的替换概率
candidates_with_prob = [] candidates_with_prob = []
for h, freq in freq_diff: for h, freq in freq_diff:
prob = self._calculate_replacement_probability(orig_freq, freq) prob = self._calculate_replacement_probability(orig_freq, freq)
if prob > 0: # 只保留有效概率的候选字 if prob > 0: # 只保留有效概率的候选字
candidates_with_prob.append((h, prob)) candidates_with_prob.append((h, prob))
if not candidates_with_prob: if not candidates_with_prob:
return None return None
# 根据概率排序 # 根据概率排序
candidates_with_prob.sort(key=lambda x: x[1], reverse=True) candidates_with_prob.sort(key=lambda x: x[1], reverse=True)
# 返回概率最高的几个字 # 返回概率最高的几个字
return [char for char, _ in candidates_with_prob[:num_candidates]] return [char for char, _ in candidates_with_prob[:num_candidates]]
@ -223,10 +225,10 @@ class ChineseTypoGenerator:
""" """
if len(word) == 1: if len(word) == 1:
return [] return []
# 获取词的拼音 # 获取词的拼音
word_pinyin = self._get_word_pinyin(word) word_pinyin = self._get_word_pinyin(word)
# 遍历所有可能的同音字组合 # 遍历所有可能的同音字组合
candidates = [] candidates = []
for py in word_pinyin: for py in word_pinyin:
@ -234,11 +236,11 @@ class ChineseTypoGenerator:
if not chars: if not chars:
return [] return []
candidates.append(chars) candidates.append(chars)
# 生成所有可能的组合 # 生成所有可能的组合
import itertools import itertools
all_combinations = itertools.product(*candidates) all_combinations = itertools.product(*candidates)
# 获取jieba词典和词频信息 # 获取jieba词典和词频信息
dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt') dict_path = os.path.join(os.path.dirname(jieba.__file__), 'dict.txt')
valid_words = {} # 改用字典存储词语及其频率 valid_words = {} # 改用字典存储词语及其频率
@ -249,11 +251,11 @@ class ChineseTypoGenerator:
word_text = parts[0] word_text = parts[0]
word_freq = float(parts[1]) # 获取词频 word_freq = float(parts[1]) # 获取词频
valid_words[word_text] = word_freq valid_words[word_text] = word_freq
# 获取原词的词频作为参考 # 获取原词的词频作为参考
original_word_freq = valid_words.get(word, 0) original_word_freq = valid_words.get(word, 0)
min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10% min_word_freq = original_word_freq * 0.1 # 设置最小词频为原词频的10%
# 过滤和计算频率 # 过滤和计算频率
homophones = [] homophones = []
for combo in all_combinations: for combo in all_combinations:
@ -268,7 +270,7 @@ class ChineseTypoGenerator:
combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3) combined_score = (new_word_freq * 0.7 + char_avg_freq * 0.3)
if combined_score >= self.min_freq: if combined_score >= self.min_freq:
homophones.append((new_word, combined_score)) homophones.append((new_word, combined_score))
# 按综合分数排序并限制返回数量 # 按综合分数排序并限制返回数量
sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True) sorted_homophones = sorted(homophones, key=lambda x: x[1], reverse=True)
return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果 return [word for word, _ in sorted_homophones[:5]] # 限制返回前5个结果
@ -286,19 +288,19 @@ class ChineseTypoGenerator:
""" """
result = [] result = []
typo_info = [] typo_info = []
# 分词 # 分词
words = self._segment_sentence(sentence) words = self._segment_sentence(sentence)
for word in words: for word in words:
# 如果是标点符号或空格,直接添加 # 如果是标点符号或空格,直接添加
if all(not self._is_chinese_char(c) for c in word): if all(not self._is_chinese_char(c) for c in word):
result.append(word) result.append(word)
continue continue
# 获取词语的拼音 # 获取词语的拼音
word_pinyin = self._get_word_pinyin(word) word_pinyin = self._get_word_pinyin(word)
# 尝试整词替换 # 尝试整词替换
if len(word) > 1 and random.random() < self.word_replace_rate: if len(word) > 1 and random.random() < self.word_replace_rate:
word_homophones = self._get_word_homophones(word) word_homophones = self._get_word_homophones(word)
@ -307,15 +309,15 @@ class ChineseTypoGenerator:
# 计算词的平均频率 # 计算词的平均频率
orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word) orig_freq = sum(self.char_frequency.get(c, 0) for c in word) / len(word)
typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word) typo_freq = sum(self.char_frequency.get(c, 0) for c in typo_word) / len(typo_word)
# 添加到结果中 # 添加到结果中
result.append(typo_word) result.append(typo_word)
typo_info.append((word, typo_word, typo_info.append((word, typo_word,
' '.join(word_pinyin), ' '.join(word_pinyin),
' '.join(self._get_word_pinyin(typo_word)), ' '.join(self._get_word_pinyin(typo_word)),
orig_freq, typo_freq)) orig_freq, typo_freq))
continue continue
# 如果不进行整词替换,则进行单字替换 # 如果不进行整词替换,则进行单字替换
if len(word) == 1: if len(word) == 1:
char = word char = word
@ -339,7 +341,7 @@ class ChineseTypoGenerator:
for i, (char, py) in enumerate(zip(word, word_pinyin)): for i, (char, py) in enumerate(zip(word, word_pinyin)):
# 词中的字替换概率降低 # 词中的字替换概率降低
word_error_rate = self.error_rate * (0.7 ** (len(word) - 1)) word_error_rate = self.error_rate * (0.7 ** (len(word) - 1))
if random.random() < word_error_rate: if random.random() < word_error_rate:
similar_chars = self._get_similar_frequency_chars(char, py) similar_chars = self._get_similar_frequency_chars(char, py)
if similar_chars: if similar_chars:
@ -354,7 +356,7 @@ class ChineseTypoGenerator:
continue continue
word_result.append(char) word_result.append(char)
result.append(''.join(word_result)) result.append(''.join(word_result))
return ''.join(result), typo_info return ''.join(result), typo_info
def format_typo_info(self, typo_info): def format_typo_info(self, typo_info):
@ -369,7 +371,7 @@ class ChineseTypoGenerator:
""" """
if not typo_info: if not typo_info:
return "未生成错别字" return "未生成错别字"
result = [] result = []
for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info: for orig, typo, orig_py, typo_py, orig_freq, typo_freq in typo_info:
# 判断是否为词语替换 # 判断是否为词语替换
@ -379,12 +381,12 @@ class ChineseTypoGenerator:
else: else:
tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1] tone_error = orig_py[:-1] == typo_py[:-1] and orig_py[-1] != typo_py[-1]
error_type = "声调错误" if tone_error else "同音字替换" error_type = "声调错误" if tone_error else "同音字替换"
result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> " result.append(f"原文:{orig}({orig_py}) [频率:{orig_freq:.2f}] -> "
f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]") f"替换:{typo}({typo_py}) [频率:{typo_freq:.2f}] [{error_type}]")
return "\n".join(result) return "\n".join(result)
def set_params(self, **kwargs): def set_params(self, **kwargs):
""" """
设置参数 设置参数
@ -399,9 +401,10 @@ class ChineseTypoGenerator:
for key, value in kwargs.items(): for key, value in kwargs.items():
if hasattr(self, key): if hasattr(self, key):
setattr(self, key, value) setattr(self, key, value)
print(f"参数 {key} 已设置为 {value}") logger.debug(f"参数 {key} 已设置为 {value}")
else: else:
print(f"警告: 参数 {key} 不存在") logger.warning(f"警告: 参数 {key} 不存在")
def main(): def main():
# 创建错别字生成器实例 # 创建错别字生成器实例
@ -411,27 +414,27 @@ def main():
tone_error_rate=0.02, tone_error_rate=0.02,
word_replace_rate=0.3 word_replace_rate=0.3
) )
# 获取用户输入 # 获取用户输入
sentence = input("请输入中文句子:") sentence = input("请输入中文句子:")
# 创建包含错别字的句子 # 创建包含错别字的句子
start_time = time.time() start_time = time.time()
typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence) typo_sentence, typo_info = typo_generator.create_typo_sentence(sentence)
# 打印结果 # 打印结果
print("\n原句:", sentence) logger.debug("原句:", sentence)
print("错字版:", typo_sentence) logger.debug("错字版:", typo_sentence)
# 打印错别字信息 # 打印错别字信息
if typo_info: if typo_info:
print("\n错别字信息:") logger.debug(f"错别字信息:{typo_generator.format_typo_info(typo_info)})")
print(typo_generator.format_typo_info(typo_info))
# 计算并打印总耗时 # 计算并打印总耗时
end_time = time.time() end_time = time.time()
total_time = end_time - start_time total_time = end_time - start_time
print(f"\n总耗时:{total_time:.2f}") logger.debug(f"总耗时:{total_time:.2f}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()