Merge pull request #253 from Cindy-Master/debug

尝试减少不同群内的同义词混淆 以及允许用户设置指定群组记忆私有
pull/256/head
Cindy-Master 2025-03-12 12:35:17 +08:00 committed by GitHub
commit a2c782cb7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
33 changed files with 853 additions and 415 deletions

View File

@ -77,10 +77,6 @@
- [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘 - [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘
- [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户 - [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户
### 常见问题
- [❓ 快速 Q & A ](docs/fast_q_a.md) - 针对新手的疑难解答,适合完全没接触过编程的新手
<div align="left"> <div align="left">
<h3>了解麦麦 </h3> <h3>了解麦麦 </h3>
</div> </div>

15
bot.py
View File

@ -12,8 +12,6 @@ from loguru import logger
from nonebot.adapters.onebot.v11 import Adapter from nonebot.adapters.onebot.v11 import Adapter
import platform import platform
from src.common.database import Database
# 获取没有加载env时的环境变量 # 获取没有加载env时的环境变量
env_mask = {key: os.getenv(key) for key in os.environ} env_mask = {key: os.getenv(key) for key in os.environ}
@ -111,17 +109,6 @@ def load_env():
logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在")
def init_database():
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
def load_logger(): def load_logger():
logger.remove() # 移除默认配置 logger.remove() # 移除默认配置
@ -223,7 +210,6 @@ def raw_main():
init_config() init_config()
init_env() init_env()
load_env() load_env()
init_database() # 加载完成环境后初始化database
load_logger() load_logger()
env_config = {key: os.getenv(key) for key in os.environ} env_config = {key: os.getenv(key) for key in os.environ}
@ -249,6 +235,7 @@ def raw_main():
if __name__ == "__main__": if __name__ == "__main__":
try: try:
raw_main() raw_main()

View File

@ -0,0 +1,141 @@
cbb569e - Create 如果你更新了版本,点我.txt
a91ef7b - 自动升级配置文件脚本
ed18f2e - 新增了知识库一键启动漂亮脚本
80ed568 - fix: 删除print调试代码
c681a82 - 修复小名无效问题
e54038f - fix: 从 nixpkgs 增加 numpy 依赖,以避免出现 libc++.so 找不到的问题
26782c9 - fix: 修复 ENVIRONMENT 变量在同一终端下不能被覆盖的问题
8c34637 - 提高健壮性
2688a96 - close SengokuCola/MaiMBot#225 让麦麦可以正确读取分享卡片
cd16e68 - 修复表情包发送时的缺失参数
b362c35 - feat: 更新 flake.nix ,采用 venv 的方式生成环境nixos用户也可以本机运行项目了
3c8c897 - 屏蔽一个臃肿的debug信息
9d0152a - 修复了合并过程中造成的代码重复
956135c - 添加一些注释
a412741 - 将print变为logger.debug
3180426 - 修复了没有改掉的typo字段
aea3bff - 添加私聊过滤开关,更新config,增加约束
cda6281 - chore: update emoji_manager.py
baed856 - 修正了私聊屏蔽词输出
66a0f18 - 修复了私聊时产生reply消息的bug
3bf5cd6 - feat: 新增运行时重载配置文件;新增根据不同环境(dev;prod)显示不同级别的log
33cd83b - 添加私聊功能
aa41f0d - fix: 放反了
ef8691c - fix: 修改message继承逻辑修复回复消息无法识别
7d017be - fix:模型降级
e1019ad - fix: 修复变量拼写错误并优化代码可读性
c24bb70 - fix: 流式输出模式增加结束判断与token用量记录
60a9376 - 添加logger的debug输出开关,默认为不开启
bfa9a3c - fix: 添加群信息获取的错误处理 (#173)
4cc5c8e - 修正.env.prod和.env.dev的生成
dea14c1 - fix: 模型降级目前只对硅基流动的V3和R1生效
b6edbea - fix: 图片保存路径不正确
01a6fa8 - fix: 删除神秘test
20f009d - 修复systemctl强制停止maimbot的问题
af962c2 - 修复了情绪管理器没有正确导入导致发布出消息
0586700 - 按照Sourcery提供的建议修改systemctl管理指南
e48b32a - 在手动部署教程中增加使用systemctl管理
5760412 - fix: 小修
1c9b0cc - fix: 修复部分cq码解析错误merge
b6867b9 - fix: 统一使用os.getenv获取数据库连接信息避免从config对象获取不存在的值时出现KeyError
5e069f7 - 修复记忆保存时无时间信息的bug
73a3e41 - 修复记忆更新bug
52c93ba - refactor: use Base64 for emoji CQ codes
67f6d7c - fix: 保证能运行的小修改
c32c4fb - refactor: 修改配置文件的版本号
a54ca8c - Merge remote-tracking branch 'upstream/debug' into feat_regix
8cbf9bb - feat: 史上最好的消息流重构和图片管理
9e41c4f - feat: 修改 bot_config 0.0.5 版本的变更日志
eede406 - fix: 修复nonebot无法加载项目的问题
00e02ed - fix: 0.0.5 版本的增加分层控制项
0f99d6a - Update docs/docker_deploy.md
c789074 - feat: 增加ruff依赖
ff65ab8 - feat: 修改默认的ruff配置文件同时消除config的所有不符合规范的地方
bf97013 - feat: 精简日志禁用Uvicorn/NoneBot默认日志启动方式改为显示加载uvicorn以便优雅shutdown
d9a2863 - 优化Docker部署文档更新容器部分
efcf00f - Docker部署文档追加更新部分
a63ce96 - fix: 更新情感判断模型配置(使配置文件里的 llm_emotion_judge 生效)
1294c88 - feat: 增加标准化格式化设置
2e8cd47 - fix: 避免可能出现的日程解析错误
043a724 - 修一下文档跳转,小美化(
e4b8865 - 支持别名,可以用不同名称召唤机器人
7b35ddd - ruff 哥又有新点子
7899e67 - feat: 重构完成开始测试debug
354d6d0 - 记忆系统优化
6cef8fd - 修复时区删去napcat用不到的端口
cd96644 - 添加使用说明
84495f8 - fix
204744c - 修改配置名与修改过滤对象为raw_message
a03b490 - Update README.md
2b2b342 - feat: 增加 ruff 依赖
72a6749 - fix: 修复docker部署时区指定问题
ee579bc - Update README.md
1b611ec - resolve SengokuCola/MaiMBot#167 根据正则表达式过滤消息
6e2ea82 - refractor: 几乎写完了,进入测试阶段
2ffdfef - More
e680405 - fix: typo 'discription'
68b3f57 - Minor Doc Update
312f065 - Create linux_deploy_guide_for_beginners.md
ed505a4 - fix: 使用动态路径替换硬编码的项目路径
8ff7bb6 - docs: 更新文档,修正格式并添加必要的换行符
6e36a56 - feat: 增加 MONGODB_URI 的配置项并将所有env文件的注释单独放在一行python的dotenv有时无法正确处理行内注释
4baa6c6 - feat: 实现MongoDB URI方式连接并统一数据库连接代码。
8a32d18 - feat: 优化willing_manager逻辑增加回复保底概率
c9f1244 - docs: 改进README.md文档格式和排版
e1b484a - docs: 添加CLAUDE.md开发指南文件用于Claude Code
a43f949 - fix: remove duplicate message(CR comments)
fddb641 - fix: 修复错误的空值检测逻辑
8b7876c - fix: 修复没有上传tag的问题
6b4130e - feat: 增加stable-dev分支的打包
052e67b - refactor: 日志打印优化(终于改完了,爽了
a7f9d05 - 修复记忆整理传入格式问题
536bb1d - fix: 更新情感判断模型配置
8d99592 - fix: logger初始化顺序
052802c - refactor: logger promotion
8661d94 - doc: README.md - telegram version information
5746afa - refactor: logger in src\plugins\chat\bot.py
288dbb6 - refactor: logger in src\plugins\chat\__init__.py
8428a06 - fix: memory logger optimization (CR comment)
665c459 - 改进了可视化脚本
6c35704 - fix: 调用了错误的函数
3223153 - feat: 一键脚本新增记忆可视化
3149dd3 - fix: mongodb.zip 无法解压 fix:更换执行命令的方法 fix:当 db 不存在时自动创建 feat: 一键安装完成后启动麦麦
089d6a6 - feat: 针对硅基流动的Pro模型添加了自动降级功能
c4b0917 - 一个记忆可视化小脚本
6a71ea4 - 修复了记忆时间bug,config添加了记忆屏蔽关键词
1b5344f - fix: 优化bot初始化的日志&格式
41aa974 - fix: 优化chat/config.py的日志&格式
980cde7 - fix: 优化scheduler_generator日志&格式
31a5514 - fix: 调整全局logger加载顺序
8baef07 - feat: 添加全局logger初始化设置
5566f17 - refractor: 几乎写完了,进入测试阶段
6a66933 - feat: 添加开发环境.env.dev初始化
411ff1a - feat: 安装 MongoDB Compass
0de9eba - feat: 增加实时更新贡献者列表的功能
f327f45 - fix: 优化src/plugins/chat/__init__.py的import
826daa5 - fix: 当虚拟环境存在时跳过创建
f54de42 - fix: time.tzset 仅在类 Unix 系统可用
47c4990 - fix: 修复docker部署场景下时间错误的问题
e23a371 - docs: 添加 compose 注释
1002822 - docs: 标注 Python 最低版本
564350d - feat: 校验 Python 版本
4cc4482 - docs: 添加傻瓜式脚本
757173a - 带麦麦看了心理医生,让她没那么容易陷入负面情绪
39bb99c - 将错别字生成提取到配置,一句一个错别字太烦了!
fe36847 - feat: 超大型重构
e304dd7 - Update README.md
b7cfe6d - feat: 发布第 0.0.2 版本配置模板
ca929d5 - 补充Docker部署文档
1e97120 - 补充Docker部署文档
25f7052 - fix: 修复兼容性选项和目前第一个版本之间的版本间隙 0.0.0 版,并将所有的直接退出修改为抛出异常
c5bdc4f - 防ipv6炸虽然小概率事件
d86610d - fix: 修复不能加载环境变量的问题
2306ebf - feat: 因为判断临界版本范围比较麻烦,增加 notice 字段,删除原本的判断逻辑(存在故障)
dd09576 - fix: 修复 TypeError: BotConfig.convert_to_specifierset() takes 1 positional argument but 2 were given
18f839b - fix: 修复 missing 1 required positional argument: 'INNER_VERSION'
6adb5ed - 调整一些细节docker部署时可选数据库账密
07f48e9 - fix: 利用filter来过滤环境变量避免直接删除key造成的 RuntimeError: dictionary changed size during iteration
5856074 - fix: 修复无法进行基础设置的问题
32aa032 - feat: 发布 0.0.1 版本的配置文件
edc07ac - feat: 重构配置加载器,增加配置文件版本控制和程序兼容能力
0f492ed - fix: 修复 BASE_URL/KEY 组合检查中被 GPG_KEY 干扰的问题

18
run.bat
View File

@ -1,10 +1,10 @@
@ECHO OFF @ECHO OFF
chcp 65001 chcp 65001
if not exist "venv" ( if not exist "venv" (
python -m venv venv python -m venv venv
call venv\Scripts\activate.bat call venv\Scripts\activate.bat
pip install -i https://mirrors.aliyun.com/pypi/simple --upgrade -r requirements.txt pip install -i https://mirrors.aliyun.com/pypi/simple --upgrade -r requirements.txt
) else ( ) else (
call venv\Scripts\activate.bat call venv\Scripts\activate.bat
) )
python run.py python run.py

View File

@ -1,29 +1,29 @@
@echo on @echo on
chcp 65001 > nul chcp 65001 > nul
set /p CONDA_ENV="请输入要激活的 conda 环境名称: " set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
call conda activate %CONDA_ENV% call conda activate %CONDA_ENV%
if errorlevel 1 ( if errorlevel 1 (
echo 激活 conda 环境失败 echo 激活 conda 环境失败
pause pause
exit /b 1 exit /b 1
) )
echo Conda 环境 "%CONDA_ENV%" 激活成功 echo Conda 环境 "%CONDA_ENV%" 激活成功
set /p OPTION="请选择运行选项 (1: 运行全部绘制, 2: 运行简单绘制): " set /p OPTION="请选择运行选项 (1: 运行全部绘制, 2: 运行简单绘制): "
if "%OPTION%"=="1" ( if "%OPTION%"=="1" (
python src/plugins/memory_system/memory_manual_build.py python src/plugins/memory_system/memory_manual_build.py
) else if "%OPTION%"=="2" ( ) else if "%OPTION%"=="2" (
python src/plugins/memory_system/draw_memory.py python src/plugins/memory_system/draw_memory.py
) else ( ) else (
echo 无效的选项 echo 无效的选项
pause pause
exit /b 1 exit /b 1
) )
if errorlevel 1 ( if errorlevel 1 (
echo 命令执行失败,错误代码 %errorlevel% echo 命令执行失败,错误代码 %errorlevel%
pause pause
exit /b 1 exit /b 1
) )
echo 脚本成功完成 echo 脚本成功完成
pause pause

View File

@ -1,7 +1,7 @@
chcp 65001 chcp 65001
call conda activate maimbot call conda activate maimbot
cd . cd .
REM 执行nb run命令 REM 执行nb run命令
nb run nb run
pause pause

View File

@ -1,5 +1,5 @@
call conda activate niuniu call conda activate niuniu
cd src\gui cd src\gui
start /b python reasoning_gui.py start /b python reasoning_gui.py
exit exit

View File

@ -1,68 +1,68 @@
@echo off @echo off
setlocal enabledelayedexpansion setlocal enabledelayedexpansion
chcp 65001 chcp 65001
REM 修正路径获取逻辑 REM 修正路径获取逻辑
cd /d "%~dp0" || ( cd /d "%~dp0" || (
echo 错误:切换目录失败 echo 错误:切换目录失败
exit /b 1 exit /b 1
) )
if not exist "venv\" ( if not exist "venv\" (
echo 正在初始化虚拟环境... echo 正在初始化虚拟环境...
where python >nul 2>&1 where python >nul 2>&1
if %errorlevel% neq 0 ( if %errorlevel% neq 0 (
echo 未找到Python解释器 echo 未找到Python解释器
exit /b 1 exit /b 1
) )
for /f "tokens=2" %%a in ('python --version 2^>^&1') do set version=%%a for /f "tokens=2" %%a in ('python --version 2^>^&1') do set version=%%a
for /f "tokens=1,2 delims=." %%b in ("!version!") do ( for /f "tokens=1,2 delims=." %%b in ("!version!") do (
set major=%%b set major=%%b
set minor=%%c set minor=%%c
) )
if !major! lss 3 ( if !major! lss 3 (
echo 需要Python大于等于3.0,当前版本 !version! echo 需要Python大于等于3.0,当前版本 !version!
exit /b 1 exit /b 1
) )
if !major! equ 3 if !minor! lss 9 ( if !major! equ 3 if !minor! lss 9 (
echo 需要Python大于等于3.9,当前版本 !version! echo 需要Python大于等于3.9,当前版本 !version!
exit /b 1 exit /b 1
) )
echo 正在安装virtualenv... echo 正在安装virtualenv...
python -m pip install virtualenv || ( python -m pip install virtualenv || (
echo virtualenv安装失败 echo virtualenv安装失败
exit /b 1 exit /b 1
) )
echo 正在创建虚拟环境... echo 正在创建虚拟环境...
python -m virtualenv venv || ( python -m virtualenv venv || (
echo 虚拟环境创建失败 echo 虚拟环境创建失败
exit /b 1 exit /b 1
) )
call venv\Scripts\activate.bat call venv\Scripts\activate.bat
) else ( ) else (
call venv\Scripts\activate.bat call venv\Scripts\activate.bat
) )
echo 正在更新依赖... echo 正在更新依赖...
pip install -r requirements.txt pip install -r requirements.txt
echo 当前代理设置: echo 当前代理设置:
echo HTTP_PROXY=%HTTP_PROXY% echo HTTP_PROXY=%HTTP_PROXY%
echo HTTPS_PROXY=%HTTPS_PROXY% echo HTTPS_PROXY=%HTTPS_PROXY%
set HTTP_PROXY= set HTTP_PROXY=
set HTTPS_PROXY= set HTTPS_PROXY=
echo 代理已取消。 echo 代理已取消。
set no_proxy=0.0.0.0/32 set no_proxy=0.0.0.0/32
call nb run call nb run
pause pause

View File

@ -1,6 +1,5 @@
from typing import Optional from typing import Optional
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database as MongoDatabase
class Database: class Database:
_instance: Optional["Database"] = None _instance: Optional["Database"] = None
@ -26,7 +25,7 @@ class Database:
else: else:
# 否则使用无认证连接 # 否则使用无认证连接
self.client = MongoClient(host, port) self.client = MongoClient(host, port)
self.db: MongoDatabase = self.client[db_name] self.db = self.client[db_name]
@classmethod @classmethod
def initialize( def initialize(
@ -38,36 +37,15 @@ class Database:
password: Optional[str] = None, password: Optional[str] = None,
auth_source: Optional[str] = None, auth_source: Optional[str] = None,
uri: Optional[str] = None, uri: Optional[str] = None,
) -> MongoDatabase: ) -> "Database":
if cls._instance is None: if cls._instance is None:
cls._instance = cls( cls._instance = cls(
host, port, db_name, username, password, auth_source, uri host, port, db_name, username, password, auth_source, uri
) )
return cls._instance.db return cls._instance
@classmethod @classmethod
def get_instance(cls) -> MongoDatabase: def get_instance(cls) -> "Database":
if cls._instance is None: if cls._instance is None:
raise RuntimeError("Database not initialized") raise RuntimeError("Database not initialized")
return cls._instance.db return cls._instance
#测试用
def get_random_group_messages(self, group_id: str, limit: int = 5):
# 先随机获取一条消息
random_message = list(self.db.messages.aggregate([
{"$match": {"group_id": group_id}},
{"$sample": {"size": 1}}
]))[0]
# 获取该消息之后的消息
subsequent_messages = list(self.db.messages.find({
"group_id": group_id,
"time": {"$gt": random_message["time"]}
}).sort("time", 1).limit(limit))
# 将随机消息和后续消息合并
messages = [random_message] + subsequent_messages
return messages

View File

@ -46,7 +46,7 @@ class ReasoningGUI:
# 初始化数据库连接 # 初始化数据库连接
try: try:
self.db = Database.get_instance() self.db = Database.get_instance().db
logger.success("数据库连接成功") logger.success("数据库连接成功")
except RuntimeError: except RuntimeError:
logger.warning("数据库未初始化,正在尝试初始化...") logger.warning("数据库未初始化,正在尝试初始化...")
@ -60,7 +60,7 @@ class ReasoningGUI:
password=os.getenv("MONGODB_PASSWORD"), password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"), auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
) )
self.db = Database.get_instance() self.db = Database.get_instance().db
logger.success("数据库初始化成功") logger.success("数据库初始化成功")
except Exception: except Exception:
logger.exception("数据库初始化失败") logger.exception("数据库初始化失败")

View File

@ -32,6 +32,18 @@ _message_manager_started = False
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
logger.success("初始化数据库成功")
# 初始化表情管理器 # 初始化表情管理器
emoji_manager.initialize() emoji_manager.initialize()

View File

@ -235,10 +235,10 @@ class ChatBot:
is_head=not mark_head, is_head=not mark_head,
is_emoji=False, is_emoji=False,
) )
logger.debug(f"bot_message: {bot_message}") print(f"bot_message: {bot_message}")
if not mark_head: if not mark_head:
mark_head = True mark_head = True
logger.debug(f"添加消息到message_set: {bot_message}") print(f"添加消息到message_set: {bot_message}")
message_set.add_message(bot_message) message_set.add_message(bot_message)
# message_set 可以直接加入 message_manager # message_set 可以直接加入 message_manager

View File

@ -111,11 +111,11 @@ class ChatManager:
def _ensure_collection(self): def _ensure_collection(self):
"""确保数据库集合存在并创建索引""" """确保数据库集合存在并创建索引"""
if "chat_streams" not in self.db.list_collection_names(): if "chat_streams" not in self.db.db.list_collection_names():
self.db.create_collection("chat_streams") self.db.db.create_collection("chat_streams")
# 创建索引 # 创建索引
self.db.chat_streams.create_index([("stream_id", 1)], unique=True) self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True)
self.db.chat_streams.create_index( self.db.db.chat_streams.create_index(
[("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)] [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)]
) )
@ -168,7 +168,7 @@ class ChatManager:
return stream return stream
# 检查数据库中是否存在 # 检查数据库中是否存在
data = self.db.chat_streams.find_one({"stream_id": stream_id}) data = self.db.db.chat_streams.find_one({"stream_id": stream_id})
if data: if data:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
# 更新用户信息和群组信息 # 更新用户信息和群组信息
@ -204,7 +204,7 @@ class ChatManager:
async def _save_stream(self, stream: ChatStream): async def _save_stream(self, stream: ChatStream):
"""保存聊天流到数据库""" """保存聊天流到数据库"""
if not stream.saved: if not stream.saved:
self.db.chat_streams.update_one( self.db.db.chat_streams.update_one(
{"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True
) )
stream.saved = True stream.saved = True
@ -216,7 +216,7 @@ class ChatManager:
async def load_all_streams(self): async def load_all_streams(self):
"""从数据库加载所有聊天流""" """从数据库加载所有聊天流"""
all_streams = self.db.chat_streams.find({}) all_streams = self.db.db.chat_streams.find({})
for data in all_streams: for data in all_streams:
stream = ChatStream.from_dict(data) stream = ChatStream.from_dict(data)
self.streams[stream.stream_id] = stream self.streams[stream.stream_id] = stream

View File

@ -105,6 +105,11 @@ class BotConfig:
default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"] default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"]
) # 添加新的配置项默认值 ) # 添加新的配置项默认值
# 记忆群组配置,用于定义私有记忆群组
# 格式:{"group1": ["123456", "234567"], "group2": ["345678", "456789"]}
# 每个群组内的群聊ID共享记忆但不与其他群组共享
memory_private_groups: Dict[str, List[str]] = field(default_factory=dict)
@staticmethod @staticmethod
def get_config_dir() -> str: def get_config_dir() -> str:
"""获取配置文件目录""" """获取配置文件目录"""
@ -304,6 +309,11 @@ class BotConfig:
config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time) config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time)
config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage) config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage)
config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate) config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate)
# 加载群组私有记忆配置
if "memory_private_groups" in memory_config:
config.memory_private_groups = memory_config.get("memory_private_groups", config.memory_private_groups)
logger.info(f"已加载群组私有记忆配置: {len(config.memory_private_groups)}个群组")
def mood(parent: dict): def mood(parent: dict):
mood_config = parent["mood"] mood_config = parent["mood"]

