diff --git a/README.md b/README.md index c14ac646..ad318aec 100644 --- a/README.md +++ b/README.md @@ -77,10 +77,6 @@ - [🎀 新手配置指南](docs/installation_cute.md) - 通俗易懂的配置教程,适合初次使用的猫娘 - [⚙️ 标准配置指南](docs/installation_standard.md) - 简明专业的配置说明,适合有经验的用户 -### 常见问题 - -- [❓ 快速 Q & A ](docs/fast_q_a.md) - 针对新手的疑难解答,适合完全没接触过编程的新手 -

了解麦麦

diff --git a/bot.py b/bot.py index 19ad8002..09d081be 100644 --- a/bot.py +++ b/bot.py @@ -12,8 +12,6 @@ from loguru import logger from nonebot.adapters.onebot.v11 import Adapter import platform -from src.common.database import Database - # 获取没有加载env时的环境变量 env_mask = {key: os.getenv(key) for key in os.environ} @@ -111,17 +109,6 @@ def load_env(): logger.error(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") RuntimeError(f"ENVIRONMENT 配置错误,请检查 .env 文件中的 ENVIRONMENT 变量及对应 .env.{env} 是否存在") -def init_database(): - Database.initialize( - uri=os.getenv("MONGODB_URI"), - host=os.getenv("MONGODB_HOST", "127.0.0.1"), - port=int(os.getenv("MONGODB_PORT", "27017")), - db_name=os.getenv("DATABASE_NAME", "MegBot"), - username=os.getenv("MONGODB_USERNAME"), - password=os.getenv("MONGODB_PASSWORD"), - auth_source=os.getenv("MONGODB_AUTH_SOURCE"), - ) - def load_logger(): logger.remove() # 移除默认配置 @@ -223,7 +210,6 @@ def raw_main(): init_config() init_env() load_env() - init_database() # 加载完成环境后初始化database load_logger() env_config = {key: os.getenv(key) for key in os.environ} @@ -249,6 +235,7 @@ def raw_main(): if __name__ == "__main__": + try: raw_main() diff --git a/hort --pretty=format-ad -s b/hort --pretty=format-ad -s new file mode 100644 index 00000000..faeacdd5 --- /dev/null +++ b/hort --pretty=format-ad -s @@ -0,0 +1,141 @@ +cbb569e - Create 如果你更新了版本,点我.txt +a91ef7b - 自动升级配置文件脚本 +ed18f2e - 新增了知识库一键启动漂亮脚本 +80ed568 - fix: 删除print调试代码 +c681a82 - 修复小名无效问题 +e54038f - fix: 从 nixpkgs 增加 numpy 依赖,以避免出现 libc++.so 找不到的问题 +26782c9 - fix: 修复 ENVIRONMENT 变量在同一终端下不能被覆盖的问题 +8c34637 - 提高健壮性 +2688a96 - close SengokuCola/MaiMBot#225 让麦麦可以正确读取分享卡片 +cd16e68 - 修复表情包发送时的缺失参数 +b362c35 - feat: 更新 flake.nix ,采用 venv 的方式生成环境,nixos用户也可以本机运行项目了 +3c8c897 - 屏蔽一个臃肿的debug信息 +9d0152a - 修复了合并过程中造成的代码重复 +956135c - 添加一些注释 +a412741 - 将print变为logger.debug +3180426 - 修复了没有改掉的typo字段 +aea3bff - 添加私聊过滤开关,更新config,增加约束 +cda6281 - chore: update emoji_manager.py +baed856 - 修正了私聊屏蔽词输出 +66a0f18 - 修复了私聊时产生reply消息的bug +3bf5cd6 - feat: 新增运行时重载配置文件;新增根据不同环境(dev;prod)显示不同级别的log +33cd83b - 添加私聊功能 +aa41f0d - fix: 放反了 +ef8691c - fix: 修改message继承逻辑,修复回复消息无法识别 +7d017be - fix:模型降级 +e1019ad - fix: 修复变量拼写错误并优化代码可读性 +c24bb70 - fix: 流式输出模式增加结束判断与token用量记录 +60a9376 - 添加logger的debug输出开关,默认为不开启 +bfa9a3c - fix: 添加群信息获取的错误处理 (#173) +4cc5c8e - 修正.env.prod和.env.dev的生成 +dea14c1 - fix: 模型降级目前只对硅基流动的V3和R1生效 +b6edbea - fix: 图片保存路径不正确 +01a6fa8 - fix: 删除神秘test +20f009d - 修复systemctl强制停止maimbot的问题 +af962c2 - 修复了情绪管理器没有正确导入导致发布出消息 +0586700 - 按照Sourcery提供的建议修改systemctl管理指南 +e48b32a - 在手动部署教程中增加使用systemctl管理 +5760412 - fix: 小修 +1c9b0cc - fix: 修复部分cq码解析错误,merge +b6867b9 - fix: 统一使用os.getenv获取数据库连接信息,避免从config对象获取不存在的值时出现KeyError +5e069f7 - 修复记忆保存时无时间信息的bug +73a3e41 - 修复记忆更新bug +52c93ba - refactor: use Base64 for emoji CQ codes +67f6d7c - fix: 保证能运行的小修改 +c32c4fb - refactor: 修改配置文件的版本号 +a54ca8c - Merge remote-tracking branch 'upstream/debug' into feat_regix +8cbf9bb - feat: 史上最好的消息流重构和图片管理 +9e41c4f - feat: 修改 bot_config 0.0.5 版本的变更日志 +eede406 - fix: 修复nonebot无法加载项目的问题 +00e02ed - fix: 0.0.5 版本的增加分层控制项 +0f99d6a - Update docs/docker_deploy.md +c789074 - feat: 增加ruff依赖 +ff65ab8 - feat: 修改默认的ruff配置文件,同时消除config的所有不符合规范的地方 +bf97013 - feat: 精简日志,禁用Uvicorn/NoneBot默认日志;启动方式改为显示加载uvicorn,以便优雅shutdown +d9a2863 - 优化Docker部署文档更新容器部分 +efcf00f - Docker部署文档追加更新部分 +a63ce96 - fix: 更新情感判断模型配置(使配置文件里的 llm_emotion_judge 生效) +1294c88 - feat: 增加标准化格式化设置 +2e8cd47 - fix: 避免可能出现的日程解析错误 +043a724 - 修一下文档跳转,小美化( +e4b8865 - 支持别名,可以用不同名称召唤机器人 +7b35ddd - ruff 哥又有新点子 +7899e67 - feat: 重构完成开始测试debug +354d6d0 - 记忆系统优化 +6cef8fd - 修复时区,删去napcat用不到的端口 +cd96644 - 添加使用说明 +84495f8 - fix +204744c - 修改配置名与修改过滤对象为raw_message +a03b490 - Update README.md +2b2b342 - feat: 增加 ruff 依赖 +72a6749 - fix: 修复docker部署时区指定问题 +ee579bc - Update README.md +1b611ec - resolve SengokuCola/MaiMBot#167 根据正则表达式过滤消息 +6e2ea82 - refractor: 几乎写完了,进入测试阶段 +2ffdfef - More +e680405 - fix: typo 'discription' +68b3f57 - Minor Doc Update +312f065 - Create linux_deploy_guide_for_beginners.md +ed505a4 - fix: 使用动态路径替换硬编码的项目路径 +8ff7bb6 - docs: 更新文档,修正格式并添加必要的换行符 +6e36a56 - feat: 增加 MONGODB_URI 的配置项,并将所有env文件的注释单独放在一行(python的dotenv有时无法正确处理行内注释) +4baa6c6 - feat: 实现MongoDB URI方式连接,并统一数据库连接代码。 +8a32d18 - feat: 优化willing_manager逻辑,增加回复保底概率 +c9f1244 - docs: 改进README.md文档格式和排版 +e1b484a - docs: 添加CLAUDE.md开发指南文件(用于Claude Code) +a43f949 - fix: remove duplicate message(CR comments) +fddb641 - fix: 修复错误的空值检测逻辑 +8b7876c - fix: 修复没有上传tag的问题 +6b4130e - feat: 增加stable-dev分支的打包 +052e67b - refactor: 日志打印优化(终于改完了,爽了 +a7f9d05 - 修复记忆整理传入格式问题 +536bb1d - fix: 更新情感判断模型配置 +8d99592 - fix: logger初始化顺序 +052802c - refactor: logger promotion +8661d94 - doc: README.md - telegram version information +5746afa - refactor: logger in src\plugins\chat\bot.py +288dbb6 - refactor: logger in src\plugins\chat\__init__.py +8428a06 - fix: memory logger optimization (CR comment) +665c459 - 改进了可视化脚本 +6c35704 - fix: 调用了错误的函数 +3223153 - feat: 一键脚本新增记忆可视化 +3149dd3 - fix: mongodb.zip 无法解压 fix:更换执行命令的方法 fix:当 db 不存在时自动创建 feat: 一键安装完成后启动麦麦 +089d6a6 - feat: 针对硅基流动的Pro模型添加了自动降级功能 +c4b0917 - 一个记忆可视化小脚本 +6a71ea4 - 修复了记忆时间bug,config添加了记忆屏蔽关键词 +1b5344f - fix: 优化bot初始化的日志&格式 +41aa974 - fix: 优化chat/config.py的日志&格式 +980cde7 - fix: 优化scheduler_generator日志&格式 +31a5514 - fix: 调整全局logger加载顺序 +8baef07 - feat: 添加全局logger初始化设置 +5566f17 - refractor: 几乎写完了,进入测试阶段 +6a66933 - feat: 添加开发环境.env.dev初始化 +411ff1a - feat: 安装 MongoDB Compass +0de9eba - feat: 增加实时更新贡献者列表的功能 +f327f45 - fix: 优化src/plugins/chat/__init__.py的import +826daa5 - fix: 当虚拟环境存在时跳过创建 +f54de42 - fix: time.tzset 仅在类 Unix 系统可用 +47c4990 - fix: 修复docker部署场景下时间错误的问题 +e23a371 - docs: 添加 compose 注释 +1002822 - docs: 标注 Python 最低版本 +564350d - feat: 校验 Python 版本 +4cc4482 - docs: 添加傻瓜式脚本 +757173a - 带麦麦看了心理医生,让她没那么容易陷入负面情绪 +39bb99c - 将错别字生成提取到配置,一句一个错别字太烦了! +fe36847 - feat: 超大型重构 +e304dd7 - Update README.md +b7cfe6d - feat: 发布第 0.0.2 版本配置模板 +ca929d5 - 补充Docker部署文档 +1e97120 - 补充Docker部署文档 +25f7052 - fix: 修复兼容性选项和目前第一个版本之间的版本间隙 0.0.0 版,并将所有的直接退出修改为抛出异常 +c5bdc4f - 防ipv6炸,虽然小概率事件 +d86610d - fix: 修复不能加载环境变量的问题 +2306ebf - feat: 因为判断临界版本范围比较麻烦,增加 notice 字段,删除原本的判断逻辑(存在故障) +dd09576 - fix: 修复 TypeError: BotConfig.convert_to_specifierset() takes 1 positional argument but 2 were given +18f839b - fix: 修复 missing 1 required positional argument: 'INNER_VERSION' +6adb5ed - 调整一些细节,docker部署时可选数据库账密 +07f48e9 - fix: 利用filter来过滤环境变量,避免直接删除key造成的 RuntimeError: dictionary changed size during iteration +5856074 - fix: 修复无法进行基础设置的问题 +32aa032 - feat: 发布 0.0.1 版本的配置文件 +edc07ac - feat: 重构配置加载器,增加配置文件版本控制和程序兼容能力 +0f492ed - fix: 修复 BASE_URL/KEY 组合检查中被 GPG_KEY 干扰的问题 \ No newline at end of file diff --git a/run.bat b/run.bat index 91904bc3..23c520e9 100644 --- a/run.bat +++ b/run.bat @@ -1,10 +1,10 @@ -@ECHO OFF -chcp 65001 -if not exist "venv" ( - python -m venv venv - call venv\Scripts\activate.bat - pip install -i https://mirrors.aliyun.com/pypi/simple --upgrade -r requirements.txt - ) else ( - call venv\Scripts\activate.bat -) +@ECHO OFF +chcp 65001 +if not exist "venv" ( + python -m venv venv + call venv\Scripts\activate.bat + pip install -i https://mirrors.aliyun.com/pypi/simple --upgrade -r requirements.txt + ) else ( + call venv\Scripts\activate.bat +) python run.py \ No newline at end of file diff --git a/run_memory_vis.bat b/run_memory_vis.bat index b1feb0cb..1b2b34a1 100644 --- a/run_memory_vis.bat +++ b/run_memory_vis.bat @@ -1,29 +1,29 @@ -@echo on -chcp 65001 > nul -set /p CONDA_ENV="请输入要激活的 conda 环境名称: " -call conda activate %CONDA_ENV% -if errorlevel 1 ( - echo 激活 conda 环境失败 - pause - exit /b 1 -) -echo Conda 环境 "%CONDA_ENV%" 激活成功 - -set /p OPTION="请选择运行选项 (1: 运行全部绘制, 2: 运行简单绘制): " -if "%OPTION%"=="1" ( - python src/plugins/memory_system/memory_manual_build.py -) else if "%OPTION%"=="2" ( - python src/plugins/memory_system/draw_memory.py -) else ( - echo 无效的选项 - pause - exit /b 1 -) - -if errorlevel 1 ( - echo 命令执行失败,错误代码 %errorlevel% - pause - exit /b 1 -) -echo 脚本成功完成 +@echo on +chcp 65001 > nul +set /p CONDA_ENV="请输入要激活的 conda 环境名称: " +call conda activate %CONDA_ENV% +if errorlevel 1 ( + echo 激活 conda 环境失败 + pause + exit /b 1 +) +echo Conda 环境 "%CONDA_ENV%" 激活成功 + +set /p OPTION="请选择运行选项 (1: 运行全部绘制, 2: 运行简单绘制): " +if "%OPTION%"=="1" ( + python src/plugins/memory_system/memory_manual_build.py +) else if "%OPTION%"=="2" ( + python src/plugins/memory_system/draw_memory.py +) else ( + echo 无效的选项 + pause + exit /b 1 +) + +if errorlevel 1 ( + echo 命令执行失败,错误代码 %errorlevel% + pause + exit /b 1 +) +echo 脚本成功完成 pause \ No newline at end of file diff --git a/script/run_maimai.bat b/script/run_maimai.bat index 3a099fd7..addcc052 100644 --- a/script/run_maimai.bat +++ b/script/run_maimai.bat @@ -1,7 +1,7 @@ -chcp 65001 -call conda activate maimbot -cd . - -REM 执行nb run命令 -nb run +chcp 65001 +call conda activate maimbot +cd . + +REM 执行nb run命令 +nb run pause \ No newline at end of file diff --git a/script/run_thingking.bat b/script/run_thingking.bat index a134da6f..c2a3e650 100644 --- a/script/run_thingking.bat +++ b/script/run_thingking.bat @@ -1,5 +1,5 @@ -call conda activate niuniu -cd src\gui -start /b python reasoning_gui.py -exit - +call conda activate niuniu +cd src\gui +start /b python reasoning_gui.py +exit + diff --git a/script/run_windows.bat b/script/run_windows.bat index bea397dd..a07f513f 100644 --- a/script/run_windows.bat +++ b/script/run_windows.bat @@ -1,68 +1,68 @@ -@echo off -setlocal enabledelayedexpansion -chcp 65001 - -REM 修正路径获取逻辑 -cd /d "%~dp0" || ( - echo 错误:切换目录失败 - exit /b 1 -) - -if not exist "venv\" ( - echo 正在初始化虚拟环境... - - where python >nul 2>&1 - if %errorlevel% neq 0 ( - echo 未找到Python解释器 - exit /b 1 - ) - - for /f "tokens=2" %%a in ('python --version 2^>^&1') do set version=%%a - for /f "tokens=1,2 delims=." %%b in ("!version!") do ( - set major=%%b - set minor=%%c - ) - - if !major! lss 3 ( - echo 需要Python大于等于3.0,当前版本 !version! - exit /b 1 - ) - - if !major! equ 3 if !minor! lss 9 ( - echo 需要Python大于等于3.9,当前版本 !version! - exit /b 1 - ) - - echo 正在安装virtualenv... - python -m pip install virtualenv || ( - echo virtualenv安装失败 - exit /b 1 - ) - - echo 正在创建虚拟环境... - python -m virtualenv venv || ( - echo 虚拟环境创建失败 - exit /b 1 - ) - - call venv\Scripts\activate.bat - -) else ( - call venv\Scripts\activate.bat -) - -echo 正在更新依赖... -pip install -r requirements.txt - -echo 当前代理设置: -echo HTTP_PROXY=%HTTP_PROXY% -echo HTTPS_PROXY=%HTTPS_PROXY% - -set HTTP_PROXY= -set HTTPS_PROXY= -echo 代理已取消。 - -set no_proxy=0.0.0.0/32 - -call nb run +@echo off +setlocal enabledelayedexpansion +chcp 65001 + +REM 修正路径获取逻辑 +cd /d "%~dp0" || ( + echo 错误:切换目录失败 + exit /b 1 +) + +if not exist "venv\" ( + echo 正在初始化虚拟环境... + + where python >nul 2>&1 + if %errorlevel% neq 0 ( + echo 未找到Python解释器 + exit /b 1 + ) + + for /f "tokens=2" %%a in ('python --version 2^>^&1') do set version=%%a + for /f "tokens=1,2 delims=." %%b in ("!version!") do ( + set major=%%b + set minor=%%c + ) + + if !major! lss 3 ( + echo 需要Python大于等于3.0,当前版本 !version! + exit /b 1 + ) + + if !major! equ 3 if !minor! lss 9 ( + echo 需要Python大于等于3.9,当前版本 !version! + exit /b 1 + ) + + echo 正在安装virtualenv... + python -m pip install virtualenv || ( + echo virtualenv安装失败 + exit /b 1 + ) + + echo 正在创建虚拟环境... + python -m virtualenv venv || ( + echo 虚拟环境创建失败 + exit /b 1 + ) + + call venv\Scripts\activate.bat + +) else ( + call venv\Scripts\activate.bat +) + +echo 正在更新依赖... +pip install -r requirements.txt + +echo 当前代理设置: +echo HTTP_PROXY=%HTTP_PROXY% +echo HTTPS_PROXY=%HTTPS_PROXY% + +set HTTP_PROXY= +set HTTPS_PROXY= +echo 代理已取消。 + +set no_proxy=0.0.0.0/32 + +call nb run pause \ No newline at end of file diff --git a/src/common/database.py b/src/common/database.py index c6cead22..d592b0f9 100644 --- a/src/common/database.py +++ b/src/common/database.py @@ -1,6 +1,5 @@ from typing import Optional from pymongo import MongoClient -from pymongo.database import Database as MongoDatabase class Database: _instance: Optional["Database"] = None @@ -26,7 +25,7 @@ class Database: else: # 否则使用无认证连接 self.client = MongoClient(host, port) - self.db: MongoDatabase = self.client[db_name] + self.db = self.client[db_name] @classmethod def initialize( @@ -38,36 +37,15 @@ class Database: password: Optional[str] = None, auth_source: Optional[str] = None, uri: Optional[str] = None, - ) -> MongoDatabase: + ) -> "Database": if cls._instance is None: cls._instance = cls( host, port, db_name, username, password, auth_source, uri ) - return cls._instance.db + return cls._instance @classmethod - def get_instance(cls) -> MongoDatabase: + def get_instance(cls) -> "Database": if cls._instance is None: raise RuntimeError("Database not initialized") - return cls._instance.db - - - #测试用 - - def get_random_group_messages(self, group_id: str, limit: int = 5): - # 先随机获取一条消息 - random_message = list(self.db.messages.aggregate([ - {"$match": {"group_id": group_id}}, - {"$sample": {"size": 1}} - ]))[0] - - # 获取该消息之后的消息 - subsequent_messages = list(self.db.messages.find({ - "group_id": group_id, - "time": {"$gt": random_message["time"]} - }).sort("time", 1).limit(limit)) - - # 将随机消息和后续消息合并 - messages = [random_message] + subsequent_messages - - return messages \ No newline at end of file + return cls._instance \ No newline at end of file diff --git a/src/gui/reasoning_gui.py b/src/gui/reasoning_gui.py index 84b95ada..e131658b 100644 --- a/src/gui/reasoning_gui.py +++ b/src/gui/reasoning_gui.py @@ -46,7 +46,7 @@ class ReasoningGUI: # 初始化数据库连接 try: - self.db = Database.get_instance() + self.db = Database.get_instance().db logger.success("数据库连接成功") except RuntimeError: logger.warning("数据库未初始化,正在尝试初始化...") @@ -60,7 +60,7 @@ class ReasoningGUI: password=os.getenv("MONGODB_PASSWORD"), auth_source=os.getenv("MONGODB_AUTH_SOURCE"), ) - self.db = Database.get_instance() + self.db = Database.get_instance().db logger.success("数据库初始化成功") except Exception: logger.exception("数据库初始化失败") diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 1c6bf3f3..38af5443 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -32,6 +32,18 @@ _message_manager_started = False driver = get_driver() config = driver.config +Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), +) +logger.success("初始化数据库成功") + + # 初始化表情管理器 emoji_manager.initialize() diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index f179e8ef..f335a2ba 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -235,10 +235,10 @@ class ChatBot: is_head=not mark_head, is_emoji=False, ) - logger.debug(f"bot_message: {bot_message}") + print(f"bot_message: {bot_message}") if not mark_head: mark_head = True - logger.debug(f"添加消息到message_set: {bot_message}") + print(f"添加消息到message_set: {bot_message}") message_set.add_message(bot_message) # message_set 可以直接加入 message_manager diff --git a/src/plugins/chat/chat_stream.py b/src/plugins/chat/chat_stream.py index 3ccd03f8..bee67917 100644 --- a/src/plugins/chat/chat_stream.py +++ b/src/plugins/chat/chat_stream.py @@ -111,11 +111,11 @@ class ChatManager: def _ensure_collection(self): """确保数据库集合存在并创建索引""" - if "chat_streams" not in self.db.list_collection_names(): - self.db.create_collection("chat_streams") + if "chat_streams" not in self.db.db.list_collection_names(): + self.db.db.create_collection("chat_streams") # 创建索引 - self.db.chat_streams.create_index([("stream_id", 1)], unique=True) - self.db.chat_streams.create_index( + self.db.db.chat_streams.create_index([("stream_id", 1)], unique=True) + self.db.db.chat_streams.create_index( [("platform", 1), ("user_info.user_id", 1), ("group_info.group_id", 1)] ) @@ -168,7 +168,7 @@ class ChatManager: return stream # 检查数据库中是否存在 - data = self.db.chat_streams.find_one({"stream_id": stream_id}) + data = self.db.db.chat_streams.find_one({"stream_id": stream_id}) if data: stream = ChatStream.from_dict(data) # 更新用户信息和群组信息 @@ -204,7 +204,7 @@ class ChatManager: async def _save_stream(self, stream: ChatStream): """保存聊天流到数据库""" if not stream.saved: - self.db.chat_streams.update_one( + self.db.db.chat_streams.update_one( {"stream_id": stream.stream_id}, {"$set": stream.to_dict()}, upsert=True ) stream.saved = True @@ -216,7 +216,7 @@ class ChatManager: async def load_all_streams(self): """从数据库加载所有聊天流""" - all_streams = self.db.chat_streams.find({}) + all_streams = self.db.db.chat_streams.find({}) for data in all_streams: stream = ChatStream.from_dict(data) self.streams[stream.stream_id] = stream diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 88cb31ed..86b6f6ba 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -105,6 +105,11 @@ class BotConfig: default_factory=lambda: ["表情包", "图片", "回复", "聊天记录"] ) # 添加新的配置项默认值 + # 记忆群组配置,用于定义私有记忆群组 + # 格式:{"group1": ["123456", "234567"], "group2": ["345678", "456789"]} + # 每个群组内的群聊ID共享记忆,但不与其他群组共享 + memory_private_groups: Dict[str, List[str]] = field(default_factory=dict) + @staticmethod def get_config_dir() -> str: """获取配置文件目录""" @@ -304,6 +309,11 @@ class BotConfig: config.memory_forget_time = memory_config.get("memory_forget_time", config.memory_forget_time) config.memory_forget_percentage = memory_config.get("memory_forget_percentage", config.memory_forget_percentage) config.memory_compress_rate = memory_config.get("memory_compress_rate", config.memory_compress_rate) + + # 加载群组私有记忆配置 + if "memory_private_groups" in memory_config: + config.memory_private_groups = memory_config.get("memory_private_groups", config.memory_private_groups) + logger.info(f"已加载群组私有记忆配置: {len(config.memory_private_groups)}个群组") def mood(parent: dict): mood_config = parent["mood"] diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index bc40cff8..0a8a71df 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -4,8 +4,6 @@ import time from dataclasses import dataclass from typing import Dict, List, Optional, Union -import os - import requests # 解析各种CQ码 diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index f1525107..98987c3e 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -76,16 +76,16 @@ class EmojiManager: 没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。 """ - if 'emoji' not in self.db.list_collection_names(): - self.db.create_collection('emoji') - self.db.emoji.create_index([('embedding', '2dsphere')]) - self.db.emoji.create_index([('filename', 1)], unique=True) + if 'emoji' not in self.db.db.list_collection_names(): + self.db.db.create_collection('emoji') + self.db.db.emoji.create_index([('embedding', '2dsphere')]) + self.db.db.emoji.create_index([('filename', 1)], unique=True) def record_usage(self, emoji_id: str): """记录表情使用次数""" try: self._ensure_db() - self.db.emoji.update_one( + self.db.db.emoji.update_one( {'_id': emoji_id}, {'$inc': {'usage_count': 1}} ) @@ -119,7 +119,7 @@ class EmojiManager: try: # 获取所有表情包 - all_emojis = list(self.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) + all_emojis = list(self.db.db.emoji.find({}, {'_id': 1, 'path': 1, 'embedding': 1, 'description': 1})) if not all_emojis: logger.warning("数据库中没有任何表情包") @@ -157,7 +157,7 @@ class EmojiManager: if selected_emoji and 'path' in selected_emoji: # 更新使用次数 - self.db.emoji.update_one( + self.db.db.emoji.update_one( {'_id': selected_emoji['_id']}, {'$inc': {'usage_count': 1}} ) @@ -239,7 +239,7 @@ class EmojiManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 检查是否已经注册过 - existing_emoji = self.db['emoji'].find_one({'filename': filename}) + existing_emoji = self.db.db['emoji'].find_one({'filename': filename}) description = None if existing_emoji: @@ -305,7 +305,7 @@ class EmojiManager: } # 保存到emoji数据库 - self.db['emoji'].insert_one(emoji_record) + self.db.db['emoji'].insert_one(emoji_record) logger.success(f"注册新表情包: {filename}") logger.info(f"描述: {description}") @@ -346,7 +346,7 @@ class EmojiManager: try: self._ensure_db() # 获取所有表情包记录 - all_emojis = list(self.db.emoji.find()) + all_emojis = list(self.db.db.emoji.find()) removed_count = 0 total_count = len(all_emojis) @@ -354,13 +354,13 @@ class EmojiManager: try: if 'path' not in emoji: logger.warning(f"发现无效记录(缺少path字段),ID: {emoji.get('_id', 'unknown')}") - self.db.emoji.delete_one({'_id': emoji['_id']}) + self.db.db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue if 'embedding' not in emoji: logger.warning(f"发现过时记录(缺少embedding字段),ID: {emoji.get('_id', 'unknown')}") - self.db.emoji.delete_one({'_id': emoji['_id']}) + self.db.db.emoji.delete_one({'_id': emoji['_id']}) removed_count += 1 continue @@ -368,7 +368,7 @@ class EmojiManager: if not os.path.exists(emoji['path']): logger.warning(f"表情包文件已被删除: {emoji['path']}") # 从数据库中删除记录 - result = self.db.emoji.delete_one({'_id': emoji['_id']}) + result = self.db.db.emoji.delete_one({'_id': emoji['_id']}) if result.deleted_count > 0: logger.debug(f"成功删除数据库记录: {emoji['_id']}") removed_count += 1 @@ -379,7 +379,7 @@ class EmojiManager: continue # 验证清理结果 - remaining_count = self.db.emoji.count_documents({}) + remaining_count = self.db.db.emoji.count_documents({}) if removed_count > 0: logger.success(f"已清理 {removed_count} 个失效的表情包记录") logger.info(f"清理前总数: {total_count} | 清理后总数: {remaining_count}") diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 84e1937b..46dc34e9 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -154,7 +154,7 @@ class ResponseGenerator: reasoning_content: str, ): """保存对话记录到数据库""" - self.db.reasoning_logs.insert_one( + self.db.db.reasoning_logs.insert_one( { "time": time.time(), "chat_id": message.chat_stream.stream_id, diff --git a/src/plugins/chat/message_sender.py b/src/plugins/chat/message_sender.py index 5b580f24..584bf9c5 100644 --- a/src/plugins/chat/message_sender.py +++ b/src/plugins/chat/message_sender.py @@ -10,7 +10,6 @@ from .message import MessageSending, MessageThinking, MessageRecv, MessageSet from .storage import MessageStorage from .config import global_config -from .utils import truncate_message class Message_Sender: @@ -35,7 +34,6 @@ class Message_Sender: message_json = message.to_dict() message_send = MessageSendCQ(data=message_json) # logger.debug(message_send.message_info,message_send.raw_message) - message_preview = truncate_message(message.processed_plain_text) if ( message_send.message_info.group_info and message_send.message_info.group_info.group_id @@ -46,10 +44,10 @@ class Message_Sender: message=message_send.raw_message, auto_escape=False, ) - logger.success(f"[调试] 发送消息“{message_preview}”成功") + logger.success(f"[调试] 发送消息{message.processed_plain_text}成功") except Exception as e: logger.error(f"[调试] 发生错误 {e}") - logger.error(f"[调试] 发送消息“{message_preview}”失败") + logger.error(f"[调试] 发送消息{message.processed_plain_text}失败") else: try: logger.debug(message.message_info.user_info) @@ -58,10 +56,10 @@ class Message_Sender: message=message_send.raw_message, auto_escape=False, ) - logger.success(f"[调试] 发送消息“{message_preview}”成功") + logger.success(f"[调试] 发送消息{message.processed_plain_text}成功") except Exception as e: - logger.error(f"[调试] 发生错误 {e}") - logger.error(f"[调试] 发送消息“{message_preview}”失败") + logger.error(f"发生错误 {e}") + logger.error(f"[调试] 发送消息{message.processed_plain_text}失败") class MessageContainer: @@ -186,7 +184,7 @@ class MessageManager: await message_earliest.process() print( - f"\033[1;34m[调试]\033[0m 消息“{truncate_message(message_earliest.processed_plain_text)}”正在发送中" + f"\033[1;34m[调试]\033[0m 消息'{message_earliest.processed_plain_text}'正在发送中" ) await self.storage.store_message( diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index c89bf3e0..d8cfd552 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -91,12 +91,20 @@ class PromptBuilder: memory_prompt = '' start_time = time.time() - # 调用 hippocampus 的 get_relevant_memories 方法 + # 获取群聊ID + stream_group_id = None + if stream_id: + chat_stream = chat_manager.get_stream(stream_id) + if chat_stream and chat_stream.group_info: + stream_group_id = str(chat_stream.group_info.group_id) + + # 调用 hippocampus 的 get_relevant_memories 方法,添加群聊ID参数 relevant_memories = await hippocampus.get_relevant_memories( text=message_txt, max_topics=5, similarity_threshold=0.4, - max_memory_num=5 + max_memory_num=5, + group_id=stream_group_id ) if relevant_memories: @@ -311,7 +319,7 @@ class PromptBuilder: {"$project": {"content": 1, "similarity": 1}} ] - results = list(self.db.knowledges.aggregate(pipeline)) + results = list(self.db.db.knowledges.aggregate(pipeline)) # print(f"\033[1;34m[调试]\033[0m获取知识库内容结果: {results}") if not results: diff --git a/src/plugins/chat/relationship_manager.py b/src/plugins/chat/relationship_manager.py index fbd8cec5..90e92e7b 100644 --- a/src/plugins/chat/relationship_manager.py +++ b/src/plugins/chat/relationship_manager.py @@ -168,7 +168,7 @@ class RelationshipManager: async def load_all_relationships(self): """加载所有关系对象""" db = Database.get_instance() - all_relationships = db.relationships.find({}) + all_relationships = db.db.relationships.find({}) for data in all_relationships: await self.load_relationship(data) @@ -176,7 +176,7 @@ class RelationshipManager: """每5分钟自动保存一次关系数据""" db = Database.get_instance() # 获取所有关系记录 - all_relationships = db.relationships.find({}) + all_relationships = db.db.relationships.find({}) # 依次加载每条记录 for data in all_relationships: await self.load_relationship(data) @@ -206,7 +206,7 @@ class RelationshipManager: saved = relationship.saved db = Database.get_instance() - db.relationships.update_one( + db.db.relationships.update_one( {'user_id': user_id, 'platform': platform}, {'$set': { 'platform': platform, diff --git a/src/plugins/chat/storage.py b/src/plugins/chat/storage.py index ec155bbe..bd422dc5 100644 --- a/src/plugins/chat/storage.py +++ b/src/plugins/chat/storage.py @@ -13,17 +13,23 @@ class MessageStorage: async def store_message(self, message: Union[MessageSending, MessageRecv],chat_stream:ChatStream, topic: Optional[str] = None) -> None: """存储消息到数据库""" try: + # 提取群组ID信息,如果存在的话 + group_id = None + if chat_stream.group_info: + group_id = str(chat_stream.group_info.group_id) + message_data = { "message_id": message.message_info.message_id, "time": message.message_info.time, - "chat_id":chat_stream.stream_id, + "chat_id": chat_stream.stream_id, "chat_info": chat_stream.to_dict(), "user_info": message.message_info.user_info.to_dict(), "processed_plain_text": message.processed_plain_text, "detailed_plain_text": message.detailed_plain_text, "topic": topic, + "group_id": group_id, # 显式添加group_id字段 } - self.db.messages.insert_one(message_data) + self.db.db.messages.insert_one(message_data) except Exception: logger.exception("存储消息失败") diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index cf3e59f7..46f2f5ef 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -104,11 +104,20 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): # 转换记录格式 formatted_records = [] for record in chat_records: - formatted_records.append({ + formatted_record = { 'time': record["time"], 'chat_id': record["chat_id"], 'detailed_plain_text': record.get("detailed_plain_text", "") # 添加文本内容 - }) + } + + # 添加group_id信息,如果存在 + if 'group_id' in record: + formatted_record['group_id'] = record['group_id'] + elif 'chat_info' in record and 'group_info' in record['chat_info'] and record['chat_info']['group_info']: + # 从chat_info中提取group_id + formatted_record['group_id'] = record['chat_info']['group_info'].get('group_id') + + formatted_records.append(formatted_record) return formatted_records @@ -406,10 +415,3 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: # 按相似度降序排序并返回前k个 return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] - - -def truncate_message(message: str, max_length=20) -> str: - """截断消息,使其不超过指定长度""" - if len(message) > max_length: - return message[:max_length] + "..." - return message diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 8f09a21a..42d5f9ef 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -40,20 +40,20 @@ class ImageManager: def _ensure_image_collection(self): """确保images集合存在并创建索引""" - if 'images' not in self.db.list_collection_names(): - self.db.create_collection('images') + if 'images' not in self.db.db.list_collection_names(): + self.db.db.create_collection('images') # 创建索引 - self.db.images.create_index([('hash', 1)], unique=True) - self.db.images.create_index([('url', 1)]) - self.db.images.create_index([('path', 1)]) + self.db.db.images.create_index([('hash', 1)], unique=True) + self.db.db.images.create_index([('url', 1)]) + self.db.db.images.create_index([('path', 1)]) def _ensure_description_collection(self): """确保image_descriptions集合存在并创建索引""" - if 'image_descriptions' not in self.db.list_collection_names(): - self.db.create_collection('image_descriptions') + if 'image_descriptions' not in self.db.db.list_collection_names(): + self.db.db.create_collection('image_descriptions') # 创建索引 - self.db.image_descriptions.create_index([('hash', 1)], unique=True) - self.db.image_descriptions.create_index([('type', 1)]) + self.db.db.image_descriptions.create_index([('hash', 1)], unique=True) + self.db.db.image_descriptions.create_index([('type', 1)]) def _get_description_from_db(self, image_hash: str, description_type: str) -> Optional[str]: """从数据库获取图片描述 @@ -65,7 +65,7 @@ class ImageManager: Returns: Optional[str]: 描述文本,如果不存在则返回None """ - result= self.db.image_descriptions.find_one({ + result= self.db.db.image_descriptions.find_one({ 'hash': image_hash, 'type': description_type }) @@ -79,7 +79,7 @@ class ImageManager: description: 描述文本 description_type: 描述类型 ('emoji' 或 'image') """ - self.db.image_descriptions.update_one( + self.db.db.image_descriptions.update_one( {'hash': image_hash, 'type': description_type}, { '$set': { @@ -121,7 +121,7 @@ class ImageManager: image_hash = hashlib.md5(image_bytes).hexdigest() # 查重 - existing = self.db.images.find_one({'hash': image_hash}) + existing = self.db.db.images.find_one({'hash': image_hash}) if existing: return existing['path'] @@ -142,7 +142,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.images.insert_one(image_doc) + self.db.db.images.insert_one(image_doc) return file_path @@ -159,7 +159,7 @@ class ImageManager: """ try: # 先查找是否已存在 - existing = self.db.images.find_one({'url': url}) + existing = self.db.db.images.find_one({'url': url}) if existing: return existing['path'] @@ -203,7 +203,7 @@ class ImageManager: Returns: bool: 是否存在 """ - return self.db.images.find_one({'url': url}) is not None + return self.db.db.images.find_one({'url': url}) is not None def check_hash_exists(self, image_data: Union[str, bytes], is_base64: bool = False) -> bool: """检查图像是否已存在 @@ -226,7 +226,7 @@ class ImageManager: return False image_hash = hashlib.md5(image_bytes).hexdigest() - return self.db.images.find_one({'hash': image_hash}) is not None + return self.db.db.images.find_one({'hash': image_hash}) is not None except Exception as e: logger.error(f"检查哈希失败: {str(e)}") @@ -269,7 +269,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.images.update_one( + self.db.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True @@ -326,7 +326,7 @@ class ImageManager: 'description': description, 'timestamp': timestamp } - self.db.images.update_one( + self.db.db.images.update_one( {'hash': image_hash}, {'$set': image_doc}, upsert=True diff --git a/src/plugins/chat/willing_manager.py b/src/plugins/chat/willing_manager.py index 773d40c6..f34afb74 100644 --- a/src/plugins/chat/willing_manager.py +++ b/src/plugins/chat/willing_manager.py @@ -5,98 +5,101 @@ from typing import Dict from .config import global_config from .chat_stream import ChatStream -from loguru import logger - class WillingManager: def __init__(self): + self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿 self.chat_reply_willing: Dict[str, float] = {} # 存储每个聊天流的回复意愿 self._decay_task = None self._started = False - + async def _decay_reply_willing(self): """定期衰减回复意愿""" while True: await asyncio.sleep(5) for chat_id in self.chat_reply_willing: self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6) - - def get_willing(self, chat_stream: ChatStream) -> float: + for chat_id in self.chat_reply_willing: + self.chat_reply_willing[chat_id] = max(0, self.chat_reply_willing[chat_id] * 0.6) + + def get_willing(self,chat_stream:ChatStream) -> float: """获取指定聊天流的回复意愿""" stream = chat_stream if stream: return self.chat_reply_willing.get(stream.stream_id, 0) return 0 - + def set_willing(self, chat_id: str, willing: float): """设置指定聊天流的回复意愿""" self.chat_reply_willing[chat_id] = willing - - async def change_reply_willing_received( - self, - chat_stream: ChatStream, - topic: str = None, - is_mentioned_bot: bool = False, - config=None, - is_emoji: bool = False, - interested_rate: float = 0, - ) -> float: + def set_willing(self, chat_id: str, willing: float): + """设置指定聊天流的回复意愿""" + self.chat_reply_willing[chat_id] = willing + + async def change_reply_willing_received(self, + chat_stream:ChatStream, + topic: str = None, + is_mentioned_bot: bool = False, + config = None, + is_emoji: bool = False, + interested_rate: float = 0) -> float: """改变指定聊天流的回复意愿并返回回复概率""" # 获取或创建聊天流 stream = chat_stream chat_id = stream.stream_id - + current_willing = self.chat_reply_willing.get(chat_id, 0) - + + # print(f"初始意愿: {current_willing}") if is_mentioned_bot and current_willing < 1.0: current_willing += 0.9 - logger.debug(f"被提及, 当前意愿: {current_willing}") + print(f"被提及, 当前意愿: {current_willing}") elif is_mentioned_bot: current_willing += 0.05 - logger.debug(f"被重复提及, 当前意愿: {current_willing}") - + print(f"被重复提及, 当前意愿: {current_willing}") + if is_emoji: current_willing *= 0.1 - logger.debug(f"表情包, 当前意愿: {current_willing}") - - logger.debug(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}") - interested_rate *= global_config.response_interested_rate_amplifier # 放大回复兴趣度 + print(f"表情包, 当前意愿: {current_willing}") + + print(f"放大系数_interested_rate: {global_config.response_interested_rate_amplifier}") + interested_rate *= global_config.response_interested_rate_amplifier #放大回复兴趣度 if interested_rate > 0.4: # print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}") - current_willing += interested_rate - 0.4 - - current_willing *= global_config.response_willing_amplifier # 放大回复意愿 + current_willing += interested_rate-0.4 + + current_willing *= global_config.response_willing_amplifier #放大回复意愿 # print(f"放大系数_willing: {global_config.response_willing_amplifier}, 当前意愿: {current_willing}") - + reply_probability = max((current_willing - 0.45) * 2, 0) - + # 检查群组权限(如果是群聊) - if chat_stream.group_info: + if chat_stream.group_info: if chat_stream.group_info.group_id in config.talk_frequency_down_groups: reply_probability = reply_probability / global_config.down_frequency_rate reply_probability = min(reply_probability, 1) if reply_probability < 0: reply_probability = 0 - + self.chat_reply_willing[chat_id] = min(current_willing, 3.0) return reply_probability - - def change_reply_willing_sent(self, chat_stream: ChatStream): + + def change_reply_willing_sent(self, chat_stream:ChatStream): """开始思考后降低聊天流的回复意愿""" stream = chat_stream if stream: current_willing = self.chat_reply_willing.get(stream.stream_id, 0) self.chat_reply_willing[stream.stream_id] = max(0, current_willing - 2) - - def change_reply_willing_after_sent(self, chat_stream: ChatStream): + + def change_reply_willing_after_sent(self,chat_stream:ChatStream): """发送消息后提高聊天流的回复意愿""" stream = chat_stream if stream: current_willing = self.chat_reply_willing.get(stream.stream_id, 0) if current_willing < 1: self.chat_reply_willing[stream.stream_id] = min(1, current_willing + 0.2) - + async def ensure_started(self): """确保衰减任务已启动""" if not self._started: @@ -104,6 +107,5 @@ class WillingManager: self._decay_task = asyncio.create_task(self._decay_reply_willing()) self._started = True - # 创建全局实例 -willing_manager = WillingManager() +willing_manager = WillingManager() \ No newline at end of file diff --git a/src/plugins/memory_system/draw_memory.py b/src/plugins/memory_system/draw_memory.py index d6ba8f3b..9f15164f 100644 --- a/src/plugins/memory_system/draw_memory.py +++ b/src/plugins/memory_system/draw_memory.py @@ -96,7 +96,7 @@ class Memory_graph: dot_data = { "concept": node } - self.db.store_memory_dots.insert_one(dot_data) + self.db.db.store_memory_dots.insert_one(dot_data) @property def dots(self): @@ -106,7 +106,7 @@ class Memory_graph: def get_random_chat_from_db(self, length: int, timestamp: str): # 从数据库中根据时间戳获取离其最近的聊天记录 chat_text = '' - closest_record = self.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 + closest_record = self.db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) # 调试输出 logger.info( f"距离time最近的消息时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(closest_record['time'])))}") @@ -115,7 +115,7 @@ class Memory_graph: group_id = closest_record['group_id'] # 获取groupid # 获取该时间戳之后的length条消息,且groupid相同 chat_record = list( - self.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( + self.db.db.messages.find({"time": {"$gt": closest_time}, "group_id": group_id}).sort('time', 1).limit( length)) for record in chat_record: time_str = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(int(record['time']))) @@ -130,34 +130,34 @@ class Memory_graph: def save_graph_to_db(self): # 清空现有的图数据 - self.db.graph_data.delete_many({}) + self.db.db.graph_data.delete_many({}) # 保存节点 for node in self.G.nodes(data=True): node_data = { 'concept': node[0], 'memory_items': node[1].get('memory_items', []) # 默认为空列表 } - self.db.graph_data.nodes.insert_one(node_data) + self.db.db.graph_data.nodes.insert_one(node_data) # 保存边 for edge in self.G.edges(): edge_data = { 'source': edge[0], 'target': edge[1] } - self.db.graph_data.edges.insert_one(edge_data) + self.db.db.graph_data.edges.insert_one(edge_data) def load_graph_from_db(self): # 清空当前图 self.G.clear() # 加载节点 - nodes = self.db.graph_data.nodes.find() + nodes = self.db.db.graph_data.nodes.find() for node in nodes: memory_items = node.get('memory_items', []) if not isinstance(memory_items, list): memory_items = [memory_items] if memory_items else [] self.G.add_node(node['concept'], memory_items=memory_items) # 加载边 - edges = self.db.graph_data.edges.find() + edges = self.db.db.graph_data.edges.find() for edge in edges: self.G.add_edge(edge['source'], edge['target']) diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 48fc1926..f9773ef1 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -44,9 +44,19 @@ class Memory_graph: created_time=current_time, # 添加创建时间 last_modified=current_time) # 添加最后修改时间 - def add_dot(self, concept, memory): + def add_dot(self, concept, memory, group_id=None): current_time = datetime.datetime.now().timestamp() + # 如果memory不是字典格式,将其转换为字典 + if not isinstance(memory, dict): + memory = { + 'content': memory, + 'group_id': group_id + } + # 如果memory是字典但没有group_id,添加group_id + elif 'group_id' not in memory and group_id is not None: + memory['group_id'] = group_id + if concept in self.G: if 'memory_items' in self.G.nodes[concept]: if not isinstance(self.G.nodes[concept]['memory_items'], list): @@ -218,6 +228,16 @@ class Hippocampus: if not messages: return set(), {} + # 提取群聊ID信息 + group_id = None + for msg in messages: + if 'group_id' in msg and msg['group_id']: + group_id = msg['group_id'] + break + + # 获取群组类型 + group_type = self.get_group_type(group_id) if group_id else None + # 合并消息文本,同时保留时间信息 input_text = "" time_info = "" @@ -264,17 +284,50 @@ class Hippocampus: # 等待所有任务完成 compressed_memory = set() similar_topics_dict = {} # 存储每个话题的相似主题列表 + for topic, task in tasks: response = await task if response: - compressed_memory.add((topic, response[0])) - # 为每个话题查找相似的已存在主题 + # 为每个主题创建特定群聊/群组的主题节点名称 + topic_node_name = topic + + # 如果是私有群组的群聊,使用群组前缀 + if group_type: + topic_node_name = f"{topic}_GT{group_type}" + # 否则使用群聊ID前缀(如果有) + elif group_id: + topic_node_name = f"{topic}_g{group_id}" + + # 记录主题内容 + memory_content = response[0] + compressed_memory.add((topic_node_name, memory_content)) + + # 为每个话题查找相似的已存在主题(不考虑群聊/群组前缀) existing_topics = list(self.memory_graph.G.nodes()) similar_topics = [] for existing_topic in existing_topics: + # 提取基础主题名和群组/群聊信息 + base_existing_topic = existing_topic + existing_group_type = None + existing_group_id = None + + # 检查是否有群组前缀 + if "_GT" in existing_topic: + parts = existing_topic.split("_GT") + base_existing_topic = parts[0] + if len(parts) > 1: + existing_group_type = parts[1] + # 检查是否有群聊前缀 + elif "_g" in existing_topic: + parts = existing_topic.split("_g") + base_existing_topic = parts[0] + if len(parts) > 1: + existing_group_id = parts[1] + + # 计算基础主题的相似度 topic_words = set(jieba.cut(topic)) - existing_words = set(jieba.cut(existing_topic)) + existing_words = set(jieba.cut(base_existing_topic)) all_words = topic_words | existing_words v1 = [1 if word in topic_words else 0 for word in all_words] @@ -282,12 +335,27 @@ class Hippocampus: similarity = cosine_similarity(v1, v2) - if similarity >= 0.6: - similar_topics.append((existing_topic, similarity)) + # 如果相似度高且不是完全相同的主题 + if similarity >= 0.6 and existing_topic != topic_node_name: + # 如果当前主题属于群组,只连接该群组内的主题或公共主题 + if group_type: + # 只连接同群组主题或没有群组/群聊标识的通用主题 + if (existing_group_type == group_type) or (not existing_group_type and not existing_group_id): + similar_topics.append((existing_topic, similarity)) + # 如果当前主题不属于群组但有群聊ID + elif group_id: + # 只连接同群聊主题或没有群组/群聊标识的通用主题 + if (existing_group_id == group_id) or (not existing_group_type and not existing_group_id): + similar_topics.append((existing_topic, similarity)) + # 如果当前主题既不属于群组也没有群聊ID(通用主题) + else: + # 只连接没有群组标识的主题 + if not existing_group_type: + similar_topics.append((existing_topic, similarity)) similar_topics.sort(key=lambda x: x[1], reverse=True) similar_topics = similar_topics[:5] - similar_topics_dict[topic] = similar_topics + similar_topics_dict[topic_node_name] = similar_topics return compressed_memory, similar_topics_dict @@ -315,6 +383,11 @@ class Hippocampus: bar = '█' * filled_length + '-' * (bar_length - filled_length) logger.debug(f"进度: [{bar}] {progress:.1f}% ({i}/{len(memory_samples)})") + # 获取该批次消息的group_id + group_id = None + if messages and len(messages) > 0 and 'group_id' in messages[0]: + group_id = messages[0]['group_id'] + compress_rate = global_config.memory_compress_rate compressed_memory, similar_topics_dict = await self.memory_compress(messages, compress_rate) logger.info(f"压缩后记忆数量: {len(compressed_memory)},似曾相识的话题: {len(similar_topics_dict)}") @@ -323,15 +396,25 @@ class Hippocampus: for topic, memory in compressed_memory: logger.info(f"添加节点: {topic}") - self.memory_graph.add_dot(topic, memory) + self.memory_graph.add_dot(topic, memory) # 不再需要传递group_id,因为已经包含在topic名称中 all_topics.append(topic) - # 连接相似的已存在主题 + # 连接相似的已存在主题,但使用较弱的连接强度 if topic in similar_topics_dict: similar_topics = similar_topics_dict[topic] for similar_topic, similarity in similar_topics: if topic != similar_topic: + # 如果是跨群聊的相似主题,使用较弱的连接强度 + is_cross_group = False + if ("_g" in topic and "_g" in similar_topic and + topic.split("_g")[1] != similar_topic.split("_g")[1]): + is_cross_group = True + + # 跨群聊的相似主题使用较弱的连接强度 strength = int(similarity * 10) + if is_cross_group: + strength = int(similarity * 5) # 降低跨群聊连接的强度 + logger.info(f"连接相似节点: {topic} 和 {similar_topic} (强度: {strength})") self.memory_graph.G.add_edge(topic, similar_topic, strength=strength, @@ -682,11 +765,12 @@ class Hippocampus: prompt = f'这是一段文字,{time_info}:{text}。我想让你基于这段文字来概括"{topic}"这个概念,帮我总结成一句自然的话,可以包含时间和人物,以及具体的观点。只输出这句话就好' return prompt - async def _identify_topics(self, text: str) -> list: + async def _identify_topics(self, text: str, group_id: str = None) -> list: """从文本中识别可能的主题 Args: text: 输入文本 + group_id: 群聊ID,用于生成群聊特定的主题名 Returns: list: 识别出的主题列表 @@ -697,6 +781,15 @@ class Hippocampus: topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()] # print(f"话题: {topics}") + # 如果提供了群聊ID,添加群聊/群组标识 + if group_id: + # 检查群聊是否属于特定群组 + group_type = self.get_group_type(group_id) + if group_type: + topics = [f"{topic}_GT{group_type}" for topic in topics] + else: + topics = [f"{topic}_g{group_id}" for topic in topics] + return topics def _find_similar_topics(self, topics: list, similarity_threshold: float = 0.4, debug_info: str = "") -> list: @@ -719,11 +812,47 @@ class Hippocampus: # print(f"\033[1;32m[{debug_info}]\033[0m 正在思考有没有见过: {topic}") pass - topic_vector = text_to_vector(topic) + # 提取基础主题(去除群聊/群组标识) + base_topic = topic + topic_group_type = None + topic_group_id = None + + # 检查是否有群组前缀 + if "_GT" in topic: + parts = topic.split("_GT") + base_topic = parts[0] + if len(parts) > 1: + topic_group_type = parts[1] + # 检查是否有群聊前缀 + elif "_g" in topic: + parts = topic.split("_g") + base_topic = parts[0] + if len(parts) > 1: + topic_group_id = parts[1] + + topic_vector = text_to_vector(base_topic) has_similar_topic = False for memory_topic in all_memory_topics: - memory_vector = text_to_vector(memory_topic) + # 提取记忆主题的基础主题和群聊/群组ID + base_memory_topic = memory_topic + memory_group_type = None + memory_group_id = None + + # 检查是否有群组前缀 + if "_GT" in memory_topic: + parts = memory_topic.split("_GT") + base_memory_topic = parts[0] + if len(parts) > 1: + memory_group_type = parts[1] + # 检查是否有群聊前缀 + elif "_g" in memory_topic: + parts = memory_topic.split("_g") + base_memory_topic = parts[0] + if len(parts) > 1: + memory_group_id = parts[1] + + memory_vector = text_to_vector(base_memory_topic) # 获取所有唯一词 all_words = set(topic_vector.keys()) | set(memory_vector.keys()) # 构建向量 @@ -732,7 +861,32 @@ class Hippocampus: # 计算相似度 similarity = cosine_similarity(v1, v2) - if similarity >= similarity_threshold: + # 检查是否应该考虑这个主题 + should_consider = False + + # 如果当前主题属于群组 + if topic_group_type: + # 只考虑同群组主题或公共主题 + if (memory_group_type == topic_group_type) or (not memory_group_type and not memory_group_id): + should_consider = True + # 如果当前主题属于特定群聊但不属于群组 + elif topic_group_id: + # 只考虑同群聊主题或公共主题 + if (memory_group_id == topic_group_id) or (not memory_group_type and not memory_group_id): + should_consider = True + # 如果当前主题是公共主题 + else: + # 只考虑公共主题 + if not memory_group_type and not memory_group_id: + should_consider = True + + # 如果基础主题相似且应该考虑该主题 + if similarity >= similarity_threshold and should_consider: + # 如果两个主题属于同一群组/群聊,提高相似度 + if (topic_group_type and memory_group_type and topic_group_type == memory_group_type) or \ + (topic_group_id and memory_group_id and topic_group_id == memory_group_id): + similarity *= 1.2 # 提高20%的相似度 + has_similar_topic = True if debug_info: # print(f"\033[1;32m[{debug_info}]\033[0m 找到相似主题: {topic} -> {memory_topic} (相似度: {similarity:.2f})") @@ -841,10 +995,21 @@ class Hippocampus: return activation async def get_relevant_memories(self, text: str, max_topics: int = 5, similarity_threshold: float = 0.4, - max_memory_num: int = 5) -> list: - """根据输入文本获取相关的记忆内容""" - # 识别主题 - identified_topics = await self._identify_topics(text) + max_memory_num: int = 5, group_id: str = None) -> list: + """根据输入文本获取相关的记忆内容 + + Args: + text: 输入文本 + max_topics: 最大主题数量 + similarity_threshold: 相似度阈值 + max_memory_num: 最大记忆数量 + group_id: 群聊ID,用于优先获取指定群聊的记忆 + """ + # 获取群组类型 + group_type = self.get_group_type(group_id) if group_id else None + + # 识别主题,传入群聊ID + identified_topics = await self._identify_topics(text, group_id) # 查找相似主题 all_similar_topics = self._find_similar_topics( @@ -857,30 +1022,140 @@ class Hippocampus: relevant_topics = self._get_top_topics(all_similar_topics, max_topics) # 获取相关记忆内容 - relevant_memories = [] + current_group_memories = [] # 当前群组/群聊的记忆 + other_memories = [] # 其他群聊/公共的记忆 + for topic, score in relevant_topics: + # 检查主题是否属于当前群组/群聊 + topic_group_type = None + topic_group_id = None + + # 检查是否有群组前缀 + if "_GT" in topic: + parts = topic.split("_GT") + if len(parts) > 1: + topic_group_type = parts[1] + # 检查是否有群聊前缀 + elif "_g" in topic: + parts = topic.split("_g") + if len(parts) > 1: + topic_group_id = parts[1] + # 获取该主题的记忆内容 first_layer, _ = self.memory_graph.get_related_item(topic, depth=1) if first_layer: - # 如果记忆条数超过限制,随机选择指定数量的记忆 - if len(first_layer) > max_memory_num / 2: - first_layer = random.sample(first_layer, max_memory_num // 2) - # 为每条记忆添加来源主题和相似度信息 + # 构建记忆项 for memory in first_layer: - relevant_memories.append({ + memory_item = { 'topic': topic, 'similarity': score, - 'content': memory - }) - - # 如果记忆数量超过5个,随机选择5个 + 'content': memory if not isinstance(memory, dict) else memory.get('content', memory), + 'group_id': topic_group_id, + 'group_type': topic_group_type + } + + # 分类记忆 + # 如果主题属于当前群组 + if group_type and topic_group_type == group_type: + current_group_memories.append(memory_item) + # 如果主题属于当前群聊(且群聊不属于任何群组) + elif not group_type and topic_group_id == group_id: + current_group_memories.append(memory_item) + # 如果主题是公共主题(没有群组/群聊标识) + elif not topic_group_type and not topic_group_id: + other_memories.append(memory_item) + # 如果是其他群聊/群组的主题且不是私有群组,也可以添加 + elif (topic_group_type and not topic_group_type in global_config.memory_private_groups) or \ + (not topic_group_type and topic_group_id): + other_memories.append(memory_item) + # 按相似度排序 - relevant_memories.sort(key=lambda x: x['similarity'], reverse=True) + current_group_memories.sort(key=lambda x: x['similarity'], reverse=True) + other_memories.sort(key=lambda x: x['similarity'], reverse=True) + + # 记录日志 + logger.debug(f"[记忆检索] 当前群聊/群组找到 {len(current_group_memories)} 条记忆,其他群聊/公共找到 {len(other_memories)} 条记忆") + + # 控制返回的记忆数量 + # 优先添加当前群组/群聊的记忆,如果不足,再添加其他群聊/公共的记忆 + final_memories = current_group_memories + remaining_slots = max_memory_num - len(final_memories) + + if remaining_slots > 0 and other_memories: + # 添加其他记忆,但最多添加remaining_slots个 + final_memories.extend(other_memories[:remaining_slots]) + + # 如果记忆总数仍然超过限制,随机采样 + if len(final_memories) > max_memory_num: + final_memories = random.sample(final_memories, max_memory_num) + + # 记录日志,显示最终返回多少条记忆 + logger.debug(f"[记忆检索] 最终返回 {len(final_memories)} 条记忆") + + return final_memories - if len(relevant_memories) > max_memory_num: - relevant_memories = random.sample(relevant_memories, max_memory_num) + def get_group_memories(self, group_id: str) -> list: + """获取特定群聊的所有记忆 + + Args: + group_id: 群聊ID + + Returns: + list: 该群聊的记忆列表,每个记忆包含主题和内容 + """ + all_memories = [] + all_nodes = list(self.memory_graph.G.nodes(data=True)) + + # 获取群组类型 + group_type = self.get_group_type(group_id) + + for concept, data in all_nodes: + # 检查是否应该包含该主题的记忆 + should_include = False + + # 如果群聊属于群组 + if group_type and f"_GT{group_type}" in concept: + should_include = True + # 如果是特定群聊的记忆 + elif f"_g{group_id}" in concept: + should_include = True + # 如果是公共记忆(没有群组/群聊标识) + elif not "_GT" in concept and not "_g" in concept: + should_include = True + + if should_include: + memory_items = data.get('memory_items', []) + if not isinstance(memory_items, list): + memory_items = [memory_items] if memory_items else [] + + # 添加所有记忆项 + for memory in memory_items: + all_memories.append({ + 'topic': concept, + 'content': memory if not isinstance(memory, dict) else memory.get('content', str(memory)) + }) + + return all_memories - return relevant_memories + def get_group_type(self, group_id: str) -> str: + """获取群聊所属的群组类型 + + Args: + group_id: 群聊ID + + Returns: + str: 群组类型名称,如果不属于任何群组则返回None + """ + if not group_id: + return None + + # 检查该群聊ID是否属于任何群组 + for group_name, group_ids in global_config.memory_private_groups.items(): + if group_id in group_ids: + return group_name + + # 如果不属于任何群组,返回None + return None def segment_text(text): @@ -892,6 +1167,15 @@ config = driver.config start_time = time.time() +Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), +) # 创建记忆图 memory_graph = Memory_graph() # 创建海马体 diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 5335e3d6..3424d662 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -41,10 +41,10 @@ class LLM_request: """初始化数据库集合""" try: # 创建llm_usage集合的索引 - self.db.llm_usage.create_index([("timestamp", 1)]) - self.db.llm_usage.create_index([("model_name", 1)]) - self.db.llm_usage.create_index([("user_id", 1)]) - self.db.llm_usage.create_index([("request_type", 1)]) + self.db.db.llm_usage.create_index([("timestamp", 1)]) + self.db.db.llm_usage.create_index([("model_name", 1)]) + self.db.db.llm_usage.create_index([("user_id", 1)]) + self.db.db.llm_usage.create_index([("request_type", 1)]) except Exception: logger.error("创建数据库索引失败") @@ -73,7 +73,7 @@ class LLM_request: "status": "success", "timestamp": datetime.now() } - self.db.llm_usage.insert_one(usage_data) + self.db.db.llm_usage.insert_one(usage_data) logger.info( f"Token使用情况 - 模型: {self.model_name}, " f"用户: {user_id}, 类型: {request_type}, " @@ -235,7 +235,7 @@ class LLM_request: delta_content = "" accumulated_content += delta_content # 检测流式输出文本是否结束 - finish_reason = chunk["choices"][0].get("finish_reason") + finish_reason = chunk["choices"][0]["finish_reason"] if finish_reason == "stop": usage = chunk.get("usage", None) if usage: diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index bde59389..12c6ce3b 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -14,6 +14,16 @@ from ..models.utils_model import LLM_request driver = get_driver() config = driver.config +Database.initialize( + uri=os.getenv("MONGODB_URI"), + host=os.getenv("MONGODB_HOST", "127.0.0.1"), + port=int(os.getenv("MONGODB_PORT", "27017")), + db_name=os.getenv("DATABASE_NAME", "MegBot"), + username=os.getenv("MONGODB_USERNAME"), + password=os.getenv("MONGODB_PASSWORD"), + auth_source=os.getenv("MONGODB_AUTH_SOURCE"), +) + class ScheduleGenerator: def __init__(self): # 根据global_config.llm_normal这一字典配置指定模型 @@ -46,7 +56,7 @@ class ScheduleGenerator: schedule_text = str - existing_schedule = self.db.schedule.find_one({"date": date_str}) + existing_schedule = self.db.db.schedule.find_one({"date": date_str}) if existing_schedule: logger.debug(f"{date_str}的日程已存在:") schedule_text = existing_schedule["schedule"] @@ -63,7 +73,7 @@ class ScheduleGenerator: try: schedule_text, _ = await self.llm_scheduler.generate_response(prompt) - self.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) + self.db.db.schedule.insert_one({"date": date_str, "schedule": schedule_text}) except Exception as e: logger.error(f"生成日程失败: {str(e)}") schedule_text = "生成日程时出错了" @@ -143,7 +153,7 @@ class ScheduleGenerator: """打印完整的日程安排""" if not self._parse_schedule(self.today_schedule_text): logger.warning("今日日程有误,将在下次运行时重新生成") - self.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) + self.db.db.schedule.delete_one({"date": datetime.datetime.now().strftime("%Y-%m-%d")}) else: logger.info("=== 今日日程安排 ===") for time_str, activity in self.today_schedule.items(): diff --git a/src/plugins/utils/statistic.py b/src/plugins/utils/statistic.py index 4629f0e0..2974389e 100644 --- a/src/plugins/utils/statistic.py +++ b/src/plugins/utils/statistic.py @@ -53,7 +53,7 @@ class LLMStatistics: "costs_by_model": defaultdict(float) } - cursor = self.db.llm_usage.find({ + cursor = self.db.db.llm_usage.find({ "timestamp": {"$gte": start_time} }) diff --git a/src/plugins/utils/typo_generator.py b/src/plugins/utils/typo_generator.py index f99a7ab2..aa72c387 100644 --- a/src/plugins/utils/typo_generator.py +++ b/src/plugins/utils/typo_generator.py @@ -13,8 +13,6 @@ from pathlib import Path import jieba from pypinyin import Style, pinyin -from loguru import logger - class ChineseTypoGenerator: def __init__(self, @@ -40,9 +38,7 @@ class ChineseTypoGenerator: self.max_freq_diff = max_freq_diff # 加载数据 - # print("正在加载汉字数据库,请稍候...") - logger.info("正在加载汉字数据库,请稍候...") - + print("正在加载汉字数据库,请稍候...") self.pinyin_dict = self._create_pinyin_dict() self.char_frequency = self._load_or_create_char_frequency() diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 089be69b..637c0b34 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -77,6 +77,16 @@ memory_ban_words = [ #不希望记忆的词 # "403","张三" ] +# 群组私有记忆配置 - 同一群组内的群聊共享记忆,但不与其他群组共享 +# 格式为 { 群组名称 = [群聊ID列表] } +# 未配置在任何群组中的群聊记忆可以与所有群聊共享(群组内群数量过少 聊天记录过少的情况下 建议修改其他记忆参数 加强回复概率等) +# 例如: +# memory_private_groups = { +# "游戏群组" = ["123456", "234567"], +# "工作群组" = ["345678", "456789"] +# } +memory_private_groups = { } + [mood] mood_update_interval = 1.0 # 情绪更新间隔 单位秒 mood_decay_rate = 0.95 # 情绪衰减率 @@ -93,7 +103,7 @@ reaction = "有人说你是人机或者机器人,否定这一事实,攻击 [[keywords_reaction.rules]] # 就像这样复制 enable = false # 仅作示例,不会触发 keywords = ["测试关键词回复","test",""] -reaction = "回答“测试成功”" +reaction = "回答"测试成功"" [chinese_typo] enable = true # 是否启用中文错别字生成器 diff --git a/如果你的配置文件版本太老就点我.bat b/如果你的配置文件版本太老就点我.bat index fec1f4cd..f33d96ec 100644 --- a/如果你的配置文件版本太老就点我.bat +++ b/如果你的配置文件版本太老就点我.bat @@ -1,45 +1,45 @@ -@echo off -setlocal enabledelayedexpansion -chcp 65001 -cd /d %~dp0 - -echo ===================================== -echo 选择Python环境: -echo 1 - venv (推荐) -echo 2 - conda -echo ===================================== -choice /c 12 /n /m "输入数字(1或2): " - -if errorlevel 2 ( - echo ===================================== - set "CONDA_ENV=" - set /p CONDA_ENV="请输入要激活的 conda 环境名称: " - - :: 检查输入是否为空 - if "!CONDA_ENV!"=="" ( - echo 错误:环境名称不能为空 - pause - exit /b 1 - ) - - call conda activate !CONDA_ENV! - if errorlevel 1 ( - echo 激活 conda 环境失败 - pause - exit /b 1 - ) - - echo Conda 环境 "!CONDA_ENV!" 激活成功 - python config/auto_update.py -) else ( - if exist "venv\Scripts\python.exe" ( - venv\Scripts\python config/auto_update.py - ) else ( - echo ===================================== - echo 错误: venv环境不存在,请先创建虚拟环境 - pause - exit /b 1 - ) -) -endlocal -pause +@echo off +setlocal enabledelayedexpansion +chcp 65001 +cd /d %~dp0 + +echo ===================================== +echo 选择Python环境: +echo 1 - venv (推荐) +echo 2 - conda +echo ===================================== +choice /c 12 /n /m "输入数字(1或2): " + +if errorlevel 2 ( + echo ===================================== + set "CONDA_ENV=" + set /p CONDA_ENV="请输入要激活的 conda 环境名称: " + + :: 检查输入是否为空 + if "!CONDA_ENV!"=="" ( + echo 错误:环境名称不能为空 + pause + exit /b 1 + ) + + call conda activate !CONDA_ENV! + if errorlevel 1 ( + echo 激活 conda 环境失败 + pause + exit /b 1 + ) + + echo Conda 环境 "!CONDA_ENV!" 激活成功 + python config/auto_update.py +) else ( + if exist "venv\Scripts\python.exe" ( + venv\Scripts\python config/auto_update.py + ) else ( + echo ===================================== + echo 错误: venv环境不存在,请先创建虚拟环境 + pause + exit /b 1 + ) +) +endlocal +pause diff --git a/麦麦开始学习.bat b/麦麦开始学习.bat index f7391150..dfdf3ccc 100644 --- a/麦麦开始学习.bat +++ b/麦麦开始学习.bat @@ -1,45 +1,45 @@ -@echo off -setlocal enabledelayedexpansion -chcp 65001 -cd /d %~dp0 - -echo ===================================== -echo 选择Python环境: -echo 1 - venv (推荐) -echo 2 - conda -echo ===================================== -choice /c 12 /n /m "输入数字(1或2): " - -if errorlevel 2 ( - echo ===================================== - set "CONDA_ENV=" - set /p CONDA_ENV="请输入要激活的 conda 环境名称: " - - :: 检查输入是否为空 - if "!CONDA_ENV!"=="" ( - echo 错误:环境名称不能为空 - pause - exit /b 1 - ) - - call conda activate !CONDA_ENV! - if errorlevel 1 ( - echo 激活 conda 环境失败 - pause - exit /b 1 - ) - - echo Conda 环境 "!CONDA_ENV!" 激活成功 - python src/plugins/zhishi/knowledge_library.py -) else ( - if exist "venv\Scripts\python.exe" ( - venv\Scripts\python src/plugins/zhishi/knowledge_library.py - ) else ( - echo ===================================== - echo 错误: venv环境不存在,请先创建虚拟环境 - pause - exit /b 1 - ) -) -endlocal -pause +@echo off +setlocal enabledelayedexpansion +chcp 65001 +cd /d %~dp0 + +echo ===================================== +echo 选择Python环境: +echo 1 - venv (推荐) +echo 2 - conda +echo ===================================== +choice /c 12 /n /m "输入数字(1或2): " + +if errorlevel 2 ( + echo ===================================== + set "CONDA_ENV=" + set /p CONDA_ENV="请输入要激活的 conda 环境名称: " + + :: 检查输入是否为空 + if "!CONDA_ENV!"=="" ( + echo 错误:环境名称不能为空 + pause + exit /b 1 + ) + + call conda activate !CONDA_ENV! + if errorlevel 1 ( + echo 激活 conda 环境失败 + pause + exit /b 1 + ) + + echo Conda 环境 "!CONDA_ENV!" 激活成功 + python src/plugins/zhishi/knowledge_library.py +) else ( + if exist "venv\Scripts\python.exe" ( + venv\Scripts\python src/plugins/zhishi/knowledge_library.py + ) else ( + echo ===================================== + echo 错误: venv环境不存在,请先创建虚拟环境 + pause + exit /b 1 + ) +) +endlocal +pause