View File

@ -4,8 +4,6 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import os
import requests import requests
# 解析各种CQ码 # 解析各种CQ码

View File

@ -76,16 +76,16 @@ class EmojiManager:
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率
""" """
if 'emoji' not in self.db.list_collection_names(): if 'emoji' not in self.db.db.list_collection_names():
self.db.create_collection('emoji') self.db.db.create_collection('emoji')
self.db.emoji.create_index([('embedding', '2dsphere')]) self.db.db.emoji.create_index([('embedding', '2dsphere')])
self.db.emoji.create_index([('filename', 1)], unique=True) self.db.db.emoji.create_index([('filename', 1)], unique=True)
def record_usage(self, emoji_id: str): def record_usage(self, emoji_id: str):
"""记录表情使用次数""" """记录表情使用次数"""
try: try:
self._ensure_db() self._ensure_db()
self.db.emoji.update_one( self.db.db.emoji.update_one(
{'_id': emoji_id}, {'_id': emoji_id},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
@ -119,7 +119,7 @@ class EmojiManager:
try: try:
# 获取所有表情包 # 获取所有表情包
all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1}))
if not all_emojis: if not all_emojis:
logger.warning("数据库中没有任何表情包") logger.warning("数据库中没有任何表情包")
@ -157,7 +157,7 @@ class EmojiManager:
if selected_emoji and 'path' in selected_emoji: if selected_emoji and 'path' in selected_emoji:
# 更新使用次数 # 更新使用次数
self.db.emoji.update_one( self.db.db.emoji.update_one(
{'_id': selected_emoji['_id']}, {'_id': selected_emoji['_id']},
{'$inc': {'usage_count': 1}} {'$inc': {'usage_count': 1}}
) )
@ -239,7 +239,7 @@ class EmojiManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 检查是否已经注册过 # 检查是否已经注册过
existing_emoji = self.db['emoji'].find_one({'filename': filename}) existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
description = None description = None
if existing_emoji: if existing_emoji:
@ -305,7 +305,7 @@ class EmojiManager:
} }
# 保存到emoji数据库 # 保存到emoji数据库
self.db['emoji'].insert_one(emoji_record) self.db.db['emoji'].insert_one(emoji_record)
logger.success(f"注册新表情包: {filename}") logger.success(f"注册新表情包: {filename}")
logger.info(f"描述: {description}") logger.info(f"描述: {description}")
@ -346,7 +346,7 @@ class EmojiManager:
try: try:
self._ensure_db() self._ensure_db()
# 获取所有表情包记录 # 获取所有表情包记录
all_emojis = list(self.db.emoji.find()) all_emojis = list(self.db.db.emoji.find())
removed_count = 0 removed_count = 0
total_count = len(all_emojis) total_count = len(all_emojis)
@ -354,13 +354,13 @@ class EmojiManager:
try: try:
if 'path' not in emoji: if 'path' not in emoji:
logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现无效记录缺少path字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({'_id': emoji['_id']}) self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
if 'embedding' not in emoji: if 'embedding' not in emoji:
logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}") logger.warning(f"发现过时记录缺少embedding字段ID: {emoji.get('_id', 'unknown')}")
self.db.emoji.delete_one({'_id': emoji['_id']}) self.db.db.emoji.delete_one({'_id': emoji['_id']})
removed_count += 1 removed_count += 1
continue continue
@ -368,7 +368,7 @@ class EmojiManager:
if not os.path.exists(emoji['path']): if not os.path.exists(emoji['path']):
logger.warning(f"表情包文件已被删除: {emoji['path']}") logger.warning(f"表情包文件已被删除: {emoji['path']}")
# 从数据库中删除记录 # 从数据库中删除记录
result = self.db.emoji.delete_one({'_id': emoji['_id']}) result = self.db.db.emoji.delete_one({'_id': emoji['_id']})
if result.deleted_count > 0: if result.deleted_count > 0:
logger.debug(f"成功删除数据库记录: {emoji['_id']}") logger.debug(f"成功删除数据库记录: {emoji['_id']}")
removed_count += 1 removed_count += 1
@ -379,7 +379,7 @@ class EmojiManager:
continue continue
# 验证清理结果 # 验证清理结果
remaining_count = self.db.emoji.count_documents({}) remaining_count = self.db.db.emoji.count_documents({})
if removed_count > 0: if removed_count > 0:
logger.success(f"已清理 {removed_count} 个失效的表情包记录") logger.success(f"已清理 {removed_count} 个失效的表情包记录")
logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}")

View File

@ -154,7 +154,7 @@ class ResponseGenerator:
reasoning_content: str, reasoning_content: str,
): ):
"""保存对话记录到数据库""" """保存对话记录到数据库"""
self.db.reasoning_logs.insert_one( self.db.db.reasoning_logs.insert_one(
{ {
"time": time.time(), "time": time.time(),
"chat_id": message.chat_stream.stream_id, "chat_id": message.chat_stream.stream_id,

View File

@ -10,7 +10,6 @@ from .message import MessageSending, MessageThinking, MessageRecv, MessageSet
from .storage import MessageStorage from .storage import MessageStorage
from .config import global_config from .config import global_config
from .utils import truncate_message
class Message_Sender: class Message_Sender:
@ -35,7 +34,6 @@ class Message_Sender:
message_json = message.to_dict() message_json = message.to_dict()
message_send = MessageSendCQ(data=message_json) message_send = MessageSendCQ(data=message_json)
# logger.debug(message_send.message_info,message_send.raw_message) # logger.debug(message_send.message_info,message_send.raw_message)
message_preview = truncate_message(message.processed_plain_text)
if ( if (
message_send.message_info.group_info message_send.message_info.group_info
and message_send.message_info.group_info.group_id and message_send.message_info.group_info.group_id
@ -46,10 +44,10 @@ class Message_Sender:
message=message_send.raw_message, message=message_send.raw_message,
auto_escape=False, auto_escape=False,
) )
logger.success(f"[调试] 发送消息{message_preview}成功") logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e: except Exception as e:
logger.error(f"[调试] 发生错误 {e}") logger.error(f"[调试] 发生错误 {e}")
logger.error(f"[调试] 发送消息{message_preview}失败") logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
else: else:
try: try:
logger.debug(message.message_info.user_info) logger.debug(message.message_info.user_info)
@ -58,10 +56,10 @@ class Message_Sender:
message=message_send.raw_message, message=message_send.raw_message,
auto_escape=False, auto_escape=False,
) )
logger.success(f"[调试] 发送消息{message_preview}成功") logger.success(f"[调试] 发送消息{message.processed_plain_text}成功")
except Exception as e: except Exception as e:
logger.error(f"[调试] 发生错误 {e}") logger.error(f"发生错误 {e}")
logger.error(f"[调试] 发送消息{message_preview}失败") logger.error(f"[调试] 发送消息{message.processed_plain_text}失败")
class MessageContainer: class MessageContainer:
@ -186,7 +184,7 @@ class MessageManager:
await message_earliest.process() await message_earliest.process()
print( print(
f"\033[1;34m[调试]\033[0m 消息{truncate_message(message_earliest.processed_plain_text)}正在发送中" f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中"
) )
await self.storage.store_message( await self.storage.store_message(

View File

@ -91,12 +91,20 @@ class PromptBuilder:
memory_prompt = '' memory_prompt = ''
start_time = time.time() start_time = time.time()
# 调用 hippocampus 的 get_relevant_memories 方法 # 获取群聊ID
stream_group_id = None
if stream_id:
chat_stream = chat_manager.get_stream(stream_id)
if chat_stream and chat_stream.group_info:
stream_group_id = str(chat_stream.group_info.group_id)
# 调用 hippocampus 的 get_relevant_memories 方法添加群聊ID参数
relevant_memories = await hippocampus.get_relevant_memories( relevant_memories = await hippocampus.get_relevant_memories(
text=message_txt, text=message_txt,
max_topics=5, max_topics=5,
similarity_threshold=0.4, similarity_threshold=0.4,
max_memory_num=5 max_memory_num=5,
group_id=stream_group_id
) )
if relevant_memories: if relevant_memories:
@ -311,7 +319,7 @@ class PromptBuilder:
{"$project": {"content": 1, "similarity": 1}} {"$project": {"content": 1, "similarity": 1}}
] ]
results = list(self.db.knowledges.aggregate(pipeline)) results = list(self.db.db.knowledges.aggregate(pipeline))
# print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}")
if not results: if not results:

View File

@ -168,7 +168,7 @@ class RelationshipManager:
async def load_all_relationships(self): async def load_all_relationships(self):
"""加载所有关系对象""" """加载所有关系对象"""
db = Database.get_instance() db = Database.get_instance()
all_relationships = db.relationships.find({}) all_relationships = db.db.relationships.find({})
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
@ -176,7 +176,7 @@ class RelationshipManager:
"""每5分钟自动保存一次关系数据""" """每5分钟自动保存一次关系数据"""
db = Database.get_instance() db = Database.get_instance()
# 获取所有关系记录 # 获取所有关系记录
all_relationships = db.relationships.find({}) all_relationships = db.db.relationships.find({})
# 依次加载每条记录 # 依次加载每条记录
for data in all_relationships: for data in all_relationships:
await self.load_relationship(data) await self.load_relationship(data)
@ -206,7 +206,7 @@ class RelationshipManager:
saved = relationship.saved saved = relationship.saved
db = Database.get_instance() db = Database.get_instance()
db.relationships.update_one( db.db.relationships.update_one(
{'user_id': user_id, 'platform': platform}, {'user_id': user_id, 'platform': platform},
{'$set': { {'$set': {
'platform': platform, 'platform': platform,

View File

@ -13,17 +13,23 @@ class MessageStorage:
async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None: async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None:
"""存储消息到数据库""" """存储消息到数据库"""
try: try:
# 提取群组ID信息如果存在的话
group_id = None
if chat_stream.group_info:
group_id = str(chat_stream.group_info.group_id)
message_data = { message_data = {
"message_id": message.message_info.message_id, "message_id": message.message_info.message_id,
"time": message.message_info.time, "time": message.message_info.time,
"chat_id":chat_stream.stream_id, "chat_id": chat_stream.stream_id,
"chat_info": chat_stream.to_dict(), "chat_info": chat_stream.to_dict(),
"user_info": message.message_info.user_info.to_dict(), "user_info": message.message_info.user_info.to_dict(),
"processed_plain_text": message.processed_plain_text, "processed_plain_text": message.processed_plain_text,
"detailed_plain_text": message.detailed_plain_text, "detailed_plain_text": message.detailed_plain_text,
"topic": topic, "topic": topic,
"group_id": group_id, # 显式添加group_id字段
} }
self.db.messages.insert_one(message_data) self.db.db.messages.insert_one(message_data)
except Exception: except Exception:
logger.exception("存储消息失败") logger.exception("存储消息失败")

View File

@ -104,11 +104,20 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str):
# 转换记录格式 # 转换记录格式
formatted_records = [] formatted_records = []
for record in chat_records: for record in chat_records:
formatted_records.append({ formatted_record = {
'time': record["time"], 'time': record["time"],
'chat_id': record["chat_id"], 'chat_id': record["chat_id"],
'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容 'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容
}) }
# 添加group_id信息如果存在
if 'group_id' in record:
formatted_record['group_id'] = record['group_id']
elif 'chat_info' in record and 'group_info' in record['chat_info'] and record['chat_info']['group_info']:
# 从chat_info中提取group_id
formatted_record['group_id'] = record['chat_info']['group_info'].get('group_id')
formatted_records.append(formatted_record)
return formatted_records return formatted_records
@ -406,10 +415,3 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
# 按相似度降序排序并返回前k个 # 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
def truncate_message(message: str, max_length=20) -> str:
"""截断消息,使其不超过指定长度"""
if len(message) > max_length:
return message[:max_length] + "..."
return message

View File

@ -40,20 +40,20 @@ class ImageManager:
def _ensure_image_collection(self): def _ensure_image_collection(self):
"""确保images集合存在并创建索引""" """确保images集合存在并创建索引"""
if 'images' not in self.db.list_collection_names(): if 'images' not in self.db.db.list_collection_names():
self.db.create_collection('images') self.db.db.create_collection('images')
# 创建索引 # 创建索引
self.db.images.create_index([('hash', 1)], unique=True) self.db.db.images.create_index([('hash', 1)], unique=True)
self.db.images.create_index([('url', 1)]) self.db.db.images.create_index([('url', 1)])
self.db.images.create_index([('path', 1)]) self.db.db.images.create_index([('path', 1)])
def _ensure_description_collection(self): def _ensure_description_collection(self):
"""确保image_descriptions集合存在并创建索引""" """确保image_descriptions集合存在并创建索引"""
if 'image_descriptions' not in self.db.list_collection_names(): if 'image_descriptions' not in self.db.db.list_collection_names():
self.db.create_collection('image_descriptions') self.db.db.create_collection('image_descriptions')
# 创建索引 # 创建索引
self.db.image_descriptions.create_index([('hash', 1)], unique=True) self.db.db.image_descriptions.create_index([('hash', 1)], unique=True)
self.db.image_descriptions.create_index([('type', 1)]) self.db.db.image_descriptions.create_index([('type', 1)])
def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]:
"""从数据库获取图片描述 """从数据库获取图片描述
@ -65,7 +65,7 @@ class ImageManager:
Returns: Returns:
Optional[str]: 描述文本如果不存在则返回None Optional[str]: 描述文本如果不存在则返回None
""" """
result= self.db.image_descriptions.find_one({ result= self.db.db.image_descriptions.find_one({
'hash': image_hash, 'hash': image_hash,
'type': description_type 'type': description_type
}) })
@ -79,7 +79,7 @@ class ImageManager:
description: 描述文本 description: 描述文本
description_type: 描述类型 ('emoji' 'image') description_type: 描述类型 ('emoji' 'image')
""" """
self.db.image_descriptions.update_one( self.db.db.image_descriptions.update_one(
{'hash': image_hash, 'type': description_type}, {'hash': image_hash, 'type': description_type},
{ {
'$set': { '$set': {
@ -121,7 +121,7 @@ class ImageManager:
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
# 查重 # 查重
existing = self.db.images.find_one({'hash': image_hash}) existing = self.db.db.images.find_one({'hash': image_hash})
if existing: if existing:
return existing['path'] return existing['path']
@ -142,7 +142,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.images.insert_one(image_doc) self.db.db.images.insert_one(image_doc)
return file_path return file_path
@ -159,7 +159,7 @@ class ImageManager:
""" """
try: try:
# 先查找是否已存在 # 先查找是否已存在
existing = self.db.images.find_one({'url': url}) existing = self.db.db.images.find_one({'url': url})
if existing: if existing:
return existing['path'] return existing['path']
@ -203,7 +203,7 @@ class ImageManager:
Returns: Returns:
bool: 是否存在 bool: 是否存在
""" """
return self.db.images.find_one({'url': url}) is not None return self.db.db.images.find_one({'url': url}) is not None
def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool:
"""检查图像是否已存在 """检查图像是否已存在
@ -226,7 +226,7 @@ class ImageManager:
return False return False
image_hash = hashlib.md5(image_bytes).hexdigest() image_hash = hashlib.md5(image_bytes).hexdigest()
return self.db.images.find_one({'hash': image_hash}) is not None return self.db.db.images.find_one({'hash': image_hash}) is not None
except Exception as e: except Exception as e:
logger.error(f"检查哈希失败: {str(e)}") logger.error(f"检查哈希失败: {str(e)}")
@ -269,7 +269,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.images.update_one( self.db.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True
@ -326,7 +326,7 @@ class ImageManager:
'description': description, 'description': description,
'timestamp': timestamp 'timestamp': timestamp
} }
self.db.images.update_one( self.db.db.images.update_one(
{'hash': image_hash}, {'hash': image_hash},
{'$set': image_doc}, {'$set': image_doc},
upsert=True upsert=True

View File

@ -5,98 +5,101 @@ from typing import Dict
from .config import global_config from .config import global_config
from .chat_stream import ChatStream from .chat_stream import ChatStream
from loguru import logger
class WillingManager: class WillingManager:
def __init__(self): def __init__(self):
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿 self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿
self._decay_task = None self._decay_task = None
self._started = False self._started = False
async def _decay_reply_willing(self): async def _decay_reply_willing(self):
"""定期衰减回复意愿""" """定期衰减回复意愿"""
while True: while True:
await asyncio.sleep(5) await asyncio.sleep(5)
for chat_id in self.chat_reply_willing: for chat_id in self.chat_reply_willing:
self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6) self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
for chat_id in self.chat_reply_willing:
def get_willing(self, chat_stream: ChatStream) -> float: self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6)
def get_willing(self,chat_stream:ChatStream) -> float:
"""获取指定聊天流的回复意愿""" """获取指定聊天流的回复意愿"""
stream = chat_stream stream = chat_stream
if stream: if stream:
return self.chat_reply_willing.get(stream.stream_id, 0) return self.chat_reply_willing.get(stream.stream_id, 0)
return 0 return 0
def set_willing(self, chat_id: str, willing: float): def set_willing(self, chat_id: str, willing: float):
"""设置指定聊天流的回复意愿""" """设置指定聊天流的回复意愿"""
self.chat_reply_willing[chat_id] = willing self.chat_reply_willing[chat_id] = willing
def set_willing(self, chat_id: str, willing: float):
async def change_reply_willing_received( """设置指定聊天流的回复意愿"""
self, self.chat_reply_willing[chat_id] = willing
chat_stream: ChatStream,
topic: str = None, async def change_reply_willing_received(self,
is_mentioned_bot: bool = False, chat_stream:ChatStream,
config=None, topic: str = None,
is_emoji: bool = False, is_mentioned_bot: bool = False,
interested_rate: float = 0, config = None,
) -> float: is_emoji: bool = False,
interested_rate: float = 0) -> float:
"""改变指定聊天流的回复意愿并返回回复概率""" """改变指定聊天流的回复意愿并返回回复概率"""
# 获取或创建聊天流 # 获取或创建聊天流
stream = chat_stream stream = chat_stream
chat_id = stream.stream_id chat_id = stream.stream_id
current_willing = self.chat_reply_willing.get(chat_id, 0) current_willing = self.chat_reply_willing.get(chat_id, 0)
# print(f"初始意愿: {current_willing}")
if is_mentioned_bot and current_willing < 1.0: if is_mentioned_bot and current_willing < 1.0:
current_willing += 0.9 current_willing += 0.9
logger.debug(f"被提及, 当前意愿: {current_willing}") print(f"被提及, 当前意愿: {current_willing}")
elif is_mentioned_bot: elif is_mentioned_bot:
current_willing += 0.05 current_willing += 0.05
logger.debug(f"被重复提及, 当前意愿: {current_willing}") print(f"被重复提及, 当前意愿: {current_willing}")
if is_emoji: if is_emoji:
current_willing *= 0.1 current_willing *= 0.1
logger.debug(f"表情包, 当前意愿: {current_willing}") print(f"表情包, 当前意愿: {current_willing}")
logger.debug(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}") print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}")
interested_rate *= global_config.response_interested_rate_amplifier # 放大回复兴趣度 interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度
if interested_rate > 0.4: if interested_rate > 0.4:
# print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") # print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
current_willing += interested_rate - 0.4 current_willing += interested_rate-0.4
current_willing *= global_config.response_willing_amplifier # 放大回复意愿 current_willing *= global_config.response_willing_amplifier #放大回复意愿
# print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}") # print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}")
reply_probability = max((current_willing - 0.45) * 2, 0) reply_probability = max((current_willing - 0.45) * 2, 0)
# 检查群组权限(如果是群聊) # 检查群组权限(如果是群聊)
if chat_stream.group_info: if chat_stream.group_info:
if chat_stream.group_info.group_id in config.talk_frequency_down_groups: if chat_stream.group_info.group_id in config.talk_frequency_down_groups:
reply_probability = reply_probability / global_config.down_frequency_rate reply_probability = reply_probability / global_config.down_frequency_rate
reply_probability = min(reply_probability, 1) reply_probability = min(reply_probability, 1)
if reply_probability < 0: if reply_probability < 0:
reply_probability = 0 reply_probability = 0
self.chat_reply_willing[chat_id] = min(current_willing, 3.0) self.chat_reply_willing[chat_id] = min(current_willing, 3.0)
return reply_probability return reply_probability
def change_reply_willing_sent(self, chat_stream: ChatStream): def change_reply_willing_sent(self, chat_stream:ChatStream):
"""开始思考后降低聊天流的回复意愿""" """开始思考后降低聊天流的回复意愿"""
stream = chat_stream stream = chat_stream
if stream: if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0) current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2) self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2)
def change_reply_willing_after_sent(self, chat_stream: ChatStream): def change_reply_willing_after_sent(self,chat_stream:ChatStream):
"""发送消息后提高聊天流的回复意愿""" """发送消息后提高聊天流的回复意愿"""
stream = chat_stream stream = chat_stream
if stream: if stream:
current_willing = self.chat_reply_willing.get(stream.stream_id, 0) current_willing = self.chat_reply_willing.get(stream.stream_id, 0)
if current_willing < 1: if current_willing < 1:
self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2) self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2)
async def ensure_started(self): async def ensure_started(self):
"""确保衰减任务已启动""" """确保衰减任务已启动"""
if not self._started: if not self._started:
@ -104,6 +107,5 @@ class WillingManager:
self._decay_task = asyncio.create_task(self._decay_reply_willing()) self._decay_task = asyncio.create_task(self._decay_reply_willing())
self._started = True self._started = True
# 创建全局实例 # 创建全局实例
willing_manager = WillingManager() willing_manager = WillingManager()

View File

@ -96,7 +96,7 @@ class Memory_graph:
dot_data = { dot_data = {
"concept": node "concept": node
} }
self.db.store_memory_dots.insert_one(dot_data) self.db.db.store_memory_dots.insert_one(dot_data)
@property @property
def dots(self): def dots(self):
@ -106,7 +106,7 @@ class Memory_graph:
def get_random_chat_from_db(self, length: int, timestamp: str): def get_random_chat_from_db(self, length: int, timestamp: str):
# 从数据库中根据时间戳获取离其最近的聊天记录 # 从数据库中根据时间戳获取离其最近的聊天记录
chat_text = '' chat_text = ''
closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出
logger.info( logger.info(
f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}")
@ -115,7 +115,7 @@ class Memory_graph:
group_id = closest_record['group_id'] # 获取groupid group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_record = list( chat_record = list(
self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit(
length)) length))
for record in chat_record: for record in chat_record:
time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time'])))
@ -130,34 +130,34 @@ class Memory_graph:
def save_graph_to_db(self): def save_graph_to_db(self):
# 清空现有的图数据 # 清空现有的图数据
self.db.graph_data.delete_many({}) self.db.db.graph_data.delete_many({})
# 保存节点 # 保存节点
for node in self.G.nodes(data=True): for node in self.G.nodes(data=True):
node_data = { node_data = {
'concept': node[0], 'concept': node[0],
'memory_items': node[1].get('memory_items', []) # 默认为空列表 'memory_items': node[1].get('memory_items', []) # 默认为空列表
} }
self.db.graph_data.nodes.insert_one(node_data) self.db.db.graph_data.nodes.insert_one(node_data)
# 保存边 # 保存边
for edge in self.G.edges(): for edge in self.G.edges():
edge_data = { edge_data = {
'source': edge[0], 'source': edge[0],
'target': edge[1] 'target': edge[1]
} }
self.db.graph_data.edges.insert_one(edge_data) self.db.db.graph_data.edges.insert_one(edge_data)
def load_graph_from_db(self): def load_graph_from_db(self):
# 清空当前图 # 清空当前图
self.G.clear() self.G.clear()
# 加载节点 # 加载节点
nodes = self.db.graph_data.nodes.find() nodes = self.db.db.graph_data.nodes.find()
for node in nodes: for node in nodes:
memory_items = node.get('memory_items', []) memory_items = node.get('memory_items', [])
if not isinstance(memory_items, list): if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else [] memory_items = [memory_items] if memory_items else []
self.G.add_node(node['concept'], memory_items=memory_items) self.G.add_node(node['concept'], memory_items=memory_items)
# 加载边 # 加载边
edges = self.db.graph_data.edges.find() edges = self.db.db.graph_data.edges.find()
for edge in edges: for edge in edges:
self.G.add_edge(edge['source'], edge['target']) self.G.add_edge(edge['source'], edge['target'])

View File

@ -44,9 +44,19 @@ class Memory_graph:
created_time=current_time, # 添加创建时间 created_time=current_time, # 添加创建时间
last_modified=current_time) # 添加最后修改时间 last_modified=current_time) # 添加最后修改时间
def add_dot(self, concept, memory): def add_dot(self, concept, memory, group_id=None):
current_time = datetime.datetime.now().timestamp() current_time = datetime.datetime.now().timestamp()
# 如果memory不是字典格式将其转换为字典
if not isinstance(memory, dict):
memory = {
'content': memory,
'group_id': group_id
}
# 如果memory是字典但没有group_id添加group_id
elif 'group_id' not in memory and group_id is not None:
memory['group_id'] = group_id
if concept in self.G: if concept in self.G:
if 'memory_items' in self.G.nodes[concept]: if 'memory_items' in self.G.nodes[concept]:
if not isinstance(self.G.nodes[concept]['memory_items'], list): if not isinstance(self.G.nodes[concept]['memory_items'], list):
@ -218,6 +228,16 @@ class Hippocampus:
if not messages: if not messages:
return set(), {} return set(), {}
# 提取群聊ID信息
group_id = None
for msg in messages:
if 'group_id' in msg and msg['group_id']:
group_id = msg['group_id']
break
# 获取群组类型
group_type = self.get_group_type(group_id) if group_id else None
# 合并消息文本,同时保留时间信息 # 合并消息文本,同时保留时间信息
input_text = "" input_text = ""
time_info = "" time_info = ""
@ -264,17 +284,50 @@ class Hippocampus:
# 等待所有任务完成 # 等待所有任务完成
compressed_memory = set() compressed_memory = set()
similar_topics_dict = {} # 存储每个话题的相似主题列表 similar_topics_dict = {} # 存储每个话题的相似主题列表
for topic, task in tasks: for topic, task in tasks:
response = await task response = await task
if response: if response:
compressed_memory.add((topic, response[0])) # 为每个主题创建特定群聊/群组的主题节点名称
# 为每个话题查找相似的已存在主题 topic_node_name = topic
# 如果是私有群组的群聊,使用群组前缀
if group_type:
topic_node_name = f"{topic}_GT{group_type}"
# 否则使用群聊ID前缀如果有
elif group_id:
topic_node_name = f"{topic}_g{group_id}"
# 记录主题内容
memory_content = response[0]
compressed_memory.add((topic_node_name, memory_content))
# 为每个话题查找相似的已存在主题(不考虑群聊/群组前缀)
existing_topics = list(self.memory_graph.G.nodes()) existing_topics = list(self.memory_graph.G.nodes())
similar_topics = [] similar_topics = []
for existing_topic in existing_topics: for existing_topic in existing_topics:
# 提取基础主题名和群组/群聊信息
base_existing_topic = existing_topic
existing_group_type = None
existing_group_id = None
# 检查是否有群组前缀
if "_GT" in existing_topic:
parts = existing_topic.split("_GT")
base_existing_topic = parts[0]
if len(parts) > 1:
existing_group_type = parts[1]
# 检查是否有群聊前缀
elif "_g" in existing_topic:
parts = existing_topic.split("_g")
base_existing_topic = parts[0]
if len(parts) > 1:
existing_group_id = parts[1]
# 计算基础主题的相似度
topic_words = set(jieba.cut(topic)) topic_words = set(jieba.cut(topic))
existing_words = set(jieba.cut(existing_topic)) existing_words = set(jieba.cut(base_existing_topic))
all_words = topic_words | existing_words all_words = topic_words | existing_words
v1 = [1 if word in topic_words else 0 for word in all_words] v1 = [1 if word in topic_words else 0 for word in all_words]
@ -282,12 +335,27 @@ class Hippocampus:
similarity = cosine_similarity(v1, v2) similarity = cosine_similarity(v1, v2)
if similarity >= 0.6: # 如果相似度高且不是完全相同的主题
similar_topics.append((existing_topic, similarity)) if similarity >= 0.6 and existing_topic != topic_node_name:
# 如果当前主题属于群组,只连接该群组内的主题或公共主题
if group_type:
# 只连接同群组主题或没有群组/群聊标识的通用主题
if (existing_group_type == group_type) or (not existing_group_type and not existing_group_id):
similar_topics.append((existing_topic, similarity))
# 如果当前主题不属于群组但有群聊ID
elif group_id:
# 只连接同群聊主题或没有群组/群聊标识的通用主题
if (existing_group_id == group_id) or (not existing_group_type and not existing_group_id):
similar_topics.append((existing_topic, similarity))
# 如果当前主题既不属于群组也没有群聊ID通用主题
else:
# 只连接没有群组标识的主题
if not existing_group_type:
similar_topics.append((existing_topic, similarity))
similar_topics.sort(key=lambda x: x[1], reverse=True) similar_topics.sort(key=lambda x: x[1], reverse=True)
similar_topics = similar_topics[:5] similar_topics = similar_topics[:5]
similar_topics_dict[topic] = similar_topics similar_topics_dict[topic_node_name] = similar_topics
return compressed_memory, similar_topics_dict return compressed_memory, similar_topics_dict
@ -315,6 +383,11 @@ class Hippocampus:
bar = '' * filled_length + '-' * (bar_length - filled_length) bar = '' * filled_length + '-' * (bar_length - filled_length)
logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})")
# 获取该批次消息的group_id
group_id = None
if messages and len(messages) > 0 and 'group_id' in messages[0]:
group_id = messages[0]['group_id']
compress_rate = global_config.memory_compress_rate compress_rate = global_config.memory_compress_rate
compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate)
logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}") logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}")
@ -323,15 +396,25 @@ class Hippocampus:
for topic, memory in compressed_memory: for topic, memory in compressed_memory:
logger.info(f"添加节点: {topic}") logger.info(f"添加节点: {topic}")
self.memory_graph.add_dot(topic, memory) self.memory_graph.add_dot(topic, memory) # 不再需要传递group_id因为已经包含在topic名称中
all_topics.append(topic) all_topics.append(topic)
# 连接相似的已存在主题 # 连接相似的已存在主题,但使用较弱的连接强度
if topic in similar_topics_dict: if topic in similar_topics_dict:
similar_topics = similar_topics_dict[topic] similar_topics = similar_topics_dict[topic]
for similar_topic, similarity in similar_topics: for similar_topic, similarity in similar_topics:
if topic != similar_topic: if topic != similar_topic:
# 如果是跨群聊的相似主题,使用较弱的连接强度
is_cross_group = False
if ("_g" in topic and "_g" in similar_topic and
topic.split("_g")[1] != similar_topic.split("_g")[1]):
is_cross_group = True
# 跨群聊的相似主题使用较弱的连接强度
strength = int(similarity * 10) strength = int(similarity * 10)
if is_cross_group:
strength = int(similarity * 5) # 降低跨群聊连接的强度
logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})") logger.info(f"连接相似节点: {topic}{similar_topic} (强度: {strength})")
self.memory_graph.G.add_edge(topic, similar_topic, self.memory_graph.G.add_edge(topic, similar_topic,
strength=strength, strength=strength,
@ -682,11 +765,12 @@ class Hippocampus:
prompt = f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好' prompt = f'这是一段文字,{time_info}{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好'
return prompt return prompt
async def _identify_topics(self, text: str) -> list: async def _identify_topics(self, text: str, group_id: str = None) -> list:
"""从文本中识别可能的主题 """从文本中识别可能的主题
Args: Args:
text: 输入文本 text: 输入文本
group_id: 群聊ID用于生成群聊特定的主题名
Returns: Returns:
list: 识别出的主题列表 list: 识别出的主题列表
@ -697,6 +781,15 @@ class Hippocampus:
topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()] topics_response[0].replace("", ",").replace("", ",").replace(" ", ",").split(",") if topic.strip()]
# print(f"话题: {topics}") # print(f"话题: {topics}")
# 如果提供了群聊ID添加群聊/群组标识
if group_id:
# 检查群聊是否属于特定群组
group_type = self.get_group_type(group_id)
if group_type:
topics = [f"{topic}_GT{group_type}" for topic in topics]
else:
topics = [f"{topic}_g{group_id}" for topic in topics]
return topics return topics
def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list: def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list:
@ -719,11 +812,47 @@ class Hippocampus:
# print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}") # print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}")
pass pass
topic_vector = text_to_vector(topic) # 提取基础主题(去除群聊/群组标识)
base_topic = topic
topic_group_type = None
topic_group_id = None
# 检查是否有群组前缀
if "_GT" in topic:
parts = topic.split("_GT")
base_topic = parts[0]
if len(parts) > 1:
topic_group_type = parts[1]
# 检查是否有群聊前缀
elif "_g" in topic:
parts = topic.split("_g")
base_topic = parts[0]
if len(parts) > 1:
topic_group_id = parts[1]
topic_vector = text_to_vector(base_topic)
has_similar_topic = False has_similar_topic = False
for memory_topic in all_memory_topics: for memory_topic in all_memory_topics:
memory_vector = text_to_vector(memory_topic) # 提取记忆主题的基础主题和群聊/群组ID
base_memory_topic = memory_topic
memory_group_type = None
memory_group_id = None
# 检查是否有群组前缀
if "_GT" in memory_topic:
parts = memory_topic.split("_GT")
base_memory_topic = parts[0]
if len(parts) > 1:
memory_group_type = parts[1]
# 检查是否有群聊前缀
elif "_g" in memory_topic:
parts = memory_topic.split("_g")
base_memory_topic = parts[0]
if len(parts) > 1:
memory_group_id = parts[1]
memory_vector = text_to_vector(base_memory_topic)
# 获取所有唯一词 # 获取所有唯一词
all_words = set(topic_vector.keys()) | set(memory_vector.keys()) all_words = set(topic_vector.keys()) | set(memory_vector.keys())
# 构建向量 # 构建向量
@ -732,7 +861,32 @@ class Hippocampus:
# 计算相似度 # 计算相似度
similarity = cosine_similarity(v1, v2) similarity = cosine_similarity(v1, v2)
if similarity >= similarity_threshold: # 检查是否应该考虑这个主题
should_consider = False
# 如果当前主题属于群组
if topic_group_type:
# 只考虑同群组主题或公共主题
if (memory_group_type == topic_group_type) or (not memory_group_type and not memory_group_id):
should_consider = True
# 如果当前主题属于特定群聊但不属于群组
elif topic_group_id:
# 只考虑同群聊主题或公共主题
if (memory_group_id == topic_group_id) or (not memory_group_type and not memory_group_id):
should_consider = True
# 如果当前主题是公共主题
else:
# 只考虑公共主题
if not memory_group_type and not memory_group_id:
should_consider = True
# 如果基础主题相似且应该考虑该主题
if similarity >= similarity_threshold and should_consider:
# 如果两个主题属于同一群组/群聊,提高相似度
if (topic_group_type and memory_group_type and topic_group_type == memory_group_type) or \
(topic_group_id and memory_group_id and topic_group_id == memory_group_id):
similarity *= 1.2 # 提高20%的相似度
has_similar_topic = True has_similar_topic = True
if debug_info: if debug_info:
# print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})") # print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})")
@ -841,10 +995,21 @@ class Hippocampus:
return activation return activation
async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4,
max_memory_num: int = 5) -> list: max_memory_num: int = 5, group_id: str = None) -> list:
"""根据输入文本获取相关的记忆内容""" """根据输入文本获取相关的记忆内容
# 识别主题
identified_topics = await self._identify_topics(text) Args:
text: 输入文本
max_topics: 最大主题数量
similarity_threshold: 相似度阈值
max_memory_num: 最大记忆数量
group_id: 群聊ID用于优先获取指定群聊的记忆
"""
# 获取群组类型
group_type = self.get_group_type(group_id) if group_id else None
# 识别主题传入群聊ID
identified_topics = await self._identify_topics(text, group_id)
# 查找相似主题 # 查找相似主题
all_similar_topics = self._find_similar_topics( all_similar_topics = self._find_similar_topics(
@ -857,30 +1022,140 @@ class Hippocampus:
relevant_topics = self._get_top_topics(all_similar_topics, max_topics) relevant_topics = self._get_top_topics(all_similar_topics, max_topics)
# 获取相关记忆内容 # 获取相关记忆内容
relevant_memories = [] current_group_memories = [] # 当前群组/群聊的记忆
other_memories = [] # 其他群聊/公共的记忆
for topic, score in relevant_topics: for topic, score in relevant_topics:
# 检查主题是否属于当前群组/群聊
topic_group_type = None
topic_group_id = None
# 检查是否有群组前缀
if "_GT" in topic:
parts = topic.split("_GT")
if len(parts) > 1:
topic_group_type = parts[1]
# 检查是否有群聊前缀
elif "_g" in topic:
parts = topic.split("_g")
if len(parts) > 1:
topic_group_id = parts[1]
# 获取该主题的记忆内容 # 获取该主题的记忆内容
first_layer, _ = self.memory_graph.get_related_item(topic, depth=1) first_layer, _ = self.memory_graph.get_related_item(topic, depth=1)
if first_layer: if first_layer:
# 如果记忆条数超过限制,随机选择指定数量的记忆 # 构建记忆项
if len(first_layer) > max_memory_num / 2:
first_layer = random.sample(first_layer, max_memory_num // 2)
# 为每条记忆添加来源主题和相似度信息
for memory in first_layer: for memory in first_layer:
relevant_memories.append({ memory_item = {
'topic': topic, 'topic': topic,
'similarity': score, 'similarity': score,
'content': memory 'content': memory if not isinstance(memory, dict) else memory.get('content', memory),
}) 'group_id': topic_group_id,
'group_type': topic_group_type
# 如果记忆数量超过5个,随机选择5个 }
# 分类记忆
# 如果主题属于当前群组
if group_type and topic_group_type == group_type:
current_group_memories.append(memory_item)
# 如果主题属于当前群聊(且群聊不属于任何群组)
elif not group_type and topic_group_id == group_id:
current_group_memories.append(memory_item)
# 如果主题是公共主题(没有群组/群聊标识)
elif not topic_group_type and not topic_group_id:
other_memories.append(memory_item)
# 如果是其他群聊/群组的主题且不是私有群组,也可以添加
elif (topic_group_type and not topic_group_type in global_config.memory_private_groups) or \
(not topic_group_type and topic_group_id):
other_memories.append(memory_item)
# 按相似度排序 # 按相似度排序
relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) current_group_memories.sort(key=lambda x: x['similarity'], reverse=True)
other_memories.sort(key=lambda x: x['similarity'], reverse=True)
# 记录日志
logger.debug(f"[记忆检索] 当前群聊/群组找到 {len(current_group_memories)} 条记忆,其他群聊/公共找到 {len(other_memories)} 条记忆")
# 控制返回的记忆数量
# 优先添加当前群组/群聊的记忆,如果不足,再添加其他群聊/公共的记忆
final_memories = current_group_memories
remaining_slots = max_memory_num - len(final_memories)
if remaining_slots > 0 and other_memories:
# 添加其他记忆但最多添加remaining_slots个
final_memories.extend(other_memories[:remaining_slots])
# 如果记忆总数仍然超过限制,随机采样
if len(final_memories) > max_memory_num:
final_memories = random.sample(final_memories, max_memory_num)
# 记录日志,显示最终返回多少条记忆
logger.debug(f"[记忆检索] 最终返回 {len(final_memories)} 条记忆")
return final_memories
if len(relevant_memories) > max_memory_num: def get_group_memories(self, group_id: str) -> list:
relevant_memories = random.sample(relevant_memories, max_memory_num) """获取特定群聊的所有记忆
Args:
group_id: 群聊ID
Returns:
list: 该群聊的记忆列表每个记忆包含主题和内容
"""
all_memories = []
all_nodes = list(self.memory_graph.G.nodes(data=True))
# 获取群组类型
group_type = self.get_group_type(group_id)
for concept, data in all_nodes:
# 检查是否应该包含该主题的记忆
should_include = False
# 如果群聊属于群组
if group_type and f"_GT{group_type}" in concept:
should_include = True
# 如果是特定群聊的记忆
elif f"_g{group_id}" in concept:
should_include = True
# 如果是公共记忆(没有群组/群聊标识)
elif not "_GT" in concept and not "_g" in concept:
should_include = True
if should_include:
memory_items = data.get('memory_items', [])
if not isinstance(memory_items, list):
memory_items = [memory_items] if memory_items else []
# 添加所有记忆项
for memory in memory_items:
all_memories.append({
'topic': concept,
'content': memory if not isinstance(memory, dict) else memory.get('content', str(memory))
})
return all_memories
return relevant_memories def get_group_type(self, group_id: str) -> str:
"""获取群聊所属的群组类型
Args:
group_id: 群聊ID
Returns:
str: 群组类型名称如果不属于任何群组则返回None
"""
if not group_id:
return None
# 检查该群聊ID是否属于任何群组
for group_name, group_ids in global_config.memory_private_groups.items():
if group_id in group_ids:
return group_name
# 如果不属于任何群组返回None
return None
def segment_text(text): def segment_text(text):
@ -892,6 +1167,15 @@ config = driver.config
start_time = time.time() start_time = time.time()
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
# 创建记忆图 # 创建记忆图
memory_graph = Memory_graph() memory_graph = Memory_graph()
# 创建海马体 # 创建海马体

View File

@ -41,10 +41,10 @@ class LLM_request:
"""初始化数据库集合""" """初始化数据库集合"""
try: try:
# 创建llm_usage集合的索引 # 创建llm_usage集合的索引
self.db.llm_usage.create_index([("timestamp", 1)]) self.db.db.llm_usage.create_index([("timestamp", 1)])
self.db.llm_usage.create_index([("model_name", 1)]) self.db.db.llm_usage.create_index([("model_name", 1)])
self.db.llm_usage.create_index([("user_id", 1)]) self.db.db.llm_usage.create_index([("user_id", 1)])
self.db.llm_usage.create_index([("request_type", 1)]) self.db.db.llm_usage.create_index([("request_type", 1)])
except Exception: except Exception:
logger.error("创建数据库索引失败") logger.error("创建数据库索引失败")
@ -73,7 +73,7 @@ class LLM_request:
"status": "success", "status": "success",
"timestamp": datetime.now() "timestamp": datetime.now()
} }
self.db.llm_usage.insert_one(usage_data) self.db.db.llm_usage.insert_one(usage_data)
logger.info( logger.info(
f"Token使用情况 - 模型: {self.model_name}, " f"Token使用情况 - 模型: {self.model_name}, "
f"用户: {user_id}, 类型: {request_type}, " f"用户: {user_id}, 类型: {request_type}, "
@ -235,7 +235,7 @@ class LLM_request:
delta_content = "" delta_content = ""
accumulated_content += delta_content accumulated_content += delta_content
# 检测流式输出文本是否结束 # 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason") finish_reason = chunk["choices"][0]["finish_reason"]
if finish_reason == "stop": if finish_reason == "stop":
usage = chunk.get("usage", None) usage = chunk.get("usage", None)
if usage: if usage:

View File

@ -14,6 +14,16 @@ from ..models.utils_model import LLM_request
driver = get_driver() driver = get_driver()
config = driver.config config = driver.config
Database.initialize(
uri=os.getenv("MONGODB_URI"),
host=os.getenv("MONGODB_HOST", "127.0.0.1"),
port=int(os.getenv("MONGODB_PORT", "27017")),
db_name=os.getenv("DATABASE_NAME", "MegBot"),
username=os.getenv("MONGODB_USERNAME"),
password=os.getenv("MONGODB_PASSWORD"),
auth_source=os.getenv("MONGODB_AUTH_SOURCE"),
)
class ScheduleGenerator: class ScheduleGenerator:
def __init__(self): def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型 # 根据global_config.llm_normal这一字典配置指定模型
@ -46,7 +56,7 @@ class ScheduleGenerator:
schedule_text = str schedule_text = str
existing_schedule = self.db.schedule.find_one({"date": date_str}) existing_schedule = self.db.db.schedule.find_one({"date": date_str})
if existing_schedule: if existing_schedule:
logger.debug(f"{date_str}的日程已存在:") logger.debug(f"{date_str}的日程已存在:")
schedule_text = existing_schedule["schedule"] schedule_text = existing_schedule["schedule"]
@ -63,7 +73,7 @@ class ScheduleGenerator:
try: try:
schedule_text, _ = await self.llm_scheduler.generate_response(prompt) schedule_text, _ = await self.llm_scheduler.generate_response(prompt)
self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text})
except Exception as e: except Exception as e:
logger.error(f"生成日程失败: {str(e)}") logger.error(f"生成日程失败: {str(e)}")
schedule_text = "生成日程时出错了" schedule_text = "生成日程时出错了"
@ -143,7 +153,7 @@ class ScheduleGenerator:
"""打印完整的日程安排""" """打印完整的日程安排"""
if not self._parse_schedule(self.today_schedule_text): if not self._parse_schedule(self.today_schedule_text):
logger.warning("今日日程有误,将在下次运行时重新生成") logger.warning("今日日程有误,将在下次运行时重新生成")
self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")})
else: else:
logger.info("=== 今日日程安排 ===") logger.info("=== 今日日程安排 ===")
for time_str, activity in self.today_schedule.items(): for time_str, activity in self.today_schedule.items():

View File

@ -53,7 +53,7 @@ class LLMStatistics:
"costs_by_model": defaultdict(float) "costs_by_model": defaultdict(float)
} }
cursor = self.db.llm_usage.find({ cursor = self.db.db.llm_usage.find({
"timestamp": {"$gte": start_time} "timestamp": {"$gte": start_time}
}) })

View File

@ -13,8 +13,6 @@ from pathlib import Path
import jieba import jieba
from pypinyin import Style, pinyin from pypinyin import Style, pinyin
from loguru import logger
class ChineseTypoGenerator: class ChineseTypoGenerator:
def __init__(self, def __init__(self,
@ -40,9 +38,7 @@ class ChineseTypoGenerator:
self.max_freq_diff = max_freq_diff self.max_freq_diff = max_freq_diff
# 加载数据 # 加载数据
# print("正在加载汉字数据库,请稍候...") print("正在加载汉字数据库,请稍候...")
logger.info("正在加载汉字数据库,请稍候...")
self.pinyin_dict = self._create_pinyin_dict() self.pinyin_dict = self._create_pinyin_dict()
self.char_frequency = self._load_or_create_char_frequency() self.char_frequency = self._load_or_create_char_frequency()

View File

@ -77,6 +77,16 @@ memory_ban_words = [ #不希望记忆的词
# "403","张三" # "403","张三"
] ]
# 群组私有记忆配置 - 同一群组内的群聊共享记忆,但不与其他群组共享
# 格式为 { 群组名称 = [群聊ID列表] }
# 未配置在任何群组中的群聊记忆可以与所有群聊共享(群组内群数量过少 聊天记录过少的情况下 建议修改其他记忆参数 加强回复概率等)
# 例如:
# memory_private_groups = {
# "游戏群组" = ["123456", "234567"],
# "工作群组" = ["345678", "456789"]
# }
memory_private_groups = { }
[mood] [mood]
mood_update_interval = 1.0 # 情绪更新间隔 单位秒 mood_update_interval = 1.0 # 情绪更新间隔 单位秒
mood_decay_rate = 0.95 # 情绪衰减率 mood_decay_rate = 0.95 # 情绪衰减率
@ -93,7 +103,7 @@ reaction = "有人说你是人机或者机器人,否定这一事实,攻击
[[keywords_reaction.rules]] # 就像这样复制 [[keywords_reaction.rules]] # 就像这样复制
enable = false # 仅作示例,不会触发 enable = false # 仅作示例,不会触发
keywords = ["测试关键词回复","test",""] keywords = ["测试关键词回复","test",""]
reaction = "回答“测试成功”" reaction = "回答"""
[chinese_typo] [chinese_typo]
enable = true # 是否启用中文错别字生成器 enable = true # 是否启用中文错别字生成器

View File

@ -1,45 +1,45 @@
@echo off @echo off
setlocal enabledelayedexpansion setlocal enabledelayedexpansion
chcp 65001 chcp 65001
cd /d %~dp0 cd /d %~dp0
echo ===================================== echo =====================================
echo 选择Python环境: echo 选择Python环境:
echo 1 - venv (推荐) echo 1 - venv (推荐)
echo 2 - conda echo 2 - conda
echo ===================================== echo =====================================
choice /c 12 /n /m "输入数字(1或2): " choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 ( if errorlevel 2 (
echo ===================================== echo =====================================
set "CONDA_ENV=" set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: " set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空 :: 检查输入是否为空
if "!CONDA_ENV!"=="" ( if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空 echo 错误:环境名称不能为空
pause pause
exit /b 1 exit /b 1
) )
call conda activate !CONDA_ENV! call conda activate !CONDA_ENV!
if errorlevel 1 ( if errorlevel 1 (
echo 激活 conda 环境失败 echo 激活 conda 环境失败
pause pause
exit /b 1 exit /b 1
) )
echo Conda 环境 "!CONDA_ENV!" 激活成功 echo Conda 环境 "!CONDA_ENV!" 激活成功
python config/auto_update.py python config/auto_update.py
) else ( ) else (
if exist "venv\Scripts\python.exe" ( if exist "venv\Scripts\python.exe" (
venv\Scripts\python config/auto_update.py venv\Scripts\python config/auto_update.py
) else ( ) else (
echo ===================================== echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境 echo 错误: venv环境不存在请先创建虚拟环境
pause pause
exit /b 1 exit /b 1
) )
) )
endlocal endlocal
pause pause

View File

@ -1,45 +1,45 @@
@echo off @echo off
setlocal enabledelayedexpansion setlocal enabledelayedexpansion
chcp 65001 chcp 65001
cd /d %~dp0 cd /d %~dp0
echo ===================================== echo =====================================
echo 选择Python环境: echo 选择Python环境:
echo 1 - venv (推荐) echo 1 - venv (推荐)
echo 2 - conda echo 2 - conda
echo ===================================== echo =====================================
choice /c 12 /n /m "输入数字(1或2): " choice /c 12 /n /m "输入数字(1或2): "
if errorlevel 2 ( if errorlevel 2 (
echo ===================================== echo =====================================
set "CONDA_ENV=" set "CONDA_ENV="
set /p CONDA_ENV="请输入要激活的 conda 环境名称: " set /p CONDA_ENV="请输入要激活的 conda 环境名称: "
:: 检查输入是否为空 :: 检查输入是否为空
if "!CONDA_ENV!"=="" ( if "!CONDA_ENV!"=="" (
echo 错误:环境名称不能为空 echo 错误:环境名称不能为空
pause pause
exit /b 1 exit /b 1
) )
call conda activate !CONDA_ENV! call conda activate !CONDA_ENV!
if errorlevel 1 ( if errorlevel 1 (
echo 激活 conda 环境失败 echo 激活 conda 环境失败
pause pause
exit /b 1 exit /b 1
) )
echo Conda 环境 "!CONDA_ENV!" 激活成功 echo Conda 环境 "!CONDA_ENV!" 激活成功
python src/plugins/zhishi/knowledge_library.py python src/plugins/zhishi/knowledge_library.py
) else ( ) else (
if exist "venv\Scripts\python.exe" ( if exist "venv\Scripts\python.exe" (
venv\Scripts\python src/plugins/zhishi/knowledge_library.py venv\Scripts\python src/plugins/zhishi/knowledge_library.py
) else ( ) else (
echo ===================================== echo =====================================
echo 错误: venv环境不存在请先创建虚拟环境 echo 错误: venv环境不存在请先创建虚拟环境
pause pause
exit /b 1 exit /b 1
) )
) )
endlocal endlocal
pause pause